storage: make static storages query real storages for some actions
If dex is configured with static passwords or clients, let the API still add or modify objects in the backing storage, so long as their IDs don't conflict with the static ones. List options now aggregate resources from the static list and backing storage.
This commit is contained in:
@@ -1,55 +0,0 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
||||
func TestStaticClients(t *testing.T) {
|
||||
logger := &logrus.Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
s := New(logger)
|
||||
|
||||
c1 := storage.Client{ID: "foo", Secret: "foo_secret"}
|
||||
c2 := storage.Client{ID: "bar", Secret: "bar_secret"}
|
||||
s.CreateClient(c1)
|
||||
s2 := storage.WithStaticClients(s, []storage.Client{c2})
|
||||
|
||||
tests := []struct {
|
||||
id string
|
||||
s storage.Storage
|
||||
wantErr bool
|
||||
wantClient storage.Client
|
||||
}{
|
||||
{"foo", s, false, c1},
|
||||
{"bar", s, true, storage.Client{}},
|
||||
{"foo", s2, true, storage.Client{}},
|
||||
{"bar", s2, false, c2},
|
||||
}
|
||||
|
||||
for i, tc := range tests {
|
||||
gotClient, err := tc.s.GetClient(tc.id)
|
||||
if err != nil {
|
||||
if !tc.wantErr {
|
||||
t.Errorf("case %d: GetClient(%q) %v", i, tc.id, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if tc.wantErr {
|
||||
t.Errorf("case %d: GetClient(%q) expected error", i, tc.id)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.wantClient, gotClient) {
|
||||
t.Errorf("case %d: expected=%#v got=%#v", i, tc.wantClient, gotClient)
|
||||
}
|
||||
}
|
||||
}
|
192
storage/memory/static_test.go
Normal file
192
storage/memory/static_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/coreos/dex/storage"
|
||||
)
|
||||
|
||||
func TestStaticClients(t *testing.T) {
|
||||
logger := &logrus.Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
backing := New(logger)
|
||||
|
||||
c1 := storage.Client{ID: "foo", Secret: "foo_secret"}
|
||||
c2 := storage.Client{ID: "bar", Secret: "bar_secret"}
|
||||
c3 := storage.Client{ID: "spam", Secret: "spam_secret"}
|
||||
|
||||
backing.CreateClient(c1)
|
||||
s := storage.WithStaticClients(backing, []storage.Client{c2})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
action func() error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "get client from static storage",
|
||||
action: func() error {
|
||||
_, err := s.GetClient(c2.ID)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get client from backing storage",
|
||||
action: func() error {
|
||||
_, err := s.GetClient(c1.ID)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update static client",
|
||||
action: func() error {
|
||||
updater := func(c storage.Client) (storage.Client, error) {
|
||||
c.Secret = "new_" + c.Secret
|
||||
return c, nil
|
||||
}
|
||||
return s.UpdateClient(c2.ID, updater)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "update non-static client",
|
||||
action: func() error {
|
||||
updater := func(c storage.Client) (storage.Client, error) {
|
||||
c.Secret = "new_" + c.Secret
|
||||
return c, nil
|
||||
}
|
||||
return s.UpdateClient(c1.ID, updater)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list clients",
|
||||
action: func() error {
|
||||
clients, err := s.ListClients()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n := len(clients); n != 2 {
|
||||
return fmt.Errorf("expected 2 clients got %d", n)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create client",
|
||||
action: func() error {
|
||||
return s.CreateClient(c3)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
err := tc.action()
|
||||
if err != nil && !tc.wantErr {
|
||||
t.Errorf("%s: %v", tc.name, err)
|
||||
}
|
||||
if err == nil && tc.wantErr {
|
||||
t.Errorf("%s: expected error, didn't get one", tc.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticPasswords(t *testing.T) {
|
||||
logger := &logrus.Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
backing := New(logger)
|
||||
|
||||
p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"}
|
||||
p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"}
|
||||
p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"}
|
||||
|
||||
backing.CreatePassword(p1)
|
||||
s := storage.WithStaticPasswords(backing, []storage.Password{p2})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
action func() error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "get password from static storage",
|
||||
action: func() error {
|
||||
_, err := s.GetPassword(p2.Email)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get password from backing storage",
|
||||
action: func() error {
|
||||
_, err := s.GetPassword(p1.Email)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get password from static storage with casing",
|
||||
action: func() error {
|
||||
_, err := s.GetPassword(strings.ToUpper(p2.Email))
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update static password",
|
||||
action: func() error {
|
||||
updater := func(p storage.Password) (storage.Password, error) {
|
||||
p.Username = "new_" + p.Username
|
||||
return p, nil
|
||||
}
|
||||
return s.UpdatePassword(p2.Email, updater)
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "update non-static password",
|
||||
action: func() error {
|
||||
updater := func(p storage.Password) (storage.Password, error) {
|
||||
p.Username = "new_" + p.Username
|
||||
return p, nil
|
||||
}
|
||||
return s.UpdatePassword(p1.Email, updater)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "list passwords",
|
||||
action: func() error {
|
||||
passwords, err := s.ListPasswords()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n := len(passwords); n != 2 {
|
||||
return fmt.Errorf("expected 2 passwords got %d", n)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create password",
|
||||
action: func() error {
|
||||
return s.CreatePassword(p3)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
err := tc.action()
|
||||
if err != nil && !tc.wantErr {
|
||||
t.Errorf("%s: %v", tc.name, err)
|
||||
}
|
||||
if err == nil && tc.wantErr {
|
||||
t.Errorf("%s: expected error, didn't get one", tc.name)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user