storage/conformance: add more conformance tests

This commit is contained in:
Eric Chiang 2016-09-19 09:41:54 -06:00 committed by Eric Chiang
parent 63f56b4269
commit 36d67574c5

View File

@ -9,9 +9,12 @@ import (
"time" "time"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
"github.com/kylelemons/godebug/pretty"
) )
var neverExpire = time.Now().Add(time.Hour * 24 * 365 * 100) // ensure that values being tested on never expire.
var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100)
// StorageFactory is a method for creating a new storage. The returned storage sould be initialized // StorageFactory is a method for creating a new storage. The returned storage sould be initialized
// but shouldn't have any existing data in it. // but shouldn't have any existing data in it.
@ -23,8 +26,10 @@ func RunTestSuite(t *testing.T, sf StorageFactory) {
name string name string
run func(t *testing.T, s storage.Storage) run func(t *testing.T, s storage.Storage)
}{ }{
{"UpdateAuthRequest", testUpdateAuthRequest}, {"AuthCodeCRUD", testAuthCodeCRUD},
{"CreateRefresh", testCreateRefresh}, {"AuthRequestCRUD", testAuthRequestCRUD},
{"ClientCRUD", testClientCRUD},
{"RefreshTokenCRUD", testRefreshTokenCRUD},
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
@ -33,14 +38,36 @@ func RunTestSuite(t *testing.T, sf StorageFactory) {
} }
} }
func testUpdateAuthRequest(t *testing.T, s storage.Storage) { func mustBeErrNotFound(t *testing.T, kind string, err error) {
switch {
case err == nil:
t.Errorf("deleting non-existant %s should return an error", kind)
case err != storage.ErrNotFound:
t.Errorf("deleting %s expected storage.ErrNotFound, got %v", kind, err)
}
}
func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
a := storage.AuthRequest{ a := storage.AuthRequest{
ID: storage.NewID(), ID: storage.NewID(),
ClientID: "foobar", ClientID: "foobar",
ResponseTypes: []string{"code"}, ResponseTypes: []string{"code"},
Scopes: []string{"openid", "email"}, Scopes: []string{"openid", "email"},
RedirectURI: "https://localhost:80/callback", RedirectURI: "https://localhost:80/callback",
Nonce: "foo",
State: "bar",
ForceApprovalPrompt: true,
LoggedIn: true,
Expiry: neverExpire, Expiry: neverExpire,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
} }
identity := storage.Claims{Email: "foobar"} identity := storage.Claims{Email: "foobar"}
@ -65,24 +92,128 @@ func testUpdateAuthRequest(t *testing.T, s storage.Storage) {
} }
} }
func testCreateRefresh(t *testing.T, s storage.Storage) { func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
a := storage.AuthCode{
ID: storage.NewID(),
ClientID: "foobar",
RedirectURI: "https://localhost:80/callback",
Nonce: "foobar",
Scopes: []string{"openid", "email"},
Expiry: neverExpire,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
}
if err := s.CreateAuthCode(a); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
got, err := s.GetAuthCode(a.ID)
if err != nil {
t.Fatalf("failed to get auth req: %v", err)
}
if a.Expiry.Unix() != got.Expiry.Unix() {
t.Errorf("auth code expiry did not match want=%s vs got=%s", a.Expiry, got.Expiry)
}
got.Expiry = a.Expiry // time fields do not compare well
if diff := pretty.Compare(a, got); diff != "" {
t.Errorf("auth code retrieved from storage did not match: %s", diff)
}
if err := s.DeleteAuthCode(a.ID); err != nil {
t.Fatalf("delete auth code: %v", err)
}
_, err = s.GetAuthCode(a.ID)
mustBeErrNotFound(t, "auth code", err)
}
func testClientCRUD(t *testing.T, s storage.Storage) {
id := storage.NewID()
c := storage.Client{
ID: id,
Secret: "foobar",
RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
Name: "dex client",
LogoURL: "https://goo.gl/JIyzIC",
}
err := s.DeleteClient(id)
mustBeErrNotFound(t, "client", err)
if err := s.CreateClient(c); err != nil {
t.Fatalf("create client: %v", err)
}
getAndCompare := func(id string, want storage.Client) {
gc, err := s.GetClient(id)
if err != nil {
t.Errorf("get client: %v", err)
return
}
if diff := pretty.Compare(want, gc); diff != "" {
t.Errorf("client retrieved from storage did not match: %s", diff)
}
}
getAndCompare(id, c)
newSecret := "barfoo"
err = s.UpdateClient(id, func(old storage.Client) (storage.Client, error) {
old.Secret = newSecret
return old, nil
})
if err != nil {
t.Errorf("update client: %v", err)
}
c.Secret = newSecret
getAndCompare(id, c)
if err := s.DeleteClient(id); err != nil {
t.Fatalf("delete client: %v", err)
}
_, err = s.GetClient(id)
mustBeErrNotFound(t, "client", err)
}
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
id := storage.NewID() id := storage.NewID()
refresh := storage.RefreshToken{ refresh := storage.RefreshToken{
RefreshToken: id, RefreshToken: id,
ClientID: "client_id", ClientID: "client_id",
ConnectorID: "client_secret", ConnectorID: "client_secret",
Scopes: []string{"openid", "email", "profile"}, Scopes: []string{"openid", "email", "profile"},
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
} }
if err := s.CreateRefresh(refresh); err != nil { if err := s.CreateRefresh(refresh); err != nil {
t.Fatalf("create refresh token: %v", err) t.Fatalf("create refresh token: %v", err)
} }
gotRefresh, err := s.GetRefresh(id)
getAndCompare := func(id string, want storage.RefreshToken) {
gr, err := s.GetRefresh(id)
if err != nil { if err != nil {
t.Fatalf("get refresh: %v", err) t.Errorf("get refresh: %v", err)
return
} }
if !reflect.DeepEqual(gotRefresh, refresh) { if diff := pretty.Compare(want, gr); diff != "" {
t.Errorf("refresh returned did not match expected") t.Errorf("refresh token retrieved from storage did not match: %s", diff)
} }
}
getAndCompare(id, refresh)
if err := s.DeleteRefresh(id); err != nil { if err := s.DeleteRefresh(id); err != nil {
t.Fatalf("failed to delete refresh request: %v", err) t.Fatalf("failed to delete refresh request: %v", err)