From ca114f7812dbc4f3de157161c722180e467b7b46 Mon Sep 17 00:00:00 2001 From: Daniel Dao Date: Tue, 3 Oct 2017 14:33:58 +0100 Subject: [PATCH] storage: add etcd storage This patch adds etcd storage implementation. This should be useful in environments where - we dont want to depends on a separate, hard to maintain SQL cluster - we dont want to incur the overhead of talking to kubernetes apiservers - kubernetes is not available yet, or if kubernetes depends on dex to perform authentication and the operator would like to remove any circular dependency if possible. --- cmd/dex/config.go | 2 + storage/etcd/config.go | 92 +++++++ storage/etcd/etcd.go | 532 ++++++++++++++++++++++++++++++++++++++ storage/etcd/etcd_test.go | 94 +++++++ storage/etcd/standup.sh | 109 ++++++++ storage/etcd/types.go | 229 ++++++++++++++++ 6 files changed, 1058 insertions(+) create mode 100644 storage/etcd/config.go create mode 100644 storage/etcd/etcd.go create mode 100644 storage/etcd/etcd_test.go create mode 100755 storage/etcd/standup.sh create mode 100644 storage/etcd/types.go diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 59733bc4..a4e1338c 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -11,6 +11,7 @@ import ( "github.com/coreos/dex/server" "github.com/coreos/dex/storage" + "github.com/coreos/dex/storage/etcd" "github.com/coreos/dex/storage/kubernetes" "github.com/coreos/dex/storage/memory" "github.com/coreos/dex/storage/sql" @@ -124,6 +125,7 @@ type StorageConfig interface { } var storages = map[string]func() StorageConfig{ + "etcd": func() StorageConfig { return new(etcd.Etcd) }, "kubernetes": func() StorageConfig { return new(kubernetes.Config) }, "memory": func() StorageConfig { return new(memory.Config) }, "sqlite3": func() StorageConfig { return new(sql.SQLite3) }, diff --git a/storage/etcd/config.go b/storage/etcd/config.go new file mode 100644 index 00000000..0c52227a --- /dev/null +++ b/storage/etcd/config.go @@ -0,0 +1,92 @@ +package etcd + +import ( + "time" + + "github.com/coreos/dex/storage" + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/namespace" + "github.com/coreos/etcd/pkg/transport" + "github.com/sirupsen/logrus" +) + +var ( + defaultDialTimeout = 2 * time.Second +) + +// SSL represents SSL options for etcd databases. +type SSL struct { + ServerName string + CAFile string + KeyFile string + CertFile string +} + +// Etcd options for connecting to etcd databases. +// If you are using a shared etcd cluster for storage, it might be useful to +// configure an etcd namespace either via Namespace field or using `etcd grpc-proxy +// --namespace=` +type Etcd struct { + Endpoints []string + Namespace string + Username string + Password string + SSL SSL +} + +// Open creates a new storage implementation backed by Etcd +func (p *Etcd) Open(logger logrus.FieldLogger) (storage.Storage, error) { + return p.open(logger) +} + +func (p *Etcd) open(logger logrus.FieldLogger) (*conn, error) { + cfg := clientv3.Config{ + Endpoints: p.Endpoints, + DialTimeout: defaultDialTimeout * time.Second, + Username: p.Username, + Password: p.Password, + } + + var cfgtls *transport.TLSInfo + tlsinfo := transport.TLSInfo{} + if p.SSL.CertFile != "" { + tlsinfo.CertFile = p.SSL.CertFile + cfgtls = &tlsinfo + } + + if p.SSL.KeyFile != "" { + tlsinfo.KeyFile = p.SSL.KeyFile + cfgtls = &tlsinfo + } + + if p.SSL.CAFile != "" { + tlsinfo.CAFile = p.SSL.CAFile + cfgtls = &tlsinfo + } + + if p.SSL.ServerName != "" { + tlsinfo.ServerName = p.SSL.ServerName + cfgtls = &tlsinfo + } + + if cfgtls != nil { + clientTLS, err := cfgtls.ClientConfig() + if err != nil { + return nil, err + } + cfg.TLS = clientTLS + } + + db, err := clientv3.New(cfg) + if err != nil { + return nil, err + } + if len(p.Namespace) > 0 { + db.KV = namespace.NewKV(db.KV, p.Namespace) + } + c := &conn{ + db: db, + logger: logger, + } + return c, nil +} diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go new file mode 100644 index 00000000..7ad35ea5 --- /dev/null +++ b/storage/etcd/etcd.go @@ -0,0 +1,532 @@ +package etcd + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/coreos/dex/storage" + "github.com/coreos/etcd/clientv3" + "github.com/sirupsen/logrus" +) + +const ( + clientPrefix = "client/" + authCodePrefix = "auth_code/" + refreshTokenPrefix = "refresh_token/" + authRequestPrefix = "auth_req/" + passwordPrefix = "password/" + offlineSessionPrefix = "offline_session/" + connectorPrefix = "connector/" + keysName = "openid-connect-keys" + + // defaultStorageTimeout will be applied to all storage's operations. + defaultStorageTimeout = 5 * time.Second +) + +type conn struct { + db *clientv3.Client + logger logrus.FieldLogger +} + +func (c *conn) Close() error { + return c.db.Close() +} + +func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + authRequests, err := c.listAuthRequests(ctx) + if err != nil { + return result, err + } + + var delErr error + for _, authRequest := range authRequests { + if now.After(authRequest.Expiry) { + if err := c.deleteKey(ctx, keyID(authRequestPrefix, authRequest.ID)); err != nil { + c.logger.Errorf("failed to delete auth request: %v", err) + delErr = fmt.Errorf("failed to delete auth request: %v", err) + } + result.AuthRequests++ + } + } + if delErr != nil { + return result, delErr + } + + authCodes, err := c.listAuthCodes(ctx) + if err != nil { + return result, err + } + + for _, authCode := range authCodes { + if now.After(authCode.Expiry) { + if err := c.deleteKey(ctx, keyID(authCodePrefix, authCode.ID)); err != nil { + c.logger.Errorf("failed to delete auth code %v", err) + delErr = fmt.Errorf("failed to delete auth code: %v", err) + } + result.AuthCodes++ + } + } + return result, delErr +} + +func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a)) +} + +func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + var req AuthRequest + if err = c.getKey(ctx, keyID(authRequestPrefix, id), &req); err != nil { + return + } + return toStorageAuthRequest(req), nil +} + +func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keyID(authRequestPrefix, id), func(currentValue []byte) ([]byte, error) { + var current AuthRequest + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, ¤t); err != nil { + return nil, err + } + } + updated, err := updater(toStorageAuthRequest(current)) + if err != nil { + return nil, err + } + return json.Marshal(fromStorageAuthRequest(updated)) + }) +} + +func (c *conn) DeleteAuthRequest(id string) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.deleteKey(ctx, keyID(authRequestPrefix, id)) +} + +func (c *conn) CreateAuthCode(a storage.AuthCode) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a)) +} + +func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + err = c.getKey(ctx, keyID(authCodePrefix, id), &a) + return a, err +} + +func (c *conn) DeleteAuthCode(id string) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.deleteKey(ctx, keyID(authCodePrefix, id)) +} + +func (c *conn) CreateRefresh(r storage.RefreshToken) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r)) +} + +func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + var token RefreshToken + if err = c.getKey(ctx, keyID(refreshTokenPrefix, id), &token); err != nil { + return + } + return toStorageRefreshToken(token), nil +} + +func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) { + var current RefreshToken + if len(currentValue) > 0 { + if err := json.Unmarshal([]byte(currentValue), ¤t); err != nil { + return nil, err + } + } + updated, err := updater(toStorageRefreshToken(current)) + if err != nil { + return nil, err + } + return json.Marshal(fromStorageRefreshToken(updated)) + }) +} + +func (c *conn) DeleteRefresh(id string) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.deleteKey(ctx, keyID(refreshTokenPrefix, id)) +} + +func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + res, err := c.db.Get(ctx, refreshTokenPrefix, clientv3.WithPrefix()) + if err != nil { + return tokens, err + } + for _, v := range res.Kvs { + var token RefreshToken + if err = json.Unmarshal(v.Value, &token); err != nil { + return tokens, err + } + tokens = append(tokens, toStorageRefreshToken(token)) + } + return tokens, nil +} + +func (c *conn) CreateClient(cli storage.Client) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli) +} + +func (c *conn) GetClient(id string) (cli storage.Client, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + err = c.getKey(ctx, keyID(clientPrefix, id), &cli) + return cli, err +} + +func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keyID(clientPrefix, id), func(currentValue []byte) ([]byte, error) { + var current storage.Client + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, ¤t); err != nil { + return nil, err + } + } + updated, err := updater(current) + if err != nil { + return nil, err + } + return json.Marshal(updated) + }) +} + +func (c *conn) DeleteClient(id string) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.deleteKey(ctx, keyID(clientPrefix, id)) +} + +func (c *conn) ListClients() (clients []storage.Client, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + res, err := c.db.Get(ctx, clientPrefix, clientv3.WithPrefix()) + if err != nil { + return clients, err + } + for _, v := range res.Kvs { + var cli storage.Client + if err = json.Unmarshal(v.Value, &cli); err != nil { + return clients, err + } + clients = append(clients, cli) + } + return clients, nil +} + +func (c *conn) CreatePassword(p storage.Password) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p) +} + +func (c *conn) GetPassword(email string) (p storage.Password, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + err = c.getKey(ctx, keyEmail(passwordPrefix, email), &p) + return p, err +} + +func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keyEmail(passwordPrefix, email), func(currentValue []byte) ([]byte, error) { + var current storage.Password + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, ¤t); err != nil { + return nil, err + } + } + updated, err := updater(current) + if err != nil { + return nil, err + } + return json.Marshal(updated) + }) +} + +func (c *conn) DeletePassword(email string) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.deleteKey(ctx, keyEmail(passwordPrefix, email)) +} + +func (c *conn) ListPasswords() (passwords []storage.Password, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + res, err := c.db.Get(ctx, passwordPrefix, clientv3.WithPrefix()) + if err != nil { + return passwords, err + } + for _, v := range res.Kvs { + var p storage.Password + if err = json.Unmarshal(v.Value, &p); err != nil { + return passwords, err + } + passwords = append(passwords, p) + } + return passwords, nil +} + +func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, keySession(offlineSessionPrefix, s.UserID, s.ConnID), fromStorageOfflineSessions(s)) +} + +func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keySession(offlineSessionPrefix, userID, connID), func(currentValue []byte) ([]byte, error) { + var current OfflineSessions + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, ¤t); err != nil { + return nil, err + } + } + updated, err := updater(toStorageOfflineSessions(current)) + if err != nil { + return nil, err + } + return json.Marshal(fromStorageOfflineSessions(updated)) + }) +} + +func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.OfflineSessions, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + var os OfflineSessions + if err = c.getKey(ctx, keySession(offlineSessionPrefix, userID, connID), &os); err != nil { + return + } + return toStorageOfflineSessions(os), nil +} + +func (c *conn) DeleteOfflineSessions(userID string, connID string) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.deleteKey(ctx, keySession(offlineSessionPrefix, userID, connID)) +} + +func (c *conn) CreateConnector(connector storage.Connector) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector) +} + +func (c *conn) GetConnector(id string) (conn storage.Connector, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + err = c.getKey(ctx, keyID(connectorPrefix, id), &conn) + return conn, err +} + +func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keyID(connectorPrefix, id), func(currentValue []byte) ([]byte, error) { + var current storage.Connector + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, ¤t); err != nil { + return nil, err + } + } + updated, err := updater(current) + if err != nil { + return nil, err + } + return json.Marshal(updated) + }) +} + +func (c *conn) DeleteConnector(id string) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.deleteKey(ctx, keyID(connectorPrefix, id)) +} + +func (c *conn) ListConnectors() (connectors []storage.Connector, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + res, err := c.db.Get(ctx, connectorPrefix, clientv3.WithPrefix()) + if err != nil { + return nil, err + } + for _, v := range res.Kvs { + var c storage.Connector + if err = json.Unmarshal(v.Value, &c); err != nil { + return nil, err + } + connectors = append(connectors, c) + } + return connectors, nil +} + +func (c *conn) GetKeys() (keys storage.Keys, err error) { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + res, err := c.db.Get(ctx, keysName) + if err != nil { + return keys, err + } + if res.Count > 0 && len(res.Kvs) > 0 { + err = json.Unmarshal(res.Kvs[0].Value, &keys) + } + return keys, err +} + +func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { + ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keysName, func(currentValue []byte) ([]byte, error) { + var current storage.Keys + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, ¤t); err != nil { + return nil, err + } + } + updated, err := updater(current) + if err != nil { + return nil, err + } + return json.Marshal(updated) + }) +} + +func (c *conn) deleteKey(ctx context.Context, key string) error { + res, err := c.db.Delete(ctx, key) + if err != nil { + return err + } + if res.Deleted == 0 { + return storage.ErrNotFound + } + return nil +} + +func (c *conn) getKey(ctx context.Context, key string, value interface{}) error { + r, err := c.db.Get(ctx, key) + if err != nil { + return err + } + if r.Count == 0 { + return storage.ErrNotFound + } + return json.Unmarshal(r.Kvs[0].Value, value) +} + +func (c *conn) listAuthRequests(ctx context.Context) (reqs []AuthRequest, err error) { + res, err := c.db.Get(ctx, authRequestPrefix, clientv3.WithPrefix()) + if err != nil { + return reqs, err + } + for _, v := range res.Kvs { + var r AuthRequest + if err = json.Unmarshal(v.Value, &r); err != nil { + return reqs, err + } + reqs = append(reqs, r) + } + return reqs, nil +} + +func (c *conn) listAuthCodes(ctx context.Context) (codes []AuthCode, err error) { + res, err := c.db.Get(ctx, authCodePrefix, clientv3.WithPrefix()) + if err != nil { + return codes, err + } + for _, v := range res.Kvs { + var c AuthCode + if err = json.Unmarshal(v.Value, &c); err != nil { + return codes, err + } + codes = append(codes, c) + } + return codes, nil +} + +func (c *conn) txnCreate(ctx context.Context, key string, value interface{}) error { + b, err := json.Marshal(value) + if err != nil { + return err + } + txn := c.db.Txn(ctx) + res, err := txn. + If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)). + Then(clientv3.OpPut(key, string(b))). + Commit() + if err != nil { + return err + } + if !res.Succeeded { + return storage.ErrAlreadyExists + } + return nil +} + +func (c *conn) txnUpdate(ctx context.Context, key string, update func(current []byte) ([]byte, error)) error { + getResp, err := c.db.Get(ctx, key) + if err != nil { + return err + } + var currentValue []byte + var modRev int64 + if len(getResp.Kvs) > 0 { + currentValue = getResp.Kvs[0].Value + modRev = getResp.Kvs[0].ModRevision + } + + updatedValue, err := update(currentValue) + if err != nil { + return err + } + + txn := c.db.Txn(ctx) + updateResp, err := txn. + If(clientv3.Compare(clientv3.ModRevision(key), "=", modRev)). + Then(clientv3.OpPut(key, string(updatedValue))). + Commit() + if err != nil { + return err + } + if !updateResp.Succeeded { + return fmt.Errorf("failed to update key=%q: concurrent conflicting update happened", key) + } + return nil +} + +func keyID(prefix, id string) string { return prefix + id } +func keyEmail(prefix, email string) string { return prefix + strings.ToLower(email) } +func keySession(prefix, userID, connID string) string { + return prefix + strings.ToLower(userID+"|"+connID) +} diff --git a/storage/etcd/etcd_test.go b/storage/etcd/etcd_test.go new file mode 100644 index 00000000..7029e4ca --- /dev/null +++ b/storage/etcd/etcd_test.go @@ -0,0 +1,94 @@ +package etcd + +import ( + "context" + "fmt" + "os" + "runtime" + "strings" + "testing" + "time" + + "github.com/coreos/dex/storage" + "github.com/coreos/dex/storage/conformance" + "github.com/coreos/etcd/clientv3" + "github.com/sirupsen/logrus" +) + +func withTimeout(t time.Duration, f func()) { + c := make(chan struct{}) + defer close(c) + + go func() { + select { + case <-c: + case <-time.After(t): + // Dump a stack trace of the program. Useful for debugging deadlocks. + buf := make([]byte, 2<<20) + fmt.Fprintf(os.Stderr, "%s\n", buf[:runtime.Stack(buf, true)]) + panic("test took too long") + } + }() + + f() +} + +func cleanDB(c *conn) error { + ctx := context.TODO() + for _, prefix := range []string{ + clientPrefix, + authCodePrefix, + refreshTokenPrefix, + authRequestPrefix, + passwordPrefix, + offlineSessionPrefix, + connectorPrefix, + } { + _, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix()) + if err != nil { + return err + } + } + return nil +} + +var logger = &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, +} + +func TestEtcd(t *testing.T) { + testEtcdEnv := "DEX_ETCD_ENDPOINTS" + endpointsStr := os.Getenv(testEtcdEnv) + if endpointsStr == "" { + t.Skipf("test environment variable %q not set, skipping", testEtcdEnv) + return + } + endpoints := strings.Split(endpointsStr, ",") + + newStorage := func() storage.Storage { + s := &Etcd{ + Endpoints: endpoints, + } + conn, err := s.open(logger) + if err != nil { + fmt.Fprintln(os.Stdout, err) + t.Fatal(err) + } + + if err := cleanDB(conn); err != nil { + fmt.Fprintln(os.Stdout, err) + t.Fatal(err) + } + return conn + } + + withTimeout(time.Second*10, func() { + conformance.RunTests(t, newStorage) + }) + + withTimeout(time.Minute*1, func() { + conformance.RunTransactionTests(t, newStorage) + }) +} diff --git a/storage/etcd/standup.sh b/storage/etcd/standup.sh new file mode 100755 index 00000000..1944b111 --- /dev/null +++ b/storage/etcd/standup.sh @@ -0,0 +1,109 @@ +#!/bin/bash + +if [ "$EUID" -ne 0 ] + then echo "Please run as root" + exit +fi + +function usage { + cat << EOF >> /dev/stderr +Usage: sudo ./standup.sh [create|destroy] [etcd] + +This is a script for standing up test databases. It uses systemd to daemonize +rkt containers running on a local loopback IP. + +The general workflow is to create a daemonized container, use the output to set +the test environment variables, run the tests, then destroy the container. + + sudo ./standup.sh create etcd + # Copy environment variables and run tests. + go test -v -i # always install test dependencies + go test -v + sudo ./standup.sh destroy etcd + +EOF + exit 2 +} + +function main { + if [ "$#" -ne 2 ]; then + usage + exit 2 + fi + + case "$1" in + "create") + case "$2" in + "etcd") + create_etcd;; + *) + usage + exit 2 + ;; + esac + ;; + "destroy") + case "$2" in + "etcd") + destroy_etcd;; + *) + usage + exit 2 + ;; + esac + ;; + *) + usage + exit 2 + ;; + esac +} + +function wait_for_file { + while [ ! -f $1 ]; do + sleep 1 + done +} + +function wait_for_container { + while [ -z "$( rkt list --full | grep $1 | grep running )" ]; do + sleep 1 + done +} + +function create_etcd { + UUID_FILE=/tmp/dex-etcd-uuid + if [ -f $UUID_FILE ]; then + echo "etcd database already exists, try ./standup.sh destroy etcd" + exit 2 + fi + + echo "Starting etcd . To view progress run:" + echo "" + echo " journalctl -fu dex-etcd" + echo "" + UNIFIED_CGROUP_HIERARCHY=no \ + systemd-run --unit=dex-etcd \ + rkt run --uuid-file-save=$UUID_FILE --insecure-options=image \ + --net=host \ + docker://quay.io/coreos/etcd:v3.2.9 + + wait_for_file $UUID_FILE + + UUID=$( cat $UUID_FILE ) + wait_for_container $UUID + echo "To run tests export the following environment variables:" + echo "" + echo " export DEX_ETCD_ENDPOINTS=http://localhost:2379" + echo "" +} + +function destroy_etcd { + UUID_FILE=/tmp/dex-etcd-uuid + systemctl stop dex-etcd + rkt rm --uuid-file=$UUID_FILE + rm $UUID_FILE +} + + +main $@ diff --git a/storage/etcd/types.go b/storage/etcd/types.go new file mode 100644 index 00000000..acca7f37 --- /dev/null +++ b/storage/etcd/types.go @@ -0,0 +1,229 @@ +package etcd + +import ( + "time" + + "github.com/coreos/dex/storage" + jose "gopkg.in/square/go-jose.v2" +) + +// AuthCode is a mirrored struct from storage with JSON struct tags +type AuthCode struct { + ID string `json:"ID"` + ClientID string `json:"clientID"` + RedirectURI string `json:"redirectURI"` + Nonce string `json:"nonce,omitempty"` + Scopes []string `json:"scopes,omitempty"` + + ConnectorID string `json:"connectorID,omitempty"` + ConnectorData []byte `json:"connectorData,omitempty"` + Claims Claims `json:"claims,omitempty"` + + Expiry time.Time `json:"expiry"` +} + +func fromStorageAuthCode(a storage.AuthCode) AuthCode { + return AuthCode{ + ID: a.ID, + ClientID: a.ClientID, + RedirectURI: a.RedirectURI, + ConnectorID: a.ConnectorID, + ConnectorData: a.ConnectorData, + Nonce: a.Nonce, + Scopes: a.Scopes, + Claims: fromStorageClaims(a.Claims), + Expiry: a.Expiry, + } +} + +// AuthRequest is a mirrored struct from storage with JSON struct tags +type AuthRequest struct { + ID string `json:"id"` + ClientID string `json:"client_id"` + + ResponseTypes []string `json:"response_types"` + Scopes []string `json:"scopes"` + RedirectURI string `json:"redirect_uri"` + Nonce string `json:"nonce"` + State string `json:"state"` + + ForceApprovalPrompt bool `json:"force_approval_prompt"` + + Expiry time.Time `json:"expiry"` + + LoggedIn bool `json:"logged_in"` + + Claims Claims `json:"claims"` + + ConnectorID string `json:"connector_id"` + ConnectorData []byte `json:"connector_data"` +} + +func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { + return AuthRequest{ + ID: a.ID, + ClientID: a.ClientID, + ResponseTypes: a.ResponseTypes, + Scopes: a.Scopes, + RedirectURI: a.RedirectURI, + Nonce: a.Nonce, + State: a.State, + ForceApprovalPrompt: a.ForceApprovalPrompt, + Expiry: a.Expiry, + LoggedIn: a.LoggedIn, + Claims: fromStorageClaims(a.Claims), + ConnectorID: a.ConnectorID, + ConnectorData: a.ConnectorData, + } +} + +func toStorageAuthRequest(a AuthRequest) storage.AuthRequest { + return storage.AuthRequest{ + ID: a.ID, + ClientID: a.ClientID, + ResponseTypes: a.ResponseTypes, + Scopes: a.Scopes, + RedirectURI: a.RedirectURI, + Nonce: a.Nonce, + State: a.State, + ForceApprovalPrompt: a.ForceApprovalPrompt, + LoggedIn: a.LoggedIn, + ConnectorID: a.ConnectorID, + ConnectorData: a.ConnectorData, + Expiry: a.Expiry, + Claims: toStorageClaims(a.Claims), + } +} + +// RefreshToken is a mirrored struct from storage with JSON struct tags +type RefreshToken struct { + ID string `json:"id"` + + Token string `json:"token"` + + CreatedAt time.Time `json:"created_at"` + LastUsed time.Time `json:"last_used"` + + ClientID string `json:"client_id"` + + ConnectorID string `json:"connector_id"` + ConnectorData []byte `json:"connector_data"` + Claims Claims `json:"claims"` + + Scopes []string `json:"scopes"` + + Nonce string `json:"nonce"` +} + +func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { + return storage.RefreshToken{ + ID: r.ID, + Token: r.Token, + CreatedAt: r.CreatedAt, + LastUsed: r.LastUsed, + ClientID: r.ClientID, + ConnectorID: r.ConnectorID, + ConnectorData: r.ConnectorData, + Scopes: r.Scopes, + Nonce: r.Nonce, + Claims: toStorageClaims(r.Claims), + } +} + +func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken { + return RefreshToken{ + ID: r.ID, + Token: r.Token, + CreatedAt: r.CreatedAt, + LastUsed: r.LastUsed, + ClientID: r.ClientID, + ConnectorID: r.ConnectorID, + ConnectorData: r.ConnectorData, + Scopes: r.Scopes, + Nonce: r.Nonce, + Claims: fromStorageClaims(r.Claims), + } +} + +// Claims is a mirrored struct from storage with JSON struct tags. +type Claims struct { + UserID string `json:"userID"` + Username string `json:"username"` + Email string `json:"email"` + EmailVerified bool `json:"emailVerified"` + Groups []string `json:"groups,omitempty"` +} + +func fromStorageClaims(i storage.Claims) Claims { + return Claims{ + UserID: i.UserID, + Username: i.Username, + Email: i.Email, + EmailVerified: i.EmailVerified, + Groups: i.Groups, + } +} + +func toStorageClaims(i Claims) storage.Claims { + return storage.Claims{ + UserID: i.UserID, + Username: i.Username, + Email: i.Email, + EmailVerified: i.EmailVerified, + Groups: i.Groups, + } +} + +// Keys is a mirrored struct from storage with JSON struct tags +type Keys struct { + SigningKey *jose.JSONWebKey `json:"signing_key,omitempty"` + SigningKeyPub *jose.JSONWebKey `json:"signing_key_pub,omitempty"` + VerificationKeys []storage.VerificationKey `json:"verification_keys"` + NextRotation time.Time `json:"next_rotation"` +} + +func fromStorageKeys(keys storage.Keys) Keys { + return Keys{ + SigningKey: keys.SigningKey, + SigningKeyPub: keys.SigningKeyPub, + VerificationKeys: keys.VerificationKeys, + NextRotation: keys.NextRotation, + } +} + +func toStorageKeys(keys Keys) storage.Keys { + return storage.Keys{ + SigningKey: keys.SigningKey, + SigningKeyPub: keys.SigningKeyPub, + VerificationKeys: keys.VerificationKeys, + NextRotation: keys.NextRotation, + } +} + +// OfflineSessions is a mirrored struct from storage with JSON struct tags +type OfflineSessions struct { + UserID string `json:"user_id,omitempty"` + ConnID string `json:"conn_id,omitempty"` + Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` +} + +func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { + return OfflineSessions{ + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + } +} + +func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { + s := storage.OfflineSessions{ + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + } + if s.Refresh == nil { + // Server code assumes this will be non-nil. + s.Refresh = make(map[string]*storage.RefreshTokenRef) + } + return s +}