storage/sql: allow specifying sql flavor specific migrations

Signed-off-by: Nandor Kracser <bonifaido@gmail.com>
This commit is contained in:
Nandor Kracser 2020-02-21 12:13:38 +01:00
parent 1160649c31
commit 80749ffd3f
No known key found for this signature in database
GPG Key ID: 7A4C93C7D6B80413
4 changed files with 28 additions and 9 deletions

View File

@ -66,7 +66,7 @@ func (s *SQLite3) open(logger log.Logger) (*conn, error) {
return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey
} }
c := &conn{db, flavorSQLite3, logger, errCheck} c := &conn{db, &flavorSQLite3, logger, errCheck}
if _, err := c.migrate(); err != nil { if _, err := c.migrate(); err != nil {
return nil, fmt.Errorf("failed to perform migrations: %v", err) return nil, fmt.Errorf("failed to perform migrations: %v", err)
} }
@ -239,7 +239,7 @@ func (p *Postgres) open(logger log.Logger) (*conn, error) {
return sqlErr.Code == pgErrUniqueViolation return sqlErr.Code == pgErrUniqueViolation
} }
c := &conn{db, flavorPostgres, logger, errCheck} c := &conn{db, &flavorPostgres, logger, errCheck}
if _, err := c.migrate(); err != nil { if _, err := c.migrate(); err != nil {
return nil, fmt.Errorf("failed to perform migrations: %v", err) return nil, fmt.Errorf("failed to perform migrations: %v", err)
} }
@ -344,7 +344,7 @@ func (s *MySQL) open(logger log.Logger) (*conn, error) {
sqlErr.Number == mysqlErrDupEntryWithKeyName sqlErr.Number == mysqlErrDupEntryWithKeyName
} }
c := &conn{db, flavorMySQL, logger, errCheck} c := &conn{db, &flavorMySQL, logger, errCheck}
if _, err := c.migrate(); err != nil { if _, err := c.migrate(); err != nil {
return nil, fmt.Errorf("failed to perform migrations: %v", err) return nil, fmt.Errorf("failed to perform migrations: %v", err)
} }

View File

@ -18,6 +18,14 @@ func (c *conn) migrate() (int, error) {
i := 0 i := 0
done := false done := false
var flavorMigrations []migration
for _, m := range migrations {
if m.flavor == nil || m.flavor == c.flavor {
flavorMigrations = append(flavorMigrations, m)
}
}
for { for {
err := c.ExecTx(func(tx *trans) error { err := c.ExecTx(func(tx *trans) error {
// Within a transaction, perform a single migration. // Within a transaction, perform a single migration.
@ -31,13 +39,13 @@ func (c *conn) migrate() (int, error) {
if num.Valid { if num.Valid {
n = int(num.Int64) n = int(num.Int64)
} }
if n >= len(migrations) { if n >= len(flavorMigrations) {
done = true done = true
return nil return nil
} }
migrationNum := n + 1 migrationNum := n + 1
m := migrations[n] m := flavorMigrations[n]
for i := range m.stmts { for i := range m.stmts {
if _, err := tx.Exec(m.stmts[i]); err != nil { if _, err := tx.Exec(m.stmts[i]); err != nil {
return fmt.Errorf("migration %d statement %d failed: %v", migrationNum, i+1, err) return fmt.Errorf("migration %d statement %d failed: %v", migrationNum, i+1, err)
@ -64,7 +72,11 @@ func (c *conn) migrate() (int, error) {
type migration struct { type migration struct {
stmts []string stmts []string
// TODO(ericchiang): consider adding additional fields like "forDrivers"
// If flavor is nil the migration will take place for all database backend flavors.
// If specified, only for that corresponding flavor, in that case stmts can be written
// in the specific SQL dialect.
flavor *flavor
} }
// All SQL flavors share migration strategies. // All SQL flavors share migration strategies.

View File

@ -30,8 +30,15 @@ func TestMigrate(t *testing.T) {
return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique
} }
c := &conn{db, flavorSQLite3, logger, errCheck} var sqliteMigrations []migration
for _, want := range []int{len(migrations), 0} { for _, m := range migrations {
if m.flavor == nil || m.flavor == &flavorSQLite3 {
sqliteMigrations = append(sqliteMigrations, m)
}
}
c := &conn{db, &flavorSQLite3, logger, errCheck}
for _, want := range []int{len(sqliteMigrations), 0} {
got, err := c.migrate() got, err := c.migrate()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@ -130,7 +130,7 @@ func (c *conn) translateArgs(args []interface{}) []interface{} {
// conn is the main database connection. // conn is the main database connection.
type conn struct { type conn struct {
db *sql.DB db *sql.DB
flavor flavor flavor *flavor
logger log.Logger logger log.Logger
alreadyExistsCheck func(err error) bool alreadyExistsCheck func(err error) bool
} }