diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 88dc98e7..f218879d 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -304,6 +304,9 @@ type Expiry struct { // DeviceRequests defines the duration of time for which the DeviceRequests will be valid. DeviceRequests string `json:"deviceRequests"` + + // RefreshTokens defines refresh tokens expiry policy + RefreshTokens RefreshToken `json:"refreshTokens"` } // Logger holds configuration required to customize logging for dex. @@ -314,3 +317,10 @@ type Logger struct { // Format specifies the format to be used for logging. Format string `json:"format"` } + +type RefreshToken struct { + DisableRotation bool `json:"disableRotation"` + ReuseInterval string `json:"reuseInterval"` + AbsoluteLifetime string `json:"absoluteLifetime"` + ValidIfNotUsedFor string `json:"validIfNotUsedFor"` +} diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 1960d101..3f2df3ed 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -304,6 +304,18 @@ func runServe(options serveOptions) error { logger.Infof("config device requests valid for: %v", deviceRequests) serverConfig.DeviceRequestsValidFor = deviceRequests } + refreshTokenPolicy, err := server.NewRefreshTokenPolicy( + logger, + c.Expiry.RefreshTokens.DisableRotation, + c.Expiry.RefreshTokens.ValidIfNotUsedFor, + c.Expiry.RefreshTokens.AbsoluteLifetime, + c.Expiry.RefreshTokens.ReuseInterval, + ) + if err != nil { + return fmt.Errorf("invalid refresh token expiration policy config: %v", err) + } + + serverConfig.RefreshTokenPolicy = refreshTokenPolicy serv, err := server.NewServer(context.Background(), serverConfig) if err != nil { return fmt.Errorf("failed to initialize server: %v", err) diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 1ca7aa66..b40ea582 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -73,10 +73,15 @@ telemetry: # tlsClientCA: examples/grpc-client/ca.crt # Uncomment this block to enable configuration for the expiration time durations. +# Is possible to specify units using only s, m and h suffixes. # expiry: # deviceRequests: "5m" # signingKeys: "6h" # idTokens: "24h" +# refreshTokens: +# reuseInterval: "3s" +# validIfNotUsedFor: "2160h" # 90 days +# absoluteLifetime: "3960h" # 165 days # Options for controlling the logger. # logger: diff --git a/server/handlers.go b/server/handlers.go index eb65f490..8e925a60 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -4,7 +4,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -919,206 +918,6 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil } -// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6 -func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { - code := r.PostFormValue("refresh_token") - scope := r.PostFormValue("scope") - if code == "" { - s.tokenErrHelper(w, errInvalidRequest, "No refresh token in request.", http.StatusBadRequest) - return - } - - token := new(internal.RefreshToken) - if err := internal.Unmarshal(code, token); err != nil { - // For backward compatibility, assume the refresh_token is a raw refresh token ID - // if it fails to decode. - // - // Because refresh_token values that aren't unmarshable were generated by servers - // that don't have a Token value, we'll still reject any attempts to claim a - // refresh_token twice. - token = &internal.RefreshToken{RefreshId: code, Token: ""} - } - - refresh, err := s.storage.GetRefresh(token.RefreshId) - if err != nil { - s.logger.Errorf("failed to get refresh token: %v", err) - if err == storage.ErrNotFound { - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) - } else { - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - } - return - } - if refresh.ClientID != client.ID { - s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) - return - } - if refresh.Token != token.Token { - s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) - return - } - - // Per the OAuth2 spec, if the client has omitted the scopes, default to the original - // authorized scopes. - // - // https://tools.ietf.org/html/rfc6749#section-6 - scopes := refresh.Scopes - if scope != "" { - requestedScopes := strings.Fields(scope) - var unauthorizedScopes []string - - for _, s := range requestedScopes { - contains := func() bool { - for _, scope := range refresh.Scopes { - if s == scope { - return true - } - } - return false - }() - if !contains { - unauthorizedScopes = append(unauthorizedScopes, s) - } - } - - if len(unauthorizedScopes) > 0 { - msg := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) - s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest) - return - } - scopes = requestedScopes - } - - var connectorData []byte - - session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) - switch { - case err != nil: - if err != storage.ErrNotFound { - s.logger.Errorf("failed to get offline session: %v", err) - return - } - case len(refresh.ConnectorData) > 0: - // Use the old connector data if it exists, should be deleted once used - connectorData = refresh.ConnectorData - default: - connectorData = session.ConnectorData - } - - conn, err := s.getConnector(refresh.ConnectorID) - if err != nil { - s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - ident := connector.Identity{ - UserID: refresh.Claims.UserID, - Username: refresh.Claims.Username, - PreferredUsername: refresh.Claims.PreferredUsername, - Email: refresh.Claims.Email, - EmailVerified: refresh.Claims.EmailVerified, - Groups: refresh.Claims.Groups, - ConnectorData: connectorData, - } - - // Can the connector refresh the identity? If so, attempt to refresh the data - // in the connector. - // - // TODO(ericchiang): We may want a strict mode where connectors that don't implement - // this interface can't perform refreshing. - if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { - newIdent, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident) - if err != nil { - s.logger.Errorf("failed to refresh identity: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - ident = newIdent - } - - claims := storage.Claims{ - UserID: ident.UserID, - Username: ident.Username, - PreferredUsername: ident.PreferredUsername, - Email: ident.Email, - EmailVerified: ident.EmailVerified, - Groups: ident.Groups, - } - - accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) - if err != nil { - s.logger.Errorf("failed to create new access token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID) - if err != nil { - s.logger.Errorf("failed to create ID token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - newToken := &internal.RefreshToken{ - RefreshId: refresh.ID, - Token: storage.NewID(), - } - rawNewToken, err := internal.Marshal(newToken) - if err != nil { - s.logger.Errorf("failed to marshal refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - lastUsed := s.now() - updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { - if old.Token != refresh.Token { - return old, errors.New("refresh token claimed twice") - } - old.Token = newToken.Token - // Update the claims of the refresh token. - // - // UserID intentionally ignored for now. - old.Claims.Username = ident.Username - old.Claims.PreferredUsername = ident.PreferredUsername - old.Claims.Email = ident.Email - old.Claims.EmailVerified = ident.EmailVerified - old.Claims.Groups = ident.Groups - old.LastUsed = lastUsed - - // ConnectorData has been moved to OfflineSession - old.ConnectorData = []byte{} - return old, nil - } - - // Update LastUsed time stamp in refresh token reference object - // in offline session for the user. - if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { - if old.Refresh[refresh.ClientID].ID != refresh.ID { - return old, errors.New("refresh token invalid") - } - old.Refresh[refresh.ClientID].LastUsed = lastUsed - old.ConnectorData = ident.ConnectorData - return old, nil - }); err != nil { - s.logger.Errorf("failed to update offline session: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - // Update refresh token in the storage. - if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil { - s.logger.Errorf("failed to update refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) - s.writeAccessToken(w, resp) -} - func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { const prefix = "Bearer " diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go new file mode 100644 index 00000000..8ea7ea9e --- /dev/null +++ b/server/refreshhandlers.go @@ -0,0 +1,339 @@ +package server + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/server/internal" + "github.com/dexidp/dex/storage" +) + +func contains(arr []string, item string) bool { + for _, itemFromArray := range arr { + if itemFromArray == item { + return true + } + } + return false +} + +type refreshError struct { + msg string + code int + desc string +} + +func newInternalServerError() *refreshError { + return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} +} + +func newBadRequestError(desc string) *refreshError { + return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} +} + +func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) { + s.tokenErrHelper(w, err.msg, err.desc, err.code) +} + +func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) { + code := r.PostFormValue("refresh_token") + if code == "" { + return nil, newBadRequestError("No refresh token is found in request.") + } + + token := new(internal.RefreshToken) + if err := internal.Unmarshal(code, token); err != nil { + // For backward compatibility, assume the refresh_token is a raw refresh token ID + // if it fails to decode. + // + // Because refresh_token values that aren't unmarshable were generated by servers + // that don't have a Token value, we'll still reject any attempts to claim a + // refresh_token twice. + token = &internal.RefreshToken{RefreshId: code, Token: ""} + } + + return token, nil +} + +// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info +func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) { + invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.") + + refresh, err := s.storage.GetRefresh(token.RefreshId) + if err != nil { + s.logger.Errorf("failed to get refresh token: %v", err) + if err != storage.ErrNotFound { + return nil, newInternalServerError() + } + + return nil, invalidErr + } + + if refresh.ClientID != clientID { + s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID) + return nil, invalidErr + } + + if refresh.Token != token.Token { + switch { + case !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed): + fallthrough + case refresh.ObsoleteToken != token.Token: + fallthrough + case refresh.ObsoleteToken == "": + s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) + return nil, invalidErr + } + } + + expiredErr := newBadRequestError("Refresh token expired.") + if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { + s.logger.Errorf("refresh token with id %s expired", refresh.ID) + return nil, expiredErr + } + + if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { + s.logger.Errorf("refresh token with id %s expired due to inactivity", refresh.ID) + return nil, expiredErr + } + + return &refresh, nil +} + +func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) { + // Per the OAuth2 spec, if the client has omitted the scopes, default to the original + // authorized scopes. + // + // https://tools.ietf.org/html/rfc6749#section-6 + scope := r.PostFormValue("scope") + + if scope == "" { + return refresh.Scopes, nil + } + + requestedScopes := strings.Fields(scope) + var unauthorizedScopes []string + + // Per the OAuth2 spec, if the client has omitted the scopes, default to the original + // authorized scopes. + // + // https://tools.ietf.org/html/rfc6749#section-6 + for _, requestScope := range requestedScopes { + if !contains(refresh.Scopes, requestScope) { + unauthorizedScopes = append(unauthorizedScopes, requestScope) + } + } + + if len(unauthorizedScopes) > 0 { + desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) + return nil, newBadRequestError(desc) + } + + return requestedScopes, nil +} + +func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) { + var connectorData []byte + + session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) + switch { + case err != nil: + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + return connector.Identity{}, newInternalServerError() + } + case len(refresh.ConnectorData) > 0: + // Use the old connector data if it exists, should be deleted once used + connectorData = refresh.ConnectorData + default: + connectorData = session.ConnectorData + } + + conn, err := s.getConnector(refresh.ConnectorID) + if err != nil { + s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) + return connector.Identity{}, newInternalServerError() + } + + ident := connector.Identity{ + UserID: refresh.Claims.UserID, + Username: refresh.Claims.Username, + PreferredUsername: refresh.Claims.PreferredUsername, + Email: refresh.Claims.Email, + EmailVerified: refresh.Claims.EmailVerified, + Groups: refresh.Claims.Groups, + ConnectorData: connectorData, + } + + // user's token was previously updated by a connector and is allowed to reuse + // it is excessive to refresh identity in upstream + if s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) && token.Token == refresh.ObsoleteToken { + return ident, nil + } + + // Can the connector refresh the identity? If so, attempt to refresh the data + // in the connector. + // + // TODO(ericchiang): We may want a strict mode where connectors that don't implement + // this interface can't perform refreshing. + if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { + newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) + if err != nil { + s.logger.Errorf("failed to refresh identity: %v", err) + return connector.Identity{}, newInternalServerError() + } + ident = newIdent + } + + return ident, nil +} + +// updateOfflineSession updates offline session in the storage +func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError { + offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if old.Refresh[refresh.ClientID].ID != refresh.ID { + return old, errors.New("refresh token invalid") + } + old.Refresh[refresh.ClientID].LastUsed = lastUsed + old.ConnectorData = ident.ConnectorData + return old, nil + } + + // Update LastUsed time stamp in refresh token reference object + // in offline session for the user. + err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) + if err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + return newInternalServerError() + } + + return nil +} + +// updateRefreshToken updates refresh token and offline session in the storage +func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) { + newToken := token + if s.refreshTokenPolicy.RotationEnabled() { + newToken = &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: storage.NewID(), + } + } + + lastUsed := s.now() + + rerr := s.updateOfflineSession(refresh, ident, lastUsed) + if rerr != nil { + return nil, rerr + } + + refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { + if s.refreshTokenPolicy.RotationEnabled() { + if old.Token != token.Token { + if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token { + newToken.Token = old.Token + return old, nil + } + return old, errors.New("refresh token claimed twice") + } + + old.ObsoleteToken = old.Token + } + + old.Token = newToken.Token + // Update the claims of the refresh token. + // + // UserID intentionally ignored for now. + old.Claims.Username = ident.Username + old.Claims.PreferredUsername = ident.PreferredUsername + old.Claims.Email = ident.Email + old.Claims.EmailVerified = ident.EmailVerified + old.Claims.Groups = ident.Groups + old.LastUsed = lastUsed + + // ConnectorData has been moved to OfflineSession + old.ConnectorData = []byte{} + return old, nil + } + + // Update refresh token in the storage. + err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) + if err != nil { + s.logger.Errorf("failed to update refresh token: %v", err) + return nil, newInternalServerError() + } + + return newToken, nil +} + +// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6 +// this method is the entrypoint for refresh tokens handling +func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { + token, rerr := s.extractRefreshTokenFromRequest(r) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + refresh, rerr := s.getRefreshTokenFromStorage(client.ID, token) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + scopes, rerr := s.getRefreshScopes(r, refresh) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + claims := storage.Claims{ + UserID: ident.UserID, + Username: ident.Username, + PreferredUsername: ident.PreferredUsername, + Email: ident.Email, + EmailVerified: ident.EmailVerified, + Groups: ident.Groups, + } + + accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) + if err != nil { + s.logger.Errorf("failed to create new access token: %v", err) + s.refreshTokenErrHelper(w, newInternalServerError()) + return + } + + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID) + if err != nil { + s.logger.Errorf("failed to create ID token: %v", err) + s.refreshTokenErrHelper(w, newInternalServerError()) + return + } + + newToken, rerr := s.updateRefreshToken(token, refresh, ident) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + rawNewToken, err := internal.Marshal(newToken) + if err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) + s.refreshTokenErrHelper(w, newInternalServerError()) + return + } + + resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) + s.writeAccessToken(w, resp) +} diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go new file mode 100644 index 00000000..c64c50b3 --- /dev/null +++ b/server/refreshhandlers_test.go @@ -0,0 +1,212 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "path" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/dexidp/dex/server/internal" + "github.com/dexidp/dex/storage" +) + +func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) { + c := storage.Client{ + ID: "test", + Secret: "barfoo", + RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, + Name: "dex client", + LogoURL: "https://goo.gl/JIyzIC", + } + + err := s.CreateClient(c) + require.NoError(t, err) + + c1 := storage.Connector{ + ID: "test", + Type: "mockCallback", + Name: "mockCallback", + Config: nil, + } + + err = s.CreateConnector(c1) + require.NoError(t, err) + + refresh := storage.RefreshToken{ + ID: "test", + Token: "bar", + ObsoleteToken: "", + Nonce: "foo", + ClientID: "test", + ConnectorID: "test", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + ConnectorData: []byte(`{"some":"data"}`), + } + + if useObsolete { + refresh.Token = "testtest" + refresh.ObsoleteToken = "bar" + } + + err = s.CreateRefresh(refresh) + require.NoError(t, err) + + offlineSessions := storage.OfflineSessions{ + UserID: "1", + ConnID: "test", + Refresh: map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}}, + ConnectorData: nil, + } + + err = s.CreateOfflineSessions(offlineSessions) + require.NoError(t, err) +} + +func TestRefreshTokenExpirationScenarios(t *testing.T) { + t0 := time.Now() + tests := []struct { + name string + policy *RefreshTokenPolicy + useObsolete bool + error string + }{ + { + name: "Normal", + policy: &RefreshTokenPolicy{rotateRefreshTokens: true}, + error: ``, + }, + { + name: "Not expired because used", + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: false, + validIfNotUsedFor: time.Second * 60, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: ``, + }, + { + name: "Expired because not used", + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: false, + validIfNotUsedFor: time.Second * 60, + now: func() time.Time { return t0.Add(time.Hour) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, + }, + { + name: "Absolutely expired", + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + absoluteLifetime: time.Second * 60, + now: func() time.Time { return t0.Add(time.Hour) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, + }, + { + name: "Obsolete tokens are allowed", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + reuseInterval: time.Second * 30, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: ``, + }, + { + name: "Obsolete tokens are not allowed", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`, + }, + { + name: "Obsolete tokens are allowed but token is expired globally", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + reuseInterval: time.Second * 30, + absoluteLifetime: time.Second * 20, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(*testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.RefreshTokenPolicy = tc.policy + c.Now = func() time.Time { return t0 } + }) + defer httpServer.Close() + + mockRefreshTokenTestStorage(t, s.storage, tc.useObsolete) + + u, err := url.Parse(s.issuerURL.String()) + require.NoError(t, err) + + tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"}) + require.NoError(t, err) + + u.Path = path.Join(u.Path, "/token") + v := url.Values{} + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", tokenData) + + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + req.SetBasicAuth("test", "barfoo") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + if tc.error == "" { + require.Equal(t, 200, rr.Code) + } else { + require.Equal(t, rr.Body.String(), tc.error) + return + } + + // Check that we received expected refresh token + var ref struct { + Token string `json:"refresh_token"` + } + err = json.Unmarshal(rr.Body.Bytes(), &ref) + require.NoError(t, err) + + if tc.policy.rotateRefreshTokens == false { + require.Equal(t, tokenData, ref.Token) + } else { + require.NotEqual(t, tokenData, ref.Token) + } + + if tc.useObsolete { + updatedTokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "testtest"}) + require.NoError(t, err) + require.Equal(t, updatedTokenData, ref.Token) + } + }) + } +} diff --git a/server/rotation.go b/server/rotation.go index b7dd8116..98489767 100644 --- a/server/rotation.go +++ b/server/rotation.go @@ -177,3 +177,73 @@ func (k keyRotator) rotate() error { k.logger.Infof("keys rotated, next rotation: %s", nextRotation) return nil } + +type RefreshTokenPolicy struct { + rotateRefreshTokens bool // enable rotation + + absoluteLifetime time.Duration // interval from token creation to the end of its life + validIfNotUsedFor time.Duration // interval from last token update to the end of its life + reuseInterval time.Duration // interval within which old refresh token is allowed to be reused + + now func() time.Time + + logger log.Logger +} + +func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) { + r := RefreshTokenPolicy{now: time.Now, logger: logger} + var err error + + if validIfNotUsedFor != "" { + r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err) + } + logger.Infof("config refresh tokens valid if not used for: %v", validIfNotUsedFor) + } + + if absoluteLifetime != "" { + r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err) + } + logger.Infof("config refresh tokens absolute lifetime: %v", absoluteLifetime) + } + + if reuseInterval != "" { + r.reuseInterval, err = time.ParseDuration(reuseInterval) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err) + } + logger.Infof("config refresh tokens reuse interval: %v", reuseInterval) + } + + r.rotateRefreshTokens = !rotation + logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens) + return &r, nil +} + +func (r *RefreshTokenPolicy) RotationEnabled() bool { + return r.rotateRefreshTokens +} + +func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool { + if r.absoluteLifetime == 0 { + return false // expiration disabled + } + return r.now().After(lastUsed.Add(r.absoluteLifetime)) +} + +func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool { + if r.validIfNotUsedFor == 0 { + return false // expiration disabled + } + return r.now().After(lastUsed.Add(r.validIfNotUsedFor)) +} + +func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool { + if r.reuseInterval == 0 { + return false // expiration disabled + } + return !r.now().After(lastUsed.Add(r.reuseInterval)) +} diff --git a/server/rotation_test.go b/server/rotation_test.go index 6f9b2ecb..e279bf54 100644 --- a/server/rotation_test.go +++ b/server/rotation_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/memory" @@ -100,3 +101,29 @@ func TestKeyRotator(t *testing.T) { } } } + +func TestRefreshTokenPolicy(t *testing.T) { + lastTime := time.Now() + l := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + + r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m") + require.NoError(t, err) + + t.Run("Allowed", func(t *testing.T) { + r.now = func() time.Time { return lastTime } + require.Equal(t, true, r.AllowedToReuse(lastTime)) + require.Equal(t, false, r.ExpiredBecauseUnused(lastTime)) + require.Equal(t, false, r.CompletelyExpired(lastTime)) + }) + + t.Run("Expired", func(t *testing.T) { + r.now = func() time.Time { return lastTime.Add(2 * time.Minute) } + require.Equal(t, false, r.AllowedToReuse(lastTime)) + require.Equal(t, true, r.ExpiredBecauseUnused(lastTime)) + require.Equal(t, true, r.CompletelyExpired(lastTime)) + }) +} diff --git a/server/server.go b/server/server.go index 93ab9f16..73b7d9da 100644 --- a/server/server.go +++ b/server/server.go @@ -84,6 +84,10 @@ type Config struct { IDTokensValidFor time.Duration // Defaults to 24 hours AuthRequestsValidFor time.Duration // Defaults to 24 hours DeviceRequestsValidFor time.Duration // Defaults to 5 minutes + + // Refresh token expiration settings + RefreshTokenPolicy *RefreshTokenPolicy + // If set, the server will use this connector to handle password grants PasswordConnector string @@ -171,6 +175,8 @@ type Server struct { authRequestsValidFor time.Duration deviceRequestsValidFor time.Duration + refreshTokenPolicy *RefreshTokenPolicy + logger log.Logger } @@ -246,6 +252,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour), deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute), + refreshTokenPolicy: c.RefreshTokenPolicy, skipApproval: c.SkipApprovalScreen, alwaysShowLogin: c.AlwaysShowLoginScreen, now: now, diff --git a/server/server_test.go b/server/server_test.go index 87ca6c17..9c739885 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -119,6 +119,16 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi t.Fatal(err) } server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. + + // Default rotation policy + if server.refreshTokenPolicy == nil { + server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "") + if err != nil { + t.Fatalf("failed to prepare rotation policy: %v", err) + } + server.refreshTokenPolicy.now = config.Now + } + return s, server } diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 3f5e2aa1..dde369c4 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -324,14 +324,15 @@ func testClientCRUD(t *testing.T, s storage.Storage) { func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { id := storage.NewID() refresh := storage.RefreshToken{ - ID: id, - Token: "bar", - Nonce: "foo", - ClientID: "client_id", - ConnectorID: "client_secret", - Scopes: []string{"openid", "email", "profile"}, - CreatedAt: time.Now().UTC().Round(time.Millisecond), - LastUsed: time.Now().UTC().Round(time.Millisecond), + ID: id, + Token: "bar", + ObsoleteToken: "", + Nonce: "foo", + ClientID: "client_id", + ConnectorID: "client_secret", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), Claims: storage.Claims{ UserID: "1", Username: "jane", @@ -378,14 +379,15 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { id2 := storage.NewID() refresh2 := storage.RefreshToken{ - ID: id2, - Token: "bar_2", - Nonce: "foo_2", - ClientID: "client_id_2", - ConnectorID: "client_secret", - Scopes: []string{"openid", "email", "profile"}, - CreatedAt: time.Now().UTC().Round(time.Millisecond), - LastUsed: time.Now().UTC().Round(time.Millisecond), + ID: id2, + Token: "bar_2", + ObsoleteToken: refresh.Token, + Nonce: "foo_2", + ClientID: "client_id_2", + ConnectorID: "client_secret", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), Claims: storage.Claims{ UserID: "2", Username: "john", diff --git a/storage/etcd/types.go b/storage/etcd/types.go index f2ffd9f7..9390608a 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -132,7 +132,8 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest { type RefreshToken struct { ID string `json:"id"` - Token string `json:"token"` + Token string `json:"token"` + ObsoleteToken string `json:"obsolete_token"` CreatedAt time.Time `json:"created_at"` LastUsed time.Time `json:"last_used"` @@ -152,6 +153,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { return storage.RefreshToken{ ID: r.ID, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, @@ -167,6 +169,7 @@ func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken { return RefreshToken{ ID: r.ID, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 07e25084..bed52736 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -496,7 +496,8 @@ type RefreshToken struct { ClientID string `json:"clientID"` Scopes []string `json:"scopes,omitempty"` - Token string `json:"token,omitempty"` + Token string `json:"token,omitempty"` + ObsoleteToken string `json:"obsoleteToken,omitempty"` Nonce string `json:"nonce,omitempty"` @@ -516,6 +517,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { return storage.RefreshToken{ ID: r.ObjectMeta.Name, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, @@ -538,6 +540,7 @@ func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken Namespace: cli.namespace, }, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 4451e5c5..5a234f9d 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -285,16 +285,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - token, created_at, last_used + token, obsolete_token, created_at, last_used ) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15); + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16); `, r.ID, r.ClientID, encoder(r.Scopes), r.Nonce, r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername, r.Claims.Email, r.Claims.EmailVerified, encoder(r.Claims.Groups), r.ConnectorID, r.ConnectorData, - r.Token, r.CreatedAt, r.LastUsed, + r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -329,17 +329,18 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok connector_id = $10, connector_data = $11, token = $12, - created_at = $13, - last_used = $14 + obsolete_token = $13, + created_at = $14, + last_used = $15 where - id = $15 + id = $16 `, r.ClientID, encoder(r.Scopes), r.Nonce, r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername, r.Claims.Email, r.Claims.EmailVerified, encoder(r.Claims.Groups), r.ConnectorID, r.ConnectorData, - r.Token, r.CreatedAt, r.LastUsed, id, + r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed, id, ) if err != nil { return fmt.Errorf("update refresh token: %v", err) @@ -360,7 +361,7 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) { claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - token, created_at, last_used + token, obsolete_token, created_at, last_used from refresh_token where id = $1; `, id)) } @@ -372,7 +373,7 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - token, created_at, last_used + token, obsolete_token, created_at, last_used from refresh_token; `) if err != nil { @@ -401,7 +402,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) { &r.Claims.Email, &r.Claims.EmailVerified, decoder(&r.Claims.Groups), &r.ConnectorID, &r.ConnectorData, - &r.Token, &r.CreatedAt, &r.LastUsed, + &r.Token, &r.ObsoleteToken, &r.CreatedAt, &r.LastUsed, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 460658c2..498db252 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -274,4 +274,11 @@ var migrations = []migration{ add column code_challenge_method text not null default '';`, }, }, + { + stmts: []string{ + ` + alter table refresh_token + add column obsolete_token text default '';`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index c308ac46..855eb09f 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -271,7 +271,8 @@ type RefreshToken struct { // A single token that's rotated every time the refresh token is refreshed. // // May be empty. - Token string + Token string + ObsoleteToken string CreatedAt time.Time LastUsed time.Time