storage: make static storages query real storages for some actions
If dex is configured with static passwords or clients, let the API still add or modify objects in the backing storage, so long as their IDs don't conflict with the static ones. List options now aggregate resources from the static list and backing storage.
This commit is contained in:
parent
d31bb1c8d5
commit
4c39bc20ae
@ -1,55 +0,0 @@
|
|||||||
package memory
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/Sirupsen/logrus"
|
|
||||||
"github.com/coreos/dex/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStaticClients(t *testing.T) {
|
|
||||||
logger := &logrus.Logger{
|
|
||||||
Out: os.Stderr,
|
|
||||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
|
||||||
Level: logrus.DebugLevel,
|
|
||||||
}
|
|
||||||
s := New(logger)
|
|
||||||
|
|
||||||
c1 := storage.Client{ID: "foo", Secret: "foo_secret"}
|
|
||||||
c2 := storage.Client{ID: "bar", Secret: "bar_secret"}
|
|
||||||
s.CreateClient(c1)
|
|
||||||
s2 := storage.WithStaticClients(s, []storage.Client{c2})
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
id string
|
|
||||||
s storage.Storage
|
|
||||||
wantErr bool
|
|
||||||
wantClient storage.Client
|
|
||||||
}{
|
|
||||||
{"foo", s, false, c1},
|
|
||||||
{"bar", s, true, storage.Client{}},
|
|
||||||
{"foo", s2, true, storage.Client{}},
|
|
||||||
{"bar", s2, false, c2},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, tc := range tests {
|
|
||||||
gotClient, err := tc.s.GetClient(tc.id)
|
|
||||||
if err != nil {
|
|
||||||
if !tc.wantErr {
|
|
||||||
t.Errorf("case %d: GetClient(%q) %v", i, tc.id, err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc.wantErr {
|
|
||||||
t.Errorf("case %d: GetClient(%q) expected error", i, tc.id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(tc.wantClient, gotClient) {
|
|
||||||
t.Errorf("case %d: expected=%#v got=%#v", i, tc.wantClient, gotClient)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
192
storage/memory/static_test.go
Normal file
192
storage/memory/static_test.go
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Sirupsen/logrus"
|
||||||
|
"github.com/coreos/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStaticClients(t *testing.T) {
|
||||||
|
logger := &logrus.Logger{
|
||||||
|
Out: os.Stderr,
|
||||||
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||||
|
Level: logrus.DebugLevel,
|
||||||
|
}
|
||||||
|
backing := New(logger)
|
||||||
|
|
||||||
|
c1 := storage.Client{ID: "foo", Secret: "foo_secret"}
|
||||||
|
c2 := storage.Client{ID: "bar", Secret: "bar_secret"}
|
||||||
|
c3 := storage.Client{ID: "spam", Secret: "spam_secret"}
|
||||||
|
|
||||||
|
backing.CreateClient(c1)
|
||||||
|
s := storage.WithStaticClients(backing, []storage.Client{c2})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action func() error
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "get client from static storage",
|
||||||
|
action: func() error {
|
||||||
|
_, err := s.GetClient(c2.ID)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "get client from backing storage",
|
||||||
|
action: func() error {
|
||||||
|
_, err := s.GetClient(c1.ID)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update static client",
|
||||||
|
action: func() error {
|
||||||
|
updater := func(c storage.Client) (storage.Client, error) {
|
||||||
|
c.Secret = "new_" + c.Secret
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
return s.UpdateClient(c2.ID, updater)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update non-static client",
|
||||||
|
action: func() error {
|
||||||
|
updater := func(c storage.Client) (storage.Client, error) {
|
||||||
|
c.Secret = "new_" + c.Secret
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
return s.UpdateClient(c1.ID, updater)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "list clients",
|
||||||
|
action: func() error {
|
||||||
|
clients, err := s.ListClients()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n := len(clients); n != 2 {
|
||||||
|
return fmt.Errorf("expected 2 clients got %d", n)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create client",
|
||||||
|
action: func() error {
|
||||||
|
return s.CreateClient(c3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
err := tc.action()
|
||||||
|
if err != nil && !tc.wantErr {
|
||||||
|
t.Errorf("%s: %v", tc.name, err)
|
||||||
|
}
|
||||||
|
if err == nil && tc.wantErr {
|
||||||
|
t.Errorf("%s: expected error, didn't get one", tc.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaticPasswords(t *testing.T) {
|
||||||
|
logger := &logrus.Logger{
|
||||||
|
Out: os.Stderr,
|
||||||
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||||
|
Level: logrus.DebugLevel,
|
||||||
|
}
|
||||||
|
backing := New(logger)
|
||||||
|
|
||||||
|
p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"}
|
||||||
|
p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"}
|
||||||
|
p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"}
|
||||||
|
|
||||||
|
backing.CreatePassword(p1)
|
||||||
|
s := storage.WithStaticPasswords(backing, []storage.Password{p2})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
action func() error
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "get password from static storage",
|
||||||
|
action: func() error {
|
||||||
|
_, err := s.GetPassword(p2.Email)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "get password from backing storage",
|
||||||
|
action: func() error {
|
||||||
|
_, err := s.GetPassword(p1.Email)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "get password from static storage with casing",
|
||||||
|
action: func() error {
|
||||||
|
_, err := s.GetPassword(strings.ToUpper(p2.Email))
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update static password",
|
||||||
|
action: func() error {
|
||||||
|
updater := func(p storage.Password) (storage.Password, error) {
|
||||||
|
p.Username = "new_" + p.Username
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
return s.UpdatePassword(p2.Email, updater)
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "update non-static password",
|
||||||
|
action: func() error {
|
||||||
|
updater := func(p storage.Password) (storage.Password, error) {
|
||||||
|
p.Username = "new_" + p.Username
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
return s.UpdatePassword(p1.Email, updater)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "list passwords",
|
||||||
|
action: func() error {
|
||||||
|
passwords, err := s.ListPasswords()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n := len(passwords); n != 2 {
|
||||||
|
return fmt.Errorf("expected 2 passwords got %d", n)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "create password",
|
||||||
|
action: func() error {
|
||||||
|
return s.CreatePassword(p3)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
err := tc.action()
|
||||||
|
if err != nil && !tc.wantErr {
|
||||||
|
t.Errorf("%s: %v", tc.name, err)
|
||||||
|
}
|
||||||
|
if err == nil && tc.wantErr {
|
||||||
|
t.Errorf("%s: expected error, didn't get one", tc.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -19,11 +19,7 @@ type staticClientsStorage struct {
|
|||||||
clientsByID map[string]Client
|
clientsByID map[string]Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithStaticClients returns a storage with a read-only set of clients. Write actions,
|
// WithStaticClients adds a read-only set of clients to the underlying storages.
|
||||||
// such as creating other clients, will fail.
|
|
||||||
//
|
|
||||||
// In the future the returned storage may allow creating and storing additional clients
|
|
||||||
// in the underlying storage.
|
|
||||||
func WithStaticClients(s Storage, staticClients []Client) Storage {
|
func WithStaticClients(s Storage, staticClients []Client) Storage {
|
||||||
clientsByID := make(map[string]Client, len(staticClients))
|
clientsByID := make(map[string]Client, len(staticClients))
|
||||||
for _, client := range staticClients {
|
for _, client := range staticClients {
|
||||||
@ -36,25 +32,50 @@ func (s staticClientsStorage) GetClient(id string) (Client, error) {
|
|||||||
if client, ok := s.clientsByID[id]; ok {
|
if client, ok := s.clientsByID[id]; ok {
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
return Client{}, ErrNotFound
|
return s.Storage.GetClient(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s staticClientsStorage) isStatic(id string) bool {
|
||||||
|
_, ok := s.clientsByID[id]
|
||||||
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticClientsStorage) ListClients() ([]Client, error) {
|
func (s staticClientsStorage) ListClients() ([]Client, error) {
|
||||||
clients := make([]Client, len(s.clients))
|
clients, err := s.Storage.ListClients()
|
||||||
copy(clients, s.clients)
|
if err != nil {
|
||||||
return clients, nil
|
return nil, err
|
||||||
|
}
|
||||||
|
n := 0
|
||||||
|
for _, client := range clients {
|
||||||
|
// If a client in the backing storage has the same ID as a static client
|
||||||
|
// prefer the static client.
|
||||||
|
if !s.isStatic(client.ID) {
|
||||||
|
clients[n] = client
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return append(clients[:n], s.clients...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticClientsStorage) CreateClient(c Client) error {
|
func (s staticClientsStorage) CreateClient(c Client) error {
|
||||||
return errors.New("static clients: read-only cannot create client")
|
if s.isStatic(c.ID) {
|
||||||
|
return errors.New("static clients: read-only cannot create client")
|
||||||
|
}
|
||||||
|
return s.Storage.CreateClient(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticClientsStorage) DeleteClient(id string) error {
|
func (s staticClientsStorage) DeleteClient(id string) error {
|
||||||
return errors.New("static clients: read-only cannot delete client")
|
if s.isStatic(id) {
|
||||||
|
return errors.New("static clients: read-only cannot delete client")
|
||||||
|
}
|
||||||
|
return s.Storage.DeleteClient(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticClientsStorage) UpdateClient(id string, updater func(old Client) (Client, error)) error {
|
func (s staticClientsStorage) UpdateClient(id string, updater func(old Client) (Client, error)) error {
|
||||||
return errors.New("static clients: read-only cannot update client")
|
if s.isStatic(id) {
|
||||||
|
return errors.New("static clients: read-only cannot update client")
|
||||||
|
}
|
||||||
|
return s.Storage.UpdateClient(id, updater)
|
||||||
}
|
}
|
||||||
|
|
||||||
type staticPasswordsStorage struct {
|
type staticPasswordsStorage struct {
|
||||||
@ -76,27 +97,56 @@ func WithStaticPasswords(s Storage, staticPasswords []Password) Storage {
|
|||||||
return staticPasswordsStorage{s, staticPasswords, passwordsByEmail}
|
return staticPasswordsStorage{s, staticPasswords, passwordsByEmail}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s staticPasswordsStorage) isStatic(email string) bool {
|
||||||
|
_, ok := s.passwordsByEmail[strings.ToLower(email)]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) GetPassword(email string) (Password, error) {
|
func (s staticPasswordsStorage) GetPassword(email string) (Password, error) {
|
||||||
if password, ok := s.passwordsByEmail[strings.ToLower(email)]; ok {
|
// TODO(ericchiang): BLAH. We really need to figure out how to handle
|
||||||
|
// lower cased emails better.
|
||||||
|
email = strings.ToLower(email)
|
||||||
|
if password, ok := s.passwordsByEmail[email]; ok {
|
||||||
return password, nil
|
return password, nil
|
||||||
}
|
}
|
||||||
return Password{}, ErrNotFound
|
return s.Storage.GetPassword(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) ListPasswords() ([]Password, error) {
|
func (s staticPasswordsStorage) ListPasswords() ([]Password, error) {
|
||||||
passwords := make([]Password, len(s.passwords))
|
passwords, err := s.Storage.ListPasswords()
|
||||||
copy(passwords, s.passwords)
|
if err != nil {
|
||||||
return passwords, nil
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n := 0
|
||||||
|
for _, password := range passwords {
|
||||||
|
// If an entry has the same email as those provided in the static
|
||||||
|
// values, prefer the static value.
|
||||||
|
if !s.isStatic(password.Email) {
|
||||||
|
passwords[n] = password
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return append(passwords[:n], s.passwords...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) CreatePassword(p Password) error {
|
func (s staticPasswordsStorage) CreatePassword(p Password) error {
|
||||||
return errors.New("static passwords: read-only cannot create password")
|
if s.isStatic(p.Email) {
|
||||||
|
return errors.New("static passwords: read-only cannot create password")
|
||||||
|
}
|
||||||
|
return s.Storage.CreatePassword(p)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) DeletePassword(id string) error {
|
func (s staticPasswordsStorage) DeletePassword(email string) error {
|
||||||
return errors.New("static passwords: read-only cannot create password")
|
if s.isStatic(email) {
|
||||||
|
return errors.New("static passwords: read-only cannot create password")
|
||||||
|
}
|
||||||
|
return s.Storage.DeletePassword(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s staticPasswordsStorage) UpdatePassword(id string, updater func(old Password) (Password, error)) error {
|
func (s staticPasswordsStorage) UpdatePassword(email string, updater func(old Password) (Password, error)) error {
|
||||||
return errors.New("static passwords: read-only cannot update password")
|
if s.isStatic(email) {
|
||||||
|
return errors.New("static passwords: read-only cannot update password")
|
||||||
|
}
|
||||||
|
return s.Storage.UpdatePassword(email, updater)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user