From 91de99d57e57f90285d5259651e36e99dea94021 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Wed, 28 Oct 2020 10:26:34 +0400 Subject: [PATCH] feat: Add refresh token expiration and rotation settings Signed-off-by: m.nabokikh --- cmd/dex/config.go | 10 +++++ cmd/dex/serve.go | 12 +++++ examples/config-dev.yaml | 4 ++ server/handlers.go | 52 ++++++++++++++++------ server/rotation.go | 70 ++++++++++++++++++++++++++++++ server/rotation_test.go | 28 ++++++++++++ server/server.go | 7 +++ server/server_test.go | 14 ++++++ storage/conformance/conformance.go | 34 ++++++++------- storage/etcd/types.go | 5 ++- storage/kubernetes/types.go | 5 ++- storage/sql/crud.go | 21 ++++----- storage/sql/migrate.go | 3 ++ storage/storage.go | 3 +- 14 files changed, 226 insertions(+), 42 deletions(-) diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 88dc98e7..6683c39e 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -304,6 +304,9 @@ type Expiry struct { // DeviceRequests defines the duration of time for which the DeviceRequests will be valid. DeviceRequests string `json:"deviceRequests"` + + // RefreshToken defines refresh tokens expiry policy + RefreshToken RefreshTokenExpiry `json:"refreshTokens"` } // Logger holds configuration required to customize logging for dex. @@ -314,3 +317,10 @@ type Logger struct { // Format specifies the format to be used for logging. Format string `json:"format"` } + +type RefreshTokenExpiry struct { + DisableRotation bool `json:"disableRotation"` + ReuseInterval string `json:"reuseInterval"` + AbsoluteLifetime string `json:"absoluteLifetime"` + ValidIfNotUsedFor string `json:"validIfNotUsedFor"` +} diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 7afb8851..aa8f5ad6 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -317,6 +317,18 @@ func runServe(options serveOptions) error { logger.Infof("config device requests valid for: %v", deviceRequests) serverConfig.DeviceRequestsValidFor = deviceRequests } + refreshTokenPolicy, err := server.NewRefreshTokenPolicyFromConfig( + logger, + c.Expiry.RefreshToken.DisableRotation, + c.Expiry.RefreshToken.ValidIfNotUsedFor, + c.Expiry.RefreshToken.AbsoluteLifetime, + c.Expiry.RefreshToken.ReuseInterval, + ) + if err != nil { + return fmt.Errorf("invalid refresh token expiration policy config: %v", err) + } + + serverConfig.RefreshTokenPolicy = refreshTokenPolicy serv, err := server.NewServer(context.Background(), serverConfig) if err != nil { return fmt.Errorf("failed to initialize server: %v", err) diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 1ca7aa66..344d72dc 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -77,6 +77,10 @@ telemetry: # deviceRequests: "5m" # signingKeys: "6h" # idTokens: "24h" +# refreshTokens: +# reuseInterval: "3s" +# validIfNotUsedFor: "2190h" +# absoluteLifetime: "5000h" # Options for controlling the logger. # logger: diff --git a/server/handlers.go b/server/handlers.go index 348700df..0f76c410 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1035,14 +1035,27 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie } return } + if refresh.ClientID != client.ID { s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID) s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) return } if refresh.Token != token.Token { - s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) + if !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) || refresh.ObsoleteToken != token.Token { + s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) + s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) + return + } + } + if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { + s.logger.Errorf("refresh token with id %s expired", refresh.ID) + s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest) + return + } + if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { + s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) + s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest) return } @@ -1147,22 +1160,28 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - newToken := &internal.RefreshToken{ - RefreshId: refresh.ID, - Token: storage.NewID(), - } - rawNewToken, err := internal.Marshal(newToken) - if err != nil { - s.logger.Errorf("failed to marshal refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + newToken := token + if s.refreshTokenPolicy.RotationEnabled() { + newToken = &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: storage.NewID(), + } } lastUsed := s.now() updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { - if old.Token != refresh.Token { - return old, errors.New("refresh token claimed twice") + if s.refreshTokenPolicy.RotationEnabled() { + if old.Token != refresh.Token { + if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token { + newToken.Token = old.Token + return old, nil + } + return old, errors.New("refresh token claimed twice") + } + + old.ObsoleteToken = old.Token } + old.Token = newToken.Token // Update the claims of the refresh token. // @@ -1201,6 +1220,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } + rawNewToken, err := internal.Marshal(newToken) + if err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) s.writeAccessToken(w, resp) } diff --git a/server/rotation.go b/server/rotation.go index b7dd8116..48593ede 100644 --- a/server/rotation.go +++ b/server/rotation.go @@ -177,3 +177,73 @@ func (k keyRotator) rotate() error { k.logger.Infof("keys rotated, next rotation: %s", nextRotation) return nil } + +type RefreshTokenPolicy struct { + rotateRefreshTokens bool // enable rotation + + absoluteLifetime time.Duration // interval from token creation to the end of its life + validIfNotUsedFor time.Duration // interval from last token update to the end of its life + reuseInterval time.Duration // interval within which old refresh token is allowed to be reused + + Clock func() time.Time + + logger log.Logger +} + +func NewRefreshTokenPolicyFromConfig(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) { + r := RefreshTokenPolicy{Clock: time.Now, logger: logger} + var err error + + if validIfNotUsedFor != "" { + r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err) + } + logger.Infof("config refresh tokens valid if not used for: %v", validIfNotUsedFor) + } + + if absoluteLifetime != "" { + r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err) + } + logger.Infof("config refresh tokens absolute lifetime: %v", absoluteLifetime) + } + + if reuseInterval != "" { + r.reuseInterval, err = time.ParseDuration(reuseInterval) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err) + } + logger.Infof("config refresh tokens reuse interval: %v", reuseInterval) + } + + r.rotateRefreshTokens = !rotation + logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens) + return &r, nil +} + +func (r *RefreshTokenPolicy) RotationEnabled() bool { + return r.rotateRefreshTokens +} + +func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool { + if r.absoluteLifetime == 0 { + return false // expiration disabled + } + return r.Clock().After(lastUsed.Add(r.absoluteLifetime)) +} + +func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool { + if r.validIfNotUsedFor == 0 { + return false // expiration disabled + } + return r.Clock().After(lastUsed.Add(r.validIfNotUsedFor)) +} + +func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool { + if r.reuseInterval == 0 { + return false // expiration disabled + } + return !r.Clock().After(lastUsed.Add(r.reuseInterval)) +} diff --git a/server/rotation_test.go b/server/rotation_test.go index 6f9b2ecb..a75a2435 100644 --- a/server/rotation_test.go +++ b/server/rotation_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/memory" @@ -100,3 +101,30 @@ func TestKeyRotator(t *testing.T) { } } } + +func TestRefreshTokenPolicy(t *testing.T) { + lastTime := time.Now() + l := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + + r, err := NewRefreshTokenPolicyFromConfig(l, true, "1m", "1m", "1m") + require.NoError(t, err) + + t.Run("Allowed", func(t *testing.T) { + r.Clock = func() time.Time { return lastTime } + require.Equal(t, true, r.AllowedToReuse(lastTime)) + require.Equal(t, false, r.ExpiredBecauseUnused(lastTime)) + require.Equal(t, false, r.CompletelyExpired(lastTime)) + }) + + t.Run("Expired", func(t *testing.T) { + r.Clock = func() time.Time { return lastTime.Add(2 * time.Minute) } + time.Sleep(1 * time.Second) + require.Equal(t, false, r.AllowedToReuse(lastTime)) + require.Equal(t, true, r.ExpiredBecauseUnused(lastTime)) + require.Equal(t, true, r.CompletelyExpired(lastTime)) + }) +} diff --git a/server/server.go b/server/server.go index 6fd4d8b7..4d6c69b4 100644 --- a/server/server.go +++ b/server/server.go @@ -80,6 +80,10 @@ type Config struct { IDTokensValidFor time.Duration // Defaults to 24 hours AuthRequestsValidFor time.Duration // Defaults to 24 hours DeviceRequestsValidFor time.Duration // Defaults to 5 minutes + + // Refresh token expiration settings + RefreshTokenPolicy *RefreshTokenPolicy + // If set, the server will use this connector to handle password grants PasswordConnector string @@ -159,6 +163,8 @@ type Server struct { authRequestsValidFor time.Duration deviceRequestsValidFor time.Duration + refreshTokenPolicy *RefreshTokenPolicy + logger log.Logger } @@ -227,6 +233,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour), deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute), + refreshTokenPolicy: c.RefreshTokenPolicy, skipApproval: c.SkipApprovalScreen, alwaysShowLogin: c.AlwaysShowLoginScreen, now: now, diff --git a/server/server_test.go b/server/server_test.go index 3a918434..2ef4f613 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -677,6 +677,13 @@ func TestOAuth2CodeFlow(t *testing.T) { }) defer httpServer.Close() + policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "") + if err != nil { + t.Fatalf("failed to prepare rotation policy: %v", err) + } + policy.Clock = now + s.refreshTokenPolicy = policy + mockConn := s.connectors["mock"] conn = mockConn.Connector.(*mock.Callback) @@ -1508,6 +1515,13 @@ func TestOAuth2DeviceFlow(t *testing.T) { }) defer httpServer.Close() + policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "") + if err != nil { + t.Fatalf("failed to prepare rotation policy: %v", err) + } + policy.Clock = now + s.refreshTokenPolicy = policy + mockConn := s.connectors["mock"] conn = mockConn.Connector.(*mock.Callback) diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 3f5e2aa1..0bae52cb 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -324,14 +324,15 @@ func testClientCRUD(t *testing.T, s storage.Storage) { func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { id := storage.NewID() refresh := storage.RefreshToken{ - ID: id, - Token: "bar", - Nonce: "foo", - ClientID: "client_id", - ConnectorID: "client_secret", - Scopes: []string{"openid", "email", "profile"}, - CreatedAt: time.Now().UTC().Round(time.Millisecond), - LastUsed: time.Now().UTC().Round(time.Millisecond), + ID: id, + Token: "bar", + ObsoleteToken: "", + Nonce: "foo", + ClientID: "client_id", + ConnectorID: "client_secret", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), Claims: storage.Claims{ UserID: "1", Username: "jane", @@ -378,14 +379,15 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { id2 := storage.NewID() refresh2 := storage.RefreshToken{ - ID: id2, - Token: "bar_2", - Nonce: "foo_2", - ClientID: "client_id_2", - ConnectorID: "client_secret", - Scopes: []string{"openid", "email", "profile"}, - CreatedAt: time.Now().UTC().Round(time.Millisecond), - LastUsed: time.Now().UTC().Round(time.Millisecond), + ID: id2, + Token: "bar_2", + ObsoleteToken: "bar", + Nonce: "foo_2", + ClientID: "client_id_2", + ConnectorID: "client_secret", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), Claims: storage.Claims{ UserID: "2", Username: "john", diff --git a/storage/etcd/types.go b/storage/etcd/types.go index f2ffd9f7..9390608a 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -132,7 +132,8 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest { type RefreshToken struct { ID string `json:"id"` - Token string `json:"token"` + Token string `json:"token"` + ObsoleteToken string `json:"obsolete_token"` CreatedAt time.Time `json:"created_at"` LastUsed time.Time `json:"last_used"` @@ -152,6 +153,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { return storage.RefreshToken{ ID: r.ID, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, @@ -167,6 +169,7 @@ func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken { return RefreshToken{ ID: r.ID, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 07e25084..bed52736 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -496,7 +496,8 @@ type RefreshToken struct { ClientID string `json:"clientID"` Scopes []string `json:"scopes,omitempty"` - Token string `json:"token,omitempty"` + Token string `json:"token,omitempty"` + ObsoleteToken string `json:"obsoleteToken,omitempty"` Nonce string `json:"nonce,omitempty"` @@ -516,6 +517,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { return storage.RefreshToken{ ID: r.ObjectMeta.Name, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, @@ -538,6 +540,7 @@ func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken Namespace: cli.namespace, }, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 4451e5c5..5a234f9d 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -285,16 +285,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - token, created_at, last_used + token, obsolete_token, created_at, last_used ) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15); + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16); `, r.ID, r.ClientID, encoder(r.Scopes), r.Nonce, r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername, r.Claims.Email, r.Claims.EmailVerified, encoder(r.Claims.Groups), r.ConnectorID, r.ConnectorData, - r.Token, r.CreatedAt, r.LastUsed, + r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -329,17 +329,18 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok connector_id = $10, connector_data = $11, token = $12, - created_at = $13, - last_used = $14 + obsolete_token = $13, + created_at = $14, + last_used = $15 where - id = $15 + id = $16 `, r.ClientID, encoder(r.Scopes), r.Nonce, r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername, r.Claims.Email, r.Claims.EmailVerified, encoder(r.Claims.Groups), r.ConnectorID, r.ConnectorData, - r.Token, r.CreatedAt, r.LastUsed, id, + r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed, id, ) if err != nil { return fmt.Errorf("update refresh token: %v", err) @@ -360,7 +361,7 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) { claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - token, created_at, last_used + token, obsolete_token, created_at, last_used from refresh_token where id = $1; `, id)) } @@ -372,7 +373,7 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - token, created_at, last_used + token, obsolete_token, created_at, last_used from refresh_token; `) if err != nil { @@ -401,7 +402,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) { &r.Claims.Email, &r.Claims.EmailVerified, decoder(&r.Claims.Groups), &r.ConnectorID, &r.ConnectorData, - &r.Token, &r.CreatedAt, &r.LastUsed, + &r.Token, &r.ObsoleteToken, &r.CreatedAt, &r.LastUsed, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 460658c2..0f2666bf 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -176,6 +176,9 @@ var migrations = []migration{ alter table refresh_token add column token text not null default '';`, ` + alter table refresh_token + add column obsolete_token text default '';`, + ` alter table refresh_token add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`, ` diff --git a/storage/storage.go b/storage/storage.go index c308ac46..855eb09f 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -271,7 +271,8 @@ type RefreshToken struct { // A single token that's rotated every time the refresh token is refreshed. // // May be empty. - Token string + Token string + ObsoleteToken string CreatedAt time.Time LastUsed time.Time