storage/sql: add a SQL storage implementation
This change adds support for SQLite3, and Postgres.
This commit is contained in:
parent
82a55cf785
commit
87a7d093b2
113
storage/sql/config.go
Normal file
113
storage/sql/config.go
Normal file
@ -0,0 +1,113 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
||||
// SQLite3 options for creating an SQL db.
|
||||
type SQLite3 struct {
|
||||
// File to
|
||||
File string `yaml:"file"`
|
||||
}
|
||||
|
||||
// Open creates a new storage implementation backed by SQLite3
|
||||
func (s *SQLite3) Open() (storage.Storage, error) {
|
||||
return s.open()
|
||||
}
|
||||
|
||||
func (s *SQLite3) open() (*conn, error) {
|
||||
db, err := sql.Open("sqlite3", s.File)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.File == ":memory:" {
|
||||
// sqlite3 uses file locks to coordinate concurrent access. In memory
|
||||
// doesn't support this, so limit the number of connections to 1.
|
||||
db.SetMaxOpenConns(1)
|
||||
}
|
||||
c := &conn{db, flavorSQLite3}
|
||||
if _, err := c.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
const (
|
||||
sslDisable = "disable"
|
||||
sslRequire = "require"
|
||||
sslVerifyCA = "verify-ca"
|
||||
sslVerifyFull = "verify-full"
|
||||
)
|
||||
|
||||
// PostgresSSL represents SSL options for Postgres databases.
|
||||
type PostgresSSL struct {
|
||||
Mode string
|
||||
CAFile string
|
||||
// Files for client auth.
|
||||
KeyFile string
|
||||
CertFile string
|
||||
}
|
||||
|
||||
// Postgres options for creating an SQL db.
|
||||
type Postgres struct {
|
||||
Database string
|
||||
User string
|
||||
Password string
|
||||
Host string
|
||||
|
||||
SSL PostgresSSL `json:"ssl" yaml:"ssl"`
|
||||
|
||||
ConnectionTimeout int // Seconds
|
||||
}
|
||||
|
||||
// Open creates a new storage implementation backed by Postgres.
|
||||
func (p *Postgres) Open() (storage.Storage, error) {
|
||||
return p.open()
|
||||
}
|
||||
|
||||
func (p *Postgres) open() (*conn, error) {
|
||||
v := url.Values{}
|
||||
set := func(key, val string) {
|
||||
if val != "" {
|
||||
v.Set(key, val)
|
||||
}
|
||||
}
|
||||
set("connect_timeout", strconv.Itoa(p.ConnectionTimeout))
|
||||
set("sslkey", p.SSL.KeyFile)
|
||||
set("sslcert", p.SSL.CertFile)
|
||||
set("sslrootcert", p.SSL.CAFile)
|
||||
if p.SSL.Mode == "" {
|
||||
// Assume the strictest mode if unspecified.
|
||||
p.SSL.Mode = sslVerifyFull
|
||||
}
|
||||
set("sslmode", p.SSL.Mode)
|
||||
|
||||
u := url.URL{
|
||||
Scheme: "postgres",
|
||||
Host: p.Host,
|
||||
Path: "/" + p.Database,
|
||||
RawQuery: v.Encode(),
|
||||
}
|
||||
|
||||
if p.User != "" {
|
||||
if p.Password != "" {
|
||||
u.User = url.UserPassword(p.User, p.Password)
|
||||
} else {
|
||||
u.User = url.User(p.User)
|
||||
}
|
||||
}
|
||||
db, err := sql.Open("postgres", u.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &conn{db, flavorPostgres}
|
||||
if _, err := c.migrate(); err != nil {
|
||||
return nil, fmt.Errorf("failed to perform migrations: %v", err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
1
storage/sql/config_test.go
Normal file
1
storage/sql/config_test.go
Normal file
@ -0,0 +1 @@
|
||||
package sql
|
487
storage/sql/crud.go
Normal file
487
storage/sql/crud.go
Normal file
@ -0,0 +1,487 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
||||
// TODO(ericchiang): The update, insert, and select methods queries are all
|
||||
// very repetivite. Consider creating them programatically.
|
||||
|
||||
// keysRowID is the ID of the only row we expect to populate the "keys" table.
|
||||
const keysRowID = "keys"
|
||||
|
||||
// encoder wraps the underlying value in a JSON marshaler which is automatically
|
||||
// called by the database/sql package.
|
||||
//
|
||||
// s := []string{"planes", "bears"}
|
||||
// err := db.Exec(`insert into t1 (id, things) values (1, $1)`, encoder(s))
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
//
|
||||
// var r []byte
|
||||
// err = db.QueryRow(`select things from t1 where id = 1;`).Scan(&r)
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
// fmt.Printf("%s\n", r) // ["planes","bears"]
|
||||
//
|
||||
func encoder(i interface{}) driver.Valuer {
|
||||
return jsonEncoder{i}
|
||||
}
|
||||
|
||||
// decoder wraps the underlying value in a JSON unmarshaler which can then be passed
|
||||
// to a database Scan() method.
|
||||
func decoder(i interface{}) sql.Scanner {
|
||||
return jsonDecoder{i}
|
||||
}
|
||||
|
||||
type jsonEncoder struct {
|
||||
i interface{}
|
||||
}
|
||||
|
||||
func (j jsonEncoder) Value() (driver.Value, error) {
|
||||
b, err := json.Marshal(j.i)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal: %v", err)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
type jsonDecoder struct {
|
||||
i interface{}
|
||||
}
|
||||
|
||||
func (j jsonDecoder) Scan(dest interface{}) error {
|
||||
if dest == nil {
|
||||
return errors.New("nil value")
|
||||
}
|
||||
b, ok := dest.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected []byte got %T", dest)
|
||||
}
|
||||
if err := json.Unmarshal(b, &j.i); err != nil {
|
||||
return fmt.Errorf("unmarshal: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Abstract conn vs trans.
|
||||
type querier interface {
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// Abstract row vs rows.
|
||||
type scanner interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}
|
||||
|
||||
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
||||
_, err := c.Exec(`
|
||||
insert into auth_request (
|
||||
id, client_id, response_types, scopes, redirect_uri, nonce, state,
|
||||
force_approval_prompt, logged_in,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
connector_id, connector_data,
|
||||
expiry
|
||||
)
|
||||
values (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
|
||||
);
|
||||
`,
|
||||
a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
|
||||
a.ForceApprovalPrompt, a.LoggedIn,
|
||||
a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified,
|
||||
encoder(a.Claims.Groups),
|
||||
a.ConnectorID, a.ConnectorData,
|
||||
a.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert auth request: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
r, err := getAuthRequest(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a, err := updater(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`
|
||||
update auth_request
|
||||
set
|
||||
client_id = $1, response_types = $2, scopes = $3, redirect_uri = $4,
|
||||
nonce = $5, state = $6, force_approval_prompt = $7, logged_in = $8,
|
||||
claims_user_id = $9, claims_username = $10, claims_email = $11,
|
||||
claims_email_verified = $12,
|
||||
claims_groups = $13,
|
||||
connector_id = $14, connector_data = $15,
|
||||
expiry = $16
|
||||
where id = $17;
|
||||
`,
|
||||
a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
|
||||
a.ForceApprovalPrompt, a.LoggedIn,
|
||||
a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified,
|
||||
encoder(a.Claims.Groups),
|
||||
a.ConnectorID, a.ConnectorData,
|
||||
a.Expiry, a.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update auth request: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||
return getAuthRequest(c, id)
|
||||
}
|
||||
|
||||
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
|
||||
err = q.QueryRow(`
|
||||
select
|
||||
id, client_id, response_types, scopes, redirect_uri, nonce, state,
|
||||
force_approval_prompt, logged_in,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
connector_id, connector_data, expiry
|
||||
from auth_request where id = $1;
|
||||
`, id).Scan(
|
||||
&a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State,
|
||||
&a.ForceApprovalPrompt, &a.LoggedIn,
|
||||
&a.Claims.UserID, &a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified,
|
||||
decoder(&a.Claims.Groups),
|
||||
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return a, storage.ErrNotFound
|
||||
}
|
||||
return a, fmt.Errorf("select auth request: %v", err)
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
|
||||
_, err := c.Exec(`
|
||||
insert into auth_code (
|
||||
id, client_id, scopes, nonce, redirect_uri,
|
||||
claims_user_id, claims_username,
|
||||
claims_email, claims_email_verified, claims_groups,
|
||||
connector_id, connector_data,
|
||||
expiry
|
||||
)
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13);
|
||||
`,
|
||||
a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID,
|
||||
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
|
||||
a.ConnectorID, a.ConnectorData, a.Expiry,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
|
||||
err = c.QueryRow(`
|
||||
select
|
||||
id, client_id, scopes, nonce, redirect_uri,
|
||||
claims_user_id, claims_username,
|
||||
claims_email, claims_email_verified, claims_groups,
|
||||
connector_id, connector_data,
|
||||
expiry
|
||||
from auth_code where id = $1;
|
||||
`, id).Scan(
|
||||
&a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID,
|
||||
&a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups),
|
||||
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return a, storage.ErrNotFound
|
||||
}
|
||||
return a, fmt.Errorf("select auth code: %v", err)
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||
_, err := c.Exec(`
|
||||
insert into refresh_token (
|
||||
id, client_id, scopes, nonce,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
connector_id, connector_data
|
||||
)
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11);
|
||||
`,
|
||||
r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
|
||||
encoder(r.Claims.Groups),
|
||||
r.ConnectorID, r.ConnectorData,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert refresh_token: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||
return scanRefresh(c.QueryRow(`
|
||||
select
|
||||
id, client_id, scopes, nonce,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
connector_id, connector_data
|
||||
from refresh_token where id = $1;
|
||||
`, id))
|
||||
}
|
||||
|
||||
func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||
rows, err := c.Query(`
|
||||
select
|
||||
id, client_id, scopes, nonce,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
connector_id, connector_data
|
||||
from refresh_token;
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query: %v", err)
|
||||
}
|
||||
var tokens []storage.RefreshToken
|
||||
for rows.Next() {
|
||||
r, err := scanRefresh(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokens = append(tokens, r)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("scan: %v", err)
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
|
||||
err = s.Scan(
|
||||
&r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
|
||||
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
|
||||
decoder(&r.Claims.Groups),
|
||||
&r.ConnectorID, &r.ConnectorData,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return r, storage.ErrNotFound
|
||||
}
|
||||
return r, fmt.Errorf("scan refresh_token: %v", err)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
firstUpdate := false
|
||||
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
|
||||
// server. Test this, and consider adding a COUNT() command beforehand.
|
||||
old, err := getKeys(tx)
|
||||
if err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
return fmt.Errorf("get keys: %v", err)
|
||||
}
|
||||
firstUpdate = true
|
||||
old = storage.Keys{}
|
||||
}
|
||||
|
||||
nk, err := updater(old)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if firstUpdate {
|
||||
_, err = tx.Exec(`
|
||||
insert into keys (
|
||||
id, verification_keys, signing_key, signing_key_pub, next_rotation
|
||||
)
|
||||
values ($1, $2, $3, $4, $5);
|
||||
`,
|
||||
keysRowID, encoder(nk.VerificationKeys), encoder(nk.SigningKey),
|
||||
encoder(nk.SigningKeyPub), nk.NextRotation,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert: %v", err)
|
||||
}
|
||||
} else {
|
||||
_, err = tx.Exec(`
|
||||
update keys
|
||||
set
|
||||
verification_keys = $1,
|
||||
signing_key = $2,
|
||||
singing_key_pub = $3,
|
||||
next_rotation = $4
|
||||
where id = $5;
|
||||
`,
|
||||
encoder(nk.VerificationKeys), encoder(nk.SigningKey),
|
||||
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetKeys() (keys storage.Keys, err error) {
|
||||
return getKeys(c)
|
||||
}
|
||||
|
||||
func getKeys(q querier) (keys storage.Keys, err error) {
|
||||
err = q.QueryRow(`
|
||||
select
|
||||
verification_keys, signing_key, signing_key_pub, next_rotation
|
||||
from keys
|
||||
where id=$q
|
||||
`, keysRowID).Scan(
|
||||
decoder(&keys.VerificationKeys), decoder(&keys.SigningKey),
|
||||
decoder(&keys.SigningKeyPub), &keys.NextRotation,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return keys, storage.ErrNotFound
|
||||
}
|
||||
return keys, fmt.Errorf("query keys: %v", err)
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
cli, err := getClient(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nc, err := updater(cli)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update client
|
||||
set
|
||||
secret = $1,
|
||||
redirect_uris = $2,
|
||||
trusted_peers = $3,
|
||||
public = $4,
|
||||
name = $5,
|
||||
logo_url = $6
|
||||
where id = $7;
|
||||
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update client: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) CreateClient(cli storage.Client) error {
|
||||
_, err := c.Exec(`
|
||||
insert into client (
|
||||
id, secret, redirect_uris, trusted_peers, public, name, logo_url
|
||||
)
|
||||
values ($1, $2, $3, $4, $5, $6, $7);
|
||||
`,
|
||||
cli.ID, cli.Secret, encoder(cli.RedirectURIs), encoder(cli.TrustedPeers),
|
||||
cli.Public, cli.Name, cli.LogoURL,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert client: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getClient(q querier, id string) (storage.Client, error) {
|
||||
return scanClient(q.QueryRow(`
|
||||
select
|
||||
id, secret, redirect_uris, trusted_peers, public, name, logo_url
|
||||
from client where id = $1;
|
||||
`, id))
|
||||
}
|
||||
|
||||
func (c *conn) GetClient(id string) (storage.Client, error) {
|
||||
return getClient(c, id)
|
||||
}
|
||||
|
||||
func (c *conn) ListClients() ([]storage.Client, error) {
|
||||
rows, err := c.Query(`
|
||||
select
|
||||
id, secret, redirect_uris, trusted_peers, public, name, logo_url
|
||||
from client;
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var clients []storage.Client
|
||||
for rows.Next() {
|
||||
cli, err := scanClient(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clients = append(clients, cli)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return clients, nil
|
||||
}
|
||||
|
||||
func scanClient(s scanner) (cli storage.Client, err error) {
|
||||
err = s.Scan(
|
||||
&cli.ID, &cli.Secret, decoder(&cli.RedirectURIs), decoder(&cli.TrustedPeers),
|
||||
&cli.Public, &cli.Name, &cli.LogoURL,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return cli, storage.ErrNotFound
|
||||
}
|
||||
return cli, fmt.Errorf("get client: %v", err)
|
||||
}
|
||||
return cli, nil
|
||||
}
|
||||
|
||||
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", id) }
|
||||
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", id) }
|
||||
func (c *conn) DeleteClient(id string) error { return c.delete("client", id) }
|
||||
func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_token", id) }
|
||||
|
||||
// Do NOT call directly. Does not escape table.
|
||||
func (c *conn) delete(table, id string) error {
|
||||
result, err := c.Exec(`delete from `+table+` where id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete %s: %v", table, id)
|
||||
}
|
||||
|
||||
// For now mandate that the driver implements RowsAffected. If we ever need to support
|
||||
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
|
||||
n, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %v", err)
|
||||
}
|
||||
if n < 1 {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
55
storage/sql/crud_test.go
Normal file
55
storage/sql/crud_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDecoder(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if _, err := db.Exec(`create table foo ( id integer primary key, bar blob );`); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := db.Exec(`insert into foo ( id, bar ) values (1, ?);`, []byte(`["a", "b"]`)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var got []string
|
||||
if err := db.QueryRow(`select bar from foo where id = 1;`).Scan(decoder(&got)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := []string{"a", "b"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("wanted %q got %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoder(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if _, err := db.Exec(`create table foo ( id integer primary key, bar blob );`); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
put := []string{"a", "b"}
|
||||
if _, err := db.Exec(`insert into foo ( id, bar ) values (1, ?)`, encoder(put)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var got []byte
|
||||
if err := db.QueryRow(`select bar from foo where id = 1;`).Scan(&got); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := []byte(`["a","b"]`)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("wanted %q got %q", want, got)
|
||||
}
|
||||
}
|
24
storage/sql/gc.go
Normal file
24
storage/sql/gc.go
Normal file
@ -0,0 +1,24 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type gc struct {
|
||||
now func() time.Time
|
||||
conn *conn
|
||||
}
|
||||
|
||||
var tablesWithGC = []string{"auth_request", "auth_code"}
|
||||
|
||||
func (gc gc) run() error {
|
||||
for _, table := range tablesWithGC {
|
||||
_, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("gc %s: %v", table, err)
|
||||
}
|
||||
// TODO(ericchiang): when we have levelled logging print how many rows were gc'd
|
||||
}
|
||||
return nil
|
||||
}
|
53
storage/sql/gc_test.go
Normal file
53
storage/sql/gc_test.go
Normal file
@ -0,0 +1,53 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
||||
func TestGC(t *testing.T) {
|
||||
// TODO(ericchiang): Add a GarbageCollect method to the storage interface so
|
||||
// we can write conformance tests instead of directly testing each implementation.
|
||||
s := &SQLite3{":memory:"}
|
||||
conn, err := s.open()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
clock := time.Now()
|
||||
now := func() time.Time { return clock }
|
||||
|
||||
runGC := (gc{now, conn}).run
|
||||
|
||||
a := storage.AuthRequest{
|
||||
ID: storage.NewID(),
|
||||
Expiry: now().Add(time.Second),
|
||||
}
|
||||
|
||||
if err := conn.CreateAuthRequest(a); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := runGC(); err != nil {
|
||||
t.Errorf("gc failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := conn.GetAuthRequest(a.ID); err != nil {
|
||||
t.Errorf("failed to get auth request after gc: %v", err)
|
||||
}
|
||||
|
||||
clock = clock.Add(time.Minute)
|
||||
|
||||
if err := runGC(); err != nil {
|
||||
t.Errorf("gc failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := conn.GetAuthRequest(a.ID); err == nil {
|
||||
t.Errorf("expected error after gc'ing auth request: %v", err)
|
||||
} else if err != storage.ErrNotFound {
|
||||
t.Errorf("expected error storage.NotFound got: %v", err)
|
||||
}
|
||||
}
|
151
storage/sql/migrate.go
Normal file
151
storage/sql/migrate.go
Normal file
@ -0,0 +1,151 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func (c *conn) migrate() (int, error) {
|
||||
_, err := c.Exec(`
|
||||
create table if not exists migrations (
|
||||
num integer not null,
|
||||
at timestamp not null
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("creating migration table: %v", err)
|
||||
}
|
||||
|
||||
i := 0
|
||||
done := false
|
||||
for {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
// Within a transaction, perform a single migration.
|
||||
var (
|
||||
num sql.NullInt64
|
||||
n int
|
||||
)
|
||||
if err := tx.QueryRow(`select max(num) from migrations;`).Scan(&num); err != nil {
|
||||
return fmt.Errorf("select max migration: %v", err)
|
||||
}
|
||||
if num.Valid {
|
||||
n = int(num.Int64)
|
||||
}
|
||||
if n >= len(migrations) {
|
||||
done = true
|
||||
return nil
|
||||
}
|
||||
|
||||
migrationNum := n + 1
|
||||
m := migrations[n]
|
||||
if _, err := tx.Exec(m.stmt); err != nil {
|
||||
return fmt.Errorf("migration %d failed: %v", migrationNum, err)
|
||||
}
|
||||
|
||||
q := `insert into migrations (num, at) values ($1, now());`
|
||||
if _, err := tx.Exec(q, migrationNum); err != nil {
|
||||
return fmt.Errorf("update migration table: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
if done {
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
type migration struct {
|
||||
stmt string
|
||||
// TODO(ericchiang): consider adding additional fields like "forDrivers"
|
||||
}
|
||||
|
||||
// All SQL flavors share migration strategies.
|
||||
var migrations = []migration{
|
||||
{
|
||||
stmt: `
|
||||
create table client (
|
||||
id text not null primary key,
|
||||
secret text not null,
|
||||
redirect_uris bytea not null, -- JSON array of strings
|
||||
trusted_peers bytea not null, -- JSON array of strings
|
||||
public boolean not null,
|
||||
name text not null,
|
||||
logo_url text not null
|
||||
);
|
||||
|
||||
create table auth_request (
|
||||
id text not null primary key,
|
||||
client_id text not null,
|
||||
response_types bytea not null, -- JSON array of strings
|
||||
scopes bytea not null, -- JSON array of strings
|
||||
redirect_uri text not null,
|
||||
nonce text not null,
|
||||
state text not null,
|
||||
force_approval_prompt boolean not null,
|
||||
|
||||
logged_in boolean not null,
|
||||
|
||||
claims_user_id text not null,
|
||||
claims_username text not null,
|
||||
claims_email text not null,
|
||||
claims_email_verified boolean not null,
|
||||
claims_groups bytea not null, -- JSON array of strings
|
||||
|
||||
connector_id text not null,
|
||||
connector_data bytea,
|
||||
|
||||
expiry timestamp not null
|
||||
);
|
||||
|
||||
create table auth_code (
|
||||
id text not null primary key,
|
||||
client_id text not null,
|
||||
scopes bytea not null, -- JSON array of strings
|
||||
nonce text not null,
|
||||
redirect_uri text not null,
|
||||
|
||||
claims_user_id text not null,
|
||||
claims_username text not null,
|
||||
claims_email text not null,
|
||||
claims_email_verified boolean not null,
|
||||
claims_groups bytea not null, -- JSON array of strings
|
||||
|
||||
connector_id text not null,
|
||||
connector_data bytea,
|
||||
|
||||
expiry timestamp not null
|
||||
);
|
||||
|
||||
create table refresh_token (
|
||||
id text not null primary key,
|
||||
client_id text not null,
|
||||
scopes bytea not null, -- JSON array of strings
|
||||
nonce text not null,
|
||||
|
||||
claims_user_id text not null,
|
||||
claims_username text not null,
|
||||
claims_email text not null,
|
||||
claims_email_verified boolean not null,
|
||||
claims_groups bytea not null, -- JSON array of strings
|
||||
|
||||
connector_id text not null,
|
||||
connector_data bytea
|
||||
);
|
||||
|
||||
-- keys is a weird table because we only ever expect there to be a single row
|
||||
create table keys (
|
||||
id text not null primary key,
|
||||
verification_keys bytea not null, -- JSON array
|
||||
signing_key bytea not null, -- JSON object
|
||||
signing_key_pub bytea not null, -- JSON object
|
||||
next_rotation timestamp not null
|
||||
);
|
||||
`,
|
||||
},
|
||||
}
|
25
storage/sql/migrate_test.go
Normal file
25
storage/sql/migrate_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
c := &conn{db, flavorSQLite3}
|
||||
for _, want := range []int{len(migrations), 0} {
|
||||
got, err := c.migrate()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("expected %d migrations, got %d", want, got)
|
||||
}
|
||||
}
|
||||
}
|
152
storage/sql/sql.go
Normal file
152
storage/sql/sql.go
Normal file
@ -0,0 +1,152 @@
|
||||
// Package sql provides SQL implementations of the storage interface.
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"regexp"
|
||||
|
||||
"github.com/cockroachdb/cockroach-go/crdb"
|
||||
|
||||
// import third party drivers
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// flavor represents a specific SQL implementation, and is used to translate query strings
|
||||
// between different drivers. Flavors shouldn't aim to translate all possible SQL statements,
|
||||
// only the specific queries used by the SQL storages.
|
||||
type flavor struct {
|
||||
queryReplacers []replacer
|
||||
|
||||
// Optional function to create and finish a transaction. This is mainly for
|
||||
// cockroachdb support which requires special retry logic provided by their
|
||||
// client package.
|
||||
//
|
||||
// This will be nil for most flavors.
|
||||
//
|
||||
// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44
|
||||
executeTx func(db *sql.DB, fn func(*sql.Tx) error) error
|
||||
}
|
||||
|
||||
// A regexp with a replacement string.
|
||||
type replacer struct {
|
||||
re *regexp.Regexp
|
||||
with string
|
||||
}
|
||||
|
||||
// Match a postgres query binds. E.g. "$1", "$12", etc.
|
||||
var bindRegexp = regexp.MustCompile(`\$\d+`)
|
||||
|
||||
func matchLiteral(s string) *regexp.Regexp {
|
||||
return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`)
|
||||
}
|
||||
|
||||
var (
|
||||
// The "github.com/lib/pq" driver is the default flavor. All others are
|
||||
// translations of this.
|
||||
flavorPostgres = flavor{}
|
||||
|
||||
flavorSQLite3 = flavor{
|
||||
queryReplacers: []replacer{
|
||||
{bindRegexp, "?"},
|
||||
// Translate for booleans to integers.
|
||||
{matchLiteral("true"), "1"},
|
||||
{matchLiteral("false"), "0"},
|
||||
{matchLiteral("boolean"), "integer"},
|
||||
// Translate other types.
|
||||
{matchLiteral("bytea"), "blob"},
|
||||
// {matchLiteral("timestamp"), "integer"},
|
||||
// SQLite doesn't have a "now()" method, replace with "date('now')"
|
||||
{regexp.MustCompile(`\bnow\(\)`), "date('now')"},
|
||||
},
|
||||
}
|
||||
|
||||
// Incomplete.
|
||||
flavorMySQL = flavor{
|
||||
queryReplacers: []replacer{
|
||||
{bindRegexp, "?"},
|
||||
},
|
||||
}
|
||||
|
||||
// Not tested.
|
||||
flavorCockroach = flavor{
|
||||
executeTx: crdb.ExecuteTx,
|
||||
}
|
||||
)
|
||||
|
||||
func (f flavor) translate(query string) string {
|
||||
// TODO(ericchiang): Heavy cashing.
|
||||
for _, r := range f.queryReplacers {
|
||||
query = r.re.ReplaceAllString(query, r.with)
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
// conn is the main database connection.
|
||||
type conn struct {
|
||||
db *sql.DB
|
||||
flavor flavor
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
// conn implements the same method signatures as encoding/sql.DB.
|
||||
|
||||
func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
query = c.flavor.translate(query)
|
||||
return c.db.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
query = c.flavor.translate(query)
|
||||
return c.db.Query(query, args...)
|
||||
}
|
||||
|
||||
func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||
query = c.flavor.translate(query)
|
||||
return c.db.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
// ExecTx runs a method which operates on a transaction.
|
||||
func (c *conn) ExecTx(fn func(tx *trans) error) error {
|
||||
if c.flavor.executeTx != nil {
|
||||
return c.flavor.executeTx(c.db, func(sqlTx *sql.Tx) error {
|
||||
return fn(&trans{sqlTx, c})
|
||||
})
|
||||
}
|
||||
|
||||
sqlTx, err := c.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fn(&trans{sqlTx, c}); err != nil {
|
||||
sqlTx.Rollback()
|
||||
return err
|
||||
}
|
||||
return sqlTx.Commit()
|
||||
}
|
||||
|
||||
type trans struct {
|
||||
tx *sql.Tx
|
||||
c *conn
|
||||
}
|
||||
|
||||
// trans implements the same method signatures as encoding/sql.Tx.
|
||||
|
||||
func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
query = t.c.flavor.translate(query)
|
||||
return t.tx.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
query = t.c.flavor.translate(query)
|
||||
return t.tx.Query(query, args...)
|
||||
}
|
||||
|
||||
func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||
query = t.c.flavor.translate(query)
|
||||
return t.tx.QueryRow(query, args...)
|
||||
}
|
55
storage/sql/sql_test.go
Normal file
55
storage/sql/sql_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package sql
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTranslate(t *testing.T) {
|
||||
tests := []struct {
|
||||
testCase string
|
||||
flavor flavor
|
||||
query string
|
||||
exp string
|
||||
}{
|
||||
{
|
||||
"sqlite3 query bind replacement",
|
||||
flavorSQLite3,
|
||||
`select foo from bar where foo.zam = $1;`,
|
||||
`select foo from bar where foo.zam = ?;`,
|
||||
},
|
||||
{
|
||||
"sqlite3 query bind replacement at newline",
|
||||
flavorSQLite3,
|
||||
`select foo from bar where foo.zam = $1`,
|
||||
`select foo from bar where foo.zam = ?`,
|
||||
},
|
||||
{
|
||||
"sqlite3 query true",
|
||||
flavorSQLite3,
|
||||
`select foo from bar where foo.zam = true`,
|
||||
`select foo from bar where foo.zam = 1`,
|
||||
},
|
||||
{
|
||||
"sqlite3 query false",
|
||||
flavorSQLite3,
|
||||
`select foo from bar where foo.zam = false`,
|
||||
`select foo from bar where foo.zam = 0`,
|
||||
},
|
||||
{
|
||||
"sqlite3 bytea",
|
||||
flavorSQLite3,
|
||||
`"connector_data" bytea not null,`,
|
||||
`"connector_data" blob not null,`,
|
||||
},
|
||||
{
|
||||
"sqlite3 now",
|
||||
flavorSQLite3,
|
||||
`now(),`,
|
||||
`date('now'),`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
if got := tc.flavor.translate(tc.query); got != tc.exp {
|
||||
t.Errorf("%s: want=%q, got=%q", tc.testCase, tc.exp, got)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user