Merge pull request #1108 from dqminh/etcd-storage
Add etcd backed storage
This commit is contained in:
@@ -628,9 +628,21 @@ func testConnectorCRUD(t *testing.T, s storage.Storage) {
|
||||
c1.Type = "oidc"
|
||||
getAndCompare(id1, c1)
|
||||
|
||||
if _, err := s.ListConnectors(); err != nil {
|
||||
t.Fatalf("failed to list connectors: %v", err)
|
||||
connectorList := []storage.Connector{c1, c2}
|
||||
listAndCompare := func(want []storage.Connector) {
|
||||
connectors, err := s.ListConnectors()
|
||||
if err != nil {
|
||||
t.Errorf("list connectors: %v", err)
|
||||
return
|
||||
}
|
||||
sort.Slice(connectors, func(i, j int) bool {
|
||||
return connectors[i].Name < connectors[j].Name
|
||||
})
|
||||
if diff := pretty.Compare(want, connectors); diff != "" {
|
||||
t.Errorf("password list retrieved from storage did not match: %s", diff)
|
||||
}
|
||||
}
|
||||
listAndCompare(connectorList)
|
||||
|
||||
if err := s.DeleteConnector(c1.ID); err != nil {
|
||||
t.Fatalf("failed to delete connector: %v", err)
|
||||
|
92
storage/etcd/config.go
Normal file
92
storage/etcd/config.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/coreos/etcd/clientv3/namespace"
|
||||
"github.com/coreos/etcd/pkg/transport"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDialTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// SSL represents SSL options for etcd databases.
|
||||
type SSL struct {
|
||||
ServerName string `json:"serverName" yaml:"serverName"`
|
||||
CAFile string `json:"caFile" yaml:"caFile"`
|
||||
KeyFile string `json:"keyFile" yaml:"keyFile"`
|
||||
CertFile string `json:"certFile" yaml:"certFile"`
|
||||
}
|
||||
|
||||
// Etcd options for connecting to etcd databases.
|
||||
// If you are using a shared etcd cluster for storage, it might be useful to
|
||||
// configure an etcd namespace either via Namespace field or using `etcd grpc-proxy
|
||||
// --namespace=<prefix>`
|
||||
type Etcd struct {
|
||||
Endpoints []string `json:"endpoints" yaml:"endpoints"`
|
||||
Namespace string `json:"namespace" yaml:"namespace"`
|
||||
Username string `json:"username" yaml:"username"`
|
||||
Password string `json:"password" yaml:"password"`
|
||||
SSL SSL `json:"ssl" yaml:"ssl"`
|
||||
}
|
||||
|
||||
// Open creates a new storage implementation backed by Etcd
|
||||
func (p *Etcd) Open(logger logrus.FieldLogger) (storage.Storage, error) {
|
||||
return p.open(logger)
|
||||
}
|
||||
|
||||
func (p *Etcd) open(logger logrus.FieldLogger) (*conn, error) {
|
||||
cfg := clientv3.Config{
|
||||
Endpoints: p.Endpoints,
|
||||
DialTimeout: defaultDialTimeout * time.Second,
|
||||
Username: p.Username,
|
||||
Password: p.Password,
|
||||
}
|
||||
|
||||
var cfgtls *transport.TLSInfo
|
||||
tlsinfo := transport.TLSInfo{}
|
||||
if p.SSL.CertFile != "" {
|
||||
tlsinfo.CertFile = p.SSL.CertFile
|
||||
cfgtls = &tlsinfo
|
||||
}
|
||||
|
||||
if p.SSL.KeyFile != "" {
|
||||
tlsinfo.KeyFile = p.SSL.KeyFile
|
||||
cfgtls = &tlsinfo
|
||||
}
|
||||
|
||||
if p.SSL.CAFile != "" {
|
||||
tlsinfo.CAFile = p.SSL.CAFile
|
||||
cfgtls = &tlsinfo
|
||||
}
|
||||
|
||||
if p.SSL.ServerName != "" {
|
||||
tlsinfo.ServerName = p.SSL.ServerName
|
||||
cfgtls = &tlsinfo
|
||||
}
|
||||
|
||||
if cfgtls != nil {
|
||||
clientTLS, err := cfgtls.ClientConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.TLS = clientTLS
|
||||
}
|
||||
|
||||
db, err := clientv3.New(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(p.Namespace) > 0 {
|
||||
db.KV = namespace.NewKV(db.KV, p.Namespace)
|
||||
}
|
||||
c := &conn{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
return c, nil
|
||||
}
|
532
storage/etcd/etcd.go
Normal file
532
storage/etcd/etcd.go
Normal file
@@ -0,0 +1,532 @@
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
clientPrefix = "client/"
|
||||
authCodePrefix = "auth_code/"
|
||||
refreshTokenPrefix = "refresh_token/"
|
||||
authRequestPrefix = "auth_req/"
|
||||
passwordPrefix = "password/"
|
||||
offlineSessionPrefix = "offline_session/"
|
||||
connectorPrefix = "connector/"
|
||||
keysName = "openid-connect-keys"
|
||||
|
||||
// defaultStorageTimeout will be applied to all storage's operations.
|
||||
defaultStorageTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type conn struct {
|
||||
db *clientv3.Client
|
||||
logger logrus.FieldLogger
|
||||
}
|
||||
|
||||
func (c *conn) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
authRequests, err := c.listAuthRequests(ctx)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
var delErr error
|
||||
for _, authRequest := range authRequests {
|
||||
if now.After(authRequest.Expiry) {
|
||||
if err := c.deleteKey(ctx, keyID(authRequestPrefix, authRequest.ID)); err != nil {
|
||||
c.logger.Errorf("failed to delete auth request: %v", err)
|
||||
delErr = fmt.Errorf("failed to delete auth request: %v", err)
|
||||
}
|
||||
result.AuthRequests++
|
||||
}
|
||||
}
|
||||
if delErr != nil {
|
||||
return result, delErr
|
||||
}
|
||||
|
||||
authCodes, err := c.listAuthCodes(ctx)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
|
||||
for _, authCode := range authCodes {
|
||||
if now.After(authCode.Expiry) {
|
||||
if err := c.deleteKey(ctx, keyID(authCodePrefix, authCode.ID)); err != nil {
|
||||
c.logger.Errorf("failed to delete auth code %v", err)
|
||||
delErr = fmt.Errorf("failed to delete auth code: %v", err)
|
||||
}
|
||||
result.AuthCodes++
|
||||
}
|
||||
}
|
||||
return result, delErr
|
||||
}
|
||||
|
||||
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a))
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
var req AuthRequest
|
||||
if err = c.getKey(ctx, keyID(authRequestPrefix, id), &req); err != nil {
|
||||
return
|
||||
}
|
||||
return toStorageAuthRequest(req), nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyID(authRequestPrefix, id), func(currentValue []byte) ([]byte, error) {
|
||||
var current AuthRequest
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(toStorageAuthRequest(current))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(fromStorageAuthRequest(updated))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) DeleteAuthRequest(id string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyID(authRequestPrefix, id))
|
||||
}
|
||||
|
||||
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a))
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyID(authCodePrefix, id), &a)
|
||||
return a, err
|
||||
}
|
||||
|
||||
func (c *conn) DeleteAuthCode(id string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyID(authCodePrefix, id))
|
||||
}
|
||||
|
||||
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r))
|
||||
}
|
||||
|
||||
func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
var token RefreshToken
|
||||
if err = c.getKey(ctx, keyID(refreshTokenPrefix, id), &token); err != nil {
|
||||
return
|
||||
}
|
||||
return toStorageRefreshToken(token), nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) {
|
||||
var current RefreshToken
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal([]byte(currentValue), ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(toStorageRefreshToken(current))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(fromStorageRefreshToken(updated))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) DeleteRefresh(id string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyID(refreshTokenPrefix, id))
|
||||
}
|
||||
|
||||
func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
res, err := c.db.Get(ctx, refreshTokenPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var token RefreshToken
|
||||
if err = json.Unmarshal(v.Value, &token); err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
tokens = append(tokens, toStorageRefreshToken(token))
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateClient(cli storage.Client) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli)
|
||||
}
|
||||
|
||||
func (c *conn) GetClient(id string) (cli storage.Client, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyID(clientPrefix, id), &cli)
|
||||
return cli, err
|
||||
}
|
||||
|
||||
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyID(clientPrefix, id), func(currentValue []byte) ([]byte, error) {
|
||||
var current storage.Client
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(updated)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) DeleteClient(id string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyID(clientPrefix, id))
|
||||
}
|
||||
|
||||
func (c *conn) ListClients() (clients []storage.Client, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
res, err := c.db.Get(ctx, clientPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return clients, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var cli storage.Client
|
||||
if err = json.Unmarshal(v.Value, &cli); err != nil {
|
||||
return clients, err
|
||||
}
|
||||
clients = append(clients, cli)
|
||||
}
|
||||
return clients, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreatePassword(p storage.Password) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p)
|
||||
}
|
||||
|
||||
func (c *conn) GetPassword(email string) (p storage.Password, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyEmail(passwordPrefix, email), &p)
|
||||
return p, err
|
||||
}
|
||||
|
||||
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyEmail(passwordPrefix, email), func(currentValue []byte) ([]byte, error) {
|
||||
var current storage.Password
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(updated)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) DeletePassword(email string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyEmail(passwordPrefix, email))
|
||||
}
|
||||
|
||||
func (c *conn) ListPasswords() (passwords []storage.Password, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
res, err := c.db.Get(ctx, passwordPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return passwords, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var p storage.Password
|
||||
if err = json.Unmarshal(v.Value, &p); err != nil {
|
||||
return passwords, err
|
||||
}
|
||||
passwords = append(passwords, p)
|
||||
}
|
||||
return passwords, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keySession(offlineSessionPrefix, s.UserID, s.ConnID), fromStorageOfflineSessions(s))
|
||||
}
|
||||
|
||||
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keySession(offlineSessionPrefix, userID, connID), func(currentValue []byte) ([]byte, error) {
|
||||
var current OfflineSessions
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(toStorageOfflineSessions(current))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(fromStorageOfflineSessions(updated))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.OfflineSessions, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
var os OfflineSessions
|
||||
if err = c.getKey(ctx, keySession(offlineSessionPrefix, userID, connID), &os); err != nil {
|
||||
return
|
||||
}
|
||||
return toStorageOfflineSessions(os), nil
|
||||
}
|
||||
|
||||
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keySession(offlineSessionPrefix, userID, connID))
|
||||
}
|
||||
|
||||
func (c *conn) CreateConnector(connector storage.Connector) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector)
|
||||
}
|
||||
|
||||
func (c *conn) GetConnector(id string) (conn storage.Connector, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyID(connectorPrefix, id), &conn)
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyID(connectorPrefix, id), func(currentValue []byte) ([]byte, error) {
|
||||
var current storage.Connector
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(updated)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) DeleteConnector(id string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.deleteKey(ctx, keyID(connectorPrefix, id))
|
||||
}
|
||||
|
||||
func (c *conn) ListConnectors() (connectors []storage.Connector, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
res, err := c.db.Get(ctx, connectorPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var c storage.Connector
|
||||
if err = json.Unmarshal(v.Value, &c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
connectors = append(connectors, c)
|
||||
}
|
||||
return connectors, nil
|
||||
}
|
||||
|
||||
func (c *conn) GetKeys() (keys storage.Keys, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
res, err := c.db.Get(ctx, keysName)
|
||||
if err != nil {
|
||||
return keys, err
|
||||
}
|
||||
if res.Count > 0 && len(res.Kvs) > 0 {
|
||||
err = json.Unmarshal(res.Kvs[0].Value, &keys)
|
||||
}
|
||||
return keys, err
|
||||
}
|
||||
|
||||
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keysName, func(currentValue []byte) ([]byte, error) {
|
||||
var current storage.Keys
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(current)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(updated)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) deleteKey(ctx context.Context, key string) error {
|
||||
res, err := c.db.Delete(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.Deleted == 0 {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) getKey(ctx context.Context, key string, value interface{}) error {
|
||||
r, err := c.db.Get(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.Count == 0 {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
return json.Unmarshal(r.Kvs[0].Value, value)
|
||||
}
|
||||
|
||||
func (c *conn) listAuthRequests(ctx context.Context) (reqs []AuthRequest, err error) {
|
||||
res, err := c.db.Get(ctx, authRequestPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return reqs, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var r AuthRequest
|
||||
if err = json.Unmarshal(v.Value, &r); err != nil {
|
||||
return reqs, err
|
||||
}
|
||||
reqs = append(reqs, r)
|
||||
}
|
||||
return reqs, nil
|
||||
}
|
||||
|
||||
func (c *conn) listAuthCodes(ctx context.Context) (codes []AuthCode, err error) {
|
||||
res, err := c.db.Get(ctx, authCodePrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return codes, err
|
||||
}
|
||||
for _, v := range res.Kvs {
|
||||
var c AuthCode
|
||||
if err = json.Unmarshal(v.Value, &c); err != nil {
|
||||
return codes, err
|
||||
}
|
||||
codes = append(codes, c)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
func (c *conn) txnCreate(ctx context.Context, key string, value interface{}) error {
|
||||
b, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txn := c.db.Txn(ctx)
|
||||
res, err := txn.
|
||||
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)).
|
||||
Then(clientv3.OpPut(key, string(b))).
|
||||
Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !res.Succeeded {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) txnUpdate(ctx context.Context, key string, update func(current []byte) ([]byte, error)) error {
|
||||
getResp, err := c.db.Get(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var currentValue []byte
|
||||
var modRev int64
|
||||
if len(getResp.Kvs) > 0 {
|
||||
currentValue = getResp.Kvs[0].Value
|
||||
modRev = getResp.Kvs[0].ModRevision
|
||||
}
|
||||
|
||||
updatedValue, err := update(currentValue)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
txn := c.db.Txn(ctx)
|
||||
updateResp, err := txn.
|
||||
If(clientv3.Compare(clientv3.ModRevision(key), "=", modRev)).
|
||||
Then(clientv3.OpPut(key, string(updatedValue))).
|
||||
Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !updateResp.Succeeded {
|
||||
return fmt.Errorf("failed to update key=%q: concurrent conflicting update happened", key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func keyID(prefix, id string) string { return prefix + id }
|
||||
func keyEmail(prefix, email string) string { return prefix + strings.ToLower(email) }
|
||||
func keySession(prefix, userID, connID string) string {
|
||||
return prefix + strings.ToLower(userID+"|"+connID)
|
||||
}
|
94
storage/etcd/etcd_test.go
Normal file
94
storage/etcd/etcd_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
"github.com/coreos/dex/storage/conformance"
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func withTimeout(t time.Duration, f func()) {
|
||||
c := make(chan struct{})
|
||||
defer close(c)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-c:
|
||||
case <-time.After(t):
|
||||
// Dump a stack trace of the program. Useful for debugging deadlocks.
|
||||
buf := make([]byte, 2<<20)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", buf[:runtime.Stack(buf, true)])
|
||||
panic("test took too long")
|
||||
}
|
||||
}()
|
||||
|
||||
f()
|
||||
}
|
||||
|
||||
func cleanDB(c *conn) error {
|
||||
ctx := context.TODO()
|
||||
for _, prefix := range []string{
|
||||
clientPrefix,
|
||||
authCodePrefix,
|
||||
refreshTokenPrefix,
|
||||
authRequestPrefix,
|
||||
passwordPrefix,
|
||||
offlineSessionPrefix,
|
||||
connectorPrefix,
|
||||
} {
|
||||
_, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var logger = &logrus.Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
|
||||
func TestEtcd(t *testing.T) {
|
||||
testEtcdEnv := "DEX_ETCD_ENDPOINTS"
|
||||
endpointsStr := os.Getenv(testEtcdEnv)
|
||||
if endpointsStr == "" {
|
||||
t.Skipf("test environment variable %q not set, skipping", testEtcdEnv)
|
||||
return
|
||||
}
|
||||
endpoints := strings.Split(endpointsStr, ",")
|
||||
|
||||
newStorage := func() storage.Storage {
|
||||
s := &Etcd{
|
||||
Endpoints: endpoints,
|
||||
}
|
||||
conn, err := s.open(logger)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stdout, err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := cleanDB(conn); err != nil {
|
||||
fmt.Fprintln(os.Stdout, err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
||||
withTimeout(time.Second*10, func() {
|
||||
conformance.RunTests(t, newStorage)
|
||||
})
|
||||
|
||||
withTimeout(time.Minute*1, func() {
|
||||
conformance.RunTransactionTests(t, newStorage)
|
||||
})
|
||||
}
|
109
storage/etcd/standup.sh
Executable file
109
storage/etcd/standup.sh
Executable file
@@ -0,0 +1,109 @@
|
||||
#!/bin/bash
|
||||
|
||||
if [ "$EUID" -ne 0 ]
|
||||
then echo "Please run as root"
|
||||
exit
|
||||
fi
|
||||
|
||||
function usage {
|
||||
cat << EOF >> /dev/stderr
|
||||
Usage: sudo ./standup.sh [create|destroy] [etcd]
|
||||
|
||||
This is a script for standing up test databases. It uses systemd to daemonize
|
||||
rkt containers running on a local loopback IP.
|
||||
|
||||
The general workflow is to create a daemonized container, use the output to set
|
||||
the test environment variables, run the tests, then destroy the container.
|
||||
|
||||
sudo ./standup.sh create etcd
|
||||
# Copy environment variables and run tests.
|
||||
go test -v -i # always install test dependencies
|
||||
go test -v
|
||||
sudo ./standup.sh destroy etcd
|
||||
|
||||
EOF
|
||||
exit 2
|
||||
}
|
||||
|
||||
function main {
|
||||
if [ "$#" -ne 2 ]; then
|
||||
usage
|
||||
exit 2
|
||||
fi
|
||||
|
||||
case "$1" in
|
||||
"create")
|
||||
case "$2" in
|
||||
"etcd")
|
||||
create_etcd;;
|
||||
*)
|
||||
usage
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
"destroy")
|
||||
case "$2" in
|
||||
"etcd")
|
||||
destroy_etcd;;
|
||||
*)
|
||||
usage
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
exit 2
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
function wait_for_file {
|
||||
while [ ! -f $1 ]; do
|
||||
sleep 1
|
||||
done
|
||||
}
|
||||
|
||||
function wait_for_container {
|
||||
while [ -z "$( rkt list --full | grep $1 | grep running )" ]; do
|
||||
sleep 1
|
||||
done
|
||||
}
|
||||
|
||||
function create_etcd {
|
||||
UUID_FILE=/tmp/dex-etcd-uuid
|
||||
if [ -f $UUID_FILE ]; then
|
||||
echo "etcd database already exists, try ./standup.sh destroy etcd"
|
||||
exit 2
|
||||
fi
|
||||
|
||||
echo "Starting etcd . To view progress run:"
|
||||
echo ""
|
||||
echo " journalctl -fu dex-etcd"
|
||||
echo ""
|
||||
UNIFIED_CGROUP_HIERARCHY=no \
|
||||
systemd-run --unit=dex-etcd \
|
||||
rkt run --uuid-file-save=$UUID_FILE --insecure-options=image \
|
||||
--net=host \
|
||||
docker://quay.io/coreos/etcd:v3.2.9
|
||||
|
||||
wait_for_file $UUID_FILE
|
||||
|
||||
UUID=$( cat $UUID_FILE )
|
||||
wait_for_container $UUID
|
||||
echo "To run tests export the following environment variables:"
|
||||
echo ""
|
||||
echo " export DEX_ETCD_ENDPOINTS=http://localhost:2379"
|
||||
echo ""
|
||||
}
|
||||
|
||||
function destroy_etcd {
|
||||
UUID_FILE=/tmp/dex-etcd-uuid
|
||||
systemctl stop dex-etcd
|
||||
rkt rm --uuid-file=$UUID_FILE
|
||||
rm $UUID_FILE
|
||||
}
|
||||
|
||||
|
||||
main $@
|
229
storage/etcd/types.go
Normal file
229
storage/etcd/types.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package etcd
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coreos/dex/storage"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
// AuthCode is a mirrored struct from storage with JSON struct tags
|
||||
type AuthCode struct {
|
||||
ID string `json:"ID"`
|
||||
ClientID string `json:"clientID"`
|
||||
RedirectURI string `json:"redirectURI"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
|
||||
ConnectorID string `json:"connectorID,omitempty"`
|
||||
ConnectorData []byte `json:"connectorData,omitempty"`
|
||||
Claims Claims `json:"claims,omitempty"`
|
||||
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
|
||||
func fromStorageAuthCode(a storage.AuthCode) AuthCode {
|
||||
return AuthCode{
|
||||
ID: a.ID,
|
||||
ClientID: a.ClientID,
|
||||
RedirectURI: a.RedirectURI,
|
||||
ConnectorID: a.ConnectorID,
|
||||
ConnectorData: a.ConnectorData,
|
||||
Nonce: a.Nonce,
|
||||
Scopes: a.Scopes,
|
||||
Claims: fromStorageClaims(a.Claims),
|
||||
Expiry: a.Expiry,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthRequest is a mirrored struct from storage with JSON struct tags
|
||||
type AuthRequest struct {
|
||||
ID string `json:"id"`
|
||||
ClientID string `json:"client_id"`
|
||||
|
||||
ResponseTypes []string `json:"response_types"`
|
||||
Scopes []string `json:"scopes"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
Nonce string `json:"nonce"`
|
||||
State string `json:"state"`
|
||||
|
||||
ForceApprovalPrompt bool `json:"force_approval_prompt"`
|
||||
|
||||
Expiry time.Time `json:"expiry"`
|
||||
|
||||
LoggedIn bool `json:"logged_in"`
|
||||
|
||||
Claims Claims `json:"claims"`
|
||||
|
||||
ConnectorID string `json:"connector_id"`
|
||||
ConnectorData []byte `json:"connector_data"`
|
||||
}
|
||||
|
||||
func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
|
||||
return AuthRequest{
|
||||
ID: a.ID,
|
||||
ClientID: a.ClientID,
|
||||
ResponseTypes: a.ResponseTypes,
|
||||
Scopes: a.Scopes,
|
||||
RedirectURI: a.RedirectURI,
|
||||
Nonce: a.Nonce,
|
||||
State: a.State,
|
||||
ForceApprovalPrompt: a.ForceApprovalPrompt,
|
||||
Expiry: a.Expiry,
|
||||
LoggedIn: a.LoggedIn,
|
||||
Claims: fromStorageClaims(a.Claims),
|
||||
ConnectorID: a.ConnectorID,
|
||||
ConnectorData: a.ConnectorData,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageAuthRequest(a AuthRequest) storage.AuthRequest {
|
||||
return storage.AuthRequest{
|
||||
ID: a.ID,
|
||||
ClientID: a.ClientID,
|
||||
ResponseTypes: a.ResponseTypes,
|
||||
Scopes: a.Scopes,
|
||||
RedirectURI: a.RedirectURI,
|
||||
Nonce: a.Nonce,
|
||||
State: a.State,
|
||||
ForceApprovalPrompt: a.ForceApprovalPrompt,
|
||||
LoggedIn: a.LoggedIn,
|
||||
ConnectorID: a.ConnectorID,
|
||||
ConnectorData: a.ConnectorData,
|
||||
Expiry: a.Expiry,
|
||||
Claims: toStorageClaims(a.Claims),
|
||||
}
|
||||
}
|
||||
|
||||
// RefreshToken is a mirrored struct from storage with JSON struct tags
|
||||
type RefreshToken struct {
|
||||
ID string `json:"id"`
|
||||
|
||||
Token string `json:"token"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
|
||||
ClientID string `json:"client_id"`
|
||||
|
||||
ConnectorID string `json:"connector_id"`
|
||||
ConnectorData []byte `json:"connector_data"`
|
||||
Claims Claims `json:"claims"`
|
||||
|
||||
Scopes []string `json:"scopes"`
|
||||
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
|
||||
func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
|
||||
return storage.RefreshToken{
|
||||
ID: r.ID,
|
||||
Token: r.Token,
|
||||
CreatedAt: r.CreatedAt,
|
||||
LastUsed: r.LastUsed,
|
||||
ClientID: r.ClientID,
|
||||
ConnectorID: r.ConnectorID,
|
||||
ConnectorData: r.ConnectorData,
|
||||
Scopes: r.Scopes,
|
||||
Nonce: r.Nonce,
|
||||
Claims: toStorageClaims(r.Claims),
|
||||
}
|
||||
}
|
||||
|
||||
func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
|
||||
return RefreshToken{
|
||||
ID: r.ID,
|
||||
Token: r.Token,
|
||||
CreatedAt: r.CreatedAt,
|
||||
LastUsed: r.LastUsed,
|
||||
ClientID: r.ClientID,
|
||||
ConnectorID: r.ConnectorID,
|
||||
ConnectorData: r.ConnectorData,
|
||||
Scopes: r.Scopes,
|
||||
Nonce: r.Nonce,
|
||||
Claims: fromStorageClaims(r.Claims),
|
||||
}
|
||||
}
|
||||
|
||||
// Claims is a mirrored struct from storage with JSON struct tags.
|
||||
type Claims struct {
|
||||
UserID string `json:"userID"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"emailVerified"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func fromStorageClaims(i storage.Claims) Claims {
|
||||
return Claims{
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageClaims(i Claims) storage.Claims {
|
||||
return storage.Claims{
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
}
|
||||
}
|
||||
|
||||
// Keys is a mirrored struct from storage with JSON struct tags
|
||||
type Keys struct {
|
||||
SigningKey *jose.JSONWebKey `json:"signing_key,omitempty"`
|
||||
SigningKeyPub *jose.JSONWebKey `json:"signing_key_pub,omitempty"`
|
||||
VerificationKeys []storage.VerificationKey `json:"verification_keys"`
|
||||
NextRotation time.Time `json:"next_rotation"`
|
||||
}
|
||||
|
||||
func fromStorageKeys(keys storage.Keys) Keys {
|
||||
return Keys{
|
||||
SigningKey: keys.SigningKey,
|
||||
SigningKeyPub: keys.SigningKeyPub,
|
||||
VerificationKeys: keys.VerificationKeys,
|
||||
NextRotation: keys.NextRotation,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageKeys(keys Keys) storage.Keys {
|
||||
return storage.Keys{
|
||||
SigningKey: keys.SigningKey,
|
||||
SigningKeyPub: keys.SigningKeyPub,
|
||||
VerificationKeys: keys.VerificationKeys,
|
||||
NextRotation: keys.NextRotation,
|
||||
}
|
||||
}
|
||||
|
||||
// OfflineSessions is a mirrored struct from storage with JSON struct tags
|
||||
type OfflineSessions struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
ConnID string `json:"conn_id,omitempty"`
|
||||
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
|
||||
}
|
||||
|
||||
func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
|
||||
return OfflineSessions{
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
||||
s := storage.OfflineSessions{
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
}
|
||||
if s.Refresh == nil {
|
||||
// Server code assumes this will be non-nil.
|
||||
s.Refresh = make(map[string]*storage.RefreshTokenRef)
|
||||
}
|
||||
return s
|
||||
}
|
Reference in New Issue
Block a user