From 91de99d57e57f90285d5259651e36e99dea94021 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Wed, 28 Oct 2020 10:26:34 +0400 Subject: [PATCH 1/8] feat: Add refresh token expiration and rotation settings Signed-off-by: m.nabokikh --- cmd/dex/config.go | 10 +++++ cmd/dex/serve.go | 12 +++++ examples/config-dev.yaml | 4 ++ server/handlers.go | 52 ++++++++++++++++------ server/rotation.go | 70 ++++++++++++++++++++++++++++++ server/rotation_test.go | 28 ++++++++++++ server/server.go | 7 +++ server/server_test.go | 14 ++++++ storage/conformance/conformance.go | 34 ++++++++------- storage/etcd/types.go | 5 ++- storage/kubernetes/types.go | 5 ++- storage/sql/crud.go | 21 ++++----- storage/sql/migrate.go | 3 ++ storage/storage.go | 3 +- 14 files changed, 226 insertions(+), 42 deletions(-) diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 88dc98e7..6683c39e 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -304,6 +304,9 @@ type Expiry struct { // DeviceRequests defines the duration of time for which the DeviceRequests will be valid. DeviceRequests string `json:"deviceRequests"` + + // RefreshToken defines refresh tokens expiry policy + RefreshToken RefreshTokenExpiry `json:"refreshTokens"` } // Logger holds configuration required to customize logging for dex. @@ -314,3 +317,10 @@ type Logger struct { // Format specifies the format to be used for logging. Format string `json:"format"` } + +type RefreshTokenExpiry struct { + DisableRotation bool `json:"disableRotation"` + ReuseInterval string `json:"reuseInterval"` + AbsoluteLifetime string `json:"absoluteLifetime"` + ValidIfNotUsedFor string `json:"validIfNotUsedFor"` +} diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 7afb8851..aa8f5ad6 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -317,6 +317,18 @@ func runServe(options serveOptions) error { logger.Infof("config device requests valid for: %v", deviceRequests) serverConfig.DeviceRequestsValidFor = deviceRequests } + refreshTokenPolicy, err := server.NewRefreshTokenPolicyFromConfig( + logger, + c.Expiry.RefreshToken.DisableRotation, + c.Expiry.RefreshToken.ValidIfNotUsedFor, + c.Expiry.RefreshToken.AbsoluteLifetime, + c.Expiry.RefreshToken.ReuseInterval, + ) + if err != nil { + return fmt.Errorf("invalid refresh token expiration policy config: %v", err) + } + + serverConfig.RefreshTokenPolicy = refreshTokenPolicy serv, err := server.NewServer(context.Background(), serverConfig) if err != nil { return fmt.Errorf("failed to initialize server: %v", err) diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 1ca7aa66..344d72dc 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -77,6 +77,10 @@ telemetry: # deviceRequests: "5m" # signingKeys: "6h" # idTokens: "24h" +# refreshTokens: +# reuseInterval: "3s" +# validIfNotUsedFor: "2190h" +# absoluteLifetime: "5000h" # Options for controlling the logger. # logger: diff --git a/server/handlers.go b/server/handlers.go index 348700df..0f76c410 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1035,14 +1035,27 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie } return } + if refresh.ClientID != client.ID { s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID) s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) return } if refresh.Token != token.Token { - s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) + if !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) || refresh.ObsoleteToken != token.Token { + s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) + s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) + return + } + } + if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { + s.logger.Errorf("refresh token with id %s expired", refresh.ID) + s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest) + return + } + if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { + s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) + s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest) return } @@ -1147,22 +1160,28 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - newToken := &internal.RefreshToken{ - RefreshId: refresh.ID, - Token: storage.NewID(), - } - rawNewToken, err := internal.Marshal(newToken) - if err != nil { - s.logger.Errorf("failed to marshal refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return + newToken := token + if s.refreshTokenPolicy.RotationEnabled() { + newToken = &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: storage.NewID(), + } } lastUsed := s.now() updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { - if old.Token != refresh.Token { - return old, errors.New("refresh token claimed twice") + if s.refreshTokenPolicy.RotationEnabled() { + if old.Token != refresh.Token { + if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token { + newToken.Token = old.Token + return old, nil + } + return old, errors.New("refresh token claimed twice") + } + + old.ObsoleteToken = old.Token } + old.Token = newToken.Token // Update the claims of the refresh token. // @@ -1201,6 +1220,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } + rawNewToken, err := internal.Marshal(newToken) + if err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) s.writeAccessToken(w, resp) } diff --git a/server/rotation.go b/server/rotation.go index b7dd8116..48593ede 100644 --- a/server/rotation.go +++ b/server/rotation.go @@ -177,3 +177,73 @@ func (k keyRotator) rotate() error { k.logger.Infof("keys rotated, next rotation: %s", nextRotation) return nil } + +type RefreshTokenPolicy struct { + rotateRefreshTokens bool // enable rotation + + absoluteLifetime time.Duration // interval from token creation to the end of its life + validIfNotUsedFor time.Duration // interval from last token update to the end of its life + reuseInterval time.Duration // interval within which old refresh token is allowed to be reused + + Clock func() time.Time + + logger log.Logger +} + +func NewRefreshTokenPolicyFromConfig(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) { + r := RefreshTokenPolicy{Clock: time.Now, logger: logger} + var err error + + if validIfNotUsedFor != "" { + r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err) + } + logger.Infof("config refresh tokens valid if not used for: %v", validIfNotUsedFor) + } + + if absoluteLifetime != "" { + r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err) + } + logger.Infof("config refresh tokens absolute lifetime: %v", absoluteLifetime) + } + + if reuseInterval != "" { + r.reuseInterval, err = time.ParseDuration(reuseInterval) + if err != nil { + return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err) + } + logger.Infof("config refresh tokens reuse interval: %v", reuseInterval) + } + + r.rotateRefreshTokens = !rotation + logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens) + return &r, nil +} + +func (r *RefreshTokenPolicy) RotationEnabled() bool { + return r.rotateRefreshTokens +} + +func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool { + if r.absoluteLifetime == 0 { + return false // expiration disabled + } + return r.Clock().After(lastUsed.Add(r.absoluteLifetime)) +} + +func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool { + if r.validIfNotUsedFor == 0 { + return false // expiration disabled + } + return r.Clock().After(lastUsed.Add(r.validIfNotUsedFor)) +} + +func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool { + if r.reuseInterval == 0 { + return false // expiration disabled + } + return !r.Clock().After(lastUsed.Add(r.reuseInterval)) +} diff --git a/server/rotation_test.go b/server/rotation_test.go index 6f9b2ecb..a75a2435 100644 --- a/server/rotation_test.go +++ b/server/rotation_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/memory" @@ -100,3 +101,30 @@ func TestKeyRotator(t *testing.T) { } } } + +func TestRefreshTokenPolicy(t *testing.T) { + lastTime := time.Now() + l := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + + r, err := NewRefreshTokenPolicyFromConfig(l, true, "1m", "1m", "1m") + require.NoError(t, err) + + t.Run("Allowed", func(t *testing.T) { + r.Clock = func() time.Time { return lastTime } + require.Equal(t, true, r.AllowedToReuse(lastTime)) + require.Equal(t, false, r.ExpiredBecauseUnused(lastTime)) + require.Equal(t, false, r.CompletelyExpired(lastTime)) + }) + + t.Run("Expired", func(t *testing.T) { + r.Clock = func() time.Time { return lastTime.Add(2 * time.Minute) } + time.Sleep(1 * time.Second) + require.Equal(t, false, r.AllowedToReuse(lastTime)) + require.Equal(t, true, r.ExpiredBecauseUnused(lastTime)) + require.Equal(t, true, r.CompletelyExpired(lastTime)) + }) +} diff --git a/server/server.go b/server/server.go index 6fd4d8b7..4d6c69b4 100644 --- a/server/server.go +++ b/server/server.go @@ -80,6 +80,10 @@ type Config struct { IDTokensValidFor time.Duration // Defaults to 24 hours AuthRequestsValidFor time.Duration // Defaults to 24 hours DeviceRequestsValidFor time.Duration // Defaults to 5 minutes + + // Refresh token expiration settings + RefreshTokenPolicy *RefreshTokenPolicy + // If set, the server will use this connector to handle password grants PasswordConnector string @@ -159,6 +163,8 @@ type Server struct { authRequestsValidFor time.Duration deviceRequestsValidFor time.Duration + refreshTokenPolicy *RefreshTokenPolicy + logger log.Logger } @@ -227,6 +233,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour), deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute), + refreshTokenPolicy: c.RefreshTokenPolicy, skipApproval: c.SkipApprovalScreen, alwaysShowLogin: c.AlwaysShowLoginScreen, now: now, diff --git a/server/server_test.go b/server/server_test.go index 3a918434..2ef4f613 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -677,6 +677,13 @@ func TestOAuth2CodeFlow(t *testing.T) { }) defer httpServer.Close() + policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "") + if err != nil { + t.Fatalf("failed to prepare rotation policy: %v", err) + } + policy.Clock = now + s.refreshTokenPolicy = policy + mockConn := s.connectors["mock"] conn = mockConn.Connector.(*mock.Callback) @@ -1508,6 +1515,13 @@ func TestOAuth2DeviceFlow(t *testing.T) { }) defer httpServer.Close() + policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "") + if err != nil { + t.Fatalf("failed to prepare rotation policy: %v", err) + } + policy.Clock = now + s.refreshTokenPolicy = policy + mockConn := s.connectors["mock"] conn = mockConn.Connector.(*mock.Callback) diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 3f5e2aa1..0bae52cb 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -324,14 +324,15 @@ func testClientCRUD(t *testing.T, s storage.Storage) { func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { id := storage.NewID() refresh := storage.RefreshToken{ - ID: id, - Token: "bar", - Nonce: "foo", - ClientID: "client_id", - ConnectorID: "client_secret", - Scopes: []string{"openid", "email", "profile"}, - CreatedAt: time.Now().UTC().Round(time.Millisecond), - LastUsed: time.Now().UTC().Round(time.Millisecond), + ID: id, + Token: "bar", + ObsoleteToken: "", + Nonce: "foo", + ClientID: "client_id", + ConnectorID: "client_secret", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), Claims: storage.Claims{ UserID: "1", Username: "jane", @@ -378,14 +379,15 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { id2 := storage.NewID() refresh2 := storage.RefreshToken{ - ID: id2, - Token: "bar_2", - Nonce: "foo_2", - ClientID: "client_id_2", - ConnectorID: "client_secret", - Scopes: []string{"openid", "email", "profile"}, - CreatedAt: time.Now().UTC().Round(time.Millisecond), - LastUsed: time.Now().UTC().Round(time.Millisecond), + ID: id2, + Token: "bar_2", + ObsoleteToken: "bar", + Nonce: "foo_2", + ClientID: "client_id_2", + ConnectorID: "client_secret", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), Claims: storage.Claims{ UserID: "2", Username: "john", diff --git a/storage/etcd/types.go b/storage/etcd/types.go index f2ffd9f7..9390608a 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -132,7 +132,8 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest { type RefreshToken struct { ID string `json:"id"` - Token string `json:"token"` + Token string `json:"token"` + ObsoleteToken string `json:"obsolete_token"` CreatedAt time.Time `json:"created_at"` LastUsed time.Time `json:"last_used"` @@ -152,6 +153,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { return storage.RefreshToken{ ID: r.ID, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, @@ -167,6 +169,7 @@ func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken { return RefreshToken{ ID: r.ID, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 07e25084..bed52736 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -496,7 +496,8 @@ type RefreshToken struct { ClientID string `json:"clientID"` Scopes []string `json:"scopes,omitempty"` - Token string `json:"token,omitempty"` + Token string `json:"token,omitempty"` + ObsoleteToken string `json:"obsoleteToken,omitempty"` Nonce string `json:"nonce,omitempty"` @@ -516,6 +517,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { return storage.RefreshToken{ ID: r.ObjectMeta.Name, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, @@ -538,6 +540,7 @@ func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken Namespace: cli.namespace, }, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 4451e5c5..5a234f9d 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -285,16 +285,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - token, created_at, last_used + token, obsolete_token, created_at, last_used ) - values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15); + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16); `, r.ID, r.ClientID, encoder(r.Scopes), r.Nonce, r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername, r.Claims.Email, r.Claims.EmailVerified, encoder(r.Claims.Groups), r.ConnectorID, r.ConnectorData, - r.Token, r.CreatedAt, r.LastUsed, + r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -329,17 +329,18 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok connector_id = $10, connector_data = $11, token = $12, - created_at = $13, - last_used = $14 + obsolete_token = $13, + created_at = $14, + last_used = $15 where - id = $15 + id = $16 `, r.ClientID, encoder(r.Scopes), r.Nonce, r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername, r.Claims.Email, r.Claims.EmailVerified, encoder(r.Claims.Groups), r.ConnectorID, r.ConnectorData, - r.Token, r.CreatedAt, r.LastUsed, id, + r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed, id, ) if err != nil { return fmt.Errorf("update refresh token: %v", err) @@ -360,7 +361,7 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) { claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - token, created_at, last_used + token, obsolete_token, created_at, last_used from refresh_token where id = $1; `, id)) } @@ -372,7 +373,7 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, - token, created_at, last_used + token, obsolete_token, created_at, last_used from refresh_token; `) if err != nil { @@ -401,7 +402,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) { &r.Claims.Email, &r.Claims.EmailVerified, decoder(&r.Claims.Groups), &r.ConnectorID, &r.ConnectorData, - &r.Token, &r.CreatedAt, &r.LastUsed, + &r.Token, &r.ObsoleteToken, &r.CreatedAt, &r.LastUsed, ) if err != nil { if err == sql.ErrNoRows { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 460658c2..0f2666bf 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -176,6 +176,9 @@ var migrations = []migration{ alter table refresh_token add column token text not null default '';`, ` + alter table refresh_token + add column obsolete_token text default '';`, + ` alter table refresh_token add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`, ` diff --git a/storage/storage.go b/storage/storage.go index c308ac46..855eb09f 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -271,7 +271,8 @@ type RefreshToken struct { // A single token that's rotated every time the refresh token is refreshed. // // May be empty. - Token string + Token string + ObsoleteToken string CreatedAt time.Time LastUsed time.Time From 06c8ab5aa71e9df22a6e093ad060fb1a6cb25d99 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Sun, 15 Nov 2020 22:26:34 +0400 Subject: [PATCH 2/8] Fixes of naming and code style Signed-off-by: m.nabokikh --- cmd/dex/config.go | 4 ++-- cmd/dex/serve.go | 10 +++++----- server/handlers.go | 7 ++++++- server/rotation.go | 12 ++++++------ server/rotation_test.go | 7 +++---- server/server_test.go | 22 ++++++++-------------- storage/sql/migrate.go | 10 +++++++--- 7 files changed, 37 insertions(+), 35 deletions(-) diff --git a/cmd/dex/config.go b/cmd/dex/config.go index 6683c39e..a75ddaee 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -305,8 +305,8 @@ type Expiry struct { // DeviceRequests defines the duration of time for which the DeviceRequests will be valid. DeviceRequests string `json:"deviceRequests"` - // RefreshToken defines refresh tokens expiry policy - RefreshToken RefreshTokenExpiry `json:"refreshTokens"` + // RefreshTokens defines refresh tokens expiry policy + RefreshTokens RefreshTokenExpiry `json:"refreshTokens"` } // Logger holds configuration required to customize logging for dex. diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index aa8f5ad6..9e6f8b9a 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -317,12 +317,12 @@ func runServe(options serveOptions) error { logger.Infof("config device requests valid for: %v", deviceRequests) serverConfig.DeviceRequestsValidFor = deviceRequests } - refreshTokenPolicy, err := server.NewRefreshTokenPolicyFromConfig( + refreshTokenPolicy, err := server.NewRefreshTokenPolicy( logger, - c.Expiry.RefreshToken.DisableRotation, - c.Expiry.RefreshToken.ValidIfNotUsedFor, - c.Expiry.RefreshToken.AbsoluteLifetime, - c.Expiry.RefreshToken.ReuseInterval, + c.Expiry.RefreshTokens.DisableRotation, + c.Expiry.RefreshTokens.ValidIfNotUsedFor, + c.Expiry.RefreshTokens.AbsoluteLifetime, + c.Expiry.RefreshTokens.ReuseInterval, ) if err != nil { return fmt.Errorf("invalid refresh token expiration policy config: %v", err) diff --git a/server/handlers.go b/server/handlers.go index 0f76c410..aa08b7a5 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -1042,7 +1042,12 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } if refresh.Token != token.Token { - if !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) || refresh.ObsoleteToken != token.Token { + switch { + case !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed): + fallthrough + case refresh.ObsoleteToken != token.Token: + fallthrough + case refresh.ObsoleteToken == "": s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) return diff --git a/server/rotation.go b/server/rotation.go index 48593ede..98489767 100644 --- a/server/rotation.go +++ b/server/rotation.go @@ -185,13 +185,13 @@ type RefreshTokenPolicy struct { validIfNotUsedFor time.Duration // interval from last token update to the end of its life reuseInterval time.Duration // interval within which old refresh token is allowed to be reused - Clock func() time.Time + now func() time.Time logger log.Logger } -func NewRefreshTokenPolicyFromConfig(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) { - r := RefreshTokenPolicy{Clock: time.Now, logger: logger} +func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) { + r := RefreshTokenPolicy{now: time.Now, logger: logger} var err error if validIfNotUsedFor != "" { @@ -231,19 +231,19 @@ func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool { if r.absoluteLifetime == 0 { return false // expiration disabled } - return r.Clock().After(lastUsed.Add(r.absoluteLifetime)) + return r.now().After(lastUsed.Add(r.absoluteLifetime)) } func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool { if r.validIfNotUsedFor == 0 { return false // expiration disabled } - return r.Clock().After(lastUsed.Add(r.validIfNotUsedFor)) + return r.now().After(lastUsed.Add(r.validIfNotUsedFor)) } func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool { if r.reuseInterval == 0 { return false // expiration disabled } - return !r.Clock().After(lastUsed.Add(r.reuseInterval)) + return !r.now().After(lastUsed.Add(r.reuseInterval)) } diff --git a/server/rotation_test.go b/server/rotation_test.go index a75a2435..e279bf54 100644 --- a/server/rotation_test.go +++ b/server/rotation_test.go @@ -110,19 +110,18 @@ func TestRefreshTokenPolicy(t *testing.T) { Level: logrus.DebugLevel, } - r, err := NewRefreshTokenPolicyFromConfig(l, true, "1m", "1m", "1m") + r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m") require.NoError(t, err) t.Run("Allowed", func(t *testing.T) { - r.Clock = func() time.Time { return lastTime } + r.now = func() time.Time { return lastTime } require.Equal(t, true, r.AllowedToReuse(lastTime)) require.Equal(t, false, r.ExpiredBecauseUnused(lastTime)) require.Equal(t, false, r.CompletelyExpired(lastTime)) }) t.Run("Expired", func(t *testing.T) { - r.Clock = func() time.Time { return lastTime.Add(2 * time.Minute) } - time.Sleep(1 * time.Second) + r.now = func() time.Time { return lastTime.Add(2 * time.Minute) } require.Equal(t, false, r.AllowedToReuse(lastTime)) require.Equal(t, true, r.ExpiredBecauseUnused(lastTime)) require.Equal(t, true, r.CompletelyExpired(lastTime)) diff --git a/server/server_test.go b/server/server_test.go index 2ef4f613..d8b40991 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -117,6 +117,14 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi t.Fatal(err) } server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. + + // Default rotation policy + server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "") + if err != nil { + t.Fatalf("failed to prepare rotation policy: %v", err) + } + server.refreshTokenPolicy.now = config.Now + return s, server } @@ -677,13 +685,6 @@ func TestOAuth2CodeFlow(t *testing.T) { }) defer httpServer.Close() - policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "") - if err != nil { - t.Fatalf("failed to prepare rotation policy: %v", err) - } - policy.Clock = now - s.refreshTokenPolicy = policy - mockConn := s.connectors["mock"] conn = mockConn.Connector.(*mock.Callback) @@ -1515,13 +1516,6 @@ func TestOAuth2DeviceFlow(t *testing.T) { }) defer httpServer.Close() - policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "") - if err != nil { - t.Fatalf("failed to prepare rotation policy: %v", err) - } - policy.Clock = now - s.refreshTokenPolicy = policy - mockConn := s.connectors["mock"] conn = mockConn.Connector.(*mock.Callback) diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 0f2666bf..498db252 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -176,9 +176,6 @@ var migrations = []migration{ alter table refresh_token add column token text not null default '';`, ` - alter table refresh_token - add column obsolete_token text default '';`, - ` alter table refresh_token add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC';`, ` @@ -277,4 +274,11 @@ var migrations = []migration{ add column code_challenge_method text not null default '';`, }, }, + { + stmts: []string{ + ` + alter table refresh_token + add column obsolete_token text default '';`, + }, + }, } From 0c75ed12e2feac0ecb7f95afe821420b0e598572 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Wed, 6 Jan 2021 04:22:38 +0400 Subject: [PATCH 3/8] Add refresh token expiration tests and some refactoring Signed-off-by: m.nabokikh --- server/handlers.go | 232 --------------------- server/refreshhandlers.go | 319 +++++++++++++++++++++++++++++ server/refreshhandlers_test.go | 187 +++++++++++++++++ server/server_test.go | 10 +- storage/conformance/conformance.go | 2 +- 5 files changed, 513 insertions(+), 237 deletions(-) create mode 100644 server/refreshhandlers.go create mode 100644 server/refreshhandlers_test.go diff --git a/server/handlers.go b/server/handlers.go index aa08b7a5..1db1f68f 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" - "errors" "fmt" "net/http" "net/url" @@ -1005,237 +1004,6 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil } -// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6 -func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { - code := r.PostFormValue("refresh_token") - scope := r.PostFormValue("scope") - if code == "" { - s.tokenErrHelper(w, errInvalidRequest, "No refresh token in request.", http.StatusBadRequest) - return - } - - token := new(internal.RefreshToken) - if err := internal.Unmarshal(code, token); err != nil { - // For backward compatibility, assume the refresh_token is a raw refresh token ID - // if it fails to decode. - // - // Because refresh_token values that aren't unmarshable were generated by servers - // that don't have a Token value, we'll still reject any attempts to claim a - // refresh_token twice. - token = &internal.RefreshToken{RefreshId: code, Token: ""} - } - - refresh, err := s.storage.GetRefresh(token.RefreshId) - if err != nil { - s.logger.Errorf("failed to get refresh token: %v", err) - if err == storage.ErrNotFound { - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) - } else { - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - } - return - } - - if refresh.ClientID != client.ID { - s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) - return - } - if refresh.Token != token.Token { - switch { - case !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed): - fallthrough - case refresh.ObsoleteToken != token.Token: - fallthrough - case refresh.ObsoleteToken == "": - s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest) - return - } - } - if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { - s.logger.Errorf("refresh token with id %s expired", refresh.ID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest) - return - } - if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { - s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) - s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest) - return - } - - // Per the OAuth2 spec, if the client has omitted the scopes, default to the original - // authorized scopes. - // - // https://tools.ietf.org/html/rfc6749#section-6 - scopes := refresh.Scopes - if scope != "" { - requestedScopes := strings.Fields(scope) - var unauthorizedScopes []string - - for _, s := range requestedScopes { - contains := func() bool { - for _, scope := range refresh.Scopes { - if s == scope { - return true - } - } - return false - }() - if !contains { - unauthorizedScopes = append(unauthorizedScopes, s) - } - } - - if len(unauthorizedScopes) > 0 { - msg := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) - s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest) - return - } - scopes = requestedScopes - } - - var connectorData []byte - - session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) - switch { - case err != nil: - if err != storage.ErrNotFound { - s.logger.Errorf("failed to get offline session: %v", err) - return - } - case len(refresh.ConnectorData) > 0: - // Use the old connector data if it exists, should be deleted once used - connectorData = refresh.ConnectorData - default: - connectorData = session.ConnectorData - } - - conn, err := s.getConnector(refresh.ConnectorID) - if err != nil { - s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - ident := connector.Identity{ - UserID: refresh.Claims.UserID, - Username: refresh.Claims.Username, - PreferredUsername: refresh.Claims.PreferredUsername, - Email: refresh.Claims.Email, - EmailVerified: refresh.Claims.EmailVerified, - Groups: refresh.Claims.Groups, - ConnectorData: connectorData, - } - - // Can the connector refresh the identity? If so, attempt to refresh the data - // in the connector. - // - // TODO(ericchiang): We may want a strict mode where connectors that don't implement - // this interface can't perform refreshing. - if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { - newIdent, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident) - if err != nil { - s.logger.Errorf("failed to refresh identity: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - ident = newIdent - } - - claims := storage.Claims{ - UserID: ident.UserID, - Username: ident.Username, - PreferredUsername: ident.PreferredUsername, - Email: ident.Email, - EmailVerified: ident.EmailVerified, - Groups: ident.Groups, - } - - accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) - if err != nil { - s.logger.Errorf("failed to create new access token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID) - if err != nil { - s.logger.Errorf("failed to create ID token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - newToken := token - if s.refreshTokenPolicy.RotationEnabled() { - newToken = &internal.RefreshToken{ - RefreshId: refresh.ID, - Token: storage.NewID(), - } - } - - lastUsed := s.now() - updater := func(old storage.RefreshToken) (storage.RefreshToken, error) { - if s.refreshTokenPolicy.RotationEnabled() { - if old.Token != refresh.Token { - if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token { - newToken.Token = old.Token - return old, nil - } - return old, errors.New("refresh token claimed twice") - } - - old.ObsoleteToken = old.Token - } - - old.Token = newToken.Token - // Update the claims of the refresh token. - // - // UserID intentionally ignored for now. - old.Claims.Username = ident.Username - old.Claims.PreferredUsername = ident.PreferredUsername - old.Claims.Email = ident.Email - old.Claims.EmailVerified = ident.EmailVerified - old.Claims.Groups = ident.Groups - old.LastUsed = lastUsed - - // ConnectorData has been moved to OfflineSession - old.ConnectorData = []byte{} - return old, nil - } - - // Update LastUsed time stamp in refresh token reference object - // in offline session for the user. - if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { - if old.Refresh[refresh.ClientID].ID != refresh.ID { - return old, errors.New("refresh token invalid") - } - old.Refresh[refresh.ClientID].LastUsed = lastUsed - old.ConnectorData = ident.ConnectorData - return old, nil - }); err != nil { - s.logger.Errorf("failed to update offline session: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - // Update refresh token in the storage. - if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil { - s.logger.Errorf("failed to update refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - rawNewToken, err := internal.Marshal(newToken) - if err != nil { - s.logger.Errorf("failed to marshal refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) - return - } - - resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) - s.writeAccessToken(w, resp) -} - func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { const prefix = "Bearer " diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go new file mode 100644 index 00000000..31709ad5 --- /dev/null +++ b/server/refreshhandlers.go @@ -0,0 +1,319 @@ +package server + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/dexidp/dex/connector" + "github.com/dexidp/dex/server/internal" + "github.com/dexidp/dex/storage" +) + +func contains(arr []string, item string) bool { + for _, itemFromArray := range arr { + if itemFromArray == item { + return true + } + } + return false +} + +type refreshError struct { + msg string + code int + desc string +} + +func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) { + s.tokenErrHelper(w, err.msg, err.desc, err.code) +} + +func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) { + code := r.PostFormValue("refresh_token") + if code == "" { + return nil, &refreshError{msg: errInvalidRequest, desc: "No refresh token in request.", code: http.StatusBadRequest} + } + + token := new(internal.RefreshToken) + if err := internal.Unmarshal(code, token); err != nil { + // For backward compatibility, assume the refresh_token is a raw refresh token ID + // if it fails to decode. + // + // Because refresh_token values that aren't unmarshable were generated by servers + // that don't have a Token value, we'll still reject any attempts to claim a + // refresh_token twice. + token = &internal.RefreshToken{RefreshId: code, Token: ""} + } + + return token, nil +} + +// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info +func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (storage.RefreshToken, *refreshError) { + refresh, err := s.storage.GetRefresh(token.RefreshId) + rerr := refreshError{ + msg: errInvalidRequest, + desc: "Refresh token is invalid or has already been claimed by another client.", + code: http.StatusBadRequest, + } + + if err != nil { + s.logger.Errorf("failed to get refresh token: %v", err) + if err != storage.ErrNotFound { + return storage.RefreshToken{}, &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} + } + + return storage.RefreshToken{}, &rerr + } + + if refresh.ClientID != clientID { + s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID) + return storage.RefreshToken{}, &rerr + } + + if refresh.Token != token.Token { + switch { + case !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed): + fallthrough + case refresh.ObsoleteToken != token.Token: + fallthrough + case refresh.ObsoleteToken == "": + s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) + return storage.RefreshToken{}, &rerr + } + } + + rerr.desc = "Refresh token expired." + if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { + s.logger.Errorf("refresh token with id %s expired", refresh.ID) + return storage.RefreshToken{}, &rerr + } + if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { + s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) + return storage.RefreshToken{}, &rerr + } + + return refresh, nil +} + +func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) { + // Per the OAuth2 spec, if the client has omitted the scopes, default to the original + // authorized scopes. + // + // https://tools.ietf.org/html/rfc6749#section-6 + scope := r.PostFormValue("scope") + + if scope == "" { + return refresh.Scopes, nil + } + + requestedScopes := strings.Fields(scope) + var unauthorizedScopes []string + + // Per the OAuth2 spec, if the client has omitted the scopes, default to the original + // authorized scopes. + // + // https://tools.ietf.org/html/rfc6749#section-6 + for _, requestScope := range requestedScopes { + if !contains(refresh.Scopes, requestScope) { + unauthorizedScopes = append(unauthorizedScopes, requestScope) + } + } + + if len(unauthorizedScopes) > 0 { + desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) + return nil, &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} + } + + return requestedScopes, nil +} + +func (s *Server) refreshWithConnector(ctx context.Context, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) { + var connectorData []byte + rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} + + session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) + switch { + case err != nil: + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + // TODO: previously there was a naked return without writing anything in response, need to figure it out + return connector.Identity{}, &rerr + } + case len(refresh.ConnectorData) > 0: + // Use the old connector data if it exists, should be deleted once used + connectorData = refresh.ConnectorData + default: + connectorData = session.ConnectorData + } + + conn, err := s.getConnector(refresh.ConnectorID) + if err != nil { + s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) + return connector.Identity{}, &rerr + } + + ident := connector.Identity{ + UserID: refresh.Claims.UserID, + Username: refresh.Claims.Username, + PreferredUsername: refresh.Claims.PreferredUsername, + Email: refresh.Claims.Email, + EmailVerified: refresh.Claims.EmailVerified, + Groups: refresh.Claims.Groups, + ConnectorData: connectorData, + } + + // Can the connector refresh the identity? If so, attempt to refresh the data + // in the connector. + // + // TODO(ericchiang): We may want a strict mode where connectors that don't implement + // this interface can't perform refreshing. + if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { + newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) + if err != nil { + s.logger.Errorf("failed to refresh identity: %v", err) + return connector.Identity{}, &rerr + } + ident = newIdent + } + + return ident, nil +} + +// updateRefreshToken updates refresh token and offline session in the storage +func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) { + newToken := token + if s.refreshTokenPolicy.RotationEnabled() { + newToken = &internal.RefreshToken{ + RefreshId: refresh.ID, + Token: storage.NewID(), + } + } + + lastUsed := s.now() + refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { + if s.refreshTokenPolicy.RotationEnabled() { + if old.Token != refresh.Token { + if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token { + newToken.Token = old.Token + return old, nil + } + return old, errors.New("refresh token claimed twice") + } + + old.ObsoleteToken = old.Token + } + + old.Token = newToken.Token + // Update the claims of the refresh token. + // + // UserID intentionally ignored for now. + old.Claims.Username = ident.Username + old.Claims.PreferredUsername = ident.PreferredUsername + old.Claims.Email = ident.Email + old.Claims.EmailVerified = ident.EmailVerified + old.Claims.Groups = ident.Groups + old.LastUsed = lastUsed + + // ConnectorData has been moved to OfflineSession + old.ConnectorData = []byte{} + return old, nil + } + + offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if old.Refresh[refresh.ClientID].ID != refresh.ID { + return old, errors.New("refresh token invalid") + } + old.Refresh[refresh.ClientID].LastUsed = lastUsed + old.ConnectorData = ident.ConnectorData + return old, nil + } + + rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} + + // Update LastUsed time stamp in refresh token reference object + // in offline session for the user. + err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) + if err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + return newToken, &rerr + } + + // Update refresh token in the storage. + err = s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) + if err != nil { + s.logger.Errorf("failed to update refresh token: %v", err) + return newToken, &rerr + } + + return newToken, nil +} + +// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6 +func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { + token, rerr := s.extractRefreshTokenFromRequest(r) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + refresh, rerr := s.getRefreshTokenFromStorage(client.ID, token) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + scopes, rerr := s.getRefreshScopes(r, &refresh) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + ident, rerr := s.refreshWithConnector(r.Context(), &refresh, scopes) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + claims := storage.Claims{ + UserID: ident.UserID, + Username: ident.Username, + PreferredUsername: ident.PreferredUsername, + Email: ident.Email, + EmailVerified: ident.EmailVerified, + Groups: ident.Groups, + } + + accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) + if err != nil { + s.logger.Errorf("failed to create new access token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID) + if err != nil { + s.logger.Errorf("failed to create ID token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + newToken, rerr := s.updateRefreshToken(token, &refresh, ident) + if rerr != nil { + s.refreshTokenErrHelper(w, rerr) + return + } + + rawNewToken, err := internal.Marshal(newToken) + if err != nil { + s.logger.Errorf("failed to marshal refresh token: %v", err) + s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + return + } + + resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry) + s.writeAccessToken(w, resp) +} diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go new file mode 100644 index 00000000..40e81435 --- /dev/null +++ b/server/refreshhandlers_test.go @@ -0,0 +1,187 @@ +package server + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "net/url" + "path" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/dexidp/dex/server/internal" + "github.com/dexidp/dex/storage" +) + +func TestRefreshTokenExpirationScenarios(t *testing.T) { + t0 := time.Now() + tests := []struct { + name string + policy *RefreshTokenPolicy + useObsolete bool + error string + }{ + { + name: "Normal", + policy: &RefreshTokenPolicy{rotateRefreshTokens: true}, + error: ``, + }, + { + name: "Not expired because used", + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: false, + validIfNotUsedFor: time.Second * 60, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: ``, + }, + { + name: "Expired because not used", + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: false, + validIfNotUsedFor: time.Second * 60, + now: func() time.Time { return t0.Add(time.Hour) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, + }, + { + name: "Absolutely expired", + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + absoluteLifetime: time.Second * 60, + now: func() time.Time { return t0.Add(time.Hour) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, + }, + { + name: "Obsolete tokens are not allowed", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`, + }, + { + name: "Obsolete tokens are allowed", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + reuseInterval: time.Second * 30, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: ``, + }, + { + name: "Obsolete tokens are allowed but token is expired globally", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + reuseInterval: time.Second * 30, + absoluteLifetime: time.Second * 20, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(*testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Setup a dex server. + httpServer, s := newTestServer(ctx, t, func(c *Config) { + c.RefreshTokenPolicy = tc.policy + c.Now = func() time.Time { return t0 } + }) + defer httpServer.Close() + + c := storage.Client{ + ID: "test", + Secret: "barfoo", + RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, + Name: "dex client", + LogoURL: "https://goo.gl/JIyzIC", + } + + err := s.storage.CreateClient(c) + require.NoError(t, err) + + c1 := storage.Connector{ + ID: "test", + Type: "mockCallback", + Name: "mockCallback", + Config: nil, + } + + err = s.storage.CreateConnector(c1) + require.NoError(t, err) + + refresh := storage.RefreshToken{ + ID: "test", + Token: "bar", + ObsoleteToken: "", + Nonce: "foo", + ClientID: "test", + ConnectorID: "test", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + ConnectorData: []byte(`{"some":"data"}`), + } + + if tc.useObsolete { + refresh.Token = "testtest" + refresh.ObsoleteToken = "bar" + } + + err = s.storage.CreateRefresh(refresh) + require.NoError(t, err) + + offlineSessions := storage.OfflineSessions{ + UserID: "1", + ConnID: "test", + Refresh: map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}}, + ConnectorData: nil, + } + + err = s.storage.CreateOfflineSessions(offlineSessions) + require.NoError(t, err) + + u, err := url.Parse(s.issuerURL.String()) + require.NoError(t, err) + + tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"}) + require.NoError(t, err) + + u.Path = path.Join(u.Path, "/token") + v := url.Values{} + v.Add("grant_type", "refresh_token") + v.Add("refresh_token", tokenData) + + req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + req.SetBasicAuth("test", "barfoo") + + rr := httptest.NewRecorder() + s.ServeHTTP(rr, req) + + if tc.error == "" { + require.Equal(t, 200, rr.Code) + } else { + require.Equal(t, rr.Body.String(), tc.error) + } + }) + } +} diff --git a/server/server_test.go b/server/server_test.go index d8b40991..62ba40c9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -119,11 +119,13 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. // Default rotation policy - server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "") - if err != nil { - t.Fatalf("failed to prepare rotation policy: %v", err) + if server.refreshTokenPolicy == nil { + server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "") + if err != nil { + t.Fatalf("failed to prepare rotation policy: %v", err) + } + server.refreshTokenPolicy.now = config.Now } - server.refreshTokenPolicy.now = config.Now return s, server } diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 0bae52cb..dde369c4 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -381,7 +381,7 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { refresh2 := storage.RefreshToken{ ID: id2, Token: "bar_2", - ObsoleteToken: "bar", + ObsoleteToken: refresh.Token, Nonce: "foo_2", ClientID: "client_id_2", ConnectorID: "client_secret", From 4e73f39f57e3b169cc45a961442f2f7f5624ea1e Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Mon, 11 Jan 2021 14:06:31 +0400 Subject: [PATCH 4/8] Do not refresh id token claims if refresh token is allowed to reuse Signed-off-by: m.nabokikh --- server/refreshhandlers.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 31709ad5..311eb30a 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -91,6 +91,7 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref s.logger.Errorf("refresh token with id %s expired", refresh.ID) return storage.RefreshToken{}, &rerr } + if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) return storage.RefreshToken{}, &rerr @@ -131,7 +132,7 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken return requestedScopes, nil } -func (s *Server) refreshWithConnector(ctx context.Context, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) { +func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) { var connectorData []byte rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} @@ -166,6 +167,12 @@ func (s *Server) refreshWithConnector(ctx context.Context, refresh *storage.Refr ConnectorData: connectorData, } + // user's token was previously updated by a connector and is allowed to reuse + // it is excessive to refresh identity in upstream + if s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) && token.Token == refresh.ObsoleteToken { + return ident, nil + } + // Can the connector refresh the identity? If so, attempt to refresh the data // in the connector. // @@ -272,7 +279,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - ident, rerr := s.refreshWithConnector(r.Context(), &refresh, scopes) + ident, rerr := s.refreshWithConnector(r.Context(), token, &refresh, scopes) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return From 89295a5b4ad96aa15d978926e99a9c3183329eac Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Fri, 15 Jan 2021 01:15:56 +0400 Subject: [PATCH 5/8] More refresh token handler refactoring, more tests Signed-off-by: m.nabokikh --- server/refreshhandlers.go | 115 +++++++++++++----------- server/refreshhandlers_test.go | 159 +++++++++++++++++++-------------- 2 files changed, 156 insertions(+), 118 deletions(-) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 311eb30a..588b91f1 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "strings" + "time" "github.com/dexidp/dex/connector" "github.com/dexidp/dex/server/internal" @@ -27,6 +28,12 @@ type refreshError struct { desc string } +var internalErr = &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} + +func newBadRequestError(desc string) *refreshError { + return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} +} + func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) { s.tokenErrHelper(w, err.msg, err.desc, err.code) } @@ -34,7 +41,7 @@ func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) { code := r.PostFormValue("refresh_token") if code == "" { - return nil, &refreshError{msg: errInvalidRequest, desc: "No refresh token in request.", code: http.StatusBadRequest} + return nil, newBadRequestError("No refresh token is found in request.") } token := new(internal.RefreshToken) @@ -52,26 +59,22 @@ func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.Refr } // getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info -func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (storage.RefreshToken, *refreshError) { - refresh, err := s.storage.GetRefresh(token.RefreshId) - rerr := refreshError{ - msg: errInvalidRequest, - desc: "Refresh token is invalid or has already been claimed by another client.", - code: http.StatusBadRequest, - } +func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) { + invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.") + refresh, err := s.storage.GetRefresh(token.RefreshId) if err != nil { s.logger.Errorf("failed to get refresh token: %v", err) if err != storage.ErrNotFound { - return storage.RefreshToken{}, &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} + return nil, internalErr } - return storage.RefreshToken{}, &rerr + return nil, invalidErr } if refresh.ClientID != clientID { s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID) - return storage.RefreshToken{}, &rerr + return nil, invalidErr } if refresh.Token != token.Token { @@ -82,22 +85,22 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref fallthrough case refresh.ObsoleteToken == "": s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) - return storage.RefreshToken{}, &rerr + return nil, invalidErr } } - rerr.desc = "Refresh token expired." + expiredErr := newBadRequestError("Refresh token expired.") if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { s.logger.Errorf("refresh token with id %s expired", refresh.ID) - return storage.RefreshToken{}, &rerr + return nil, expiredErr } if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) - return storage.RefreshToken{}, &rerr + return nil, expiredErr } - return refresh, nil + return &refresh, nil } func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) { @@ -126,7 +129,7 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken if len(unauthorizedScopes) > 0 { desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) - return nil, &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} + return nil, newBadRequestError(desc) } return requestedScopes, nil @@ -134,15 +137,15 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) { var connectorData []byte - rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) switch { case err != nil: if err != storage.ErrNotFound { s.logger.Errorf("failed to get offline session: %v", err) - // TODO: previously there was a naked return without writing anything in response, need to figure it out - return connector.Identity{}, &rerr + // TODO: previously there was a naked return without writing anything in response + // Need to ensure that everything works as expected. + return connector.Identity{}, internalErr } case len(refresh.ConnectorData) > 0: // Use the old connector data if it exists, should be deleted once used @@ -154,7 +157,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre conn, err := s.getConnector(refresh.ConnectorID) if err != nil { s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) - return connector.Identity{}, &rerr + return connector.Identity{}, internalErr } ident := connector.Identity{ @@ -182,7 +185,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) if err != nil { s.logger.Errorf("failed to refresh identity: %v", err) - return connector.Identity{}, &rerr + return connector.Identity{}, internalErr } ident = newIdent } @@ -190,6 +193,28 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre return ident, nil } +// updateOfflineSession updates offline session in the storage +func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError { + offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if old.Refresh[refresh.ClientID].ID != refresh.ID { + return old, errors.New("refresh token invalid") + } + old.Refresh[refresh.ClientID].LastUsed = lastUsed + old.ConnectorData = ident.ConnectorData + return old, nil + } + + // Update LastUsed time stamp in refresh token reference object + // in offline session for the user. + err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) + if err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + return internalErr + } + + return nil +} + // updateRefreshToken updates refresh token and offline session in the storage func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) { newToken := token @@ -201,10 +226,16 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora } lastUsed := s.now() + + rerr := s.updateOfflineSession(refresh, ident, lastUsed) + if rerr != nil { + return nil, rerr + } + refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { if s.refreshTokenPolicy.RotationEnabled() { - if old.Token != refresh.Token { - if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token { + if old.Token != token.Token { + if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token { newToken.Token = old.Token return old, nil } @@ -230,36 +261,18 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora return old, nil } - offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { - if old.Refresh[refresh.ClientID].ID != refresh.ID { - return old, errors.New("refresh token invalid") - } - old.Refresh[refresh.ClientID].LastUsed = lastUsed - old.ConnectorData = ident.ConnectorData - return old, nil - } - - rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} - - // Update LastUsed time stamp in refresh token reference object - // in offline session for the user. - err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) - if err != nil { - s.logger.Errorf("failed to update offline session: %v", err) - return newToken, &rerr - } - // Update refresh token in the storage. - err = s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) + err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) if err != nil { s.logger.Errorf("failed to update refresh token: %v", err) - return newToken, &rerr + return nil, internalErr } return newToken, nil } // handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6 +// this method is the entrypoint for refresh tokens handling func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { token, rerr := s.extractRefreshTokenFromRequest(r) if rerr != nil { @@ -273,13 +286,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - scopes, rerr := s.getRefreshScopes(r, &refresh) + scopes, rerr := s.getRefreshScopes(r, refresh) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return } - ident, rerr := s.refreshWithConnector(r.Context(), token, &refresh, scopes) + ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return @@ -297,18 +310,18 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) if err != nil { s.logger.Errorf("failed to create new access token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + s.refreshTokenErrHelper(w, internalErr) return } idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + s.refreshTokenErrHelper(w, internalErr) return } - newToken, rerr := s.updateRefreshToken(token, &refresh, ident) + newToken, rerr := s.updateRefreshToken(token, refresh, ident) if rerr != nil { s.refreshTokenErrHelper(w, rerr) return @@ -317,7 +330,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie rawNewToken, err := internal.Marshal(newToken) if err != nil { s.logger.Errorf("failed to marshal refresh token: %v", err) - s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) + s.refreshTokenErrHelper(w, internalErr) return } diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go index 40e81435..c64c50b3 100644 --- a/server/refreshhandlers_test.go +++ b/server/refreshhandlers_test.go @@ -3,6 +3,7 @@ package server import ( "bytes" "context" + "encoding/json" "net/http" "net/http/httptest" "net/url" @@ -16,6 +17,67 @@ import ( "github.com/dexidp/dex/storage" ) +func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) { + c := storage.Client{ + ID: "test", + Secret: "barfoo", + RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, + Name: "dex client", + LogoURL: "https://goo.gl/JIyzIC", + } + + err := s.CreateClient(c) + require.NoError(t, err) + + c1 := storage.Connector{ + ID: "test", + Type: "mockCallback", + Name: "mockCallback", + Config: nil, + } + + err = s.CreateConnector(c1) + require.NoError(t, err) + + refresh := storage.RefreshToken{ + ID: "test", + Token: "bar", + ObsoleteToken: "", + Nonce: "foo", + ClientID: "test", + ConnectorID: "test", + Scopes: []string{"openid", "email", "profile"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane.doe@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + ConnectorData: []byte(`{"some":"data"}`), + } + + if useObsolete { + refresh.Token = "testtest" + refresh.ObsoleteToken = "bar" + } + + err = s.CreateRefresh(refresh) + require.NoError(t, err) + + offlineSessions := storage.OfflineSessions{ + UserID: "1", + ConnID: "test", + Refresh: map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}}, + ConnectorData: nil, + } + + err = s.CreateOfflineSessions(offlineSessions) + require.NoError(t, err) +} + func TestRefreshTokenExpirationScenarios(t *testing.T) { t0 := time.Now() tests := []struct { @@ -56,15 +118,6 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { }, error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, }, - { - name: "Obsolete tokens are not allowed", - useObsolete: true, - policy: &RefreshTokenPolicy{ - rotateRefreshTokens: true, - now: func() time.Time { return t0.Add(time.Second * 25) }, - }, - error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`, - }, { name: "Obsolete tokens are allowed", useObsolete: true, @@ -75,6 +128,15 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { }, error: ``, }, + { + name: "Obsolete tokens are not allowed", + useObsolete: true, + policy: &RefreshTokenPolicy{ + rotateRefreshTokens: true, + now: func() time.Time { return t0.Add(time.Second * 25) }, + }, + error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`, + }, { name: "Obsolete tokens are allowed but token is expired globally", useObsolete: true, @@ -100,64 +162,7 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { }) defer httpServer.Close() - c := storage.Client{ - ID: "test", - Secret: "barfoo", - RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, - Name: "dex client", - LogoURL: "https://goo.gl/JIyzIC", - } - - err := s.storage.CreateClient(c) - require.NoError(t, err) - - c1 := storage.Connector{ - ID: "test", - Type: "mockCallback", - Name: "mockCallback", - Config: nil, - } - - err = s.storage.CreateConnector(c1) - require.NoError(t, err) - - refresh := storage.RefreshToken{ - ID: "test", - Token: "bar", - ObsoleteToken: "", - Nonce: "foo", - ClientID: "test", - ConnectorID: "test", - Scopes: []string{"openid", "email", "profile"}, - CreatedAt: time.Now().UTC().Round(time.Millisecond), - LastUsed: time.Now().UTC().Round(time.Millisecond), - Claims: storage.Claims{ - UserID: "1", - Username: "jane", - Email: "jane.doe@example.com", - EmailVerified: true, - Groups: []string{"a", "b"}, - }, - ConnectorData: []byte(`{"some":"data"}`), - } - - if tc.useObsolete { - refresh.Token = "testtest" - refresh.ObsoleteToken = "bar" - } - - err = s.storage.CreateRefresh(refresh) - require.NoError(t, err) - - offlineSessions := storage.OfflineSessions{ - UserID: "1", - ConnID: "test", - Refresh: map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}}, - ConnectorData: nil, - } - - err = s.storage.CreateOfflineSessions(offlineSessions) - require.NoError(t, err) + mockRefreshTokenTestStorage(t, s.storage, tc.useObsolete) u, err := url.Parse(s.issuerURL.String()) require.NoError(t, err) @@ -181,6 +186,26 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { require.Equal(t, 200, rr.Code) } else { require.Equal(t, rr.Body.String(), tc.error) + return + } + + // Check that we received expected refresh token + var ref struct { + Token string `json:"refresh_token"` + } + err = json.Unmarshal(rr.Body.Bytes(), &ref) + require.NoError(t, err) + + if tc.policy.rotateRefreshTokens == false { + require.Equal(t, tokenData, ref.Token) + } else { + require.NotEqual(t, tokenData, ref.Token) + } + + if tc.useObsolete { + updatedTokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "testtest"}) + require.NoError(t, err) + require.Equal(t, updatedTokenData, ref.Token) } }) } From 9340fee011d9691f7ef19c1b6309ca7f836d3b18 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Wed, 10 Feb 2021 23:46:17 +0400 Subject: [PATCH 6/8] Fixes after rebasing to the actual main branch Signed-off-by: m.nabokikh --- server/refreshhandlers.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 588b91f1..296c994b 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -314,7 +314,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } - idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID) + idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) s.refreshTokenErrHelper(w, internalErr) From 568fc065208b7eb2186ac49d8730dcf1c7998191 Mon Sep 17 00:00:00 2001 From: Maksim Nabokikh <32434187+nabokihms@users.noreply.github.com> Date: Tue, 9 Mar 2021 09:13:54 +0400 Subject: [PATCH 7/8] Update server/refreshhandlers.go Co-authored-by: Joel Speed Signed-off-by: m.nabokikh --- cmd/dex/config.go | 4 ++-- server/refreshhandlers.go | 26 +++++++++++++------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/cmd/dex/config.go b/cmd/dex/config.go index a75ddaee..f218879d 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -306,7 +306,7 @@ type Expiry struct { DeviceRequests string `json:"deviceRequests"` // RefreshTokens defines refresh tokens expiry policy - RefreshTokens RefreshTokenExpiry `json:"refreshTokens"` + RefreshTokens RefreshToken `json:"refreshTokens"` } // Logger holds configuration required to customize logging for dex. @@ -318,7 +318,7 @@ type Logger struct { Format string `json:"format"` } -type RefreshTokenExpiry struct { +type RefreshToken struct { DisableRotation bool `json:"disableRotation"` ReuseInterval string `json:"reuseInterval"` AbsoluteLifetime string `json:"absoluteLifetime"` diff --git a/server/refreshhandlers.go b/server/refreshhandlers.go index 296c994b..8ea7ea9e 100644 --- a/server/refreshhandlers.go +++ b/server/refreshhandlers.go @@ -28,7 +28,9 @@ type refreshError struct { desc string } -var internalErr = &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} +func newInternalServerError() *refreshError { + return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} +} func newBadRequestError(desc string) *refreshError { return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} @@ -66,7 +68,7 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref if err != nil { s.logger.Errorf("failed to get refresh token: %v", err) if err != storage.ErrNotFound { - return nil, internalErr + return nil, newInternalServerError() } return nil, invalidErr @@ -96,7 +98,7 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref } if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { - s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) + s.logger.Errorf("refresh token with id %s expired due to inactivity", refresh.ID) return nil, expiredErr } @@ -143,9 +145,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre case err != nil: if err != storage.ErrNotFound { s.logger.Errorf("failed to get offline session: %v", err) - // TODO: previously there was a naked return without writing anything in response - // Need to ensure that everything works as expected. - return connector.Identity{}, internalErr + return connector.Identity{}, newInternalServerError() } case len(refresh.ConnectorData) > 0: // Use the old connector data if it exists, should be deleted once used @@ -157,7 +157,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre conn, err := s.getConnector(refresh.ConnectorID) if err != nil { s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) - return connector.Identity{}, internalErr + return connector.Identity{}, newInternalServerError() } ident := connector.Identity{ @@ -185,7 +185,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) if err != nil { s.logger.Errorf("failed to refresh identity: %v", err) - return connector.Identity{}, internalErr + return connector.Identity{}, newInternalServerError() } ident = newIdent } @@ -209,7 +209,7 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) if err != nil { s.logger.Errorf("failed to update offline session: %v", err) - return internalErr + return newInternalServerError() } return nil @@ -265,7 +265,7 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) if err != nil { s.logger.Errorf("failed to update refresh token: %v", err) - return nil, internalErr + return nil, newInternalServerError() } return newToken, nil @@ -310,14 +310,14 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) if err != nil { s.logger.Errorf("failed to create new access token: %v", err) - s.refreshTokenErrHelper(w, internalErr) + s.refreshTokenErrHelper(w, newInternalServerError()) return } idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID) if err != nil { s.logger.Errorf("failed to create ID token: %v", err) - s.refreshTokenErrHelper(w, internalErr) + s.refreshTokenErrHelper(w, newInternalServerError()) return } @@ -330,7 +330,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie rawNewToken, err := internal.Marshal(newToken) if err != nil { s.logger.Errorf("failed to marshal refresh token: %v", err) - s.refreshTokenErrHelper(w, internalErr) + s.refreshTokenErrHelper(w, newInternalServerError()) return } From beb8911cf7658b6b95d05bffaabf1a4183c3a059 Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Fri, 2 Apr 2021 16:12:43 +0400 Subject: [PATCH 8/8] chore: add note about units to expire config Signed-off-by: m.nabokikh --- examples/config-dev.yaml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 344d72dc..b40ea582 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -73,14 +73,15 @@ telemetry: # tlsClientCA: examples/grpc-client/ca.crt # Uncomment this block to enable configuration for the expiration time durations. +# Is possible to specify units using only s, m and h suffixes. # expiry: # deviceRequests: "5m" # signingKeys: "6h" # idTokens: "24h" # refreshTokens: # reuseInterval: "3s" -# validIfNotUsedFor: "2190h" -# absoluteLifetime: "5000h" +# validIfNotUsedFor: "2160h" # 90 days +# absoluteLifetime: "3960h" # 165 days # Options for controlling the logger. # logger: