49adc4e5bb
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
192 lines
4.3 KiB
Go
192 lines
4.3 KiB
Go
package ent
|
|
|
|
import (
|
|
"os"
|
|
"strconv"
|
|
"testing"
|
|
|
|
"github.com/sirupsen/logrus"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/dexidp/dex/storage"
|
|
"github.com/dexidp/dex/storage/conformance"
|
|
)
|
|
|
|
const (
|
|
PostgresEntHostEnv = "DEX_POSTGRES_ENT_HOST"
|
|
PostgresEntPortEnv = "DEX_POSTGRES_ENT_PORT"
|
|
PostgresEntDatabaseEnv = "DEX_POSTGRES_ENT_DATABASE"
|
|
PostgresEntUserEnv = "DEX_POSTGRES_ENT_USER"
|
|
PostgresEntPasswordEnv = "DEX_POSTGRES_ENT_PASSWORD"
|
|
)
|
|
|
|
func getenv(key, defaultVal string) string {
|
|
if val := os.Getenv(key); val != "" {
|
|
return val
|
|
}
|
|
return defaultVal
|
|
}
|
|
|
|
func postgresTestConfig(host string, port uint64) *Postgres {
|
|
return &Postgres{
|
|
NetworkDB: NetworkDB{
|
|
Database: getenv(PostgresEntDatabaseEnv, "postgres"),
|
|
User: getenv(PostgresEntUserEnv, "postgres"),
|
|
Password: getenv(PostgresEntPasswordEnv, "postgres"),
|
|
Host: host,
|
|
Port: uint16(port),
|
|
},
|
|
SSL: SSL{
|
|
Mode: pgSSLDisable, // Postgres container doesn't support SSL.
|
|
},
|
|
}
|
|
}
|
|
|
|
func newPostgresStorage(host string, port uint64) storage.Storage {
|
|
logger := &logrus.Logger{
|
|
Out: os.Stderr,
|
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
|
Level: logrus.DebugLevel,
|
|
}
|
|
|
|
cfg := postgresTestConfig(host, port)
|
|
s, err := cfg.Open(logger)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return s
|
|
}
|
|
|
|
func TestPostgres(t *testing.T) {
|
|
host := os.Getenv(PostgresEntHostEnv)
|
|
if host == "" {
|
|
t.Skipf("test environment variable %s not set, skipping", PostgresEntHostEnv)
|
|
}
|
|
|
|
port := uint64(5432)
|
|
if rawPort := os.Getenv(PostgresEntPortEnv); rawPort != "" {
|
|
var err error
|
|
|
|
port, err = strconv.ParseUint(rawPort, 10, 32)
|
|
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err)
|
|
}
|
|
|
|
newStorage := func() storage.Storage {
|
|
return newPostgresStorage(host, port)
|
|
}
|
|
conformance.RunTests(t, newStorage)
|
|
conformance.RunTransactionTests(t, newStorage)
|
|
}
|
|
|
|
func TestPostgresDSN(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg *Postgres
|
|
desiredDSN string
|
|
}{
|
|
{
|
|
name: "Host port",
|
|
cfg: &Postgres{
|
|
NetworkDB: NetworkDB{
|
|
Host: "localhost",
|
|
Port: uint16(5432),
|
|
},
|
|
},
|
|
desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'",
|
|
},
|
|
{
|
|
name: "Host with port",
|
|
cfg: &Postgres{
|
|
NetworkDB: NetworkDB{
|
|
Host: "localhost:5432",
|
|
},
|
|
},
|
|
desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'",
|
|
},
|
|
{
|
|
name: "Host ipv6 with port",
|
|
cfg: &Postgres{
|
|
NetworkDB: NetworkDB{
|
|
Host: "[a:b:c:d]:5432",
|
|
},
|
|
},
|
|
desiredDSN: "connect_timeout=0 host='a:b:c:d' port=5432 sslmode='verify-full'",
|
|
},
|
|
{
|
|
name: "Credentials and timeout",
|
|
cfg: &Postgres{
|
|
NetworkDB: NetworkDB{
|
|
Database: "test",
|
|
User: "test",
|
|
Password: "test",
|
|
ConnectionTimeout: 5,
|
|
},
|
|
},
|
|
desiredDSN: "connect_timeout=5 user='test' password='test' dbname='test' sslmode='verify-full'",
|
|
},
|
|
{
|
|
name: "SSL",
|
|
cfg: &Postgres{
|
|
SSL: SSL{
|
|
Mode: pgSSLRequire,
|
|
CAFile: "/ca.crt",
|
|
KeyFile: "/cert.crt",
|
|
CertFile: "/cert.key",
|
|
},
|
|
},
|
|
desiredDSN: "connect_timeout=0 sslmode='require' sslrootcert='/ca.crt' sslcert='/cert.key' sslkey='/cert.crt'",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
require.Equal(t, tt.desiredDSN, tt.cfg.dsn())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPostgresDriver(t *testing.T) {
|
|
host := os.Getenv(PostgresEntHostEnv)
|
|
if host == "" {
|
|
t.Skipf("test environment variable %s not set, skipping", PostgresEntHostEnv)
|
|
}
|
|
|
|
port := uint64(5432)
|
|
if rawPort := os.Getenv(PostgresEntPortEnv); rawPort != "" {
|
|
var err error
|
|
|
|
port, err = strconv.ParseUint(rawPort, 10, 32)
|
|
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
cfg func() *Postgres
|
|
desiredConns int
|
|
}{
|
|
{
|
|
name: "Defaults",
|
|
cfg: func() *Postgres { return postgresTestConfig(host, port) },
|
|
desiredConns: 5,
|
|
},
|
|
{
|
|
name: "Tune",
|
|
cfg: func() *Postgres {
|
|
cfg := postgresTestConfig(host, port)
|
|
cfg.MaxOpenConns = 101
|
|
return cfg
|
|
},
|
|
desiredConns: 101,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
drv, err := tt.cfg().driver()
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, tt.desiredConns, drv.DB().Stats().MaxOpenConnections)
|
|
})
|
|
}
|
|
}
|