storage/sql: rework of the original MySQL PR
This commit is contained in:
@@ -6,12 +6,12 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"net"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/lib/pq"
|
||||
sqlite3 "github.com/mattn/go-sqlite3"
|
||||
@@ -124,7 +124,7 @@ type Postgres struct {
|
||||
|
||||
// Open creates a new storage implementation backed by Postgres.
|
||||
func (p *Postgres) Open(logger log.Logger) (storage.Storage, error) {
|
||||
conn, err := p.open(logger, p.createDataSourceName())
|
||||
conn, err := p.open(logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -183,7 +183,7 @@ func (p *Postgres) createDataSourceName() string {
|
||||
|
||||
if p.SSL.Mode == "" {
|
||||
// Assume the strictest mode if unspecified.
|
||||
addParam("sslmode", dataSourceStr(sslVerifyFull))
|
||||
addParam("sslmode", dataSourceStr(pgSSLVerifyFull))
|
||||
} else {
|
||||
addParam("sslmode", dataSourceStr(p.SSL.Mode))
|
||||
}
|
||||
@@ -203,7 +203,9 @@ func (p *Postgres) createDataSourceName() string {
|
||||
return strings.Join(parameters, " ")
|
||||
}
|
||||
|
||||
func (p *Postgres) open(logger log.Logger, dataSourceName string) (*conn, error) {
|
||||
func (p *Postgres) open(logger log.Logger) (*conn, error) {
|
||||
dataSourceName := p.createDataSourceName()
|
||||
|
||||
db, err := sql.Open("postgres", dataSourceName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -253,7 +255,7 @@ type MySQL struct {
|
||||
}
|
||||
|
||||
// Open creates a new storage implementation backed by MySQL.
|
||||
func (s *MySQL) Open(logger logrus.FieldLogger) (storage.Storage, error) {
|
||||
func (s *MySQL) Open(logger log.Logger) (storage.Storage, error) {
|
||||
conn, err := s.open(logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -261,17 +263,18 @@ func (s *MySQL) Open(logger logrus.FieldLogger) (storage.Storage, error) {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *MySQL) open(logger logrus.FieldLogger) (*conn, error) {
|
||||
func (s *MySQL) open(logger log.Logger) (*conn, error) {
|
||||
cfg := mysql.Config{
|
||||
User: s.User,
|
||||
Passwd: s.Password,
|
||||
DBName: s.Database,
|
||||
User: s.User,
|
||||
Passwd: s.Password,
|
||||
DBName: s.Database,
|
||||
AllowNativePasswords: true,
|
||||
|
||||
Timeout: time.Second * time.Duration(s.ConnectionTimeout),
|
||||
|
||||
ParseTime: true,
|
||||
Params: map[string]string{
|
||||
"tx_isolation": "'SERIALIZABLE'",
|
||||
"transaction_isolation": "'SERIALIZABLE'",
|
||||
},
|
||||
}
|
||||
if s.Host != "" {
|
||||
@@ -288,6 +291,8 @@ func (s *MySQL) open(logger logrus.FieldLogger) (*conn, error) {
|
||||
return nil, fmt.Errorf("failed to make TLS config: %v", err)
|
||||
}
|
||||
cfg.TLSConfig = mysqlSSLCustom
|
||||
} else if s.SSL.Mode == "" {
|
||||
cfg.TLSConfig = mysqlSSLTrue
|
||||
} else {
|
||||
cfg.TLSConfig = s.SSL.Mode
|
||||
}
|
||||
|
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/dexidp/dex/pkg/log"
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/conformance"
|
||||
)
|
||||
@@ -51,7 +52,7 @@ var logger = &logrus.Logger{
|
||||
}
|
||||
|
||||
type opener interface {
|
||||
open(logrus.FieldLogger) (*conn, error)
|
||||
open(logger log.Logger) (*conn, error)
|
||||
}
|
||||
|
||||
func testDB(t *testing.T, o opener, withTransactions bool) {
|
||||
@@ -108,19 +109,23 @@ func TestCreateDataSourceName(t *testing.T) {
|
||||
{
|
||||
description: "with typical configuration",
|
||||
input: &Postgres{
|
||||
Host: "1.2.3.4",
|
||||
Port: 6543,
|
||||
User: "some-user",
|
||||
Password: "some-password",
|
||||
Database: "some-db",
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "1.2.3.4",
|
||||
Port: 6543,
|
||||
User: "some-user",
|
||||
Password: "some-password",
|
||||
Database: "some-db",
|
||||
},
|
||||
},
|
||||
expected: "connect_timeout=0 host='1.2.3.4' port=6543 user='some-user' password='some-password' dbname='some-db' sslmode='verify-full'",
|
||||
},
|
||||
{
|
||||
description: "with unix socket host",
|
||||
input: &Postgres{
|
||||
Host: "/var/run/postgres",
|
||||
SSL: PostgresSSL{
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "/var/run/postgres",
|
||||
},
|
||||
SSL: SSL{
|
||||
Mode: "disable",
|
||||
},
|
||||
},
|
||||
@@ -129,8 +134,10 @@ func TestCreateDataSourceName(t *testing.T) {
|
||||
{
|
||||
description: "with tcp host",
|
||||
input: &Postgres{
|
||||
Host: "coreos.com",
|
||||
SSL: PostgresSSL{
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "coreos.com",
|
||||
},
|
||||
SSL: SSL{
|
||||
Mode: "disable",
|
||||
},
|
||||
},
|
||||
@@ -139,23 +146,29 @@ func TestCreateDataSourceName(t *testing.T) {
|
||||
{
|
||||
description: "with tcp host:port",
|
||||
input: &Postgres{
|
||||
Host: "coreos.com:6543",
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "coreos.com:6543",
|
||||
},
|
||||
},
|
||||
expected: "connect_timeout=0 host='coreos.com' port=6543 sslmode='verify-full'",
|
||||
},
|
||||
{
|
||||
description: "with tcp host and port",
|
||||
input: &Postgres{
|
||||
Host: "coreos.com",
|
||||
Port: 6543,
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "coreos.com",
|
||||
Port: 6543,
|
||||
},
|
||||
},
|
||||
expected: "connect_timeout=0 host='coreos.com' port=6543 sslmode='verify-full'",
|
||||
},
|
||||
{
|
||||
description: "with ssl ca cert",
|
||||
input: &Postgres{
|
||||
Host: "coreos.com",
|
||||
SSL: PostgresSSL{
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "coreos.com",
|
||||
},
|
||||
SSL: SSL{
|
||||
Mode: "verify-ca",
|
||||
CAFile: "/some/file/path",
|
||||
},
|
||||
@@ -165,8 +178,10 @@ func TestCreateDataSourceName(t *testing.T) {
|
||||
{
|
||||
description: "with ssl client cert",
|
||||
input: &Postgres{
|
||||
Host: "coreos.com",
|
||||
SSL: PostgresSSL{
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "coreos.com",
|
||||
},
|
||||
SSL: SSL{
|
||||
Mode: "verify-ca",
|
||||
CAFile: "/some/ca/path",
|
||||
CertFile: "/some/cert/path",
|
||||
@@ -178,9 +193,11 @@ func TestCreateDataSourceName(t *testing.T) {
|
||||
{
|
||||
description: "with funny characters in credentials",
|
||||
input: &Postgres{
|
||||
Host: "coreos.com",
|
||||
User: `some'user\slashed`,
|
||||
Password: "some'password!",
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "coreos.com",
|
||||
User: `some'user\slashed`,
|
||||
Password: "some'password!",
|
||||
},
|
||||
},
|
||||
expected: `connect_timeout=0 host='coreos.com' user='some\'user\\slashed' password='some\'password!' sslmode='verify-full'`,
|
||||
},
|
||||
|
@@ -13,17 +13,19 @@ func TestPostgresTunables(t *testing.T) {
|
||||
t.Skipf("test environment variable %q not set, skipping", testPostgresEnv)
|
||||
}
|
||||
baseCfg := &Postgres{
|
||||
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
|
||||
User: getenv("DEX_POSTGRES_USER", "postgres"),
|
||||
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
|
||||
Host: host,
|
||||
SSL: PostgresSSL{
|
||||
Mode: sslDisable, // Postgres container doesn't support SSL.
|
||||
NetworkDB: NetworkDB{
|
||||
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
|
||||
User: getenv("DEX_POSTGRES_USER", "postgres"),
|
||||
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
|
||||
Host: host,
|
||||
},
|
||||
SSL: SSL{
|
||||
Mode: pgSSLDisable, // Postgres container doesn't support SSL.
|
||||
}}
|
||||
|
||||
t.Run("with nothing set, uses defaults", func(t *testing.T) {
|
||||
cfg := *baseCfg
|
||||
c, err := cfg.open(logger, cfg.createDataSourceName())
|
||||
c, err := cfg.open(logger)
|
||||
if err != nil {
|
||||
t.Fatalf("error opening connector: %s", err.Error())
|
||||
}
|
||||
@@ -36,7 +38,7 @@ func TestPostgresTunables(t *testing.T) {
|
||||
t.Run("with something set, uses that", func(t *testing.T) {
|
||||
cfg := *baseCfg
|
||||
cfg.MaxOpenConns = 101
|
||||
c, err := cfg.open(logger, cfg.createDataSourceName())
|
||||
c, err := cfg.open(logger)
|
||||
if err != nil {
|
||||
t.Fatalf("error opening connector: %s", err.Error())
|
||||
}
|
||||
|
@@ -91,20 +91,16 @@ var (
|
||||
{matchLiteral("bytea"), "blob"},
|
||||
{matchLiteral("timestamptz"), "datetime(3)"},
|
||||
// MySQL doesn't support indicies on text fields w/o
|
||||
// specifying key length. Use varchar instead (768 is
|
||||
// the max key length for InnoDB with 4k pages).
|
||||
{matchLiteral("text"), "varchar(768)"},
|
||||
// specifying key length. Use varchar instead (767 byte
|
||||
// is the max key length for InnoDB with 4k pages).
|
||||
// For compound indexes (with two keys) even less.
|
||||
{matchLiteral("text"), "varchar(384)"},
|
||||
// Quote keywords and reserved words used as identifiers.
|
||||
{regexp.MustCompile(`\b(keys)\b`), "`$1`"},
|
||||
// Change default timestamp to fit datetime.
|
||||
{regexp.MustCompile(`0001-01-01 00:00:00 UTC`), "1000-01-01 00:00:00"},
|
||||
},
|
||||
}
|
||||
|
||||
// Not tested.
|
||||
flavorCockroach = flavor{
|
||||
executeTx: crdb.ExecuteTx,
|
||||
}
|
||||
)
|
||||
|
||||
func (f flavor) translate(query string) string {
|
||||
|
Reference in New Issue
Block a user