From 4b5f1d52894625db361d4208ea4303c27787a7ae Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Sun, 2 Oct 2022 22:40:22 +0200 Subject: [PATCH] fix: refresh token only once for all concurrent requests Signed-off-by: m.nabokikh --- server/refreshhandlers.go | 205 +++++++++++++++++------------ storage/kubernetes/lock.go | 124 +++++++++++++++++ storage/kubernetes/storage.go | 9 ++ storage/kubernetes/storage_test.go | 104 ++++++++++++++- 4 files changed, 357 insertions(+), 85 deletions(-) create mode 100644 storage/kubernetes/lock.go diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 0c9b1393..ecfda137 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -28,6 +28,10 @@ type refreshError struct { desc string } +func (r *refreshError) Error() string { + return fmt.Sprintf("refresh token error: status %d, %q %s", r.code, r.msg, r.desc) +} + func newInternalServerError() *refreshError { return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} } @@ -60,10 +64,23 @@ func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.Refr return token, nil } +type refreshContext struct { + storageToken *storage.RefreshToken + requestToken *internal.RefreshToken + + connector Connector + connectorData []byte + + scopes []string +} + // getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info -func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) { +func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*refreshContext, *refreshError) { + refreshCtx := refreshContext{requestToken: token} + invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.") + // Get RefreshToken refresh, err := s.storage.GetRefresh(token.RefreshId) if err != nil { if err != storage.ErrNotFound { @@ -103,7 +120,31 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref return nil, expiredErr } - return &refresh, nil + refreshCtx.storageToken = &refresh + + // Get Connector + refreshCtx.connector, err = s.getConnector(refresh.ConnectorID) + if err != nil { + s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) + return nil, newInternalServerError() + } + + // Get Connector Data + session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) + switch { + case err != nil: + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + return nil, newInternalServerError() + } + case len(refresh.ConnectorData) > 0: + // Use the old connector data if it exists, should be deleted once used + refreshCtx.connectorData = refresh.ConnectorData + default: + refreshCtx.connectorData = session.ConnectorData + } + + return &refreshCtx, nil } func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) { @@ -138,59 +179,23 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken return requestedScopes, nil } -func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) { - var connectorData []byte - - session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) - switch { - case err != nil: - if err != storage.ErrNotFound { - s.logger.Errorf("failed to get offline session: %v", err) - return connector.Identity{}, newInternalServerError() - } - case len(refresh.ConnectorData) > 0: - // Use the old connector data if it exists, should be deleted once used - connectorData = refresh.ConnectorData - default: - connectorData = session.ConnectorData - } - - conn, err := s.getConnector(refresh.ConnectorID) - if err != nil { - s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) - return connector.Identity{}, newInternalServerError() - } - - ident := connector.Identity{ - UserID: refresh.Claims.UserID, - Username: refresh.Claims.Username, - PreferredUsername: refresh.Claims.PreferredUsername, - Email: refresh.Claims.Email, - EmailVerified: refresh.Claims.EmailVerified, - Groups: refresh.Claims.Groups, - ConnectorData: connectorData, - } - - // user's token was previously updated by a connector and is allowed to reuse - // it is excessive to refresh identity in upstream - if s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) && token.Token == refresh.ObsoleteToken { - return ident, nil - } - +func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext, ident connector.Identity) (connector.Identity, *refreshError) { // Can the connector refresh the identity? If so, attempt to refresh the data // in the connector. // // TODO(ericchiang): We may want a strict mode where connectors that don't implement // this interface can't perform refreshing. - if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { - newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) + if refreshConn, ok := rCtx.connector.Connector.(connector.RefreshConnector); ok { + s.logger.Debugf("connector data before refresh: %s", ident.ConnectorData) + + newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident) if err != nil { s.logger.Errorf("failed to refresh identity: %v", err) - return connector.Identity{}, newInternalServerError() + return ident, newInternalServerError() } - ident = newIdent - } + return newIdent, nil + } return ident, nil } @@ -200,8 +205,14 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne if old.Refresh[refresh.ClientID].ID != refresh.ID { return old, errors.New("refresh token invalid") } + old.Refresh[refresh.ClientID].LastUsed = lastUsed - old.ConnectorData = ident.ConnectorData + if len(ident.ConnectorData) > 0 { + old.ConnectorData = ident.ConnectorData + } + + s.logger.Debugf("saved connector data: %s %s", ident.UserID, ident.ConnectorData) + return old, nil } @@ -217,33 +228,74 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne } // updateRefreshToken updates refresh token and offline session in the storage -func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) { - newToken := token - if s.refreshTokenPolicy.RotationEnabled() { - newToken = &internal.RefreshToken{ - RefreshId: refresh.ID, - Token: storage.NewID(), - } +func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (*internal.RefreshToken, connector.Identity, *refreshError) { + var rerr *refreshError + + newToken := &internal.RefreshToken{ + Token: rCtx.requestToken.Token, + RefreshId: rCtx.requestToken.RefreshId, } lastUsed := s.now() + ident := connector.Identity{ + UserID: rCtx.storageToken.Claims.UserID, + Username: rCtx.storageToken.Claims.Username, + PreferredUsername: rCtx.storageToken.Claims.PreferredUsername, + Email: rCtx.storageToken.Claims.Email, + EmailVerified: rCtx.storageToken.Claims.EmailVerified, + Groups: rCtx.storageToken.Claims.Groups, + ConnectorData: rCtx.connectorData, + } + refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { - if s.refreshTokenPolicy.RotationEnabled() { - if old.Token != token.Token { - if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token { - newToken.Token = old.Token - // Do not update last used time for offline session if token is allowed to be reused - lastUsed = old.LastUsed - return old, nil - } + rotationEnabled := s.refreshTokenPolicy.RotationEnabled() + reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) + + switch { + case !rotationEnabled && reusingAllowed: + // If rotation is disabled and the offline session was updated not so long ago - skip further actions. + return old, nil + + case rotationEnabled && reusingAllowed: + if old.Token != rCtx.requestToken.Token && old.ObsoleteToken != rCtx.requestToken.Token { return old, errors.New("refresh token claimed twice") } + // Return previously generated token for all requests with an obsolete tokens + if old.ObsoleteToken == rCtx.requestToken.Token { + newToken.Token = old.Token + } + + // Do not update last used time for offline session if token is allowed to be reused + lastUsed = old.LastUsed + ident.ConnectorData = nil + return old, nil + + case rotationEnabled && !reusingAllowed: + if old.Token != rCtx.requestToken.Token { + return old, errors.New("refresh token claimed twice") + } + + // Issue new refresh token old.ObsoleteToken = old.Token + newToken.Token = storage.NewID() } old.Token = newToken.Token + old.LastUsed = lastUsed + + // ConnectorData has been moved to OfflineSession + old.ConnectorData = []byte{} + + // Call only once if there is a request which is not in the reuse interval. + // This is required to avoid multiple calls to the external IdP for concurrent requests. + // Dex will call the connector's Refresh method only once if request is not in reuse interval. + ident, rerr = s.refreshWithConnector(ctx, rCtx, ident) + if rerr != nil { + return old, rerr + } + // Update the claims of the refresh token. // // UserID intentionally ignored for now. @@ -252,26 +304,23 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora old.Claims.Email = ident.Email old.Claims.EmailVerified = ident.EmailVerified old.Claims.Groups = ident.Groups - old.LastUsed = lastUsed - // ConnectorData has been moved to OfflineSession - old.ConnectorData = []byte{} return old, nil } // Update refresh token in the storage. - err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) + err := s.storage.UpdateRefreshToken(rCtx.storageToken.ID, refreshTokenUpdater) if err != nil { s.logger.Errorf("failed to update refresh token: %v", err) - return nil, newInternalServerError() + return nil, ident, newInternalServerError() } - rerr := s.updateOfflineSession(refresh, ident, lastUsed) + rerr = s.updateOfflineSession(rCtx.storageToken, ident, lastUsed) if rerr != nil { - return nil, rerr + return nil, ident, rerr } - return newToken, nil + return newToken, ident, nil } // handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6 @@ -283,19 +332,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - refresh, rerr := s.getRefreshTokenFromStorage(client.ID, token) + rCtx, rerr := s.getRefreshTokenFromStorage(client.ID, token) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return } - scopes, rerr := s.getRefreshScopes(r, refresh) + rCtx.scopes, rerr = s.getRefreshScopes(r, rCtx.storageToken) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return } - ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes) + newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return @@ -310,26 +359,20 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Groups: ident.Groups, } - accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) + accessToken, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID) if err != nil { s.logger.Errorf("failed to create new access token: %v", err) s.refreshTokenErrHelper(w, newInternalServerError()) return } - idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID) + idToken, expiry, err := s.newIDToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.refreshTokenErrHelper(w, newInternalServerError()) return } - newToken, rerr := s.updateRefreshToken(token, refresh, ident) - if rerr != nil { - s.refreshTokenErrHelper(w, rerr) - return - } - rawNewToken, err := internal.Marshal(newToken) if err != nil { s.logger.Errorf("failed to marshal refresh token: %v", err) diff --git a/storage/kubernetes/lock.go b/storage/kubernetes/lock.go new file mode 100644 index 00000000..798c1f16 --- /dev/null +++ b/storage/kubernetes/lock.go @@ -0,0 +1,124 @@ +package kubernetes + +import ( + "fmt" + "time" +) + +const ( + lockAnnotation = "dexidp.com/resource-lock" + lockTimeFormat = time.RFC3339 +) + +var ( + lockTimeout = 10 * time.Second + lockCheckPeriod = 100 * time.Millisecond +) + +// refreshTokenLock is an implementation of annotation-based optimistic locking. +// +// Refresh token contains data to refresh identity in external authentication system. +// There is a requirement that refresh should be called only once because of several reasons: +// * Some of OIDC providers could use the refresh token rotation feature which requires calling refresh only once. +// * Providers can limit the rate of requests to the token endpoint, which will lead to the error +// in case of many concurrent requests. +type refreshTokenLock struct { + cli *client + waitingState bool +} + +func newRefreshTokenLock(cli *client) *refreshTokenLock { + return &refreshTokenLock{cli: cli} +} + +func (l *refreshTokenLock) Lock(id string) error { + for i := 0; i <= 60; i++ { + ok, err := l.setLockAnnotation(id) + if err != nil { + return err + } + if !ok { + return nil + } + time.Sleep(lockCheckPeriod) + } + return fmt.Errorf("timeout waiting for refresh token %s lock", id) +} + +func (l *refreshTokenLock) Unlock(id string) { + if l.waitingState { + // Do not need to unlock for waiting goroutines, because the have not set it. + return + } + + r, err := l.cli.getRefreshToken(id) + if err != nil { + l.cli.logger.Debugf("failed to get resource to release lock for refresh token %s: %v", id, err) + return + } + + r.Annotations = nil + err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) + if err != nil { + l.cli.logger.Debugf("failed to release lock for refresh token %s: %v", id, err) + } +} + +func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) { + r, err := l.cli.getRefreshToken(id) + if err != nil { + return false, err + } + + currentTime := time.Now() + lockData := map[string]string{ + lockAnnotation: currentTime.Add(lockTimeout).Format(lockTimeFormat), + } + + val, ok := r.Annotations[lockAnnotation] + if !ok { + if l.waitingState { + return false, nil + } + + r.Annotations = lockData + err := l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) + if err == nil { + return false, nil + } + + if isKubernetesAPIConflictError(err) { + l.waitingState = true + return true, nil + } + return false, err + } + + until, err := time.Parse(lockTimeFormat, val) + if err != nil { + return false, fmt.Errorf("lock annotation value is malformed: %v", err) + } + + if !currentTime.After(until) { + // waiting for the lock to be released + l.waitingState = true + return true, nil + } + + // Lock time is out, lets break the lock and take the advantage + r.Annotations = lockData + + err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) + if err == nil { + // break lock annotation + return false, nil + } + + l.cli.logger.Debugf("break lock annotation error: %v", err) + if isKubernetesAPIConflictError(err) { + l.waitingState = true + // after breaking error waiting for the lock to be released + return true, nil + } + return false, err +} diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 033d3e23..0979f14a 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -451,11 +451,19 @@ func (cli *client) DeleteConnector(id string) error { } func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { + lock := newRefreshTokenLock(cli) + + if err := lock.Lock(id); err != nil { + return err + } + defer lock.Unlock(id) + return retryOnConflict(context.TODO(), func() error { r, err := cli.getRefreshToken(id) if err != nil { return err } + updated, err := updater(toStorageRefreshToken(r)) if err != nil { return err @@ -464,6 +472,7 @@ func (cli *client) UpdateRefreshToken(id string, updater func(old storage.Refres newToken := cli.fromStorageRefreshToken(updated) newToken.ObjectMeta = r.ObjectMeta + return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken) }) } diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index 828ff472..d0048506 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -11,6 +11,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" @@ -35,19 +36,22 @@ type StorageTestSuite struct { client *client } -func (s *StorageTestSuite) expandDir(dir string) string { +func expandDir(dir string) (string, error) { dir = strings.Trim(dir, `"`) if strings.HasPrefix(dir, "~/") { homedir, err := os.UserHomeDir() - s.Require().NoError(err) + if err != nil { + return "", err + } dir = filepath.Join(homedir, strings.TrimPrefix(dir, "~/")) } - return dir + return dir, nil } func (s *StorageTestSuite) SetupTest() { - kubeconfigPath := s.expandDir(os.Getenv(kubeconfigPathVariableName)) + kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName)) + s.Require().NoError(err) config := Config{ KubeConfigFile: kubeconfigPath, @@ -292,3 +296,95 @@ func TestRetryOnConflict(t *testing.T) { }) } } + +func TestRefreshTokenLock(t *testing.T) { + if os.Getenv(kubeconfigPathVariableName) == "" { + t.Skipf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName) + } + + kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName)) + require.NoError(t, err) + + config := Config{ + KubeConfigFile: kubeconfigPath, + } + + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + + kubeClient, err := config.open(logger, true) + require.NoError(t, err) + + lockCheckPeriod = time.Nanosecond + + // Creating a storage with an existing refresh token and offline session for the user. + id := storage.NewID() + r := storage.RefreshToken{ + ID: id, + Token: "bar", + Nonce: "foo", + ClientID: "client_id", + ConnectorID: "client_secret", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + ConnectorData: []byte(`{"some":"data"}`), + } + + err = kubeClient.CreateRefresh(r) + require.NoError(t, err) + + t.Run("Timeout lock error", func(t *testing.T) { + err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { + r.Token = "update-result-1" + err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { + r.Token = "timeout-err" + return r, nil + }) + require.Equal(t, fmt.Errorf("timeout waiting for refresh token %s lock", r.ID), err) + return r, nil + }) + require.NoError(t, err) + + token, err := kubeClient.GetRefresh(r.ID) + require.NoError(t, err) + require.Equal(t, "update-result-1", token.Token) + }) + + t.Run("Break the lock", func(t *testing.T) { + var lockBroken bool + lockTimeout = -time.Hour + + err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { + r.Token = "update-result-2" + if lockBroken { + return r, nil + } + + err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { + r.Token = "should-break-the-lock-and-finish-updating" + return r, nil + }) + require.NoError(t, err) + + lockBroken = true + return r, nil + }) + require.NoError(t, err) + + token, err := kubeClient.GetRefresh(r.ID) + require.NoError(t, err) + // Because concurrent update breaks the lock, the final result will be the value of the first update + require.Equal(t, "update-result-2", token.Token) + }) +}