More refresh token handler refactoring, more tests
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
		| @@ -6,6 +6,7 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/dexidp/dex/connector" | 	"github.com/dexidp/dex/connector" | ||||||
| 	"github.com/dexidp/dex/server/internal" | 	"github.com/dexidp/dex/server/internal" | ||||||
| @@ -27,6 +28,12 @@ type refreshError struct { | |||||||
| 	desc string | 	desc string | ||||||
| } | } | ||||||
|  |  | ||||||
|  | var internalErr = &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} | ||||||
|  |  | ||||||
|  | func newBadRequestError(desc string) *refreshError { | ||||||
|  | 	return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} | ||||||
|  | } | ||||||
|  |  | ||||||
| func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) { | func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) { | ||||||
| 	s.tokenErrHelper(w, err.msg, err.desc, err.code) | 	s.tokenErrHelper(w, err.msg, err.desc, err.code) | ||||||
| } | } | ||||||
| @@ -34,7 +41,7 @@ func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) | |||||||
| func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) { | func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) { | ||||||
| 	code := r.PostFormValue("refresh_token") | 	code := r.PostFormValue("refresh_token") | ||||||
| 	if code == "" { | 	if code == "" { | ||||||
| 		return nil, &refreshError{msg: errInvalidRequest, desc: "No refresh token in request.", code: http.StatusBadRequest} | 		return nil, newBadRequestError("No refresh token is found in request.") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	token := new(internal.RefreshToken) | 	token := new(internal.RefreshToken) | ||||||
| @@ -52,26 +59,22 @@ func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.Refr | |||||||
| } | } | ||||||
|  |  | ||||||
| // getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info | // getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info | ||||||
| func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (storage.RefreshToken, *refreshError) { | func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) { | ||||||
| 	refresh, err := s.storage.GetRefresh(token.RefreshId) | 	invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.") | ||||||
| 	rerr := refreshError{ |  | ||||||
| 		msg:  errInvalidRequest, |  | ||||||
| 		desc: "Refresh token is invalid or has already been claimed by another client.", |  | ||||||
| 		code: http.StatusBadRequest, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
|  | 	refresh, err := s.storage.GetRefresh(token.RefreshId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		s.logger.Errorf("failed to get refresh token: %v", err) | 		s.logger.Errorf("failed to get refresh token: %v", err) | ||||||
| 		if err != storage.ErrNotFound { | 		if err != storage.ErrNotFound { | ||||||
| 			return storage.RefreshToken{}, &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} | 			return nil, internalErr | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		return storage.RefreshToken{}, &rerr | 		return nil, invalidErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if refresh.ClientID != clientID { | 	if refresh.ClientID != clientID { | ||||||
| 		s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID) | 		s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID) | ||||||
| 		return storage.RefreshToken{}, &rerr | 		return nil, invalidErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if refresh.Token != token.Token { | 	if refresh.Token != token.Token { | ||||||
| @@ -82,22 +85,22 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref | |||||||
| 			fallthrough | 			fallthrough | ||||||
| 		case refresh.ObsoleteToken == "": | 		case refresh.ObsoleteToken == "": | ||||||
| 			s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) | 			s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) | ||||||
| 			return storage.RefreshToken{}, &rerr | 			return nil, invalidErr | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	rerr.desc = "Refresh token expired." | 	expiredErr := newBadRequestError("Refresh token expired.") | ||||||
| 	if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { | 	if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { | ||||||
| 		s.logger.Errorf("refresh token with id %s expired", refresh.ID) | 		s.logger.Errorf("refresh token with id %s expired", refresh.ID) | ||||||
| 		return storage.RefreshToken{}, &rerr | 		return nil, expiredErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { | 	if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { | ||||||
| 		s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) | 		s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID) | ||||||
| 		return storage.RefreshToken{}, &rerr | 		return nil, expiredErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return refresh, nil | 	return &refresh, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) { | func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) { | ||||||
| @@ -126,7 +129,7 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken | |||||||
|  |  | ||||||
| 	if len(unauthorizedScopes) > 0 { | 	if len(unauthorizedScopes) > 0 { | ||||||
| 		desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) | 		desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) | ||||||
| 		return nil, &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest} | 		return nil, newBadRequestError(desc) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return requestedScopes, nil | 	return requestedScopes, nil | ||||||
| @@ -134,15 +137,15 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken | |||||||
|  |  | ||||||
| func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) { | func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) { | ||||||
| 	var connectorData []byte | 	var connectorData []byte | ||||||
| 	rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} |  | ||||||
|  |  | ||||||
| 	session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) | 	session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID) | ||||||
| 	switch { | 	switch { | ||||||
| 	case err != nil: | 	case err != nil: | ||||||
| 		if err != storage.ErrNotFound { | 		if err != storage.ErrNotFound { | ||||||
| 			s.logger.Errorf("failed to get offline session: %v", err) | 			s.logger.Errorf("failed to get offline session: %v", err) | ||||||
| 			// TODO: previously there was a naked return without writing anything in response, need to figure it out | 			// TODO: previously there was a naked return without writing anything in response | ||||||
| 			return connector.Identity{}, &rerr | 			//   Need to ensure that everything works as expected. | ||||||
|  | 			return connector.Identity{}, internalErr | ||||||
| 		} | 		} | ||||||
| 	case len(refresh.ConnectorData) > 0: | 	case len(refresh.ConnectorData) > 0: | ||||||
| 		// Use the old connector data if it exists, should be deleted once used | 		// Use the old connector data if it exists, should be deleted once used | ||||||
| @@ -154,7 +157,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre | |||||||
| 	conn, err := s.getConnector(refresh.ConnectorID) | 	conn, err := s.getConnector(refresh.ConnectorID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) | 		s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) | ||||||
| 		return connector.Identity{}, &rerr | 		return connector.Identity{}, internalErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ident := connector.Identity{ | 	ident := connector.Identity{ | ||||||
| @@ -182,7 +185,7 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre | |||||||
| 		newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) | 		newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			s.logger.Errorf("failed to refresh identity: %v", err) | 			s.logger.Errorf("failed to refresh identity: %v", err) | ||||||
| 			return connector.Identity{}, &rerr | 			return connector.Identity{}, internalErr | ||||||
| 		} | 		} | ||||||
| 		ident = newIdent | 		ident = newIdent | ||||||
| 	} | 	} | ||||||
| @@ -190,6 +193,28 @@ func (s *Server) refreshWithConnector(ctx context.Context, token *internal.Refre | |||||||
| 	return ident, nil | 	return ident, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // updateOfflineSession updates offline session in the storage | ||||||
|  | func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError { | ||||||
|  | 	offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { | ||||||
|  | 		if old.Refresh[refresh.ClientID].ID != refresh.ID { | ||||||
|  | 			return old, errors.New("refresh token invalid") | ||||||
|  | 		} | ||||||
|  | 		old.Refresh[refresh.ClientID].LastUsed = lastUsed | ||||||
|  | 		old.ConnectorData = ident.ConnectorData | ||||||
|  | 		return old, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Update LastUsed time stamp in refresh token reference object | ||||||
|  | 	// in offline session for the user. | ||||||
|  | 	err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) | ||||||
|  | 	if err != nil { | ||||||
|  | 		s.logger.Errorf("failed to update offline session: %v", err) | ||||||
|  | 		return internalErr | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
| // updateRefreshToken updates refresh token and offline session in the storage | // updateRefreshToken updates refresh token and offline session in the storage | ||||||
| func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) { | func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) { | ||||||
| 	newToken := token | 	newToken := token | ||||||
| @@ -201,10 +226,16 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	lastUsed := s.now() | 	lastUsed := s.now() | ||||||
|  |  | ||||||
|  | 	rerr := s.updateOfflineSession(refresh, ident, lastUsed) | ||||||
|  | 	if rerr != nil { | ||||||
|  | 		return nil, rerr | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { | 	refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) { | ||||||
| 		if s.refreshTokenPolicy.RotationEnabled() { | 		if s.refreshTokenPolicy.RotationEnabled() { | ||||||
| 			if old.Token != refresh.Token { | 			if old.Token != token.Token { | ||||||
| 				if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token { | 				if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token { | ||||||
| 					newToken.Token = old.Token | 					newToken.Token = old.Token | ||||||
| 					return old, nil | 					return old, nil | ||||||
| 				} | 				} | ||||||
| @@ -230,36 +261,18 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora | |||||||
| 		return old, nil | 		return old, nil | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { |  | ||||||
| 		if old.Refresh[refresh.ClientID].ID != refresh.ID { |  | ||||||
| 			return old, errors.New("refresh token invalid") |  | ||||||
| 		} |  | ||||||
| 		old.Refresh[refresh.ClientID].LastUsed = lastUsed |  | ||||||
| 		old.ConnectorData = ident.ConnectorData |  | ||||||
| 		return old, nil |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	rerr := refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError} |  | ||||||
|  |  | ||||||
| 	// Update LastUsed time stamp in refresh token reference object |  | ||||||
| 	// in offline session for the user. |  | ||||||
| 	err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) |  | ||||||
| 	if err != nil { |  | ||||||
| 		s.logger.Errorf("failed to update offline session: %v", err) |  | ||||||
| 		return newToken, &rerr |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Update refresh token in the storage. | 	// Update refresh token in the storage. | ||||||
| 	err = s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) | 	err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		s.logger.Errorf("failed to update refresh token: %v", err) | 		s.logger.Errorf("failed to update refresh token: %v", err) | ||||||
| 		return newToken, &rerr | 		return nil, internalErr | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return newToken, nil | 	return newToken, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6 | // handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6 | ||||||
|  | // this method is the entrypoint for refresh tokens handling | ||||||
| func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { | func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) { | ||||||
| 	token, rerr := s.extractRefreshTokenFromRequest(r) | 	token, rerr := s.extractRefreshTokenFromRequest(r) | ||||||
| 	if rerr != nil { | 	if rerr != nil { | ||||||
| @@ -273,13 +286,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	scopes, rerr := s.getRefreshScopes(r, &refresh) | 	scopes, rerr := s.getRefreshScopes(r, refresh) | ||||||
| 	if rerr != nil { | 	if rerr != nil { | ||||||
| 		s.refreshTokenErrHelper(w, rerr) | 		s.refreshTokenErrHelper(w, rerr) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	ident, rerr := s.refreshWithConnector(r.Context(), token, &refresh, scopes) | 	ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes) | ||||||
| 	if rerr != nil { | 	if rerr != nil { | ||||||
| 		s.refreshTokenErrHelper(w, rerr) | 		s.refreshTokenErrHelper(w, rerr) | ||||||
| 		return | 		return | ||||||
| @@ -297,18 +310,18 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie | |||||||
| 	accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) | 	accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		s.logger.Errorf("failed to create new access token: %v", err) | 		s.logger.Errorf("failed to create new access token: %v", err) | ||||||
| 		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) | 		s.refreshTokenErrHelper(w, internalErr) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID) | 	idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		s.logger.Errorf("failed to create ID token: %v", err) | 		s.logger.Errorf("failed to create ID token: %v", err) | ||||||
| 		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) | 		s.refreshTokenErrHelper(w, internalErr) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	newToken, rerr := s.updateRefreshToken(token, &refresh, ident) | 	newToken, rerr := s.updateRefreshToken(token, refresh, ident) | ||||||
| 	if rerr != nil { | 	if rerr != nil { | ||||||
| 		s.refreshTokenErrHelper(w, rerr) | 		s.refreshTokenErrHelper(w, rerr) | ||||||
| 		return | 		return | ||||||
| @@ -317,7 +330,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie | |||||||
| 	rawNewToken, err := internal.Marshal(newToken) | 	rawNewToken, err := internal.Marshal(newToken) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		s.logger.Errorf("failed to marshal refresh token: %v", err) | 		s.logger.Errorf("failed to marshal refresh token: %v", err) | ||||||
| 		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) | 		s.refreshTokenErrHelper(w, internalErr) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ package server | |||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"encoding/json" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| @@ -16,6 +17,67 @@ import ( | |||||||
| 	"github.com/dexidp/dex/storage" | 	"github.com/dexidp/dex/storage" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) { | ||||||
|  | 	c := storage.Client{ | ||||||
|  | 		ID:           "test", | ||||||
|  | 		Secret:       "barfoo", | ||||||
|  | 		RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, | ||||||
|  | 		Name:         "dex client", | ||||||
|  | 		LogoURL:      "https://goo.gl/JIyzIC", | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err := s.CreateClient(c) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	c1 := storage.Connector{ | ||||||
|  | 		ID:     "test", | ||||||
|  | 		Type:   "mockCallback", | ||||||
|  | 		Name:   "mockCallback", | ||||||
|  | 		Config: nil, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = s.CreateConnector(c1) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	refresh := storage.RefreshToken{ | ||||||
|  | 		ID:            "test", | ||||||
|  | 		Token:         "bar", | ||||||
|  | 		ObsoleteToken: "", | ||||||
|  | 		Nonce:         "foo", | ||||||
|  | 		ClientID:      "test", | ||||||
|  | 		ConnectorID:   "test", | ||||||
|  | 		Scopes:        []string{"openid", "email", "profile"}, | ||||||
|  | 		CreatedAt:     time.Now().UTC().Round(time.Millisecond), | ||||||
|  | 		LastUsed:      time.Now().UTC().Round(time.Millisecond), | ||||||
|  | 		Claims: storage.Claims{ | ||||||
|  | 			UserID:        "1", | ||||||
|  | 			Username:      "jane", | ||||||
|  | 			Email:         "jane.doe@example.com", | ||||||
|  | 			EmailVerified: true, | ||||||
|  | 			Groups:        []string{"a", "b"}, | ||||||
|  | 		}, | ||||||
|  | 		ConnectorData: []byte(`{"some":"data"}`), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if useObsolete { | ||||||
|  | 		refresh.Token = "testtest" | ||||||
|  | 		refresh.ObsoleteToken = "bar" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = s.CreateRefresh(refresh) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 	offlineSessions := storage.OfflineSessions{ | ||||||
|  | 		UserID:        "1", | ||||||
|  | 		ConnID:        "test", | ||||||
|  | 		Refresh:       map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}}, | ||||||
|  | 		ConnectorData: nil, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = s.CreateOfflineSessions(offlineSessions) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestRefreshTokenExpirationScenarios(t *testing.T) { | func TestRefreshTokenExpirationScenarios(t *testing.T) { | ||||||
| 	t0 := time.Now() | 	t0 := time.Now() | ||||||
| 	tests := []struct { | 	tests := []struct { | ||||||
| @@ -56,15 +118,6 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { | |||||||
| 			}, | 			}, | ||||||
| 			error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, | 			error: `{"error":"invalid_request","error_description":"Refresh token expired."}`, | ||||||
| 		}, | 		}, | ||||||
| 		{ |  | ||||||
| 			name:        "Obsolete tokens are not allowed", |  | ||||||
| 			useObsolete: true, |  | ||||||
| 			policy: &RefreshTokenPolicy{ |  | ||||||
| 				rotateRefreshTokens: true, |  | ||||||
| 				now:                 func() time.Time { return t0.Add(time.Second * 25) }, |  | ||||||
| 			}, |  | ||||||
| 			error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`, |  | ||||||
| 		}, |  | ||||||
| 		{ | 		{ | ||||||
| 			name:        "Obsolete tokens are allowed", | 			name:        "Obsolete tokens are allowed", | ||||||
| 			useObsolete: true, | 			useObsolete: true, | ||||||
| @@ -75,6 +128,15 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { | |||||||
| 			}, | 			}, | ||||||
| 			error: ``, | 			error: ``, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:        "Obsolete tokens are not allowed", | ||||||
|  | 			useObsolete: true, | ||||||
|  | 			policy: &RefreshTokenPolicy{ | ||||||
|  | 				rotateRefreshTokens: true, | ||||||
|  | 				now:                 func() time.Time { return t0.Add(time.Second * 25) }, | ||||||
|  | 			}, | ||||||
|  | 			error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name:        "Obsolete tokens are allowed but token is expired globally", | 			name:        "Obsolete tokens are allowed but token is expired globally", | ||||||
| 			useObsolete: true, | 			useObsolete: true, | ||||||
| @@ -100,64 +162,7 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { | |||||||
| 			}) | 			}) | ||||||
| 			defer httpServer.Close() | 			defer httpServer.Close() | ||||||
|  |  | ||||||
| 			c := storage.Client{ | 			mockRefreshTokenTestStorage(t, s.storage, tc.useObsolete) | ||||||
| 				ID:           "test", |  | ||||||
| 				Secret:       "barfoo", |  | ||||||
| 				RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, |  | ||||||
| 				Name:         "dex client", |  | ||||||
| 				LogoURL:      "https://goo.gl/JIyzIC", |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			err := s.storage.CreateClient(c) |  | ||||||
| 			require.NoError(t, err) |  | ||||||
|  |  | ||||||
| 			c1 := storage.Connector{ |  | ||||||
| 				ID:     "test", |  | ||||||
| 				Type:   "mockCallback", |  | ||||||
| 				Name:   "mockCallback", |  | ||||||
| 				Config: nil, |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			err = s.storage.CreateConnector(c1) |  | ||||||
| 			require.NoError(t, err) |  | ||||||
|  |  | ||||||
| 			refresh := storage.RefreshToken{ |  | ||||||
| 				ID:            "test", |  | ||||||
| 				Token:         "bar", |  | ||||||
| 				ObsoleteToken: "", |  | ||||||
| 				Nonce:         "foo", |  | ||||||
| 				ClientID:      "test", |  | ||||||
| 				ConnectorID:   "test", |  | ||||||
| 				Scopes:        []string{"openid", "email", "profile"}, |  | ||||||
| 				CreatedAt:     time.Now().UTC().Round(time.Millisecond), |  | ||||||
| 				LastUsed:      time.Now().UTC().Round(time.Millisecond), |  | ||||||
| 				Claims: storage.Claims{ |  | ||||||
| 					UserID:        "1", |  | ||||||
| 					Username:      "jane", |  | ||||||
| 					Email:         "jane.doe@example.com", |  | ||||||
| 					EmailVerified: true, |  | ||||||
| 					Groups:        []string{"a", "b"}, |  | ||||||
| 				}, |  | ||||||
| 				ConnectorData: []byte(`{"some":"data"}`), |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			if tc.useObsolete { |  | ||||||
| 				refresh.Token = "testtest" |  | ||||||
| 				refresh.ObsoleteToken = "bar" |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			err = s.storage.CreateRefresh(refresh) |  | ||||||
| 			require.NoError(t, err) |  | ||||||
|  |  | ||||||
| 			offlineSessions := storage.OfflineSessions{ |  | ||||||
| 				UserID:        "1", |  | ||||||
| 				ConnID:        "test", |  | ||||||
| 				Refresh:       map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}}, |  | ||||||
| 				ConnectorData: nil, |  | ||||||
| 			} |  | ||||||
|  |  | ||||||
| 			err = s.storage.CreateOfflineSessions(offlineSessions) |  | ||||||
| 			require.NoError(t, err) |  | ||||||
|  |  | ||||||
| 			u, err := url.Parse(s.issuerURL.String()) | 			u, err := url.Parse(s.issuerURL.String()) | ||||||
| 			require.NoError(t, err) | 			require.NoError(t, err) | ||||||
| @@ -181,6 +186,26 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { | |||||||
| 				require.Equal(t, 200, rr.Code) | 				require.Equal(t, 200, rr.Code) | ||||||
| 			} else { | 			} else { | ||||||
| 				require.Equal(t, rr.Body.String(), tc.error) | 				require.Equal(t, rr.Body.String(), tc.error) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Check that we received expected refresh token | ||||||
|  | 			var ref struct { | ||||||
|  | 				Token string `json:"refresh_token"` | ||||||
|  | 			} | ||||||
|  | 			err = json.Unmarshal(rr.Body.Bytes(), &ref) | ||||||
|  | 			require.NoError(t, err) | ||||||
|  |  | ||||||
|  | 			if tc.policy.rotateRefreshTokens == false { | ||||||
|  | 				require.Equal(t, tokenData, ref.Token) | ||||||
|  | 			} else { | ||||||
|  | 				require.NotEqual(t, tokenData, ref.Token) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if tc.useObsolete { | ||||||
|  | 				updatedTokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "testtest"}) | ||||||
|  | 				require.NoError(t, err) | ||||||
|  | 				require.Equal(t, updatedTokenData, ref.Token) | ||||||
| 			} | 			} | ||||||
| 		}) | 		}) | ||||||
| 	} | 	} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user