storage/sql: allow specifying sql flavor specific migrations
Signed-off-by: Nandor Kracser <bonifaido@gmail.com>
This commit is contained in:
		| @@ -66,7 +66,7 @@ func (s *SQLite3) open(logger log.Logger) (*conn, error) { | |||||||
| 		return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey | 		return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c := &conn{db, flavorSQLite3, logger, errCheck} | 	c := &conn{db, &flavorSQLite3, logger, errCheck} | ||||||
| 	if _, err := c.migrate(); err != nil { | 	if _, err := c.migrate(); err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to perform migrations: %v", err) | 		return nil, fmt.Errorf("failed to perform migrations: %v", err) | ||||||
| 	} | 	} | ||||||
| @@ -239,7 +239,7 @@ func (p *Postgres) open(logger log.Logger) (*conn, error) { | |||||||
| 		return sqlErr.Code == pgErrUniqueViolation | 		return sqlErr.Code == pgErrUniqueViolation | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c := &conn{db, flavorPostgres, logger, errCheck} | 	c := &conn{db, &flavorPostgres, logger, errCheck} | ||||||
| 	if _, err := c.migrate(); err != nil { | 	if _, err := c.migrate(); err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to perform migrations: %v", err) | 		return nil, fmt.Errorf("failed to perform migrations: %v", err) | ||||||
| 	} | 	} | ||||||
| @@ -344,7 +344,7 @@ func (s *MySQL) open(logger log.Logger) (*conn, error) { | |||||||
| 			sqlErr.Number == mysqlErrDupEntryWithKeyName | 			sqlErr.Number == mysqlErrDupEntryWithKeyName | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c := &conn{db, flavorMySQL, logger, errCheck} | 	c := &conn{db, &flavorMySQL, logger, errCheck} | ||||||
| 	if _, err := c.migrate(); err != nil { | 	if _, err := c.migrate(); err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to perform migrations: %v", err) | 		return nil, fmt.Errorf("failed to perform migrations: %v", err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -18,6 +18,14 @@ func (c *conn) migrate() (int, error) { | |||||||
|  |  | ||||||
| 	i := 0 | 	i := 0 | ||||||
| 	done := false | 	done := false | ||||||
|  |  | ||||||
|  | 	var flavorMigrations []migration | ||||||
|  | 	for _, m := range migrations { | ||||||
|  | 		if m.flavor == nil || m.flavor == c.flavor { | ||||||
|  | 			flavorMigrations = append(flavorMigrations, m) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	for { | 	for { | ||||||
| 		err := c.ExecTx(func(tx *trans) error { | 		err := c.ExecTx(func(tx *trans) error { | ||||||
| 			// Within a transaction, perform a single migration. | 			// Within a transaction, perform a single migration. | ||||||
| @@ -31,13 +39,13 @@ func (c *conn) migrate() (int, error) { | |||||||
| 			if num.Valid { | 			if num.Valid { | ||||||
| 				n = int(num.Int64) | 				n = int(num.Int64) | ||||||
| 			} | 			} | ||||||
| 			if n >= len(migrations) { | 			if n >= len(flavorMigrations) { | ||||||
| 				done = true | 				done = true | ||||||
| 				return nil | 				return nil | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			migrationNum := n + 1 | 			migrationNum := n + 1 | ||||||
| 			m := migrations[n] | 			m := flavorMigrations[n] | ||||||
| 			for i := range m.stmts { | 			for i := range m.stmts { | ||||||
| 				if _, err := tx.Exec(m.stmts[i]); err != nil { | 				if _, err := tx.Exec(m.stmts[i]); err != nil { | ||||||
| 					return fmt.Errorf("migration %d statement %d failed: %v", migrationNum, i+1, err) | 					return fmt.Errorf("migration %d statement %d failed: %v", migrationNum, i+1, err) | ||||||
| @@ -64,7 +72,11 @@ func (c *conn) migrate() (int, error) { | |||||||
|  |  | ||||||
| type migration struct { | type migration struct { | ||||||
| 	stmts []string | 	stmts []string | ||||||
| 	// TODO(ericchiang): consider adding additional fields like "forDrivers" |  | ||||||
|  | 	// If flavor is nil the migration will take place for all database backend flavors. | ||||||
|  | 	// If specified, only for that corresponding flavor, in that case stmts can be written | ||||||
|  | 	// in the specific SQL dialect. | ||||||
|  | 	flavor *flavor | ||||||
| } | } | ||||||
|  |  | ||||||
| // All SQL flavors share migration strategies. | // All SQL flavors share migration strategies. | ||||||
|   | |||||||
| @@ -30,8 +30,15 @@ func TestMigrate(t *testing.T) { | |||||||
| 		return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique | 		return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	c := &conn{db, flavorSQLite3, logger, errCheck} | 	var sqliteMigrations []migration | ||||||
| 	for _, want := range []int{len(migrations), 0} { | 	for _, m := range migrations { | ||||||
|  | 		if m.flavor == nil || m.flavor == &flavorSQLite3 { | ||||||
|  | 			sqliteMigrations = append(sqliteMigrations, m) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c := &conn{db, &flavorSQLite3, logger, errCheck} | ||||||
|  | 	for _, want := range []int{len(sqliteMigrations), 0} { | ||||||
| 		got, err := c.migrate() | 		got, err := c.migrate() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			t.Fatal(err) | 			t.Fatal(err) | ||||||
|   | |||||||
| @@ -130,7 +130,7 @@ func (c *conn) translateArgs(args []interface{}) []interface{} { | |||||||
| // conn is the main database connection. | // conn is the main database connection. | ||||||
| type conn struct { | type conn struct { | ||||||
| 	db                 *sql.DB | 	db                 *sql.DB | ||||||
| 	flavor             flavor | 	flavor             *flavor | ||||||
| 	logger             log.Logger | 	logger             log.Logger | ||||||
| 	alreadyExistsCheck func(err error) bool | 	alreadyExistsCheck func(err error) bool | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user