feat: Add ent-based postgres storage
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
183
storage/ent/postgres_test.go
Normal file
183
storage/ent/postgres_test.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package ent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/conformance"
|
||||
)
|
||||
|
||||
func getenv(key, defaultVal string) string {
|
||||
if val := os.Getenv(key); val != "" {
|
||||
return val
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
func postgresTestConfig(host string, port uint64) *Postgres {
|
||||
return &Postgres{
|
||||
NetworkDB: NetworkDB{
|
||||
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
|
||||
User: getenv("DEX_POSTGRES_USER", "postgres"),
|
||||
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
|
||||
Host: host,
|
||||
Port: uint16(port),
|
||||
},
|
||||
SSL: SSL{
|
||||
Mode: pgSSLDisable, // Postgres container doesn't support SSL.
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newPostgresStorage(host string, port uint64) storage.Storage {
|
||||
logger := &logrus.Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
|
||||
cfg := postgresTestConfig(host, port)
|
||||
s, err := cfg.Open(logger)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestPostgres(t *testing.T) {
|
||||
host := os.Getenv("DEX_POSTGRES_HOST")
|
||||
if host == "" {
|
||||
t.Skipf("test environment variable DEX_POSTGRES_HOST not set, skipping")
|
||||
}
|
||||
|
||||
port := uint64(5432)
|
||||
if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" {
|
||||
var err error
|
||||
|
||||
port, err = strconv.ParseUint(rawPort, 10, 32)
|
||||
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err)
|
||||
}
|
||||
|
||||
newStorage := func() storage.Storage {
|
||||
return newPostgresStorage(host, port)
|
||||
}
|
||||
conformance.RunTests(t, newStorage)
|
||||
conformance.RunTransactionTests(t, newStorage)
|
||||
}
|
||||
|
||||
func TestPostgresDSN(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *Postgres
|
||||
desiredDSN string
|
||||
}{
|
||||
{
|
||||
name: "Host port",
|
||||
cfg: &Postgres{
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "localhost",
|
||||
Port: uint16(5432),
|
||||
},
|
||||
},
|
||||
desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'",
|
||||
},
|
||||
{
|
||||
name: "Host with port",
|
||||
cfg: &Postgres{
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "localhost:5432",
|
||||
},
|
||||
},
|
||||
desiredDSN: "connect_timeout=0 host='localhost' port=5432 sslmode='verify-full'",
|
||||
},
|
||||
{
|
||||
name: "Host ipv6 with port",
|
||||
cfg: &Postgres{
|
||||
NetworkDB: NetworkDB{
|
||||
Host: "[a:b:c:d]:5432",
|
||||
},
|
||||
},
|
||||
desiredDSN: "connect_timeout=0 host='a:b:c:d' port=5432 sslmode='verify-full'",
|
||||
},
|
||||
{
|
||||
name: "Credentials and timeout",
|
||||
cfg: &Postgres{
|
||||
NetworkDB: NetworkDB{
|
||||
Database: "test",
|
||||
User: "test",
|
||||
Password: "test",
|
||||
ConnectionTimeout: 5,
|
||||
},
|
||||
},
|
||||
desiredDSN: "connect_timeout=5 user='test' password='test' dbname='test' sslmode='verify-full'",
|
||||
},
|
||||
{
|
||||
name: "SSL",
|
||||
cfg: &Postgres{
|
||||
SSL: SSL{
|
||||
Mode: pgSSLRequire,
|
||||
CAFile: "/ca.crt",
|
||||
KeyFile: "/cert.crt",
|
||||
CertFile: "/cert.key",
|
||||
},
|
||||
},
|
||||
desiredDSN: "connect_timeout=0 sslmode='require' sslrootcert='/ca.crt' sslcert='/cert.key' sslkey='/cert.crt'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.desiredDSN, tt.cfg.dsn())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgresDriver(t *testing.T) {
|
||||
host := os.Getenv("DEX_POSTGRES_HOST")
|
||||
if host == "" {
|
||||
t.Skipf("test environment variable DEX_POSTGRES_HOST not set, skipping")
|
||||
}
|
||||
|
||||
port := uint64(5432)
|
||||
if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" {
|
||||
var err error
|
||||
|
||||
port, err = strconv.ParseUint(rawPort, 10, 32)
|
||||
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg func() *Postgres
|
||||
desiredConns int
|
||||
}{
|
||||
{
|
||||
name: "Defaults",
|
||||
cfg: func() *Postgres { return postgresTestConfig(host, port) },
|
||||
desiredConns: 5,
|
||||
},
|
||||
{
|
||||
name: "Tune",
|
||||
cfg: func() *Postgres {
|
||||
cfg := postgresTestConfig(host, port)
|
||||
cfg.MaxOpenConns = 101
|
||||
return cfg
|
||||
},
|
||||
desiredConns: 101,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
drv, err := tt.cfg().driver()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, tt.desiredConns, drv.DB().Stats().MaxOpenConnections)
|
||||
})
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user