From 87a7d093b27c42b992d6a49c44abfebf37c90c30 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Wed, 14 Sep 2016 18:11:57 -0700 Subject: [PATCH] storage/sql: add a SQL storage implementation This change adds support for SQLite3, and Postgres. --- storage/sql/config.go | 113 +++++++++ storage/sql/config_test.go | 1 + storage/sql/crud.go | 487 ++++++++++++++++++++++++++++++++++++ storage/sql/crud_test.go | 55 ++++ storage/sql/gc.go | 24 ++ storage/sql/gc_test.go | 53 ++++ storage/sql/migrate.go | 151 +++++++++++ storage/sql/migrate_test.go | 25 ++ storage/sql/sql.go | 152 +++++++++++ storage/sql/sql_test.go | 55 ++++ 10 files changed, 1116 insertions(+) create mode 100644 storage/sql/config.go create mode 100644 storage/sql/config_test.go create mode 100644 storage/sql/crud.go create mode 100644 storage/sql/crud_test.go create mode 100644 storage/sql/gc.go create mode 100644 storage/sql/gc_test.go create mode 100644 storage/sql/migrate.go create mode 100644 storage/sql/migrate_test.go create mode 100644 storage/sql/sql.go create mode 100644 storage/sql/sql_test.go diff --git a/storage/sql/config.go b/storage/sql/config.go new file mode 100644 index 00000000..bf56d017 --- /dev/null +++ b/storage/sql/config.go @@ -0,0 +1,113 @@ +package sql + +import ( + "database/sql" + "fmt" + "net/url" + "strconv" + + "github.com/coreos/dex/storage" +) + +// SQLite3 options for creating an SQL db. +type SQLite3 struct { + // File to + File string `yaml:"file"` +} + +// Open creates a new storage implementation backed by SQLite3 +func (s *SQLite3) Open() (storage.Storage, error) { + return s.open() +} + +func (s *SQLite3) open() (*conn, error) { + db, err := sql.Open("sqlite3", s.File) + if err != nil { + return nil, err + } + if s.File == ":memory:" { + // sqlite3 uses file locks to coordinate concurrent access. In memory + // doesn't support this, so limit the number of connections to 1. + db.SetMaxOpenConns(1) + } + c := &conn{db, flavorSQLite3} + if _, err := c.migrate(); err != nil { + return nil, fmt.Errorf("failed to perform migrations: %v", err) + } + return c, nil +} + +const ( + sslDisable = "disable" + sslRequire = "require" + sslVerifyCA = "verify-ca" + sslVerifyFull = "verify-full" +) + +// PostgresSSL represents SSL options for Postgres databases. +type PostgresSSL struct { + Mode string + CAFile string + // Files for client auth. + KeyFile string + CertFile string +} + +// Postgres options for creating an SQL db. +type Postgres struct { + Database string + User string + Password string + Host string + + SSL PostgresSSL `json:"ssl" yaml:"ssl"` + + ConnectionTimeout int // Seconds +} + +// Open creates a new storage implementation backed by Postgres. +func (p *Postgres) Open() (storage.Storage, error) { + return p.open() +} + +func (p *Postgres) open() (*conn, error) { + v := url.Values{} + set := func(key, val string) { + if val != "" { + v.Set(key, val) + } + } + set("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) + set("sslkey", p.SSL.KeyFile) + set("sslcert", p.SSL.CertFile) + set("sslrootcert", p.SSL.CAFile) + if p.SSL.Mode == "" { + // Assume the strictest mode if unspecified. + p.SSL.Mode = sslVerifyFull + } + set("sslmode", p.SSL.Mode) + + u := url.URL{ + Scheme: "postgres", + Host: p.Host, + Path: "/" + p.Database, + RawQuery: v.Encode(), + } + + if p.User != "" { + if p.Password != "" { + u.User = url.UserPassword(p.User, p.Password) + } else { + u.User = url.User(p.User) + } + } + db, err := sql.Open("postgres", u.String()) + if err != nil { + return nil, err + } + c := &conn{db, flavorPostgres} + if _, err := c.migrate(); err != nil { + return nil, fmt.Errorf("failed to perform migrations: %v", err) + } + return c, nil +} diff --git a/storage/sql/config_test.go b/storage/sql/config_test.go new file mode 100644 index 00000000..e4b317b4 --- /dev/null +++ b/storage/sql/config_test.go @@ -0,0 +1 @@ +package sql diff --git a/storage/sql/crud.go b/storage/sql/crud.go new file mode 100644 index 00000000..532b8648 --- /dev/null +++ b/storage/sql/crud.go @@ -0,0 +1,487 @@ +package sql + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + + "github.com/coreos/dex/storage" +) + +// TODO(ericchiang): The update, insert, and select methods queries are all +// very repetivite. Consider creating them programatically. + +// keysRowID is the ID of the only row we expect to populate the "keys" table. +const keysRowID = "keys" + +// encoder wraps the underlying value in a JSON marshaler which is automatically +// called by the database/sql package. +// +// s := []string{"planes", "bears"} +// err := db.Exec(`insert into t1 (id, things) values (1, $1)`, encoder(s)) +// if err != nil { +// // handle error +// } +// +// var r []byte +// err = db.QueryRow(`select things from t1 where id = 1;`).Scan(&r) +// if err != nil { +// // handle error +// } +// fmt.Printf("%s\n", r) // ["planes","bears"] +// +func encoder(i interface{}) driver.Valuer { + return jsonEncoder{i} +} + +// decoder wraps the underlying value in a JSON unmarshaler which can then be passed +// to a database Scan() method. +func decoder(i interface{}) sql.Scanner { + return jsonDecoder{i} +} + +type jsonEncoder struct { + i interface{} +} + +func (j jsonEncoder) Value() (driver.Value, error) { + b, err := json.Marshal(j.i) + if err != nil { + return nil, fmt.Errorf("marshal: %v", err) + } + return b, nil +} + +type jsonDecoder struct { + i interface{} +} + +func (j jsonDecoder) Scan(dest interface{}) error { + if dest == nil { + return errors.New("nil value") + } + b, ok := dest.([]byte) + if !ok { + return fmt.Errorf("expected []byte got %T", dest) + } + if err := json.Unmarshal(b, &j.i); err != nil { + return fmt.Errorf("unmarshal: %v", err) + } + return nil +} + +// Abstract conn vs trans. +type querier interface { + QueryRow(query string, args ...interface{}) *sql.Row +} + +// Abstract row vs rows. +type scanner interface { + Scan(dest ...interface{}) error +} + +func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { + _, err := c.Exec(` + insert into auth_request ( + id, client_id, response_types, scopes, redirect_uri, nonce, state, + force_approval_prompt, logged_in, + claims_user_id, claims_username, claims_email, claims_email_verified, + claims_groups, + connector_id, connector_data, + expiry + ) + values ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 + ); + `, + a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, + a.ForceApprovalPrompt, a.LoggedIn, + a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, + encoder(a.Claims.Groups), + a.ConnectorID, a.ConnectorData, + a.Expiry, + ) + if err != nil { + return fmt.Errorf("insert auth request: %v", err) + } + return nil +} + +func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { + return c.ExecTx(func(tx *trans) error { + r, err := getAuthRequest(tx, id) + if err != nil { + return err + } + + a, err := updater(r) + if err != nil { + return err + } + _, err = tx.Exec(` + update auth_request + set + client_id = $1, response_types = $2, scopes = $3, redirect_uri = $4, + nonce = $5, state = $6, force_approval_prompt = $7, logged_in = $8, + claims_user_id = $9, claims_username = $10, claims_email = $11, + claims_email_verified = $12, + claims_groups = $13, + connector_id = $14, connector_data = $15, + expiry = $16 + where id = $17; + `, + a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, + a.ForceApprovalPrompt, a.LoggedIn, + a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, + encoder(a.Claims.Groups), + a.ConnectorID, a.ConnectorData, + a.Expiry, a.ID, + ) + if err != nil { + return fmt.Errorf("update auth request: %v", err) + } + return nil + }) + +} + +func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) { + return getAuthRequest(c, id) +} + +func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { + err = q.QueryRow(` + select + id, client_id, response_types, scopes, redirect_uri, nonce, state, + force_approval_prompt, logged_in, + claims_user_id, claims_username, claims_email, claims_email_verified, + claims_groups, + connector_id, connector_data, expiry + from auth_request where id = $1; + `, id).Scan( + &a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, + &a.ForceApprovalPrompt, &a.LoggedIn, + &a.Claims.UserID, &a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, + decoder(&a.Claims.Groups), + &a.ConnectorID, &a.ConnectorData, &a.Expiry, + ) + if err != nil { + if err == sql.ErrNoRows { + return a, storage.ErrNotFound + } + return a, fmt.Errorf("select auth request: %v", err) + } + return a, nil +} + +func (c *conn) CreateAuthCode(a storage.AuthCode) error { + _, err := c.Exec(` + insert into auth_code ( + id, client_id, scopes, nonce, redirect_uri, + claims_user_id, claims_username, + claims_email, claims_email_verified, claims_groups, + connector_id, connector_data, + expiry + ) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13); + `, + a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID, + a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), + a.ConnectorID, a.ConnectorData, a.Expiry, + ) + return err +} + +func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { + err = c.QueryRow(` + select + id, client_id, scopes, nonce, redirect_uri, + claims_user_id, claims_username, + claims_email, claims_email_verified, claims_groups, + connector_id, connector_data, + expiry + from auth_code where id = $1; + `, id).Scan( + &a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID, + &a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups), + &a.ConnectorID, &a.ConnectorData, &a.Expiry, + ) + if err != nil { + if err == sql.ErrNoRows { + return a, storage.ErrNotFound + } + return a, fmt.Errorf("select auth code: %v", err) + } + return a, nil +} + +func (c *conn) CreateRefresh(r storage.RefreshToken) error { + _, err := c.Exec(` + insert into refresh_token ( + id, client_id, scopes, nonce, + claims_user_id, claims_username, claims_email, claims_email_verified, + claims_groups, + connector_id, connector_data + ) + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); + `, + r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce, + r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, + encoder(r.Claims.Groups), + r.ConnectorID, r.ConnectorData, + ) + if err != nil { + return fmt.Errorf("insert refresh_token: %v", err) + } + return nil +} + +func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { + return scanRefresh(c.QueryRow(` + select + id, client_id, scopes, nonce, + claims_user_id, claims_username, claims_email, claims_email_verified, + claims_groups, + connector_id, connector_data + from refresh_token where id = $1; + `, id)) +} + +func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { + rows, err := c.Query(` + select + id, client_id, scopes, nonce, + claims_user_id, claims_username, claims_email, claims_email_verified, + claims_groups, + connector_id, connector_data + from refresh_token; + `) + if err != nil { + return nil, fmt.Errorf("query: %v", err) + } + var tokens []storage.RefreshToken + for rows.Next() { + r, err := scanRefresh(rows) + if err != nil { + return nil, err + } + tokens = append(tokens, r) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("scan: %v", err) + } + return tokens, nil +} + +func scanRefresh(s scanner) (r storage.RefreshToken, err error) { + err = s.Scan( + &r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce, + &r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified, + decoder(&r.Claims.Groups), + &r.ConnectorID, &r.ConnectorData, + ) + if err != nil { + if err == sql.ErrNoRows { + return r, storage.ErrNotFound + } + return r, fmt.Errorf("scan refresh_token: %v", err) + } + return r, nil +} + +func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { + return c.ExecTx(func(tx *trans) error { + firstUpdate := false + // TODO(ericchiang): errors may cause a transaction be rolled back by the SQL + // server. Test this, and consider adding a COUNT() command beforehand. + old, err := getKeys(tx) + if err != nil { + if err != storage.ErrNotFound { + return fmt.Errorf("get keys: %v", err) + } + firstUpdate = true + old = storage.Keys{} + } + + nk, err := updater(old) + if err != nil { + return err + } + + if firstUpdate { + _, err = tx.Exec(` + insert into keys ( + id, verification_keys, signing_key, signing_key_pub, next_rotation + ) + values ($1, $2, $3, $4, $5); + `, + keysRowID, encoder(nk.VerificationKeys), encoder(nk.SigningKey), + encoder(nk.SigningKeyPub), nk.NextRotation, + ) + if err != nil { + return fmt.Errorf("insert: %v", err) + } + } else { + _, err = tx.Exec(` + update keys + set + verification_keys = $1, + signing_key = $2, + singing_key_pub = $3, + next_rotation = $4 + where id = $5; + `, + encoder(nk.VerificationKeys), encoder(nk.SigningKey), + encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID, + ) + if err != nil { + return fmt.Errorf("update: %v", err) + } + } + return nil + }) +} + +func (c *conn) GetKeys() (keys storage.Keys, err error) { + return getKeys(c) +} + +func getKeys(q querier) (keys storage.Keys, err error) { + err = q.QueryRow(` + select + verification_keys, signing_key, signing_key_pub, next_rotation + from keys + where id=$q + `, keysRowID).Scan( + decoder(&keys.VerificationKeys), decoder(&keys.SigningKey), + decoder(&keys.SigningKeyPub), &keys.NextRotation, + ) + if err != nil { + if err == sql.ErrNoRows { + return keys, storage.ErrNotFound + } + return keys, fmt.Errorf("query keys: %v", err) + } + return keys, nil +} + +func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { + return c.ExecTx(func(tx *trans) error { + cli, err := getClient(tx, id) + if err != nil { + return err + } + nc, err := updater(cli) + if err != nil { + return err + } + + _, err = tx.Exec(` + update client + set + secret = $1, + redirect_uris = $2, + trusted_peers = $3, + public = $4, + name = $5, + logo_url = $6 + where id = $7; + `, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id, + ) + if err != nil { + return fmt.Errorf("update client: %v", err) + } + return nil + }) +} + +func (c *conn) CreateClient(cli storage.Client) error { + _, err := c.Exec(` + insert into client ( + id, secret, redirect_uris, trusted_peers, public, name, logo_url + ) + values ($1, $2, $3, $4, $5, $6, $7); + `, + cli.ID, cli.Secret, encoder(cli.RedirectURIs), encoder(cli.TrustedPeers), + cli.Public, cli.Name, cli.LogoURL, + ) + if err != nil { + return fmt.Errorf("insert client: %v", err) + } + return nil +} + +func getClient(q querier, id string) (storage.Client, error) { + return scanClient(q.QueryRow(` + select + id, secret, redirect_uris, trusted_peers, public, name, logo_url + from client where id = $1; + `, id)) +} + +func (c *conn) GetClient(id string) (storage.Client, error) { + return getClient(c, id) +} + +func (c *conn) ListClients() ([]storage.Client, error) { + rows, err := c.Query(` + select + id, secret, redirect_uris, trusted_peers, public, name, logo_url + from client; + `) + if err != nil { + return nil, err + } + var clients []storage.Client + for rows.Next() { + cli, err := scanClient(rows) + if err != nil { + return nil, err + } + clients = append(clients, cli) + } + if err := rows.Err(); err != nil { + return nil, err + } + return clients, nil +} + +func scanClient(s scanner) (cli storage.Client, err error) { + err = s.Scan( + &cli.ID, &cli.Secret, decoder(&cli.RedirectURIs), decoder(&cli.TrustedPeers), + &cli.Public, &cli.Name, &cli.LogoURL, + ) + if err != nil { + if err == sql.ErrNoRows { + return cli, storage.ErrNotFound + } + return cli, fmt.Errorf("get client: %v", err) + } + return cli, nil +} + +func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", id) } +func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", id) } +func (c *conn) DeleteClient(id string) error { return c.delete("client", id) } +func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_token", id) } + +// Do NOT call directly. Does not escape table. +func (c *conn) delete(table, id string) error { + result, err := c.Exec(`delete from `+table+` where id = $1`, id) + if err != nil { + return fmt.Errorf("delete %s: %v", table, id) + } + + // For now mandate that the driver implements RowsAffected. If we ever need to support + // a driver that doesn't implement this, we can run this in a transaction with a get beforehand. + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %v", err) + } + if n < 1 { + return storage.ErrNotFound + } + return nil +} diff --git a/storage/sql/crud_test.go b/storage/sql/crud_test.go new file mode 100644 index 00000000..d6682e17 --- /dev/null +++ b/storage/sql/crud_test.go @@ -0,0 +1,55 @@ +package sql + +import ( + "database/sql" + "reflect" + "testing" +) + +func TestDecoder(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, err := db.Exec(`create table foo ( id integer primary key, bar blob );`); err != nil { + t.Fatal(err) + } + if _, err := db.Exec(`insert into foo ( id, bar ) values (1, ?);`, []byte(`["a", "b"]`)); err != nil { + t.Fatal(err) + } + var got []string + if err := db.QueryRow(`select bar from foo where id = 1;`).Scan(decoder(&got)); err != nil { + t.Fatal(err) + } + want := []string{"a", "b"} + if !reflect.DeepEqual(got, want) { + t.Errorf("wanted %q got %q", want, got) + } +} + +func TestEncoder(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + if _, err := db.Exec(`create table foo ( id integer primary key, bar blob );`); err != nil { + t.Fatal(err) + } + put := []string{"a", "b"} + if _, err := db.Exec(`insert into foo ( id, bar ) values (1, ?)`, encoder(put)); err != nil { + t.Fatal(err) + } + + var got []byte + if err := db.QueryRow(`select bar from foo where id = 1;`).Scan(&got); err != nil { + t.Fatal(err) + } + want := []byte(`["a","b"]`) + if !reflect.DeepEqual(got, want) { + t.Errorf("wanted %q got %q", want, got) + } +} diff --git a/storage/sql/gc.go b/storage/sql/gc.go new file mode 100644 index 00000000..11e70f95 --- /dev/null +++ b/storage/sql/gc.go @@ -0,0 +1,24 @@ +package sql + +import ( + "fmt" + "time" +) + +type gc struct { + now func() time.Time + conn *conn +} + +var tablesWithGC = []string{"auth_request", "auth_code"} + +func (gc gc) run() error { + for _, table := range tablesWithGC { + _, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now()) + if err != nil { + return fmt.Errorf("gc %s: %v", table, err) + } + // TODO(ericchiang): when we have levelled logging print how many rows were gc'd + } + return nil +} diff --git a/storage/sql/gc_test.go b/storage/sql/gc_test.go new file mode 100644 index 00000000..ad6097e0 --- /dev/null +++ b/storage/sql/gc_test.go @@ -0,0 +1,53 @@ +package sql + +import ( + "testing" + "time" + + "github.com/coreos/dex/storage" +) + +func TestGC(t *testing.T) { + // TODO(ericchiang): Add a GarbageCollect method to the storage interface so + // we can write conformance tests instead of directly testing each implementation. + s := &SQLite3{":memory:"} + conn, err := s.open() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + clock := time.Now() + now := func() time.Time { return clock } + + runGC := (gc{now, conn}).run + + a := storage.AuthRequest{ + ID: storage.NewID(), + Expiry: now().Add(time.Second), + } + + if err := conn.CreateAuthRequest(a); err != nil { + t.Fatal(err) + } + + if err := runGC(); err != nil { + t.Errorf("gc failed: %v", err) + } + + if _, err := conn.GetAuthRequest(a.ID); err != nil { + t.Errorf("failed to get auth request after gc: %v", err) + } + + clock = clock.Add(time.Minute) + + if err := runGC(); err != nil { + t.Errorf("gc failed: %v", err) + } + + if _, err := conn.GetAuthRequest(a.ID); err == nil { + t.Errorf("expected error after gc'ing auth request: %v", err) + } else if err != storage.ErrNotFound { + t.Errorf("expected error storage.NotFound got: %v", err) + } +} diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go new file mode 100644 index 00000000..8754caf5 --- /dev/null +++ b/storage/sql/migrate.go @@ -0,0 +1,151 @@ +package sql + +import ( + "database/sql" + "fmt" +) + +func (c *conn) migrate() (int, error) { + _, err := c.Exec(` + create table if not exists migrations ( + num integer not null, + at timestamp not null + ); + `) + if err != nil { + return 0, fmt.Errorf("creating migration table: %v", err) + } + + i := 0 + done := false + for { + err := c.ExecTx(func(tx *trans) error { + // Within a transaction, perform a single migration. + var ( + num sql.NullInt64 + n int + ) + if err := tx.QueryRow(`select max(num) from migrations;`).Scan(&num); err != nil { + return fmt.Errorf("select max migration: %v", err) + } + if num.Valid { + n = int(num.Int64) + } + if n >= len(migrations) { + done = true + return nil + } + + migrationNum := n + 1 + m := migrations[n] + if _, err := tx.Exec(m.stmt); err != nil { + return fmt.Errorf("migration %d failed: %v", migrationNum, err) + } + + q := `insert into migrations (num, at) values ($1, now());` + if _, err := tx.Exec(q, migrationNum); err != nil { + return fmt.Errorf("update migration table: %v", err) + } + return nil + }) + if err != nil { + return i, err + } + if done { + break + } + i++ + } + + return i, nil +} + +type migration struct { + stmt string + // TODO(ericchiang): consider adding additional fields like "forDrivers" +} + +// All SQL flavors share migration strategies. +var migrations = []migration{ + { + stmt: ` + create table client ( + id text not null primary key, + secret text not null, + redirect_uris bytea not null, -- JSON array of strings + trusted_peers bytea not null, -- JSON array of strings + public boolean not null, + name text not null, + logo_url text not null + ); + + create table auth_request ( + id text not null primary key, + client_id text not null, + response_types bytea not null, -- JSON array of strings + scopes bytea not null, -- JSON array of strings + redirect_uri text not null, + nonce text not null, + state text not null, + force_approval_prompt boolean not null, + + logged_in boolean not null, + + claims_user_id text not null, + claims_username text not null, + claims_email text not null, + claims_email_verified boolean not null, + claims_groups bytea not null, -- JSON array of strings + + connector_id text not null, + connector_data bytea, + + expiry timestamp not null + ); + + create table auth_code ( + id text not null primary key, + client_id text not null, + scopes bytea not null, -- JSON array of strings + nonce text not null, + redirect_uri text not null, + + claims_user_id text not null, + claims_username text not null, + claims_email text not null, + claims_email_verified boolean not null, + claims_groups bytea not null, -- JSON array of strings + + connector_id text not null, + connector_data bytea, + + expiry timestamp not null + ); + + create table refresh_token ( + id text not null primary key, + client_id text not null, + scopes bytea not null, -- JSON array of strings + nonce text not null, + + claims_user_id text not null, + claims_username text not null, + claims_email text not null, + claims_email_verified boolean not null, + claims_groups bytea not null, -- JSON array of strings + + connector_id text not null, + connector_data bytea + ); + + -- keys is a weird table because we only ever expect there to be a single row + create table keys ( + id text not null primary key, + verification_keys bytea not null, -- JSON array + signing_key bytea not null, -- JSON object + signing_key_pub bytea not null, -- JSON object + next_rotation timestamp not null + ); + `, + }, +} diff --git a/storage/sql/migrate_test.go b/storage/sql/migrate_test.go new file mode 100644 index 00000000..6c0765a8 --- /dev/null +++ b/storage/sql/migrate_test.go @@ -0,0 +1,25 @@ +package sql + +import ( + "database/sql" + "testing" +) + +func TestMigrate(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + c := &conn{db, flavorSQLite3} + for _, want := range []int{len(migrations), 0} { + got, err := c.migrate() + if err != nil { + t.Fatal(err) + } + if got != want { + t.Errorf("expected %d migrations, got %d", want, got) + } + } +} diff --git a/storage/sql/sql.go b/storage/sql/sql.go new file mode 100644 index 00000000..113fa972 --- /dev/null +++ b/storage/sql/sql.go @@ -0,0 +1,152 @@ +// Package sql provides SQL implementations of the storage interface. +package sql + +import ( + "database/sql" + "regexp" + + "github.com/cockroachdb/cockroach-go/crdb" + + // import third party drivers + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" +) + +// flavor represents a specific SQL implementation, and is used to translate query strings +// between different drivers. Flavors shouldn't aim to translate all possible SQL statements, +// only the specific queries used by the SQL storages. +type flavor struct { + queryReplacers []replacer + + // Optional function to create and finish a transaction. This is mainly for + // cockroachdb support which requires special retry logic provided by their + // client package. + // + // This will be nil for most flavors. + // + // See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44 + executeTx func(db *sql.DB, fn func(*sql.Tx) error) error +} + +// A regexp with a replacement string. +type replacer struct { + re *regexp.Regexp + with string +} + +// Match a postgres query binds. E.g. "$1", "$12", etc. +var bindRegexp = regexp.MustCompile(`\$\d+`) + +func matchLiteral(s string) *regexp.Regexp { + return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`) +} + +var ( + // The "github.com/lib/pq" driver is the default flavor. All others are + // translations of this. + flavorPostgres = flavor{} + + flavorSQLite3 = flavor{ + queryReplacers: []replacer{ + {bindRegexp, "?"}, + // Translate for booleans to integers. + {matchLiteral("true"), "1"}, + {matchLiteral("false"), "0"}, + {matchLiteral("boolean"), "integer"}, + // Translate other types. + {matchLiteral("bytea"), "blob"}, + // {matchLiteral("timestamp"), "integer"}, + // SQLite doesn't have a "now()" method, replace with "date('now')" + {regexp.MustCompile(`\bnow\(\)`), "date('now')"}, + }, + } + + // Incomplete. + flavorMySQL = flavor{ + queryReplacers: []replacer{ + {bindRegexp, "?"}, + }, + } + + // Not tested. + flavorCockroach = flavor{ + executeTx: crdb.ExecuteTx, + } +) + +func (f flavor) translate(query string) string { + // TODO(ericchiang): Heavy cashing. + for _, r := range f.queryReplacers { + query = r.re.ReplaceAllString(query, r.with) + } + return query +} + +// conn is the main database connection. +type conn struct { + db *sql.DB + flavor flavor +} + +func (c *conn) Close() error { + return c.db.Close() +} + +// conn implements the same method signatures as encoding/sql.DB. + +func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) { + query = c.flavor.translate(query) + return c.db.Exec(query, args...) +} + +func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) { + query = c.flavor.translate(query) + return c.db.Query(query, args...) +} + +func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row { + query = c.flavor.translate(query) + return c.db.QueryRow(query, args...) +} + +// ExecTx runs a method which operates on a transaction. +func (c *conn) ExecTx(fn func(tx *trans) error) error { + if c.flavor.executeTx != nil { + return c.flavor.executeTx(c.db, func(sqlTx *sql.Tx) error { + return fn(&trans{sqlTx, c}) + }) + } + + sqlTx, err := c.db.Begin() + if err != nil { + return err + } + if err := fn(&trans{sqlTx, c}); err != nil { + sqlTx.Rollback() + return err + } + return sqlTx.Commit() +} + +type trans struct { + tx *sql.Tx + c *conn +} + +// trans implements the same method signatures as encoding/sql.Tx. + +func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) { + query = t.c.flavor.translate(query) + return t.tx.Exec(query, args...) +} + +func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) { + query = t.c.flavor.translate(query) + return t.tx.Query(query, args...) +} + +func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row { + query = t.c.flavor.translate(query) + return t.tx.QueryRow(query, args...) +} diff --git a/storage/sql/sql_test.go b/storage/sql/sql_test.go new file mode 100644 index 00000000..402d9586 --- /dev/null +++ b/storage/sql/sql_test.go @@ -0,0 +1,55 @@ +package sql + +import "testing" + +func TestTranslate(t *testing.T) { + tests := []struct { + testCase string + flavor flavor + query string + exp string + }{ + { + "sqlite3 query bind replacement", + flavorSQLite3, + `select foo from bar where foo.zam = $1;`, + `select foo from bar where foo.zam = ?;`, + }, + { + "sqlite3 query bind replacement at newline", + flavorSQLite3, + `select foo from bar where foo.zam = $1`, + `select foo from bar where foo.zam = ?`, + }, + { + "sqlite3 query true", + flavorSQLite3, + `select foo from bar where foo.zam = true`, + `select foo from bar where foo.zam = 1`, + }, + { + "sqlite3 query false", + flavorSQLite3, + `select foo from bar where foo.zam = false`, + `select foo from bar where foo.zam = 0`, + }, + { + "sqlite3 bytea", + flavorSQLite3, + `"connector_data" bytea not null,`, + `"connector_data" blob not null,`, + }, + { + "sqlite3 now", + flavorSQLite3, + `now(),`, + `date('now'),`, + }, + } + + for _, tc := range tests { + if got := tc.flavor.translate(tc.query); got != tc.exp { + t.Errorf("%s: want=%q, got=%q", tc.testCase, tc.exp, got) + } + } +}