storage/sql: initial MySQL storage implementation

It will be shared by both Postgres and MySQL configs.

Signed-off-by: Pavel Borzenkov <pavel.borzenkov@gmail.com>
This commit is contained in:
Pavel Borzenkov
2017-04-21 18:51:55 +03:00
committed by Nandor Kracser
parent 92920c86ea
commit e53bdfabb9
6 changed files with 292 additions and 97 deletions

View File

@@ -1,14 +1,18 @@
package sql
import (
"crypto/tls"
"crypto/x509"
"database/sql"
"fmt"
"net"
"regexp"
"io/ioutil"
"net/url"
"strconv"
"strings"
"time"
"github.com/Sirupsen/logrus"
"github.com/coreos/dex/storage"
"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
@@ -21,6 +25,12 @@ const (
pgErrUniqueViolation = "23505" // unique_violation
)
const (
// MySQL error codes
mysqlErrDupEntry = 1062
mysqlErrDupEntryWithKeyName = 1586
)
// SQLite3 options for creating an SQL db.
type SQLite3 struct {
// File to
@@ -63,31 +73,29 @@ func (s *SQLite3) open(logger log.Logger) (*conn, error) {
}
const (
sslDisable = "disable"
sslRequire = "require"
sslVerifyCA = "verify-ca"
sslVerifyFull = "verify-full"
// postgres SSL modes
pgSSLDisable = "disable"
pgSSLRequire = "require"
pgSSLVerifyCA = "verify-ca"
pgSSLVerifyFull = "verify-full"
)
// PostgresSSL represents SSL options for Postgres databases.
type PostgresSSL struct {
Mode string
CAFile string
// Files for client auth.
KeyFile string
CertFile string
}
const (
// MySQL SSL modes
mysqlSSLTrue = "true"
mysqlSSLFalse = "false"
mysqlSSLSkipVerify = "skip-verify"
mysqlSSLCustom = "custom"
)
// Postgres options for creating an SQL db.
type Postgres struct {
// NetworkDB contains options common to SQL databases accessed over network.
type NetworkDB struct {
Database string
User string
Password string
Host string
Port uint16
SSL PostgresSSL `json:"ssl" yaml:"ssl"`
ConnectionTimeout int // Seconds
// database/sql tunables, see
@@ -98,6 +106,22 @@ type Postgres struct {
ConnMaxLifetime int // Seconds, default: not set
}
// SSL represents SSL options for network databases.
type SSL struct {
Mode string
CAFile string
// Files for client auth.
KeyFile string
CertFile string
}
// Postgres options for creating an SQL db.
type Postgres struct {
NetworkDB
SSL SSL `json:"ssl" yaml:"ssl"`
}
// Open creates a new storage implementation backed by Postgres.
func (p *Postgres) Open(logger log.Logger) (storage.Storage, error) {
conn, err := p.open(logger, p.createDataSourceName())
@@ -216,3 +240,105 @@ func (p *Postgres) open(logger log.Logger, dataSourceName string) (*conn, error)
}
return c, nil
}
// MySQL options for creating a MySQL db.
type MySQL struct {
NetworkDB
SSL SSL `json:"ssl" yaml:"ssl"`
// TODO(pborzenkov): used by tests to reduce lock wait timeout. Should
// we make it exported and allow users to provide arbitrary params?
params map[string]string
}
// Open creates a new storage implementation backed by MySQL.
func (s *MySQL) Open(logger logrus.FieldLogger) (storage.Storage, error) {
conn, err := s.open(logger)
if err != nil {
return nil, err
}
return conn, nil
}
func (s *MySQL) open(logger logrus.FieldLogger) (*conn, error) {
cfg := mysql.Config{
User: s.User,
Passwd: s.Password,
DBName: s.Database,
Timeout: time.Second * time.Duration(s.ConnectionTimeout),
ParseTime: true,
Params: map[string]string{
"tx_isolation": "'SERIALIZABLE'",
},
}
if s.Host != "" {
if s.Host[0] != '/' {
cfg.Net = "tcp"
cfg.Addr = s.Host
} else {
cfg.Net = "unix"
cfg.Addr = s.Host
}
}
if s.SSL.CAFile != "" || s.SSL.CertFile != "" || s.SSL.KeyFile != "" {
if err := s.makeTLSConfig(); err != nil {
return nil, fmt.Errorf("failed to make TLS config: %v", err)
}
cfg.TLSConfig = mysqlSSLCustom
} else {
cfg.TLSConfig = s.SSL.Mode
}
for k, v := range s.params {
cfg.Params[k] = v
}
db, err := sql.Open("mysql", cfg.FormatDSN())
if err != nil {
return nil, err
}
errCheck := func(err error) bool {
sqlErr, ok := err.(*mysql.MySQLError)
if !ok {
return false
}
return sqlErr.Number == mysqlErrDupEntry ||
sqlErr.Number == mysqlErrDupEntryWithKeyName
}
c := &conn{db, flavorMySQL, logger, errCheck}
if _, err := c.migrate(); err != nil {
return nil, fmt.Errorf("failed to perform migrations: %v", err)
}
return c, nil
}
func (s *MySQL) makeTLSConfig() error {
cfg := &tls.Config{}
if s.SSL.CAFile != "" {
rootCertPool := x509.NewCertPool()
pem, err := ioutil.ReadFile(s.SSL.CAFile)
if err != nil {
return err
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return fmt.Errorf("failed to append PEM")
}
cfg.RootCAs = rootCertPool
}
if s.SSL.CertFile != "" && s.SSL.KeyFile != "" {
clientCert := make([]tls.Certificate, 0, 1)
certs, err := tls.LoadX509KeyPair(s.SSL.CertFile, s.SSL.KeyFile)
if err != nil {
return err
}
clientCert = append(clientCert, certs)
cfg.Certificates = clientCert
}
mysql.RegisterTLSConfig(mysqlSSLCustom, cfg)
return nil
}

View File

@@ -32,15 +32,16 @@ func withTimeout(t time.Duration, f func()) {
}
func cleanDB(c *conn) error {
_, err := c.Exec(`
delete from client;
delete from auth_request;
delete from auth_code;
delete from refresh_token;
delete from keys;
delete from password;
`)
return err
tables := []string{"client", "auth_request", "auth_code",
"refresh_token", "keys", "password"}
for _, tbl := range tables {
_, err := c.Exec("delete from " + tbl)
if err != nil {
return err
}
}
return nil
}
var logger = &logrus.Logger{
@@ -49,23 +50,39 @@ var logger = &logrus.Logger{
Level: logrus.DebugLevel,
}
func TestSQLite3(t *testing.T) {
type opener interface {
open(logrus.FieldLogger) (*conn, error)
}
func testDB(t *testing.T, o opener, withTransactions bool) {
// t.Fatal has a bad habbit of not actually printing the error
fatal := func(i interface{}) {
fmt.Fprintln(os.Stdout, i)
t.Fatal(i)
}
newStorage := func() storage.Storage {
// NOTE(ericchiang): In memory means we only get one connection at a time. If we
// ever write tests that require using multiple connections, for instance to test
// transactions, we need to move to a file based system.
s := &SQLite3{":memory:"}
conn, err := s.open(logger)
conn, err := o.open(logger)
if err != nil {
fmt.Fprintln(os.Stdout, err)
t.Fatal(err)
fatal(err)
}
if err := cleanDB(conn); err != nil {
fatal(err)
}
return conn
}
withTimeout(time.Second*10, func() {
withTimeout(time.Minute*1, func() {
conformance.RunTests(t, newStorage)
})
if withTransactions {
withTimeout(time.Minute*1, func() {
conformance.RunTransactionTests(t, newStorage)
})
}
}
func TestSQLite3(t *testing.T) {
testDB(t, &SQLite3{":memory:"}, false)
}
func getenv(key, defaultVal string) string {
@@ -186,37 +203,42 @@ func TestPostgres(t *testing.T) {
if host == "" {
t.Skipf("test environment variable %q not set, skipping", testPostgresEnv)
}
p := Postgres{
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
User: getenv("DEX_POSTGRES_USER", "postgres"),
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
Host: host,
SSL: PostgresSSL{
Mode: sslDisable, // Postgres container doesn't support SSL.
p := &Postgres{
NetworkDB: NetworkDB{
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
User: getenv("DEX_POSTGRES_USER", "postgres"),
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
Host: host,
ConnectionTimeout: 5,
},
SSL: SSL{
Mode: pgSSLDisable, // Postgres container doesn't support SSL.
},
ConnectionTimeout: 5,
}
// t.Fatal has a bad habbit of not actually printing the error
fatal := func(i interface{}) {
fmt.Fprintln(os.Stdout, i)
t.Fatal(i)
}
newStorage := func() storage.Storage {
conn, err := p.open(logger, p.createDataSourceName())
if err != nil {
fatal(err)
}
if err := cleanDB(conn); err != nil {
fatal(err)
}
return conn
}
withTimeout(time.Minute*1, func() {
conformance.RunTests(t, newStorage)
})
withTimeout(time.Minute*1, func() {
conformance.RunTransactionTests(t, newStorage)
})
testDB(t, p, true)
}
const testMySQLEnv = "DEX_MYSQL_HOST"
func TestMySQL(t *testing.T) {
host := os.Getenv(testMySQLEnv)
if host == "" {
t.Skipf("test environment variable %q not set, skipping", testMySQLEnv)
}
s := &MySQL{
NetworkDB: NetworkDB{
Database: getenv("DEX_MYSQL_DATABASE", "mysql"),
User: getenv("DEX_MYSQL_USER", "mysql"),
Password: getenv("DEX_MYSQL_PASSWORD", ""),
Host: host,
ConnectionTimeout: 5,
},
SSL: SSL{
Mode: mysqlSSLFalse,
},
params: map[string]string{
"innodb_lock_wait_timeout": "3",
},
}
testDB(t, s, true)
}

View File

@@ -38,8 +38,10 @@ func (c *conn) migrate() (int, error) {
migrationNum := n + 1
m := migrations[n]
if _, err := tx.Exec(m.stmt); err != nil {
return fmt.Errorf("migration %d failed: %v", migrationNum, err)
for i := range m.stmts {
if _, err := tx.Exec(m.stmts[i]); err != nil {
return fmt.Errorf("migration %d statement %d failed: %v", migrationNum, i+1, err)
}
}
q := `insert into migrations (num, at) values ($1, now());`
@@ -61,14 +63,14 @@ func (c *conn) migrate() (int, error) {
}
type migration struct {
stmt string
stmts []string
// TODO(ericchiang): consider adding additional fields like "forDrivers"
}
// All SQL flavors share migration strategies.
var migrations = []migration{
{
stmt: `
stmts: []string{`
create table client (
id text not null primary key,
secret text not null,
@@ -77,8 +79,8 @@ var migrations = []migration{
public boolean not null,
name text not null,
logo_url text not null
);
);`,
`
create table auth_request (
id text not null primary key,
client_id text not null,
@@ -101,8 +103,8 @@ var migrations = []migration{
connector_data bytea,
expiry timestamptz not null
);
);`,
`
create table auth_code (
id text not null primary key,
client_id text not null,
@@ -120,8 +122,8 @@ var migrations = []migration{
connector_data bytea,
expiry timestamptz not null
);
);`,
`
create table refresh_token (
id text not null primary key,
client_id text not null,
@@ -136,15 +138,15 @@ var migrations = []migration{
connector_id text not null,
connector_data bytea
);
);`,
`
create table password (
email text not null primary key,
hash bytea not null,
username text not null,
user_id text not null
);
);`,
`
-- keys is a weird table because we only ever expect there to be a single row
create table keys (
id text not null primary key,
@@ -152,39 +154,40 @@ var migrations = []migration{
signing_key bytea not null, -- JSON object
signing_key_pub bytea not null, -- JSON object
next_rotation timestamptz not null
);
`,
);`,
},
},
{
stmt: `
stmts: []string{`
alter table refresh_token
add column token text not null default '';
add column token text not null default '';`,
`
alter table refresh_token
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';
add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`,
`
alter table refresh_token
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
`,
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';`,
},
},
{
stmt: `
stmts: []string{`
create table offline_session (
user_id text not null,
conn_id text not null,
refresh bytea not null,
PRIMARY KEY (user_id, conn_id)
);
`,
);`,
},
},
{
stmt: `
stmts: []string{`
create table connector (
id text not null primary key,
type text not null,
name text not null,
resource_version text not null,
config bytea
);
`,
);`,
},
},
}

View File

@@ -83,6 +83,28 @@ var (
{regexp.MustCompile(`\bnow\(\)`), "date('now')"},
},
}
flavorMySQL = flavor{
queryReplacers: []replacer{
{bindRegexp, "?"},
// Translate types.
{matchLiteral("bytea"), "blob"},
{matchLiteral("timestamptz"), "datetime(3)"},
// MySQL doesn't support indicies on text fields w/o
// specifying key length. Use varchar instead (768 is
// the max key length for InnoDB with 4k pages).
{matchLiteral("text"), "varchar(768)"},
// Quote keywords and reserved words used as identifiers.
{regexp.MustCompile(`\b(keys)\b`), "`$1`"},
// Change default timestamp to fit datetime.
{regexp.MustCompile(`0001-01-01 00:00:00 UTC`), "1000-01-01 00:00:00"},
},
}
// Not tested.
flavorCockroach = flavor{
executeTx: crdb.ExecuteTx,
}
)
func (f flavor) translate(query string) string {