storage/sql: rework of the original MySQL PR

This commit is contained in:
Nandor Kracser
2019-07-12 16:29:46 +02:00
parent e53bdfabb9
commit a572ad8fec
40 changed files with 6983 additions and 74 deletions

View File

@@ -6,12 +6,12 @@ import (
"database/sql"
"fmt"
"io/ioutil"
"net/url"
"net"
"regexp"
"strconv"
"strings"
"time"
"github.com/Sirupsen/logrus"
"github.com/coreos/dex/storage"
"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
sqlite3 "github.com/mattn/go-sqlite3"
@@ -124,7 +124,7 @@ type Postgres struct {
// Open creates a new storage implementation backed by Postgres.
func (p *Postgres) Open(logger log.Logger) (storage.Storage, error) {
conn, err := p.open(logger, p.createDataSourceName())
conn, err := p.open(logger)
if err != nil {
return nil, err
}
@@ -183,7 +183,7 @@ func (p *Postgres) createDataSourceName() string {
if p.SSL.Mode == "" {
// Assume the strictest mode if unspecified.
addParam("sslmode", dataSourceStr(sslVerifyFull))
addParam("sslmode", dataSourceStr(pgSSLVerifyFull))
} else {
addParam("sslmode", dataSourceStr(p.SSL.Mode))
}
@@ -203,7 +203,9 @@ func (p *Postgres) createDataSourceName() string {
return strings.Join(parameters, " ")
}
func (p *Postgres) open(logger log.Logger, dataSourceName string) (*conn, error) {
func (p *Postgres) open(logger log.Logger) (*conn, error) {
dataSourceName := p.createDataSourceName()
db, err := sql.Open("postgres", dataSourceName)
if err != nil {
return nil, err
@@ -253,7 +255,7 @@ type MySQL struct {
}
// Open creates a new storage implementation backed by MySQL.
func (s *MySQL) Open(logger logrus.FieldLogger) (storage.Storage, error) {
func (s *MySQL) Open(logger log.Logger) (storage.Storage, error) {
conn, err := s.open(logger)
if err != nil {
return nil, err
@@ -261,17 +263,18 @@ func (s *MySQL) Open(logger logrus.FieldLogger) (storage.Storage, error) {
return conn, nil
}
func (s *MySQL) open(logger logrus.FieldLogger) (*conn, error) {
func (s *MySQL) open(logger log.Logger) (*conn, error) {
cfg := mysql.Config{
User: s.User,
Passwd: s.Password,
DBName: s.Database,
User: s.User,
Passwd: s.Password,
DBName: s.Database,
AllowNativePasswords: true,
Timeout: time.Second * time.Duration(s.ConnectionTimeout),
ParseTime: true,
Params: map[string]string{
"tx_isolation": "'SERIALIZABLE'",
"transaction_isolation": "'SERIALIZABLE'",
},
}
if s.Host != "" {
@@ -288,6 +291,8 @@ func (s *MySQL) open(logger logrus.FieldLogger) (*conn, error) {
return nil, fmt.Errorf("failed to make TLS config: %v", err)
}
cfg.TLSConfig = mysqlSSLCustom
} else if s.SSL.Mode == "" {
cfg.TLSConfig = mysqlSSLTrue
} else {
cfg.TLSConfig = s.SSL.Mode
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/conformance"
)
@@ -51,7 +52,7 @@ var logger = &logrus.Logger{
}
type opener interface {
open(logrus.FieldLogger) (*conn, error)
open(logger log.Logger) (*conn, error)
}
func testDB(t *testing.T, o opener, withTransactions bool) {
@@ -108,19 +109,23 @@ func TestCreateDataSourceName(t *testing.T) {
{
description: "with typical configuration",
input: &Postgres{
Host: "1.2.3.4",
Port: 6543,
User: "some-user",
Password: "some-password",
Database: "some-db",
NetworkDB: NetworkDB{
Host: "1.2.3.4",
Port: 6543,
User: "some-user",
Password: "some-password",
Database: "some-db",
},
},
expected: "connect_timeout=0 host='1.2.3.4' port=6543 user='some-user' password='some-password' dbname='some-db' sslmode='verify-full'",
},
{
description: "with unix socket host",
input: &Postgres{
Host: "/var/run/postgres",
SSL: PostgresSSL{
NetworkDB: NetworkDB{
Host: "/var/run/postgres",
},
SSL: SSL{
Mode: "disable",
},
},
@@ -129,8 +134,10 @@ func TestCreateDataSourceName(t *testing.T) {
{
description: "with tcp host",
input: &Postgres{
Host: "coreos.com",
SSL: PostgresSSL{
NetworkDB: NetworkDB{
Host: "coreos.com",
},
SSL: SSL{
Mode: "disable",
},
},
@@ -139,23 +146,29 @@ func TestCreateDataSourceName(t *testing.T) {
{
description: "with tcp host:port",
input: &Postgres{
Host: "coreos.com:6543",
NetworkDB: NetworkDB{
Host: "coreos.com:6543",
},
},
expected: "connect_timeout=0 host='coreos.com' port=6543 sslmode='verify-full'",
},
{
description: "with tcp host and port",
input: &Postgres{
Host: "coreos.com",
Port: 6543,
NetworkDB: NetworkDB{
Host: "coreos.com",
Port: 6543,
},
},
expected: "connect_timeout=0 host='coreos.com' port=6543 sslmode='verify-full'",
},
{
description: "with ssl ca cert",
input: &Postgres{
Host: "coreos.com",
SSL: PostgresSSL{
NetworkDB: NetworkDB{
Host: "coreos.com",
},
SSL: SSL{
Mode: "verify-ca",
CAFile: "/some/file/path",
},
@@ -165,8 +178,10 @@ func TestCreateDataSourceName(t *testing.T) {
{
description: "with ssl client cert",
input: &Postgres{
Host: "coreos.com",
SSL: PostgresSSL{
NetworkDB: NetworkDB{
Host: "coreos.com",
},
SSL: SSL{
Mode: "verify-ca",
CAFile: "/some/ca/path",
CertFile: "/some/cert/path",
@@ -178,9 +193,11 @@ func TestCreateDataSourceName(t *testing.T) {
{
description: "with funny characters in credentials",
input: &Postgres{
Host: "coreos.com",
User: `some'user\slashed`,
Password: "some'password!",
NetworkDB: NetworkDB{
Host: "coreos.com",
User: `some'user\slashed`,
Password: "some'password!",
},
},
expected: `connect_timeout=0 host='coreos.com' user='some\'user\\slashed' password='some\'password!' sslmode='verify-full'`,
},

View File

@@ -13,17 +13,19 @@ func TestPostgresTunables(t *testing.T) {
t.Skipf("test environment variable %q not set, skipping", testPostgresEnv)
}
baseCfg := &Postgres{
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
User: getenv("DEX_POSTGRES_USER", "postgres"),
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
Host: host,
SSL: PostgresSSL{
Mode: sslDisable, // Postgres container doesn't support SSL.
NetworkDB: NetworkDB{
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
User: getenv("DEX_POSTGRES_USER", "postgres"),
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
Host: host,
},
SSL: SSL{
Mode: pgSSLDisable, // Postgres container doesn't support SSL.
}}
t.Run("with nothing set, uses defaults", func(t *testing.T) {
cfg := *baseCfg
c, err := cfg.open(logger, cfg.createDataSourceName())
c, err := cfg.open(logger)
if err != nil {
t.Fatalf("error opening connector: %s", err.Error())
}
@@ -36,7 +38,7 @@ func TestPostgresTunables(t *testing.T) {
t.Run("with something set, uses that", func(t *testing.T) {
cfg := *baseCfg
cfg.MaxOpenConns = 101
c, err := cfg.open(logger, cfg.createDataSourceName())
c, err := cfg.open(logger)
if err != nil {
t.Fatalf("error opening connector: %s", err.Error())
}

View File

@@ -91,20 +91,16 @@ var (
{matchLiteral("bytea"), "blob"},
{matchLiteral("timestamptz"), "datetime(3)"},
// MySQL doesn't support indicies on text fields w/o
// specifying key length. Use varchar instead (768 is
// the max key length for InnoDB with 4k pages).
{matchLiteral("text"), "varchar(768)"},
// specifying key length. Use varchar instead (767 byte
// is the max key length for InnoDB with 4k pages).
// For compound indexes (with two keys) even less.
{matchLiteral("text"), "varchar(384)"},
// Quote keywords and reserved words used as identifiers.
{regexp.MustCompile(`\b(keys)\b`), "`$1`"},
// Change default timestamp to fit datetime.
{regexp.MustCompile(`0001-01-01 00:00:00 UTC`), "1000-01-01 00:00:00"},
},
}
// Not tested.
flavorCockroach = flavor{
executeTx: crdb.ExecuteTx,
}
)
func (f flavor) translate(query string) string {