package sql import ( "fmt" "os" "runtime" "strconv" "testing" "time" "github.com/sirupsen/logrus" "github.com/dexidp/dex/pkg/log" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/conformance" ) func withTimeout(t time.Duration, f func()) { c := make(chan struct{}) defer close(c) go func() { select { case <-c: case <-time.After(t): // Dump a stack trace of the program. Useful for debugging deadlocks. buf := make([]byte, 2<<20) fmt.Fprintf(os.Stderr, "%s\n", buf[:runtime.Stack(buf, true)]) panic("test took too long") } }() f() } func cleanDB(c *conn) error { tables := []string{ "client", "auth_request", "auth_code", "refresh_token", "keys", "password", } for _, tbl := range tables { _, err := c.Exec("delete from " + tbl) if err != nil { return err } } return nil } var logger = &logrus.Logger{ Out: os.Stderr, Formatter: &logrus.TextFormatter{DisableColors: true}, Level: logrus.DebugLevel, } type opener interface { open(logger log.Logger) (*conn, error) } func testDB(t *testing.T, o opener, withTransactions bool) { // t.Fatal has a bad habit of not actually printing the error fatal := func(i interface{}) { fmt.Fprintln(os.Stdout, i) t.Fatal(i) } newStorage := func() storage.Storage { conn, err := o.open(logger) if err != nil { fatal(err) } if err := cleanDB(conn); err != nil { fatal(err) } return conn } withTimeout(time.Minute*1, func() { conformance.RunTests(t, newStorage) }) if withTransactions { withTimeout(time.Minute*1, func() { conformance.RunTransactionTests(t, newStorage) }) } } func getenv(key, defaultVal string) string { if val := os.Getenv(key); val != "" { return val } return defaultVal } const testPostgresEnv = "DEX_POSTGRES_HOST" func TestCreateDataSourceName(t *testing.T) { testCases := []struct { description string input *Postgres expected string }{ { description: "with no configuration", input: &Postgres{}, expected: "connect_timeout=0 sslmode='verify-full'", }, { description: "with typical configuration", input: &Postgres{ NetworkDB: NetworkDB{ Host: "1.2.3.4", Port: 6543, User: "some-user", Password: "some-password", Database: "some-db", }, }, expected: "connect_timeout=0 host='1.2.3.4' port=6543 user='some-user' password='some-password' dbname='some-db' sslmode='verify-full'", }, { description: "with unix socket host", input: &Postgres{ NetworkDB: NetworkDB{ Host: "/var/run/postgres", }, SSL: SSL{ Mode: "disable", }, }, expected: "connect_timeout=0 host='/var/run/postgres' sslmode='disable'", }, { description: "with tcp host", input: &Postgres{ NetworkDB: NetworkDB{ Host: "coreos.com", }, SSL: SSL{ Mode: "disable", }, }, expected: "connect_timeout=0 host='coreos.com' sslmode='disable'", }, { description: "with tcp host:port", input: &Postgres{ NetworkDB: NetworkDB{ Host: "coreos.com:6543", }, }, expected: "connect_timeout=0 host='coreos.com' port=6543 sslmode='verify-full'", }, { description: "with tcp host and port", input: &Postgres{ NetworkDB: NetworkDB{ Host: "coreos.com", Port: 6543, }, }, expected: "connect_timeout=0 host='coreos.com' port=6543 sslmode='verify-full'", }, { description: "with ssl ca cert", input: &Postgres{ NetworkDB: NetworkDB{ Host: "coreos.com", }, SSL: SSL{ Mode: "verify-ca", CAFile: "/some/file/path", }, }, expected: "connect_timeout=0 host='coreos.com' sslmode='verify-ca' sslrootcert='/some/file/path'", }, { description: "with ssl client cert", input: &Postgres{ NetworkDB: NetworkDB{ Host: "coreos.com", }, SSL: SSL{ Mode: "verify-ca", CAFile: "/some/ca/path", CertFile: "/some/cert/path", KeyFile: "/some/key/path", }, }, expected: "connect_timeout=0 host='coreos.com' sslmode='verify-ca' sslrootcert='/some/ca/path' sslcert='/some/cert/path' sslkey='/some/key/path'", }, { description: "with funny characters in credentials", input: &Postgres{ NetworkDB: NetworkDB{ Host: "coreos.com", User: `some'user\slashed`, Password: "some'password!", }, }, expected: `connect_timeout=0 host='coreos.com' user='some\'user\\slashed' password='some\'password!' sslmode='verify-full'`, }, } var actual string for _, testCase := range testCases { t.Run(testCase.description, func(t *testing.T) { actual = testCase.input.createDataSourceName() if actual != testCase.expected { t.Fatalf("%s != %s", actual, testCase.expected) } }) } } func TestPostgres(t *testing.T) { host := os.Getenv(testPostgresEnv) if host == "" { t.Skipf("test environment variable %q not set, skipping", testPostgresEnv) } port := uint64(5432) if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" { var err error port, err = strconv.ParseUint(rawPort, 10, 32) if err != nil { t.Fatalf("invalid postgres port %q: %s", rawPort, err) } } p := &Postgres{ NetworkDB: NetworkDB{ Database: getenv("DEX_POSTGRES_DATABASE", "postgres"), User: getenv("DEX_POSTGRES_USER", "postgres"), Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"), Host: host, Port: uint16(port), ConnectionTimeout: 5, }, SSL: SSL{ Mode: pgSSLDisable, // Postgres container doesn't support SSL. }, } testDB(t, p, true) } const testMySQLEnv = "DEX_MYSQL_HOST" func TestMySQL(t *testing.T) { host := os.Getenv(testMySQLEnv) if host == "" { t.Skipf("test environment variable %q not set, skipping", testMySQLEnv) } port := uint64(3306) if rawPort := os.Getenv("DEX_MYSQL_PORT"); rawPort != "" { var err error port, err = strconv.ParseUint(rawPort, 10, 32) if err != nil { t.Fatalf("invalid mysql port %q: %s", rawPort, err) } } s := &MySQL{ NetworkDB: NetworkDB{ Database: getenv("DEX_MYSQL_DATABASE", "mysql"), User: getenv("DEX_MYSQL_USER", "mysql"), Password: getenv("DEX_MYSQL_PASSWORD", "mysql"), Host: host, Port: uint16(port), ConnectionTimeout: 5, }, SSL: SSL{ Mode: mysqlSSLFalse, }, params: map[string]string{ "innodb_lock_wait_timeout": "3", }, } testDB(t, s, true) }