From d928ac0677f214df047b23bd67f54de74acaa47c Mon Sep 17 00:00:00 2001 From: rithu john Date: Tue, 31 Jan 2017 16:11:59 -0800 Subject: [PATCH] storage: Add OfflineSession object to backend storage. --- server/handlers.go | 89 +++++++++++++++++++++++- server/server_test.go | 105 ++++++++++++++++++++++++++++ storage/conformance/conformance.go | 55 +++++++++++++++ storage/kubernetes/client.go | 12 ++++ storage/kubernetes/storage.go | 74 ++++++++++++++++---- storage/kubernetes/types.go | 43 ++++++++++++ storage/memory/memory.go | 108 ++++++++++++++++++++++++----- storage/sql/crud.go | 87 +++++++++++++++++++++++ storage/sql/migrate.go | 11 +++ storage/storage.go | 28 ++++++++ 10 files changed, 580 insertions(+), 32 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index b8430ebc..a994b989 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -682,6 +682,75 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } + + // deleteToken determines if we need to delete the newly created refresh token + // due to a failure in updating/creating the OfflineSession object for the + // corresponding user. + var deleteToken bool + defer func() { + if deleteToken { + // Delete newly created refresh token from storage. + if err := s.storage.DeleteRefresh(refresh.ID); err != nil { + s.logger.Errorf("failed to delete refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + } + }() + + tokenRef := storage.RefreshTokenRef{ + ID: refresh.ID, + ClientID: refresh.ClientID, + CreatedAt: refresh.CreatedAt, + LastUsed: refresh.LastUsed, + } + + // Try to retrieve an existing OfflineSession object for the corresponding user. + if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + offlineSessions := storage.OfflineSessions{ + UserID: refresh.Claims.UserID, + ConnID: refresh.ConnectorID, + Refresh: make(map[string]*storage.RefreshTokenRef), + } + offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef + + // Create a new OfflineSession object for the user and add a reference object for + // the newly recieved refreshtoken. + if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil { + s.logger.Errorf("failed to create offline session: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + } else { + if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { + // Delete old refresh token from storage. + if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil { + s.logger.Errorf("failed to delete refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + } + + // Update existing OfflineSession obj with new RefreshTokenRef. + if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + old.Refresh[tokenRef.ClientID] = &tokenRef + return old, nil + }); err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + deleteToken = true + return + } + + } } s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) } @@ -815,6 +884,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } + lastUsed := s.now() updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { if old.Token != refresh.Token { return old, errors.New("refresh token claimed twice") @@ -828,14 +898,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie old.Claims.EmailVerified = ident.EmailVerified old.Claims.Groups = ident.Groups old.ConnectorData = ident.ConnectorData - old.LastUsed = s.now() + old.LastUsed = lastUsed return old, nil } + + // Update LastUsed time stamp in refresh token reference object + // in offline session for the user. + if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if old.Refresh[refresh.ClientID].ID != refresh.ID { + return old, errors.New("refresh token invalid") + } + old.Refresh[refresh.ClientID].LastUsed = lastUsed + return old, nil + }); err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + // Update refresh token in the storage. if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil { s.logger.Errorf("failed to update refresh token: %v", err) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) return } + s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry) } diff --git a/server/server_test.go b/server/server_test.go index 7a4eb0c1..688c606e 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -971,3 +971,108 @@ func TestKeyCacher(t *testing.T) { } } } + +type oauth2Client struct { + config *oauth2.Config + token *oauth2.Token + server *httptest.Server +} + +// TestRefreshTokenFlow tests the refresh token code flow for oauth2. The test verifies +// that only valid refresh tokens can be used to refresh an expired token. +func TestRefreshTokenFlow(t *testing.T) { + state := "state" + now := func() time.Time { return time.Now() } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.Now = now + }) + defer httpServer.Close() + + p, err := oidc.NewProvider(ctx, httpServer.URL) + if err != nil { + t.Fatalf("failed to get provider: %v", err) + } + + var oauth2Client oauth2Client + + oauth2Client.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/callback" { + // User is visiting app first time. Redirect to dex. + http.Redirect(w, r, oauth2Client.config.AuthCodeURL(state), http.StatusSeeOther) + return + } + + // User is at '/callback' so they were just redirected _from_ dex. + q := r.URL.Query() + + if errType := q.Get("error"); errType != "" { + if desc := q.Get("error_description"); desc != "" { + t.Errorf("got error from server %s: %s", errType, desc) + } else { + t.Errorf("got error from server %s", errType) + } + w.WriteHeader(http.StatusInternalServerError) + return + } + + // Grab code, exchange for token. + if code := q.Get("code"); code != "" { + token, err := oauth2Client.config.Exchange(ctx, code) + if err != nil { + t.Errorf("failed to exchange code for token: %v", err) + return + } + oauth2Client.token = token + } + + // Ensure state matches. + if gotState := q.Get("state"); gotState != state { + t.Errorf("state did not match, want=%q got=%q", state, gotState) + } + w.WriteHeader(http.StatusOK) + return + })) + defer oauth2Client.server.Close() + + // Register the client above with dex. + redirectURL := oauth2Client.server.URL + "/callback" + client := storage.Client{ + ID: "testclient", + Secret: "testclientsecret", + RedirectURIs: []string{redirectURL}, + } + if err := s.storage.CreateClient(client); err != nil { + t.Fatalf("failed to create client: %v", err) + } + + oauth2Client.config = &oauth2.Config{ + ClientID: client.ID, + ClientSecret: client.Secret, + Endpoint: p.Endpoint(), + Scopes: []string{oidc.ScopeOpenID, "email", "offline_access"}, + RedirectURL: redirectURL, + } + + if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil { + t.Fatalf("get failed: %v", err) + } + + tok := &oauth2.Token{ + RefreshToken: oauth2Client.token.RefreshToken, + Expiry: time.Now().Add(-time.Hour), + } + + // Login in again to recieve a new token. + if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil { + t.Fatalf("get failed: %v", err) + } + + // try to refresh expired token with old refresh token. + newToken, err := oauth2Client.config.TokenSource(ctx, tok).Token() + if newToken != nil { + t.Errorf("Token refreshed with invalid refresh token.") + } +} diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 0a6fe1c9..01c62865 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -47,6 +47,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) { {"RefreshTokenCRUD", testRefreshTokenCRUD}, {"PasswordCRUD", testPasswordCRUD}, {"KeysCRUD", testKeysCRUD}, + {"OfflineSessionCRUD", testOfflineSessionCRUD}, {"GarbageCollection", testGC}, {"TimezoneSupport", testTimezones}, }) @@ -340,6 +341,60 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { } +func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { + session := storage.OfflineSessions{ + UserID: "User", + ConnID: "Conn", + Refresh: make(map[string]*storage.RefreshTokenRef), + } + + // Creating an OfflineSession with an empty Refresh list to ensure that + // an empty map is translated as expected by the storage. + if err := s.CreateOfflineSessions(session); err != nil { + t.Fatalf("create offline session: %v", err) + } + + getAndCompare := func(userID string, connID string, want storage.OfflineSessions) { + gr, err := s.GetOfflineSessions(userID, connID) + if err != nil { + t.Errorf("get offline session: %v", err) + return + } + if diff := pretty.Compare(want, gr); diff != "" { + t.Errorf("offline session retrieved from storage did not match: %s", diff) + } + } + + getAndCompare("User", "Conn", session) + + id := storage.NewID() + tokenRef := storage.RefreshTokenRef{ + ID: id, + ClientID: "client_id", + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), + } + session.Refresh[tokenRef.ClientID] = &tokenRef + + if err := s.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + old.Refresh[tokenRef.ClientID] = &tokenRef + return old, nil + }); err != nil { + t.Fatalf("failed to update offline session: %v", err) + } + + getAndCompare("User", "Conn", session) + + if err := s.DeleteOfflineSessions(session.UserID, session.ConnID); err != nil { + t.Fatalf("failed to delete offline session: %v", err) + } + + if _, err := s.GetOfflineSessions(session.UserID, session.ConnID); err != storage.ErrNotFound { + t.Errorf("after deleting offline session expected storage.ErrNotFound, got %v", err) + } + +} + func testKeysCRUD(t *testing.T, s storage.Storage) { updateAndCompare := func(k storage.Keys) { err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) { diff --git a/storage/kubernetes/client.go b/storage/kubernetes/client.go index b21fb63d..dc238a9c 100644 --- a/storage/kubernetes/client.go +++ b/storage/kubernetes/client.go @@ -58,6 +58,12 @@ func (c *client) idToName(s string) string { return idToName(s, c.hash) } +// offlineTokenName maps two arbitrary IDs, to a single Kubernetes object name. +// This is used when more than one field is used to uniquely identify the object. +func (c *client) offlineTokenName(userID string, connID string) string { + return offlineTokenName(userID, connID, c.hash) +} + // Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'. var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567") @@ -65,6 +71,12 @@ func idToName(s string, h func() hash.Hash) string { return strings.TrimRight(encoding.EncodeToString(h().Sum([]byte(s))), "=") } +func offlineTokenName(userID string, connID string, h func() hash.Hash) string { + h().Write([]byte(userID)) + h().Write([]byte(connID)) + return strings.TrimRight(encoding.EncodeToString(h().Sum(nil)), "=") +} + func (c *client) urlFor(apiVersion, namespace, resource, name string) string { basePath := "apis/" if apiVersion == "v1" { diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 102a7494..421b74c9 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -15,21 +15,23 @@ import ( ) const ( - kindAuthCode = "AuthCode" - kindAuthRequest = "AuthRequest" - kindClient = "OAuth2Client" - kindRefreshToken = "RefreshToken" - kindKeys = "SigningKey" - kindPassword = "Password" + kindAuthCode = "AuthCode" + kindAuthRequest = "AuthRequest" + kindClient = "OAuth2Client" + kindRefreshToken = "RefreshToken" + kindKeys = "SigningKey" + kindPassword = "Password" + kindOfflineSessions = "OfflineSessions" ) const ( - resourceAuthCode = "authcodes" - resourceAuthRequest = "authrequests" - resourceClient = "oauth2clients" - resourceRefreshToken = "refreshtokens" - resourceKeys = "signingkeies" // Kubernetes attempts to pluralize. - resourcePassword = "passwords" + resourceAuthCode = "authcodes" + resourceAuthRequest = "authrequests" + resourceClient = "oauth2clients" + resourceRefreshToken = "refreshtokens" + resourceKeys = "signingkeies" // Kubernetes attempts to pluralize. + resourcePassword = "passwords" + resourceOfflineSessions = "offlinesessions" ) // Config values for the Kubernetes storage type. @@ -156,6 +158,10 @@ func (cli *client) CreateRefresh(r storage.RefreshToken) error { return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r)) } +func (cli *client) CreateOfflineSessions(o storage.OfflineSessions) error { + return cli.post(resourceOfflineSessions, cli.fromStorageOfflineSessions(o)) +} + func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) { var req AuthRequest if err := cli.get(resourceAuthRequest, id, &req); err != nil { @@ -235,6 +241,25 @@ func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) { return } +func (cli *client) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) { + o, err := cli.getOfflineSessions(userID, connID) + if err != nil { + return storage.OfflineSessions{}, err + } + return toStorageOfflineSessions(o), nil +} + +func (cli *client) getOfflineSessions(userID string, connID string) (o OfflineSessions, err error) { + name := cli.offlineTokenName(userID, connID) + if err = cli.get(resourceOfflineSessions, name, &o); err != nil { + return OfflineSessions{}, err + } + if userID != o.UserID || connID != o.ConnID { + return OfflineSessions{}, fmt.Errorf("get offline session: wrong session retrieved") + } + return o, nil +} + func (cli *client) ListClients() ([]storage.Client, error) { return nil, errors.New("not implemented") } @@ -292,6 +317,15 @@ func (cli *client) DeletePassword(email string) error { return cli.delete(resourcePassword, p.ObjectMeta.Name) } +func (cli *client) DeleteOfflineSessions(userID string, connID string) error { + // Check for hash collition. + o, err := cli.getOfflineSessions(userID, connID) + if err != nil { + return err + } + return cli.delete(resourceOfflineSessions, o.ObjectMeta.Name) +} + func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { r, err := cli.getRefreshToken(id) if err != nil { @@ -342,6 +376,22 @@ func (cli *client) UpdatePassword(email string, updater func(old storage.Passwor return cli.put(resourcePassword, p.ObjectMeta.Name, newPassword) } +func (cli *client) UpdateOfflineSessions(userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error { + o, err := cli.getOfflineSessions(userID, connID) + if err != nil { + return err + } + + updated, err := updater(toStorageOfflineSessions(o)) + if err != nil { + return err + } + + newOfflineSessions := cli.fromStorageOfflineSessions(updated) + newOfflineSessions.ObjectMeta = o.ObjectMeta + return cli.put(resourceOfflineSessions, o.ObjectMeta.Name, newOfflineSessions) +} + func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { firstUpdate := false var keys Keys diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 660f86d8..81dc01c3 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -66,6 +66,14 @@ var thirdPartyResources = []k8sapi.ThirdPartyResource{ Description: "Passwords managed by the OIDC server.", Versions: []k8sapi.APIVersion{{Name: "v1"}}, }, + { + ObjectMeta: k8sapi.ObjectMeta{ + Name: "offline-sessions.oidc.coreos.com", + }, + TypeMeta: tprMeta, + Description: "User sessions with an active refresh token.", + Versions: []k8sapi.APIVersion{{Name: "v1"}}, + }, } // There will only ever be a single keys resource. Maintain this by setting a @@ -465,3 +473,38 @@ func toStorageKeys(keys Keys) storage.Keys { NextRotation: keys.NextRotation, } } + +// OfflineSessions is a mirrored struct from storage with JSON struct tags and Kubernetes +// type metadata. +type OfflineSessions struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ObjectMeta `json:"metadata,omitempty"` + + UserID string `json:"userID,omitempty"` + ConnID string `json:"connID,omitempty"` + Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` +} + +func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { + return OfflineSessions{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindOfflineSessions, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: cli.offlineTokenName(o.UserID, o.ConnID), + Namespace: cli.namespace, + }, + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + } +} + +func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { + return storage.OfflineSessions{ + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + } +} diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 8bfbdce2..ac0b1d4e 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -13,12 +13,13 @@ import ( // New returns an in memory storage. func New(logger logrus.FieldLogger) storage.Storage { return &memStorage{ - clients: make(map[string]storage.Client), - authCodes: make(map[string]storage.AuthCode), - refreshTokens: make(map[string]storage.RefreshToken), - authReqs: make(map[string]storage.AuthRequest), - passwords: make(map[string]storage.Password), - logger: logger, + clients: make(map[string]storage.Client), + authCodes: make(map[string]storage.AuthCode), + refreshTokens: make(map[string]storage.RefreshToken), + authReqs: make(map[string]storage.AuthRequest), + passwords: make(map[string]storage.Password), + offlineSessions: make(map[offlineSessionID]storage.OfflineSessions), + logger: logger, } } @@ -37,17 +38,23 @@ func (c *Config) Open(logger logrus.FieldLogger) (storage.Storage, error) { type memStorage struct { mu sync.Mutex - clients map[string]storage.Client - authCodes map[string]storage.AuthCode - refreshTokens map[string]storage.RefreshToken - authReqs map[string]storage.AuthRequest - passwords map[string]storage.Password + clients map[string]storage.Client + authCodes map[string]storage.AuthCode + refreshTokens map[string]storage.RefreshToken + authReqs map[string]storage.AuthRequest + passwords map[string]storage.Password + offlineSessions map[offlineSessionID]storage.OfflineSessions keys storage.Keys logger logrus.FieldLogger } +type offlineSessionID struct { + userID string + connID string +} + func (s *memStorage) tx(f func()) { s.mu.Lock() defer s.mu.Unlock() @@ -130,6 +137,32 @@ func (s *memStorage) CreatePassword(p storage.Password) (err error) { return } +func (s *memStorage) CreateOfflineSessions(o storage.OfflineSessions) (err error) { + id := offlineSessionID{ + userID: o.UserID, + connID: o.ConnID, + } + s.tx(func() { + if _, ok := s.offlineSessions[id]; ok { + err = storage.ErrAlreadyExists + } else { + s.offlineSessions[id] = o + } + }) + return +} + +func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) { + s.tx(func() { + var ok bool + if c, ok = s.authCodes[id]; !ok { + err = storage.ErrNotFound + return + } + }) + return +} + func (s *memStorage) GetPassword(email string) (p storage.Password, err error) { email = strings.ToLower(email) s.tx(func() { @@ -156,10 +189,10 @@ func (s *memStorage) GetKeys() (keys storage.Keys, err error) { return } -func (s *memStorage) GetRefresh(token string) (tok storage.RefreshToken, err error) { +func (s *memStorage) GetRefresh(id string) (tok storage.RefreshToken, err error) { s.tx(func() { var ok bool - if tok, ok = s.refreshTokens[token]; !ok { + if tok, ok = s.refreshTokens[id]; !ok { err = storage.ErrNotFound return } @@ -178,6 +211,21 @@ func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err err return } +func (s *memStorage) GetOfflineSessions(userID string, connID string) (o storage.OfflineSessions, err error) { + id := offlineSessionID{ + userID: userID, + connID: connID, + } + s.tx(func() { + var ok bool + if o, ok = s.offlineSessions[id]; !ok { + err = storage.ErrNotFound + return + } + }) + return +} + func (s *memStorage) ListClients() (clients []storage.Client, err error) { s.tx(func() { for _, client := range s.clients { @@ -228,13 +276,13 @@ func (s *memStorage) DeleteClient(id string) (err error) { return } -func (s *memStorage) DeleteRefresh(token string) (err error) { +func (s *memStorage) DeleteRefresh(id string) (err error) { s.tx(func() { - if _, ok := s.refreshTokens[token]; !ok { + if _, ok := s.refreshTokens[id]; !ok { err = storage.ErrNotFound return } - delete(s.refreshTokens, token) + delete(s.refreshTokens, id) }) return } @@ -261,13 +309,17 @@ func (s *memStorage) DeleteAuthRequest(id string) (err error) { return } -func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) { +func (s *memStorage) DeleteOfflineSessions(userID string, connID string) (err error) { + id := offlineSessionID{ + userID: userID, + connID: connID, + } s.tx(func() { - var ok bool - if c, ok = s.authCodes[id]; !ok { + if _, ok := s.offlineSessions[id]; !ok { err = storage.ErrNotFound return } + delete(s.offlineSessions, id) }) return } @@ -338,3 +390,21 @@ func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.Refres }) return } + +func (s *memStorage) UpdateOfflineSessions(userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) { + id := offlineSessionID{ + userID: userID, + connID: connID, + } + s.tx(func() { + r, ok := s.offlineSessions[id] + if !ok { + err = storage.ErrNotFound + return + } + if r, err = updater(r); err == nil { + s.offlineSessions[id] = r + } + }) + return +} diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 494f1c20..ef1a8fbd 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -624,6 +624,75 @@ func scanPassword(s scanner) (p storage.Password, err error) { return p, nil } +func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { + _, err := c.Exec(` + insert into offline_session ( + user_id, conn_id, refresh + ) + values ( + $1, $2, $3 + ); + `, + s.UserID, s.ConnID, encoder(s.Refresh), + ) + if err != nil { + return fmt.Errorf("insert offline session: %v", err) + } + return nil +} + +func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { + return c.ExecTx(func(tx *trans) error { + s, err := getOfflineSessions(tx, userID, connID) + if err != nil { + return err + } + + newSession, err := updater(s) + if err != nil { + return err + } + _, err = tx.Exec(` + update offline_session + set + refresh = $1 + where user_id = $2 AND conn_id = $3; + `, + encoder(newSession.Refresh), s.UserID, s.ConnID, + ) + if err != nil { + return fmt.Errorf("update offline session: %v", err) + } + return nil + }) +} + +func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) { + return getOfflineSessions(c, userID, connID) +} + +func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { + return scanOfflineSessions(q.QueryRow(` + select + user_id, conn_id, refresh + from offline_session + where user_id = $1 AND conn_id = $2; + `, userID, connID)) +} + +func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { + err = s.Scan( + &o.UserID, &o.ConnID, decoder(&o.Refresh), + ) + if err != nil { + if err == sql.ErrNoRows { + return o, storage.ErrNotFound + } + return o, fmt.Errorf("select offline session: %v", err) + } + return o, nil +} + func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) } func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) } func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) } @@ -632,6 +701,24 @@ func (c *conn) DeletePassword(email string) error { return c.delete("password", "email", strings.ToLower(email)) } +func (c *conn) DeleteOfflineSessions(userID string, connID string) error { + result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID) + if err != nil { + return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID) + } + + // For now mandate that the driver implements RowsAffected. If we ever need to support + // a driver that doesn't implement this, we can run this in a transaction with a get beforehand. + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %v", err) + } + if n < 1 { + return storage.ErrNotFound + } + return nil +} + // Do NOT call directly. Does not escape table. func (c *conn) delete(table, field, id string) error { result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id) diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index b2b66d39..07ba4a22 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -153,6 +153,7 @@ var migrations = []migration{ signing_key_pub bytea not null, -- JSON object next_rotation timestamptz not null ); + `, }, { @@ -165,4 +166,14 @@ var migrations = []migration{ add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC'; `, }, + { + stmt: ` + create table offline_session ( + user_id text not null, + conn_id text not null, + refresh bytea not null, + PRIMARY KEY (user_id, conn_id) + ); + `, + }, } diff --git a/storage/storage.go b/storage/storage.go index 3d27e6f7..869b7066 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -52,6 +52,7 @@ type Storage interface { CreateAuthCode(c AuthCode) error CreateRefresh(r RefreshToken) error CreatePassword(p Password) error + CreateOfflineSessions(s OfflineSessions) error // TODO(ericchiang): return (T, bool, error) so we can indicate not found // requests that way instead of using ErrNotFound. @@ -61,6 +62,7 @@ type Storage interface { GetKeys() (Keys, error) GetRefresh(id string) (RefreshToken, error) GetPassword(email string) (Password, error) + GetOfflineSessions(userID string, connID string) (OfflineSessions, error) ListClients() ([]Client, error) ListRefreshTokens() ([]RefreshToken, error) @@ -72,6 +74,7 @@ type Storage interface { DeleteClient(id string) error DeleteRefresh(id string) error DeletePassword(email string) error + DeleteOfflineSessions(userID string, connID string) error // Update methods take a function for updating an object then performs that update within // a transaction. "updater" functions may be called multiple times by a single update call. @@ -92,6 +95,7 @@ type Storage interface { UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error UpdatePassword(email string, updater func(p Password) (Password, error)) error + UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error // GarbageCollect deletes all expired AuthCodes and AuthRequests. GarbageCollect(now time.Time) (GCResult, error) @@ -241,6 +245,30 @@ type RefreshToken struct { Nonce string } +// RefreshTokenRef is a reference object that contains metadata about refresh tokens. +type RefreshTokenRef struct { + ID string + + // Client the refresh token is valid for. + ClientID string + + CreatedAt time.Time + LastUsed time.Time +} + +// OfflineSessions objects are sessions pertaining to users with refresh tokens. +type OfflineSessions struct { + // UserID of an end user who has logged in to the server. + UserID string + + // The ID of the connector used to login the user. + ConnID string + + // Refresh is a hash table of refresh token reference objects + // indexed by the ClientID of the refresh token. + Refresh map[string]*RefreshTokenRef +} + // Password is an email to password mapping managed by the storage. type Password struct { // Email and identifying name of the password. Emails are assumed to be valid and