diff --git a/storage/sql/config.go b/storage/sql/config.go index f3dede4a..170dfd00 100644 --- a/storage/sql/config.go +++ b/storage/sql/config.go @@ -66,7 +66,7 @@ func (s *SQLite3) open(logger log.Logger) (*conn, error) { return sqlErr.ExtendedCode == sqlite3.ErrConstraintPrimaryKey } - c := &conn{db, flavorSQLite3, logger, errCheck} + c := &conn{db, &flavorSQLite3, logger, errCheck} if _, err := c.migrate(); err != nil { return nil, fmt.Errorf("failed to perform migrations: %v", err) } @@ -239,7 +239,7 @@ func (p *Postgres) open(logger log.Logger) (*conn, error) { return sqlErr.Code == pgErrUniqueViolation } - c := &conn{db, flavorPostgres, logger, errCheck} + c := &conn{db, &flavorPostgres, logger, errCheck} if _, err := c.migrate(); err != nil { return nil, fmt.Errorf("failed to perform migrations: %v", err) } @@ -344,7 +344,7 @@ func (s *MySQL) open(logger log.Logger) (*conn, error) { sqlErr.Number == mysqlErrDupEntryWithKeyName } - c := &conn{db, flavorMySQL, logger, errCheck} + c := &conn{db, &flavorMySQL, logger, errCheck} if _, err := c.migrate(); err != nil { return nil, fmt.Errorf("failed to perform migrations: %v", err) } diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 5b86bc78..5e42d05f 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -18,6 +18,14 @@ func (c *conn) migrate() (int, error) { i := 0 done := false + + var flavorMigrations []migration + for _, m := range migrations { + if m.flavor == nil || m.flavor == c.flavor { + flavorMigrations = append(flavorMigrations, m) + } + } + for { err := c.ExecTx(func(tx *trans) error { // Within a transaction, perform a single migration. @@ -31,13 +39,13 @@ func (c *conn) migrate() (int, error) { if num.Valid { n = int(num.Int64) } - if n >= len(migrations) { + if n >= len(flavorMigrations) { done = true return nil } migrationNum := n + 1 - m := migrations[n] + m := flavorMigrations[n] for i := range m.stmts { if _, err := tx.Exec(m.stmts[i]); err != nil { return fmt.Errorf("migration %d statement %d failed: %v", migrationNum, i+1, err) @@ -64,7 +72,11 @@ func (c *conn) migrate() (int, error) { type migration struct { stmts []string - // TODO(ericchiang): consider adding additional fields like "forDrivers" + + // If flavor is nil the migration will take place for all database backend flavors. + // If specified, only for that corresponding flavor, in that case stmts can be written + // in the specific SQL dialect. + flavor *flavor } // All SQL flavors share migration strategies. diff --git a/storage/sql/migrate_test.go b/storage/sql/migrate_test.go index e94e819f..a528aa7e 100644 --- a/storage/sql/migrate_test.go +++ b/storage/sql/migrate_test.go @@ -30,8 +30,15 @@ func TestMigrate(t *testing.T) { return sqlErr.ExtendedCode == sqlite3.ErrConstraintUnique } - c := &conn{db, flavorSQLite3, logger, errCheck} - for _, want := range []int{len(migrations), 0} { + var sqliteMigrations []migration + for _, m := range migrations { + if m.flavor == nil || m.flavor == &flavorSQLite3 { + sqliteMigrations = append(sqliteMigrations, m) + } + } + + c := &conn{db, &flavorSQLite3, logger, errCheck} + for _, want := range []int{len(sqliteMigrations), 0} { got, err := c.migrate() if err != nil { t.Fatal(err) diff --git a/storage/sql/sql.go b/storage/sql/sql.go index 45ecdd79..4f1ed6c9 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -130,7 +130,7 @@ func (c *conn) translateArgs(args []interface{}) []interface{} { // conn is the main database connection. type conn struct { db *sql.DB - flavor flavor + flavor *flavor logger log.Logger alreadyExistsCheck func(err error) bool }