Merge pull request #1846 from flant/refresh-token-expiration-policy
feat: Add refresh token expiration and rotation settings
This commit is contained in:
		@@ -304,6 +304,9 @@ type Expiry struct {
 | 
			
		||||
 | 
			
		||||
	// DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
 | 
			
		||||
	DeviceRequests string `json:"deviceRequests"`
 | 
			
		||||
 | 
			
		||||
	// RefreshTokens defines refresh tokens expiry policy
 | 
			
		||||
	RefreshTokens RefreshToken `json:"refreshTokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Logger holds configuration required to customize logging for dex.
 | 
			
		||||
@@ -314,3 +317,10 @@ type Logger struct {
 | 
			
		||||
	// Format specifies the format to be used for logging.
 | 
			
		||||
	Format string `json:"format"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RefreshToken struct {
 | 
			
		||||
	DisableRotation   bool   `json:"disableRotation"`
 | 
			
		||||
	ReuseInterval     string `json:"reuseInterval"`
 | 
			
		||||
	AbsoluteLifetime  string `json:"absoluteLifetime"`
 | 
			
		||||
	ValidIfNotUsedFor string `json:"validIfNotUsedFor"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -304,6 +304,18 @@ func runServe(options serveOptions) error {
 | 
			
		||||
		logger.Infof("config device requests valid for: %v", deviceRequests)
 | 
			
		||||
		serverConfig.DeviceRequestsValidFor = deviceRequests
 | 
			
		||||
	}
 | 
			
		||||
	refreshTokenPolicy, err := server.NewRefreshTokenPolicy(
 | 
			
		||||
		logger,
 | 
			
		||||
		c.Expiry.RefreshTokens.DisableRotation,
 | 
			
		||||
		c.Expiry.RefreshTokens.ValidIfNotUsedFor,
 | 
			
		||||
		c.Expiry.RefreshTokens.AbsoluteLifetime,
 | 
			
		||||
		c.Expiry.RefreshTokens.ReuseInterval,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("invalid refresh token expiration policy config: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	serverConfig.RefreshTokenPolicy = refreshTokenPolicy
 | 
			
		||||
	serv, err := server.NewServer(context.Background(), serverConfig)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to initialize server: %v", err)
 | 
			
		||||
 
 | 
			
		||||
@@ -73,10 +73,15 @@ telemetry:
 | 
			
		||||
#   tlsClientCA: examples/grpc-client/ca.crt
 | 
			
		||||
 | 
			
		||||
# Uncomment this block to enable configuration for the expiration time durations.
 | 
			
		||||
# Is possible to specify units using only s, m and h suffixes.
 | 
			
		||||
# expiry:
 | 
			
		||||
#   deviceRequests: "5m"
 | 
			
		||||
#   signingKeys: "6h"
 | 
			
		||||
#   idTokens: "24h"
 | 
			
		||||
#   refreshTokens:
 | 
			
		||||
#     reuseInterval: "3s"
 | 
			
		||||
#     validIfNotUsedFor: "2160h" # 90 days
 | 
			
		||||
#     absoluteLifetime: "3960h" # 165 days
 | 
			
		||||
 | 
			
		||||
# Options for controlling the logger.
 | 
			
		||||
# logger:
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,6 @@ import (
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
@@ -919,206 +918,6 @@ func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCo
 | 
			
		||||
	return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
 | 
			
		||||
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
 | 
			
		||||
	code := r.PostFormValue("refresh_token")
 | 
			
		||||
	scope := r.PostFormValue("scope")
 | 
			
		||||
	if code == "" {
 | 
			
		||||
		s.tokenErrHelper(w, errInvalidRequest, "No refresh token in request.", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token := new(internal.RefreshToken)
 | 
			
		||||
	if err := internal.Unmarshal(code, token); err != nil {
 | 
			
		||||
		// For backward compatibility, assume the refresh_token is a raw refresh token ID
 | 
			
		||||
		// if it fails to decode.
 | 
			
		||||
		//
 | 
			
		||||
		// Because refresh_token values that aren't unmarshable were generated by servers
 | 
			
		||||
		// that don't have a Token value, we'll still reject any attempts to claim a
 | 
			
		||||
		// refresh_token twice.
 | 
			
		||||
		token = &internal.RefreshToken{RefreshId: code, Token: ""}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	refresh, err := s.storage.GetRefresh(token.RefreshId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to get refresh token: %v", err)
 | 
			
		||||
		if err == storage.ErrNotFound {
 | 
			
		||||
			s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
 | 
			
		||||
		} else {
 | 
			
		||||
			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if refresh.ClientID != client.ID {
 | 
			
		||||
		s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID)
 | 
			
		||||
		s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if refresh.Token != token.Token {
 | 
			
		||||
		s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
 | 
			
		||||
		s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Per the OAuth2 spec, if the client has omitted the scopes, default to the original
 | 
			
		||||
	// authorized scopes.
 | 
			
		||||
	//
 | 
			
		||||
	// https://tools.ietf.org/html/rfc6749#section-6
 | 
			
		||||
	scopes := refresh.Scopes
 | 
			
		||||
	if scope != "" {
 | 
			
		||||
		requestedScopes := strings.Fields(scope)
 | 
			
		||||
		var unauthorizedScopes []string
 | 
			
		||||
 | 
			
		||||
		for _, s := range requestedScopes {
 | 
			
		||||
			contains := func() bool {
 | 
			
		||||
				for _, scope := range refresh.Scopes {
 | 
			
		||||
					if s == scope {
 | 
			
		||||
						return true
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				return false
 | 
			
		||||
			}()
 | 
			
		||||
			if !contains {
 | 
			
		||||
				unauthorizedScopes = append(unauthorizedScopes, s)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if len(unauthorizedScopes) > 0 {
 | 
			
		||||
			msg := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes)
 | 
			
		||||
			s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		scopes = requestedScopes
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var connectorData []byte
 | 
			
		||||
 | 
			
		||||
	session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
 | 
			
		||||
	switch {
 | 
			
		||||
	case err != nil:
 | 
			
		||||
		if err != storage.ErrNotFound {
 | 
			
		||||
			s.logger.Errorf("failed to get offline session: %v", err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	case len(refresh.ConnectorData) > 0:
 | 
			
		||||
		// Use the old connector data if it exists, should be deleted once used
 | 
			
		||||
		connectorData = refresh.ConnectorData
 | 
			
		||||
	default:
 | 
			
		||||
		connectorData = session.ConnectorData
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	conn, err := s.getConnector(refresh.ConnectorID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
 | 
			
		||||
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	ident := connector.Identity{
 | 
			
		||||
		UserID:            refresh.Claims.UserID,
 | 
			
		||||
		Username:          refresh.Claims.Username,
 | 
			
		||||
		PreferredUsername: refresh.Claims.PreferredUsername,
 | 
			
		||||
		Email:             refresh.Claims.Email,
 | 
			
		||||
		EmailVerified:     refresh.Claims.EmailVerified,
 | 
			
		||||
		Groups:            refresh.Claims.Groups,
 | 
			
		||||
		ConnectorData:     connectorData,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Can the connector refresh the identity? If so, attempt to refresh the data
 | 
			
		||||
	// in the connector.
 | 
			
		||||
	//
 | 
			
		||||
	// TODO(ericchiang): We may want a strict mode where connectors that don't implement
 | 
			
		||||
	// this interface can't perform refreshing.
 | 
			
		||||
	if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
 | 
			
		||||
		newIdent, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			s.logger.Errorf("failed to refresh identity: %v", err)
 | 
			
		||||
			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		ident = newIdent
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	claims := storage.Claims{
 | 
			
		||||
		UserID:            ident.UserID,
 | 
			
		||||
		Username:          ident.Username,
 | 
			
		||||
		PreferredUsername: ident.PreferredUsername,
 | 
			
		||||
		Email:             ident.Email,
 | 
			
		||||
		EmailVerified:     ident.EmailVerified,
 | 
			
		||||
		Groups:            ident.Groups,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to create new access token: %v", err)
 | 
			
		||||
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to create ID token: %v", err)
 | 
			
		||||
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newToken := &internal.RefreshToken{
 | 
			
		||||
		RefreshId: refresh.ID,
 | 
			
		||||
		Token:     storage.NewID(),
 | 
			
		||||
	}
 | 
			
		||||
	rawNewToken, err := internal.Marshal(newToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to marshal refresh token: %v", err)
 | 
			
		||||
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	lastUsed := s.now()
 | 
			
		||||
	updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
 | 
			
		||||
		if old.Token != refresh.Token {
 | 
			
		||||
			return old, errors.New("refresh token claimed twice")
 | 
			
		||||
		}
 | 
			
		||||
		old.Token = newToken.Token
 | 
			
		||||
		// Update the claims of the refresh token.
 | 
			
		||||
		//
 | 
			
		||||
		// UserID intentionally ignored for now.
 | 
			
		||||
		old.Claims.Username = ident.Username
 | 
			
		||||
		old.Claims.PreferredUsername = ident.PreferredUsername
 | 
			
		||||
		old.Claims.Email = ident.Email
 | 
			
		||||
		old.Claims.EmailVerified = ident.EmailVerified
 | 
			
		||||
		old.Claims.Groups = ident.Groups
 | 
			
		||||
		old.LastUsed = lastUsed
 | 
			
		||||
 | 
			
		||||
		// ConnectorData has been moved to OfflineSession
 | 
			
		||||
		old.ConnectorData = []byte{}
 | 
			
		||||
		return old, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Update LastUsed time stamp in refresh token reference object
 | 
			
		||||
	// in offline session for the user.
 | 
			
		||||
	if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
 | 
			
		||||
		if old.Refresh[refresh.ClientID].ID != refresh.ID {
 | 
			
		||||
			return old, errors.New("refresh token invalid")
 | 
			
		||||
		}
 | 
			
		||||
		old.Refresh[refresh.ClientID].LastUsed = lastUsed
 | 
			
		||||
		old.ConnectorData = ident.ConnectorData
 | 
			
		||||
		return old, nil
 | 
			
		||||
	}); err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to update offline session: %v", err)
 | 
			
		||||
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Update refresh token in the storage.
 | 
			
		||||
	if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to update refresh token: %v", err)
 | 
			
		||||
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
 | 
			
		||||
	s.writeAccessToken(w, resp)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	const prefix = "Bearer "
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										339
									
								
								server/refreshhandlers.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										339
									
								
								server/refreshhandlers.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,339 @@
 | 
			
		||||
package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/connector"
 | 
			
		||||
	"github.com/dexidp/dex/server/internal"
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func contains(arr []string, item string) bool {
 | 
			
		||||
	for _, itemFromArray := range arr {
 | 
			
		||||
		if itemFromArray == item {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type refreshError struct {
 | 
			
		||||
	msg  string
 | 
			
		||||
	code int
 | 
			
		||||
	desc string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newInternalServerError() *refreshError {
 | 
			
		||||
	return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newBadRequestError(desc string) *refreshError {
 | 
			
		||||
	return &refreshError{msg: errInvalidRequest, desc: desc, code: http.StatusBadRequest}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) refreshTokenErrHelper(w http.ResponseWriter, err *refreshError) {
 | 
			
		||||
	s.tokenErrHelper(w, err.msg, err.desc, err.code)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.RefreshToken, *refreshError) {
 | 
			
		||||
	code := r.PostFormValue("refresh_token")
 | 
			
		||||
	if code == "" {
 | 
			
		||||
		return nil, newBadRequestError("No refresh token is found in request.")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token := new(internal.RefreshToken)
 | 
			
		||||
	if err := internal.Unmarshal(code, token); err != nil {
 | 
			
		||||
		// For backward compatibility, assume the refresh_token is a raw refresh token ID
 | 
			
		||||
		// if it fails to decode.
 | 
			
		||||
		//
 | 
			
		||||
		// Because refresh_token values that aren't unmarshable were generated by servers
 | 
			
		||||
		// that don't have a Token value, we'll still reject any attempts to claim a
 | 
			
		||||
		// refresh_token twice.
 | 
			
		||||
		token = &internal.RefreshToken{RefreshId: code, Token: ""}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return token, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info
 | 
			
		||||
func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) {
 | 
			
		||||
	invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.")
 | 
			
		||||
 | 
			
		||||
	refresh, err := s.storage.GetRefresh(token.RefreshId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to get refresh token: %v", err)
 | 
			
		||||
		if err != storage.ErrNotFound {
 | 
			
		||||
			return nil, newInternalServerError()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil, invalidErr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if refresh.ClientID != clientID {
 | 
			
		||||
		s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID)
 | 
			
		||||
		return nil, invalidErr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if refresh.Token != token.Token {
 | 
			
		||||
		switch {
 | 
			
		||||
		case !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed):
 | 
			
		||||
			fallthrough
 | 
			
		||||
		case refresh.ObsoleteToken != token.Token:
 | 
			
		||||
			fallthrough
 | 
			
		||||
		case refresh.ObsoleteToken == "":
 | 
			
		||||
			s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
 | 
			
		||||
			return nil, invalidErr
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	expiredErr := newBadRequestError("Refresh token expired.")
 | 
			
		||||
	if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) {
 | 
			
		||||
		s.logger.Errorf("refresh token with id %s expired", refresh.ID)
 | 
			
		||||
		return nil, expiredErr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) {
 | 
			
		||||
		s.logger.Errorf("refresh token with id %s expired due to inactivity", refresh.ID)
 | 
			
		||||
		return nil, expiredErr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &refresh, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) {
 | 
			
		||||
	// Per the OAuth2 spec, if the client has omitted the scopes, default to the original
 | 
			
		||||
	// authorized scopes.
 | 
			
		||||
	//
 | 
			
		||||
	// https://tools.ietf.org/html/rfc6749#section-6
 | 
			
		||||
	scope := r.PostFormValue("scope")
 | 
			
		||||
 | 
			
		||||
	if scope == "" {
 | 
			
		||||
		return refresh.Scopes, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	requestedScopes := strings.Fields(scope)
 | 
			
		||||
	var unauthorizedScopes []string
 | 
			
		||||
 | 
			
		||||
	// Per the OAuth2 spec, if the client has omitted the scopes, default to the original
 | 
			
		||||
	// authorized scopes.
 | 
			
		||||
	//
 | 
			
		||||
	// https://tools.ietf.org/html/rfc6749#section-6
 | 
			
		||||
	for _, requestScope := range requestedScopes {
 | 
			
		||||
		if !contains(refresh.Scopes, requestScope) {
 | 
			
		||||
			unauthorizedScopes = append(unauthorizedScopes, requestScope)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(unauthorizedScopes) > 0 {
 | 
			
		||||
		desc := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes)
 | 
			
		||||
		return nil, newBadRequestError(desc)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return requestedScopes, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) {
 | 
			
		||||
	var connectorData []byte
 | 
			
		||||
 | 
			
		||||
	session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
 | 
			
		||||
	switch {
 | 
			
		||||
	case err != nil:
 | 
			
		||||
		if err != storage.ErrNotFound {
 | 
			
		||||
			s.logger.Errorf("failed to get offline session: %v", err)
 | 
			
		||||
			return connector.Identity{}, newInternalServerError()
 | 
			
		||||
		}
 | 
			
		||||
	case len(refresh.ConnectorData) > 0:
 | 
			
		||||
		// Use the old connector data if it exists, should be deleted once used
 | 
			
		||||
		connectorData = refresh.ConnectorData
 | 
			
		||||
	default:
 | 
			
		||||
		connectorData = session.ConnectorData
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	conn, err := s.getConnector(refresh.ConnectorID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
 | 
			
		||||
		return connector.Identity{}, newInternalServerError()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ident := connector.Identity{
 | 
			
		||||
		UserID:            refresh.Claims.UserID,
 | 
			
		||||
		Username:          refresh.Claims.Username,
 | 
			
		||||
		PreferredUsername: refresh.Claims.PreferredUsername,
 | 
			
		||||
		Email:             refresh.Claims.Email,
 | 
			
		||||
		EmailVerified:     refresh.Claims.EmailVerified,
 | 
			
		||||
		Groups:            refresh.Claims.Groups,
 | 
			
		||||
		ConnectorData:     connectorData,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// user's token was previously updated by a connector and is allowed to reuse
 | 
			
		||||
	// it is excessive to refresh identity in upstream
 | 
			
		||||
	if s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) && token.Token == refresh.ObsoleteToken {
 | 
			
		||||
		return ident, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Can the connector refresh the identity? If so, attempt to refresh the data
 | 
			
		||||
	// in the connector.
 | 
			
		||||
	//
 | 
			
		||||
	// TODO(ericchiang): We may want a strict mode where connectors that don't implement
 | 
			
		||||
	// this interface can't perform refreshing.
 | 
			
		||||
	if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
 | 
			
		||||
		newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			s.logger.Errorf("failed to refresh identity: %v", err)
 | 
			
		||||
			return connector.Identity{}, newInternalServerError()
 | 
			
		||||
		}
 | 
			
		||||
		ident = newIdent
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ident, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// updateOfflineSession updates offline session in the storage
 | 
			
		||||
func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident connector.Identity, lastUsed time.Time) *refreshError {
 | 
			
		||||
	offlineSessionUpdater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
 | 
			
		||||
		if old.Refresh[refresh.ClientID].ID != refresh.ID {
 | 
			
		||||
			return old, errors.New("refresh token invalid")
 | 
			
		||||
		}
 | 
			
		||||
		old.Refresh[refresh.ClientID].LastUsed = lastUsed
 | 
			
		||||
		old.ConnectorData = ident.ConnectorData
 | 
			
		||||
		return old, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Update LastUsed time stamp in refresh token reference object
 | 
			
		||||
	// in offline session for the user.
 | 
			
		||||
	err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to update offline session: %v", err)
 | 
			
		||||
		return newInternalServerError()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// updateRefreshToken updates refresh token and offline session in the storage
 | 
			
		||||
func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) {
 | 
			
		||||
	newToken := token
 | 
			
		||||
	if s.refreshTokenPolicy.RotationEnabled() {
 | 
			
		||||
		newToken = &internal.RefreshToken{
 | 
			
		||||
			RefreshId: refresh.ID,
 | 
			
		||||
			Token:     storage.NewID(),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	lastUsed := s.now()
 | 
			
		||||
 | 
			
		||||
	rerr := s.updateOfflineSession(refresh, ident, lastUsed)
 | 
			
		||||
	if rerr != nil {
 | 
			
		||||
		return nil, rerr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
 | 
			
		||||
		if s.refreshTokenPolicy.RotationEnabled() {
 | 
			
		||||
			if old.Token != token.Token {
 | 
			
		||||
				if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token {
 | 
			
		||||
					newToken.Token = old.Token
 | 
			
		||||
					return old, nil
 | 
			
		||||
				}
 | 
			
		||||
				return old, errors.New("refresh token claimed twice")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			old.ObsoleteToken = old.Token
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		old.Token = newToken.Token
 | 
			
		||||
		// Update the claims of the refresh token.
 | 
			
		||||
		//
 | 
			
		||||
		// UserID intentionally ignored for now.
 | 
			
		||||
		old.Claims.Username = ident.Username
 | 
			
		||||
		old.Claims.PreferredUsername = ident.PreferredUsername
 | 
			
		||||
		old.Claims.Email = ident.Email
 | 
			
		||||
		old.Claims.EmailVerified = ident.EmailVerified
 | 
			
		||||
		old.Claims.Groups = ident.Groups
 | 
			
		||||
		old.LastUsed = lastUsed
 | 
			
		||||
 | 
			
		||||
		// ConnectorData has been moved to OfflineSession
 | 
			
		||||
		old.ConnectorData = []byte{}
 | 
			
		||||
		return old, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Update refresh token in the storage.
 | 
			
		||||
	err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to update refresh token: %v", err)
 | 
			
		||||
		return nil, newInternalServerError()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return newToken, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6
 | 
			
		||||
// this method is the entrypoint for refresh tokens handling
 | 
			
		||||
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
 | 
			
		||||
	token, rerr := s.extractRefreshTokenFromRequest(r)
 | 
			
		||||
	if rerr != nil {
 | 
			
		||||
		s.refreshTokenErrHelper(w, rerr)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	refresh, rerr := s.getRefreshTokenFromStorage(client.ID, token)
 | 
			
		||||
	if rerr != nil {
 | 
			
		||||
		s.refreshTokenErrHelper(w, rerr)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scopes, rerr := s.getRefreshScopes(r, refresh)
 | 
			
		||||
	if rerr != nil {
 | 
			
		||||
		s.refreshTokenErrHelper(w, rerr)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes)
 | 
			
		||||
	if rerr != nil {
 | 
			
		||||
		s.refreshTokenErrHelper(w, rerr)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	claims := storage.Claims{
 | 
			
		||||
		UserID:            ident.UserID,
 | 
			
		||||
		Username:          ident.Username,
 | 
			
		||||
		PreferredUsername: ident.PreferredUsername,
 | 
			
		||||
		Email:             ident.Email,
 | 
			
		||||
		EmailVerified:     ident.EmailVerified,
 | 
			
		||||
		Groups:            ident.Groups,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to create new access token: %v", err)
 | 
			
		||||
		s.refreshTokenErrHelper(w, newInternalServerError())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to create ID token: %v", err)
 | 
			
		||||
		s.refreshTokenErrHelper(w, newInternalServerError())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newToken, rerr := s.updateRefreshToken(token, refresh, ident)
 | 
			
		||||
	if rerr != nil {
 | 
			
		||||
		s.refreshTokenErrHelper(w, rerr)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	rawNewToken, err := internal.Marshal(newToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to marshal refresh token: %v", err)
 | 
			
		||||
		s.refreshTokenErrHelper(w, newInternalServerError())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
 | 
			
		||||
	s.writeAccessToken(w, resp)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										212
									
								
								server/refreshhandlers_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										212
									
								
								server/refreshhandlers_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,212 @@
 | 
			
		||||
package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path"
 | 
			
		||||
	"testing"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/server/internal"
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) {
 | 
			
		||||
	c := storage.Client{
 | 
			
		||||
		ID:           "test",
 | 
			
		||||
		Secret:       "barfoo",
 | 
			
		||||
		RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
 | 
			
		||||
		Name:         "dex client",
 | 
			
		||||
		LogoURL:      "https://goo.gl/JIyzIC",
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := s.CreateClient(c)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	c1 := storage.Connector{
 | 
			
		||||
		ID:     "test",
 | 
			
		||||
		Type:   "mockCallback",
 | 
			
		||||
		Name:   "mockCallback",
 | 
			
		||||
		Config: nil,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = s.CreateConnector(c1)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	refresh := storage.RefreshToken{
 | 
			
		||||
		ID:            "test",
 | 
			
		||||
		Token:         "bar",
 | 
			
		||||
		ObsoleteToken: "",
 | 
			
		||||
		Nonce:         "foo",
 | 
			
		||||
		ClientID:      "test",
 | 
			
		||||
		ConnectorID:   "test",
 | 
			
		||||
		Scopes:        []string{"openid", "email", "profile"},
 | 
			
		||||
		CreatedAt:     time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		LastUsed:      time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		Claims: storage.Claims{
 | 
			
		||||
			UserID:        "1",
 | 
			
		||||
			Username:      "jane",
 | 
			
		||||
			Email:         "jane.doe@example.com",
 | 
			
		||||
			EmailVerified: true,
 | 
			
		||||
			Groups:        []string{"a", "b"},
 | 
			
		||||
		},
 | 
			
		||||
		ConnectorData: []byte(`{"some":"data"}`),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if useObsolete {
 | 
			
		||||
		refresh.Token = "testtest"
 | 
			
		||||
		refresh.ObsoleteToken = "bar"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = s.CreateRefresh(refresh)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	offlineSessions := storage.OfflineSessions{
 | 
			
		||||
		UserID:        "1",
 | 
			
		||||
		ConnID:        "test",
 | 
			
		||||
		Refresh:       map[string]*storage.RefreshTokenRef{"test": {ID: "test", ClientID: "test"}},
 | 
			
		||||
		ConnectorData: nil,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = s.CreateOfflineSessions(offlineSessions)
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRefreshTokenExpirationScenarios(t *testing.T) {
 | 
			
		||||
	t0 := time.Now()
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name        string
 | 
			
		||||
		policy      *RefreshTokenPolicy
 | 
			
		||||
		useObsolete bool
 | 
			
		||||
		error       string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:   "Normal",
 | 
			
		||||
			policy: &RefreshTokenPolicy{rotateRefreshTokens: true},
 | 
			
		||||
			error:  ``,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "Not expired because used",
 | 
			
		||||
			policy: &RefreshTokenPolicy{
 | 
			
		||||
				rotateRefreshTokens: false,
 | 
			
		||||
				validIfNotUsedFor:   time.Second * 60,
 | 
			
		||||
				now:                 func() time.Time { return t0.Add(time.Second * 25) },
 | 
			
		||||
			},
 | 
			
		||||
			error: ``,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "Expired because not used",
 | 
			
		||||
			policy: &RefreshTokenPolicy{
 | 
			
		||||
				rotateRefreshTokens: false,
 | 
			
		||||
				validIfNotUsedFor:   time.Second * 60,
 | 
			
		||||
				now:                 func() time.Time { return t0.Add(time.Hour) },
 | 
			
		||||
			},
 | 
			
		||||
			error: `{"error":"invalid_request","error_description":"Refresh token expired."}`,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "Absolutely expired",
 | 
			
		||||
			policy: &RefreshTokenPolicy{
 | 
			
		||||
				rotateRefreshTokens: true,
 | 
			
		||||
				absoluteLifetime:    time.Second * 60,
 | 
			
		||||
				now:                 func() time.Time { return t0.Add(time.Hour) },
 | 
			
		||||
			},
 | 
			
		||||
			error: `{"error":"invalid_request","error_description":"Refresh token expired."}`,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:        "Obsolete tokens are allowed",
 | 
			
		||||
			useObsolete: true,
 | 
			
		||||
			policy: &RefreshTokenPolicy{
 | 
			
		||||
				rotateRefreshTokens: true,
 | 
			
		||||
				reuseInterval:       time.Second * 30,
 | 
			
		||||
				now:                 func() time.Time { return t0.Add(time.Second * 25) },
 | 
			
		||||
			},
 | 
			
		||||
			error: ``,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:        "Obsolete tokens are not allowed",
 | 
			
		||||
			useObsolete: true,
 | 
			
		||||
			policy: &RefreshTokenPolicy{
 | 
			
		||||
				rotateRefreshTokens: true,
 | 
			
		||||
				now:                 func() time.Time { return t0.Add(time.Second * 25) },
 | 
			
		||||
			},
 | 
			
		||||
			error: `{"error":"invalid_request","error_description":"Refresh token is invalid or has already been claimed by another client."}`,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:        "Obsolete tokens are allowed but token is expired globally",
 | 
			
		||||
			useObsolete: true,
 | 
			
		||||
			policy: &RefreshTokenPolicy{
 | 
			
		||||
				rotateRefreshTokens: true,
 | 
			
		||||
				reuseInterval:       time.Second * 30,
 | 
			
		||||
				absoluteLifetime:    time.Second * 20,
 | 
			
		||||
				now:                 func() time.Time { return t0.Add(time.Second * 25) },
 | 
			
		||||
			},
 | 
			
		||||
			error: `{"error":"invalid_request","error_description":"Refresh token expired."}`,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tc := range tests {
 | 
			
		||||
		t.Run(tc.name, func(*testing.T) {
 | 
			
		||||
			ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
			defer cancel()
 | 
			
		||||
 | 
			
		||||
			// Setup a dex server.
 | 
			
		||||
			httpServer, s := newTestServer(ctx, t, func(c *Config) {
 | 
			
		||||
				c.RefreshTokenPolicy = tc.policy
 | 
			
		||||
				c.Now = func() time.Time { return t0 }
 | 
			
		||||
			})
 | 
			
		||||
			defer httpServer.Close()
 | 
			
		||||
 | 
			
		||||
			mockRefreshTokenTestStorage(t, s.storage, tc.useObsolete)
 | 
			
		||||
 | 
			
		||||
			u, err := url.Parse(s.issuerURL.String())
 | 
			
		||||
			require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
			tokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "bar"})
 | 
			
		||||
			require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
			u.Path = path.Join(u.Path, "/token")
 | 
			
		||||
			v := url.Values{}
 | 
			
		||||
			v.Add("grant_type", "refresh_token")
 | 
			
		||||
			v.Add("refresh_token", tokenData)
 | 
			
		||||
 | 
			
		||||
			req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(v.Encode()))
 | 
			
		||||
			req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
 | 
			
		||||
			req.SetBasicAuth("test", "barfoo")
 | 
			
		||||
 | 
			
		||||
			rr := httptest.NewRecorder()
 | 
			
		||||
			s.ServeHTTP(rr, req)
 | 
			
		||||
 | 
			
		||||
			if tc.error == "" {
 | 
			
		||||
				require.Equal(t, 200, rr.Code)
 | 
			
		||||
			} else {
 | 
			
		||||
				require.Equal(t, rr.Body.String(), tc.error)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Check that we received expected refresh token
 | 
			
		||||
			var ref struct {
 | 
			
		||||
				Token string `json:"refresh_token"`
 | 
			
		||||
			}
 | 
			
		||||
			err = json.Unmarshal(rr.Body.Bytes(), &ref)
 | 
			
		||||
			require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
			if tc.policy.rotateRefreshTokens == false {
 | 
			
		||||
				require.Equal(t, tokenData, ref.Token)
 | 
			
		||||
			} else {
 | 
			
		||||
				require.NotEqual(t, tokenData, ref.Token)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if tc.useObsolete {
 | 
			
		||||
				updatedTokenData, err := internal.Marshal(&internal.RefreshToken{RefreshId: "test", Token: "testtest"})
 | 
			
		||||
				require.NoError(t, err)
 | 
			
		||||
				require.Equal(t, updatedTokenData, ref.Token)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -177,3 +177,73 @@ func (k keyRotator) rotate() error {
 | 
			
		||||
	k.logger.Infof("keys rotated, next rotation: %s", nextRotation)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type RefreshTokenPolicy struct {
 | 
			
		||||
	rotateRefreshTokens bool // enable rotation
 | 
			
		||||
 | 
			
		||||
	absoluteLifetime  time.Duration // interval from token creation to the end of its life
 | 
			
		||||
	validIfNotUsedFor time.Duration // interval from last token update to the end of its life
 | 
			
		||||
	reuseInterval     time.Duration // interval within which old refresh token is allowed to be reused
 | 
			
		||||
 | 
			
		||||
	now func() time.Time
 | 
			
		||||
 | 
			
		||||
	logger log.Logger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) {
 | 
			
		||||
	r := RefreshTokenPolicy{now: time.Now, logger: logger}
 | 
			
		||||
	var err error
 | 
			
		||||
 | 
			
		||||
	if validIfNotUsedFor != "" {
 | 
			
		||||
		r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err)
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("config refresh tokens valid if not used for: %v", validIfNotUsedFor)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if absoluteLifetime != "" {
 | 
			
		||||
		r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err)
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("config refresh tokens absolute lifetime: %v", absoluteLifetime)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if reuseInterval != "" {
 | 
			
		||||
		r.reuseInterval, err = time.ParseDuration(reuseInterval)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err)
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("config refresh tokens reuse interval: %v", reuseInterval)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r.rotateRefreshTokens = !rotation
 | 
			
		||||
	logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens)
 | 
			
		||||
	return &r, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RefreshTokenPolicy) RotationEnabled() bool {
 | 
			
		||||
	return r.rotateRefreshTokens
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool {
 | 
			
		||||
	if r.absoluteLifetime == 0 {
 | 
			
		||||
		return false // expiration disabled
 | 
			
		||||
	}
 | 
			
		||||
	return r.now().After(lastUsed.Add(r.absoluteLifetime))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool {
 | 
			
		||||
	if r.validIfNotUsedFor == 0 {
 | 
			
		||||
		return false // expiration disabled
 | 
			
		||||
	}
 | 
			
		||||
	return r.now().After(lastUsed.Add(r.validIfNotUsedFor))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool {
 | 
			
		||||
	if r.reuseInterval == 0 {
 | 
			
		||||
		return false // expiration disabled
 | 
			
		||||
	}
 | 
			
		||||
	return !r.now().After(lastUsed.Add(r.reuseInterval))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/memory"
 | 
			
		||||
@@ -100,3 +101,29 @@ func TestKeyRotator(t *testing.T) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestRefreshTokenPolicy(t *testing.T) {
 | 
			
		||||
	lastTime := time.Now()
 | 
			
		||||
	l := &logrus.Logger{
 | 
			
		||||
		Out:       os.Stderr,
 | 
			
		||||
		Formatter: &logrus.TextFormatter{DisableColors: true},
 | 
			
		||||
		Level:     logrus.DebugLevel,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m")
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
 | 
			
		||||
	t.Run("Allowed", func(t *testing.T) {
 | 
			
		||||
		r.now = func() time.Time { return lastTime }
 | 
			
		||||
		require.Equal(t, true, r.AllowedToReuse(lastTime))
 | 
			
		||||
		require.Equal(t, false, r.ExpiredBecauseUnused(lastTime))
 | 
			
		||||
		require.Equal(t, false, r.CompletelyExpired(lastTime))
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	t.Run("Expired", func(t *testing.T) {
 | 
			
		||||
		r.now = func() time.Time { return lastTime.Add(2 * time.Minute) }
 | 
			
		||||
		require.Equal(t, false, r.AllowedToReuse(lastTime))
 | 
			
		||||
		require.Equal(t, true, r.ExpiredBecauseUnused(lastTime))
 | 
			
		||||
		require.Equal(t, true, r.CompletelyExpired(lastTime))
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -84,6 +84,10 @@ type Config struct {
 | 
			
		||||
	IDTokensValidFor       time.Duration // Defaults to 24 hours
 | 
			
		||||
	AuthRequestsValidFor   time.Duration // Defaults to 24 hours
 | 
			
		||||
	DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
 | 
			
		||||
 | 
			
		||||
	// Refresh token expiration settings
 | 
			
		||||
	RefreshTokenPolicy *RefreshTokenPolicy
 | 
			
		||||
 | 
			
		||||
	// If set, the server will use this connector to handle password grants
 | 
			
		||||
	PasswordConnector string
 | 
			
		||||
 | 
			
		||||
@@ -171,6 +175,8 @@ type Server struct {
 | 
			
		||||
	authRequestsValidFor   time.Duration
 | 
			
		||||
	deviceRequestsValidFor time.Duration
 | 
			
		||||
 | 
			
		||||
	refreshTokenPolicy *RefreshTokenPolicy
 | 
			
		||||
 | 
			
		||||
	logger log.Logger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -246,6 +252,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
 | 
			
		||||
		idTokensValidFor:       value(c.IDTokensValidFor, 24*time.Hour),
 | 
			
		||||
		authRequestsValidFor:   value(c.AuthRequestsValidFor, 24*time.Hour),
 | 
			
		||||
		deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
 | 
			
		||||
		refreshTokenPolicy:     c.RefreshTokenPolicy,
 | 
			
		||||
		skipApproval:           c.SkipApprovalScreen,
 | 
			
		||||
		alwaysShowLogin:        c.AlwaysShowLoginScreen,
 | 
			
		||||
		now:                    now,
 | 
			
		||||
 
 | 
			
		||||
@@ -119,6 +119,16 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
 | 
			
		||||
 | 
			
		||||
	// Default rotation policy
 | 
			
		||||
	if server.refreshTokenPolicy == nil {
 | 
			
		||||
		server.refreshTokenPolicy, err = NewRefreshTokenPolicy(logger, false, "", "", "")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("failed to prepare rotation policy: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		server.refreshTokenPolicy.now = config.Now
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return s, server
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -324,14 +324,15 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
 | 
			
		||||
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
 | 
			
		||||
	id := storage.NewID()
 | 
			
		||||
	refresh := storage.RefreshToken{
 | 
			
		||||
		ID:          id,
 | 
			
		||||
		Token:       "bar",
 | 
			
		||||
		Nonce:       "foo",
 | 
			
		||||
		ClientID:    "client_id",
 | 
			
		||||
		ConnectorID: "client_secret",
 | 
			
		||||
		Scopes:      []string{"openid", "email", "profile"},
 | 
			
		||||
		CreatedAt:   time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		LastUsed:    time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		ID:            id,
 | 
			
		||||
		Token:         "bar",
 | 
			
		||||
		ObsoleteToken: "",
 | 
			
		||||
		Nonce:         "foo",
 | 
			
		||||
		ClientID:      "client_id",
 | 
			
		||||
		ConnectorID:   "client_secret",
 | 
			
		||||
		Scopes:        []string{"openid", "email", "profile"},
 | 
			
		||||
		CreatedAt:     time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		LastUsed:      time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		Claims: storage.Claims{
 | 
			
		||||
			UserID:        "1",
 | 
			
		||||
			Username:      "jane",
 | 
			
		||||
@@ -378,14 +379,15 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
 | 
			
		||||
 | 
			
		||||
	id2 := storage.NewID()
 | 
			
		||||
	refresh2 := storage.RefreshToken{
 | 
			
		||||
		ID:          id2,
 | 
			
		||||
		Token:       "bar_2",
 | 
			
		||||
		Nonce:       "foo_2",
 | 
			
		||||
		ClientID:    "client_id_2",
 | 
			
		||||
		ConnectorID: "client_secret",
 | 
			
		||||
		Scopes:      []string{"openid", "email", "profile"},
 | 
			
		||||
		CreatedAt:   time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		LastUsed:    time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		ID:            id2,
 | 
			
		||||
		Token:         "bar_2",
 | 
			
		||||
		ObsoleteToken: refresh.Token,
 | 
			
		||||
		Nonce:         "foo_2",
 | 
			
		||||
		ClientID:      "client_id_2",
 | 
			
		||||
		ConnectorID:   "client_secret",
 | 
			
		||||
		Scopes:        []string{"openid", "email", "profile"},
 | 
			
		||||
		CreatedAt:     time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		LastUsed:      time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		Claims: storage.Claims{
 | 
			
		||||
			UserID:        "2",
 | 
			
		||||
			Username:      "john",
 | 
			
		||||
 
 | 
			
		||||
@@ -132,7 +132,8 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest {
 | 
			
		||||
type RefreshToken struct {
 | 
			
		||||
	ID string `json:"id"`
 | 
			
		||||
 | 
			
		||||
	Token string `json:"token"`
 | 
			
		||||
	Token         string `json:"token"`
 | 
			
		||||
	ObsoleteToken string `json:"obsolete_token"`
 | 
			
		||||
 | 
			
		||||
	CreatedAt time.Time `json:"created_at"`
 | 
			
		||||
	LastUsed  time.Time `json:"last_used"`
 | 
			
		||||
@@ -152,6 +153,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
 | 
			
		||||
	return storage.RefreshToken{
 | 
			
		||||
		ID:            r.ID,
 | 
			
		||||
		Token:         r.Token,
 | 
			
		||||
		ObsoleteToken: r.ObsoleteToken,
 | 
			
		||||
		CreatedAt:     r.CreatedAt,
 | 
			
		||||
		LastUsed:      r.LastUsed,
 | 
			
		||||
		ClientID:      r.ClientID,
 | 
			
		||||
@@ -167,6 +169,7 @@ func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
 | 
			
		||||
	return RefreshToken{
 | 
			
		||||
		ID:            r.ID,
 | 
			
		||||
		Token:         r.Token,
 | 
			
		||||
		ObsoleteToken: r.ObsoleteToken,
 | 
			
		||||
		CreatedAt:     r.CreatedAt,
 | 
			
		||||
		LastUsed:      r.LastUsed,
 | 
			
		||||
		ClientID:      r.ClientID,
 | 
			
		||||
 
 | 
			
		||||
@@ -496,7 +496,8 @@ type RefreshToken struct {
 | 
			
		||||
	ClientID string   `json:"clientID"`
 | 
			
		||||
	Scopes   []string `json:"scopes,omitempty"`
 | 
			
		||||
 | 
			
		||||
	Token string `json:"token,omitempty"`
 | 
			
		||||
	Token         string `json:"token,omitempty"`
 | 
			
		||||
	ObsoleteToken string `json:"obsoleteToken,omitempty"`
 | 
			
		||||
 | 
			
		||||
	Nonce string `json:"nonce,omitempty"`
 | 
			
		||||
 | 
			
		||||
@@ -516,6 +517,7 @@ func toStorageRefreshToken(r RefreshToken) storage.RefreshToken {
 | 
			
		||||
	return storage.RefreshToken{
 | 
			
		||||
		ID:            r.ObjectMeta.Name,
 | 
			
		||||
		Token:         r.Token,
 | 
			
		||||
		ObsoleteToken: r.ObsoleteToken,
 | 
			
		||||
		CreatedAt:     r.CreatedAt,
 | 
			
		||||
		LastUsed:      r.LastUsed,
 | 
			
		||||
		ClientID:      r.ClientID,
 | 
			
		||||
@@ -538,6 +540,7 @@ func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken
 | 
			
		||||
			Namespace: cli.namespace,
 | 
			
		||||
		},
 | 
			
		||||
		Token:         r.Token,
 | 
			
		||||
		ObsoleteToken: r.ObsoleteToken,
 | 
			
		||||
		CreatedAt:     r.CreatedAt,
 | 
			
		||||
		LastUsed:      r.LastUsed,
 | 
			
		||||
		ClientID:      r.ClientID,
 | 
			
		||||
 
 | 
			
		||||
@@ -285,16 +285,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
 | 
			
		||||
			claims_user_id, claims_username, claims_preferred_username,
 | 
			
		||||
			claims_email, claims_email_verified, claims_groups,
 | 
			
		||||
			connector_id, connector_data,
 | 
			
		||||
			token, created_at, last_used
 | 
			
		||||
			token, obsolete_token, created_at, last_used
 | 
			
		||||
		)
 | 
			
		||||
		values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15);
 | 
			
		||||
		values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16);
 | 
			
		||||
	`,
 | 
			
		||||
		r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
 | 
			
		||||
		r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
 | 
			
		||||
		r.Claims.Email, r.Claims.EmailVerified,
 | 
			
		||||
		encoder(r.Claims.Groups),
 | 
			
		||||
		r.ConnectorID, r.ConnectorData,
 | 
			
		||||
		r.Token, r.CreatedAt, r.LastUsed,
 | 
			
		||||
		r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if c.alreadyExistsCheck(err) {
 | 
			
		||||
@@ -329,17 +329,18 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
 | 
			
		||||
				connector_id = $10,
 | 
			
		||||
				connector_data = $11,
 | 
			
		||||
				token = $12,
 | 
			
		||||
				created_at = $13,
 | 
			
		||||
				last_used = $14
 | 
			
		||||
                obsolete_token = $13,
 | 
			
		||||
				created_at = $14,
 | 
			
		||||
				last_used = $15
 | 
			
		||||
			where
 | 
			
		||||
				id = $15
 | 
			
		||||
				id = $16
 | 
			
		||||
		`,
 | 
			
		||||
			r.ClientID, encoder(r.Scopes), r.Nonce,
 | 
			
		||||
			r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
 | 
			
		||||
			r.Claims.Email, r.Claims.EmailVerified,
 | 
			
		||||
			encoder(r.Claims.Groups),
 | 
			
		||||
			r.ConnectorID, r.ConnectorData,
 | 
			
		||||
			r.Token, r.CreatedAt, r.LastUsed, id,
 | 
			
		||||
			r.Token, r.ObsoleteToken, r.CreatedAt, r.LastUsed, id,
 | 
			
		||||
		)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("update refresh token: %v", err)
 | 
			
		||||
@@ -360,7 +361,7 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) {
 | 
			
		||||
			claims_email, claims_email_verified,
 | 
			
		||||
			claims_groups,
 | 
			
		||||
			connector_id, connector_data,
 | 
			
		||||
			token, created_at, last_used
 | 
			
		||||
			token, obsolete_token, created_at, last_used
 | 
			
		||||
		from refresh_token where id = $1;
 | 
			
		||||
	`, id))
 | 
			
		||||
}
 | 
			
		||||
@@ -372,7 +373,7 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
 | 
			
		||||
			claims_user_id, claims_username, claims_preferred_username,
 | 
			
		||||
			claims_email, claims_email_verified, claims_groups,
 | 
			
		||||
			connector_id, connector_data,
 | 
			
		||||
			token, created_at, last_used
 | 
			
		||||
			token, obsolete_token, created_at, last_used
 | 
			
		||||
		from refresh_token;
 | 
			
		||||
	`)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -401,7 +402,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
 | 
			
		||||
		&r.Claims.Email, &r.Claims.EmailVerified,
 | 
			
		||||
		decoder(&r.Claims.Groups),
 | 
			
		||||
		&r.ConnectorID, &r.ConnectorData,
 | 
			
		||||
		&r.Token, &r.CreatedAt, &r.LastUsed,
 | 
			
		||||
		&r.Token, &r.ObsoleteToken, &r.CreatedAt, &r.LastUsed,
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == sql.ErrNoRows {
 | 
			
		||||
 
 | 
			
		||||
@@ -274,4 +274,11 @@ var migrations = []migration{
 | 
			
		||||
				add column code_challenge_method text not null default '';`,
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
	{
 | 
			
		||||
		stmts: []string{
 | 
			
		||||
			`
 | 
			
		||||
			alter table refresh_token
 | 
			
		||||
				add column obsolete_token text default '';`,
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -271,7 +271,8 @@ type RefreshToken struct {
 | 
			
		||||
	// A single token that's rotated every time the refresh token is refreshed.
 | 
			
		||||
	//
 | 
			
		||||
	// May be empty.
 | 
			
		||||
	Token string
 | 
			
		||||
	Token         string
 | 
			
		||||
	ObsoleteToken string
 | 
			
		||||
 | 
			
		||||
	CreatedAt time.Time
 | 
			
		||||
	LastUsed  time.Time
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user