diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 76d39780..9ba2ec68 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -9,9 +9,12 @@ import ( "time" "github.com/coreos/dex/storage" + + "github.com/kylelemons/godebug/pretty" ) -var neverExpire = time.Now().Add(time.Hour * 24 * 365 * 100) +// ensure that values being tested on never expire. +var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100) // StorageFactory is a method for creating a new storage. The returned storage sould be initialized // but shouldn't have any existing data in it. @@ -23,8 +26,10 @@ func RunTestSuite(t *testing.T, sf StorageFactory) { name string run func(t *testing.T, s storage.Storage) }{ - {"UpdateAuthRequest", testUpdateAuthRequest}, - {"CreateRefresh", testCreateRefresh}, + {"AuthCodeCRUD", testAuthCodeCRUD}, + {"AuthRequestCRUD", testAuthRequestCRUD}, + {"ClientCRUD", testClientCRUD}, + {"RefreshTokenCRUD", testRefreshTokenCRUD}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -33,14 +38,36 @@ func RunTestSuite(t *testing.T, sf StorageFactory) { } } -func testUpdateAuthRequest(t *testing.T, s storage.Storage) { +func mustBeErrNotFound(t *testing.T, kind string, err error) { + switch { + case err == nil: + t.Errorf("deleting non-existant %s should return an error", kind) + case err != storage.ErrNotFound: + t.Errorf("deleting %s expected storage.ErrNotFound, got %v", kind, err) + } +} + +func testAuthRequestCRUD(t *testing.T, s storage.Storage) { a := storage.AuthRequest{ - ID: storage.NewID(), - ClientID: "foobar", - ResponseTypes: []string{"code"}, - Scopes: []string{"openid", "email"}, - RedirectURI: "https://localhost:80/callback", - Expiry: neverExpire, + ID: storage.NewID(), + ClientID: "foobar", + ResponseTypes: []string{"code"}, + Scopes: []string{"openid", "email"}, + RedirectURI: "https://localhost:80/callback", + Nonce: "foo", + State: "bar", + ForceApprovalPrompt: true, + LoggedIn: true, + Expiry: neverExpire, + ConnectorID: "ldap", + ConnectorData: []byte(`{"some":"data"}`), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, } identity := storage.Claims{Email: "foobar"} @@ -65,25 +92,129 @@ func testUpdateAuthRequest(t *testing.T, s storage.Storage) { } } -func testCreateRefresh(t *testing.T, s storage.Storage) { +func testAuthCodeCRUD(t *testing.T, s storage.Storage) { + a := storage.AuthCode{ + ID: storage.NewID(), + ClientID: "foobar", + RedirectURI: "https://localhost:80/callback", + Nonce: "foobar", + Scopes: []string{"openid", "email"}, + Expiry: neverExpire, + ConnectorID: "ldap", + ConnectorData: []byte(`{"some":"data"}`), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + } + + if err := s.CreateAuthCode(a); err != nil { + t.Fatalf("failed creating auth code: %v", err) + } + + got, err := s.GetAuthCode(a.ID) + if err != nil { + t.Fatalf("failed to get auth req: %v", err) + } + if a.Expiry.Unix() != got.Expiry.Unix() { + t.Errorf("auth code expiry did not match want=%s vs got=%s", a.Expiry, got.Expiry) + } + got.Expiry = a.Expiry // time fields do not compare well + if diff := pretty.Compare(a, got); diff != "" { + t.Errorf("auth code retrieved from storage did not match: %s", diff) + } + + if err := s.DeleteAuthCode(a.ID); err != nil { + t.Fatalf("delete auth code: %v", err) + } + + _, err = s.GetAuthCode(a.ID) + mustBeErrNotFound(t, "auth code", err) +} + +func testClientCRUD(t *testing.T, s storage.Storage) { + id := storage.NewID() + c := storage.Client{ + ID: id, + Secret: "foobar", + RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, + Name: "dex client", + LogoURL: "https://goo.gl/JIyzIC", + } + err := s.DeleteClient(id) + mustBeErrNotFound(t, "client", err) + + if err := s.CreateClient(c); err != nil { + t.Fatalf("create client: %v", err) + } + + getAndCompare := func(id string, want storage.Client) { + gc, err := s.GetClient(id) + if err != nil { + t.Errorf("get client: %v", err) + return + } + if diff := pretty.Compare(want, gc); diff != "" { + t.Errorf("client retrieved from storage did not match: %s", diff) + } + } + + getAndCompare(id, c) + + newSecret := "barfoo" + err = s.UpdateClient(id, func(old storage.Client) (storage.Client, error) { + old.Secret = newSecret + return old, nil + }) + if err != nil { + t.Errorf("update client: %v", err) + } + c.Secret = newSecret + getAndCompare(id, c) + + if err := s.DeleteClient(id); err != nil { + t.Fatalf("delete client: %v", err) + } + + _, err = s.GetClient(id) + mustBeErrNotFound(t, "client", err) +} + +func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { id := storage.NewID() refresh := storage.RefreshToken{ RefreshToken: id, ClientID: "client_id", ConnectorID: "client_secret", Scopes: []string{"openid", "email", "profile"}, + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, } if err := s.CreateRefresh(refresh); err != nil { t.Fatalf("create refresh token: %v", err) } - gotRefresh, err := s.GetRefresh(id) - if err != nil { - t.Fatalf("get refresh: %v", err) - } - if !reflect.DeepEqual(gotRefresh, refresh) { - t.Errorf("refresh returned did not match expected") + + getAndCompare := func(id string, want storage.RefreshToken) { + gr, err := s.GetRefresh(id) + if err != nil { + t.Errorf("get refresh: %v", err) + return + } + if diff := pretty.Compare(want, gr); diff != "" { + t.Errorf("refresh token retrieved from storage did not match: %s", diff) + } } + getAndCompare(id, refresh) + if err := s.DeleteRefresh(id); err != nil { t.Fatalf("failed to delete refresh request: %v", err) }