storage/sql: initial MySQL storage implementation
It will be shared by both Postgres and MySQL configs. Signed-off-by: Pavel Borzenkov <pavel.borzenkov@gmail.com>
This commit is contained in:
committed by
Nandor Kracser
parent
92920c86ea
commit
e53bdfabb9
@@ -1,14 +1,18 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"regexp"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
|
||||
@@ -21,6 +25,12 @@ const (
|
||||
pgErrUniqueViolation = "23505" // unique_violation
|
||||
)
|
||||
|
||||
const (
|
||||
// MySQL error codes
|
||||
mysqlErrDupEntry = 1062
|
||||
mysqlErrDupEntryWithKeyName = 1586
|
||||
)
|
||||
|
||||
// SQLite3 options for creating an SQL db.
|
||||
type SQLite3 struct {
|
||||
// File to
|
||||
@@ -63,31 +73,29 @@ func (s *SQLite3) open(logger log.Logger) (*conn, error) {
|
||||
}
|
||||
|
||||
const (
|
||||
sslDisable = "disable"
|
||||
sslRequire = "require"
|
||||
sslVerifyCA = "verify-ca"
|
||||
sslVerifyFull = "verify-full"
|
||||
// postgres SSL modes
|
||||
pgSSLDisable = "disable"
|
||||
pgSSLRequire = "require"
|
||||
pgSSLVerifyCA = "verify-ca"
|
||||
pgSSLVerifyFull = "verify-full"
|
||||
)
|
||||
|
||||
// PostgresSSL represents SSL options for Postgres databases.
|
||||
type PostgresSSL struct {
|
||||
Mode string
|
||||
CAFile string
|
||||
// Files for client auth.
|
||||
KeyFile string
|
||||
CertFile string
|
||||
}
|
||||
const (
|
||||
// MySQL SSL modes
|
||||
mysqlSSLTrue = "true"
|
||||
mysqlSSLFalse = "false"
|
||||
mysqlSSLSkipVerify = "skip-verify"
|
||||
mysqlSSLCustom = "custom"
|
||||
)
|
||||
|
||||
// Postgres options for creating an SQL db.
|
||||
type Postgres struct {
|
||||
// NetworkDB contains options common to SQL databases accessed over network.
|
||||
type NetworkDB struct {
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
Host string
|
||||
Port uint16
|
||||
|
||||
SSL PostgresSSL `json:"ssl" yaml:"ssl"`
|
||||
|
||||
ConnectionTimeout int // Seconds
|
||||
|
||||
// database/sql tunables, see
|
||||
@@ -98,6 +106,22 @@ type Postgres struct {
|
||||
ConnMaxLifetime int // Seconds, default: not set
|
||||
}
|
||||
|
||||
// SSL represents SSL options for network databases.
|
||||
type SSL struct {
|
||||
Mode string
|
||||
CAFile string
|
||||
// Files for client auth.
|
||||
KeyFile string
|
||||
CertFile string
|
||||
}
|
||||
|
||||
// Postgres options for creating an SQL db.
|
||||
type Postgres struct {
|
||||
NetworkDB
|
||||
|
||||
SSL SSL `json:"ssl" yaml:"ssl"`
|
||||
}
|
||||
|
||||
// Open creates a new storage implementation backed by Postgres.
|
||||
func (p *Postgres) Open(logger log.Logger) (storage.Storage, error) {
|
||||
conn, err := p.open(logger, p.createDataSourceName())
|
||||
@@ -216,3 +240,105 @@ func (p *Postgres) open(logger log.Logger, dataSourceName string) (*conn, error)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// MySQL options for creating a MySQL db.
|
||||
type MySQL struct {
|
||||
NetworkDB
|
||||
|
||||
SSL SSL `json:"ssl" yaml:"ssl"`
|
||||
|
||||
// TODO(pborzenkov): used by tests to reduce lock wait timeout. Should
|
||||
// we make it exported and allow users to provide arbitrary params?
|
||||
params map[string]string
|
||||
}
|
||||
|
||||
// Open creates a new storage implementation backed by MySQL.
|
||||
func (s *MySQL) Open(logger logrus.FieldLogger) (storage.Storage, error) {
|
||||
conn, err := s.open(logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *MySQL) open(logger logrus.FieldLogger) (*conn, error) {
|
||||
cfg := mysql.Config{
|
||||
User: s.User,
|
||||
Passwd: s.Password,
|
||||
DBName: s.Database,
|
||||
|
||||
Timeout: time.Second * time.Duration(s.ConnectionTimeout),
|
||||
|
||||
ParseTime: true,
|
||||
Params: map[string]string{
|
||||
"tx_isolation": "'SERIALIZABLE'",
|
||||
},
|
||||
}
|
||||
if s.Host != "" {
|
||||
if s.Host[0] != '/' {
|
||||
cfg.Net = "tcp"
|
||||
cfg.Addr = s.Host
|
||||
} else {
|
||||
cfg.Net = "unix"
|
||||
cfg.Addr = s.Host
|
||||
}
|
||||
}
|
||||
if s.SSL.CAFile != "" || s.SSL.CertFile != "" || s.SSL.KeyFile != "" {
|
||||
if err := s.makeTLSConfig(); err != nil {
|
||||
return nil, fmt.Errorf("failed to make TLS config: %v", err)
|
||||
}
|
||||
cfg.TLSConfig = mysqlSSLCustom
|
||||
} else {
|
||||
cfg.TLSConfig = s.SSL.Mode
|
||||
}
|
||||
for k, v := range s.params {
|
||||
cfg.Params[k] = v
|
||||
}
|
||||
|
||||
db, err := sql.Open("mysql", cfg.FormatDSN())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errCheck := func(err error) bool {
|
||||
sqlErr, ok := err.(*mysql.MySQLError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return sqlErr.Number == mysqlErrDupEntry ||
|
||||
sqlErr.Number == mysqlErrDupEntryWithKeyName
|
||||
}
|
||||
|
||||
c := &conn{db, flavorMySQL, logger, errCheck}
|
||||
if _, err := c.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *MySQL) makeTLSConfig() error {
|
||||
cfg := &tls.Config{}
|
||||
if s.SSL.CAFile != "" {
|
||||
rootCertPool := x509.NewCertPool()
|
||||
pem, err := ioutil.ReadFile(s.SSL.CAFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
||||
return fmt.Errorf("failed to append PEM")
|
||||
}
|
||||
cfg.RootCAs = rootCertPool
|
||||
}
|
||||
if s.SSL.CertFile != "" && s.SSL.KeyFile != "" {
|
||||
clientCert := make([]tls.Certificate, 0, 1)
|
||||
certs, err := tls.LoadX509KeyPair(s.SSL.CertFile, s.SSL.KeyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
clientCert = append(clientCert, certs)
|
||||
cfg.Certificates = clientCert
|
||||
}
|
||||
|
||||
mysql.RegisterTLSConfig(mysqlSSLCustom, cfg)
|
||||
return nil
|
||||
}
|
||||
|
@@ -32,15 +32,16 @@ func withTimeout(t time.Duration, f func()) {
|
||||
}
|
||||
|
||||
func cleanDB(c *conn) error {
|
||||
_, err := c.Exec(`
|
||||
delete from client;
|
||||
delete from auth_request;
|
||||
delete from auth_code;
|
||||
delete from refresh_token;
|
||||
delete from keys;
|
||||
delete from password;
|
||||
`)
|
||||
return err
|
||||
tables := []string{"client", "auth_request", "auth_code",
|
||||
"refresh_token", "keys", "password"}
|
||||
|
||||
for _, tbl := range tables {
|
||||
_, err := c.Exec("delete from " + tbl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var logger = &logrus.Logger{
|
||||
@@ -49,23 +50,39 @@ var logger = &logrus.Logger{
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
|
||||
func TestSQLite3(t *testing.T) {
|
||||
type opener interface {
|
||||
open(logrus.FieldLogger) (*conn, error)
|
||||
}
|
||||
|
||||
func testDB(t *testing.T, o opener, withTransactions bool) {
|
||||
// t.Fatal has a bad habbit of not actually printing the error
|
||||
fatal := func(i interface{}) {
|
||||
fmt.Fprintln(os.Stdout, i)
|
||||
t.Fatal(i)
|
||||
}
|
||||
|
||||
newStorage := func() storage.Storage {
|
||||
// NOTE(ericchiang): In memory means we only get one connection at a time. If we
|
||||
// ever write tests that require using multiple connections, for instance to test
|
||||
// transactions, we need to move to a file based system.
|
||||
s := &SQLite3{":memory:"}
|
||||
conn, err := s.open(logger)
|
||||
conn, err := o.open(logger)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stdout, err)
|
||||
t.Fatal(err)
|
||||
fatal(err)
|
||||
}
|
||||
if err := cleanDB(conn); err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
withTimeout(time.Second*10, func() {
|
||||
withTimeout(time.Minute*1, func() {
|
||||
conformance.RunTests(t, newStorage)
|
||||
})
|
||||
if withTransactions {
|
||||
withTimeout(time.Minute*1, func() {
|
||||
conformance.RunTransactionTests(t, newStorage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLite3(t *testing.T) {
|
||||
testDB(t, &SQLite3{":memory:"}, false)
|
||||
}
|
||||
|
||||
func getenv(key, defaultVal string) string {
|
||||
@@ -186,37 +203,42 @@ func TestPostgres(t *testing.T) {
|
||||
if host == "" {
|
||||
t.Skipf("test environment variable %q not set, skipping", testPostgresEnv)
|
||||
}
|
||||
p := Postgres{
|
||||
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
|
||||
User: getenv("DEX_POSTGRES_USER", "postgres"),
|
||||
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
|
||||
Host: host,
|
||||
SSL: PostgresSSL{
|
||||
Mode: sslDisable, // Postgres container doesn't support SSL.
|
||||
p := &Postgres{
|
||||
NetworkDB: NetworkDB{
|
||||
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
|
||||
User: getenv("DEX_POSTGRES_USER", "postgres"),
|
||||
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
|
||||
Host: host,
|
||||
ConnectionTimeout: 5,
|
||||
},
|
||||
SSL: SSL{
|
||||
Mode: pgSSLDisable, // Postgres container doesn't support SSL.
|
||||
},
|
||||
ConnectionTimeout: 5,
|
||||
}
|
||||
|
||||
// t.Fatal has a bad habbit of not actually printing the error
|
||||
fatal := func(i interface{}) {
|
||||
fmt.Fprintln(os.Stdout, i)
|
||||
t.Fatal(i)
|
||||
}
|
||||
|
||||
newStorage := func() storage.Storage {
|
||||
conn, err := p.open(logger, p.createDataSourceName())
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
if err := cleanDB(conn); err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
withTimeout(time.Minute*1, func() {
|
||||
conformance.RunTests(t, newStorage)
|
||||
})
|
||||
withTimeout(time.Minute*1, func() {
|
||||
conformance.RunTransactionTests(t, newStorage)
|
||||
})
|
||||
testDB(t, p, true)
|
||||
}
|
||||
|
||||
const testMySQLEnv = "DEX_MYSQL_HOST"
|
||||
|
||||
func TestMySQL(t *testing.T) {
|
||||
host := os.Getenv(testMySQLEnv)
|
||||
if host == "" {
|
||||
t.Skipf("test environment variable %q not set, skipping", testMySQLEnv)
|
||||
}
|
||||
s := &MySQL{
|
||||
NetworkDB: NetworkDB{
|
||||
Database: getenv("DEX_MYSQL_DATABASE", "mysql"),
|
||||
User: getenv("DEX_MYSQL_USER", "mysql"),
|
||||
Password: getenv("DEX_MYSQL_PASSWORD", ""),
|
||||
Host: host,
|
||||
ConnectionTimeout: 5,
|
||||
},
|
||||
SSL: SSL{
|
||||
Mode: mysqlSSLFalse,
|
||||
},
|
||||
params: map[string]string{
|
||||
"innodb_lock_wait_timeout": "3",
|
||||
},
|
||||
}
|
||||
testDB(t, s, true)
|
||||
}
|
||||
|
@@ -38,8 +38,10 @@ func (c *conn) migrate() (int, error) {
|
||||
|
||||
migrationNum := n + 1
|
||||
m := migrations[n]
|
||||
if _, err := tx.Exec(m.stmt); err != nil {
|
||||
return fmt.Errorf("migration %d failed: %v", migrationNum, err)
|
||||
for i := range m.stmts {
|
||||
if _, err := tx.Exec(m.stmts[i]); err != nil {
|
||||
return fmt.Errorf("migration %d statement %d failed: %v", migrationNum, i+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
q := `insert into migrations (num, at) values ($1, now());`
|
||||
@@ -61,14 +63,14 @@ func (c *conn) migrate() (int, error) {
|
||||
}
|
||||
|
||||
type migration struct {
|
||||
stmt string
|
||||
stmts []string
|
||||
// TODO(ericchiang): consider adding additional fields like "forDrivers"
|
||||
}
|
||||
|
||||
// All SQL flavors share migration strategies.
|
||||
var migrations = []migration{
|
||||
{
|
||||
stmt: `
|
||||
stmts: []string{`
|
||||
create table client (
|
||||
id text not null primary key,
|
||||
secret text not null,
|
||||
@@ -77,8 +79,8 @@ var migrations = []migration{
|
||||
public boolean not null,
|
||||
name text not null,
|
||||
logo_url text not null
|
||||
);
|
||||
|
||||
);`,
|
||||
`
|
||||
create table auth_request (
|
||||
id text not null primary key,
|
||||
client_id text not null,
|
||||
@@ -101,8 +103,8 @@ var migrations = []migration{
|
||||
connector_data bytea,
|
||||
|
||||
expiry timestamptz not null
|
||||
);
|
||||
|
||||
);`,
|
||||
`
|
||||
create table auth_code (
|
||||
id text not null primary key,
|
||||
client_id text not null,
|
||||
@@ -120,8 +122,8 @@ var migrations = []migration{
|
||||
connector_data bytea,
|
||||
|
||||
expiry timestamptz not null
|
||||
);
|
||||
|
||||
);`,
|
||||
`
|
||||
create table refresh_token (
|
||||
id text not null primary key,
|
||||
client_id text not null,
|
||||
@@ -136,15 +138,15 @@ var migrations = []migration{
|
||||
|
||||
connector_id text not null,
|
||||
connector_data bytea
|
||||
);
|
||||
|
||||
);`,
|
||||
`
|
||||
create table password (
|
||||
email text not null primary key,
|
||||
hash bytea not null,
|
||||
username text not null,
|
||||
user_id text not null
|
||||
);
|
||||
|
||||
);`,
|
||||
`
|
||||
-- keys is a weird table because we only ever expect there to be a single row
|
||||
create table keys (
|
||||
id text not null primary key,
|
||||
@@ -152,39 +154,40 @@ var migrations = []migration{
|
||||
signing_key bytea not null, -- JSON object
|
||||
signing_key_pub bytea not null, -- JSON object
|
||||
next_rotation timestamptz not null
|
||||
);
|
||||
|
||||
`,
|
||||
);`,
|
||||
},
|
||||
},
|
||||
{
|
||||
stmt: `
|
||||
stmts: []string{`
|
||||
alter table refresh_token
|
||||
add column token text not null default '';
|
||||
add column token text not null default '';`,
|
||||
`
|
||||
alter table refresh_token
|
||||
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';
|
||||
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`,
|
||||
`
|
||||
alter table refresh_token
|
||||
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
|
||||
`,
|
||||
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';`,
|
||||
},
|
||||
},
|
||||
{
|
||||
stmt: `
|
||||
stmts: []string{`
|
||||
create table offline_session (
|
||||
user_id text not null,
|
||||
conn_id text not null,
|
||||
refresh bytea not null,
|
||||
PRIMARY KEY (user_id, conn_id)
|
||||
);
|
||||
`,
|
||||
);`,
|
||||
},
|
||||
},
|
||||
{
|
||||
stmt: `
|
||||
stmts: []string{`
|
||||
create table connector (
|
||||
id text not null primary key,
|
||||
type text not null,
|
||||
name text not null,
|
||||
resource_version text not null,
|
||||
config bytea
|
||||
);
|
||||
`,
|
||||
);`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@@ -83,6 +83,28 @@ var (
|
||||
{regexp.MustCompile(`\bnow\(\)`), "date('now')"},
|
||||
},
|
||||
}
|
||||
|
||||
flavorMySQL = flavor{
|
||||
queryReplacers: []replacer{
|
||||
{bindRegexp, "?"},
|
||||
// Translate types.
|
||||
{matchLiteral("bytea"), "blob"},
|
||||
{matchLiteral("timestamptz"), "datetime(3)"},
|
||||
// MySQL doesn't support indicies on text fields w/o
|
||||
// specifying key length. Use varchar instead (768 is
|
||||
// the max key length for InnoDB with 4k pages).
|
||||
{matchLiteral("text"), "varchar(768)"},
|
||||
// Quote keywords and reserved words used as identifiers.
|
||||
{regexp.MustCompile(`\b(keys)\b`), "`$1`"},
|
||||
// Change default timestamp to fit datetime.
|
||||
{regexp.MustCompile(`0001-01-01 00:00:00 UTC`), "1000-01-01 00:00:00"},
|
||||
},
|
||||
}
|
||||
|
||||
// Not tested.
|
||||
flavorCockroach = flavor{
|
||||
executeTx: crdb.ExecuteTx,
|
||||
}
|
||||
)
|
||||
|
||||
func (f flavor) translate(query string) string {
|
||||
|
Reference in New Issue
Block a user