feat: Add refresh token expiration and rotation settings
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
@@ -1035,14 +1035,27 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if refresh.ClientID != client.ID {
|
||||
s.logger.Errorf("client %s trying to claim token for client %s", client.ID, refresh.ClientID)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if refresh.Token != token.Token {
|
||||
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
||||
if !s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) || refresh.ObsoleteToken != token.Token {
|
||||
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) {
|
||||
s.logger.Errorf("refresh token with id %s expired", refresh.ID)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) {
|
||||
s.logger.Errorf("refresh token with id %s expired because being unused", refresh.ID)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Refresh token expired.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1147,22 +1160,28 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
||||
return
|
||||
}
|
||||
|
||||
newToken := &internal.RefreshToken{
|
||||
RefreshId: refresh.ID,
|
||||
Token: storage.NewID(),
|
||||
}
|
||||
rawNewToken, err := internal.Marshal(newToken)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
newToken := token
|
||||
if s.refreshTokenPolicy.RotationEnabled() {
|
||||
newToken = &internal.RefreshToken{
|
||||
RefreshId: refresh.ID,
|
||||
Token: storage.NewID(),
|
||||
}
|
||||
}
|
||||
|
||||
lastUsed := s.now()
|
||||
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
||||
if old.Token != refresh.Token {
|
||||
return old, errors.New("refresh token claimed twice")
|
||||
if s.refreshTokenPolicy.RotationEnabled() {
|
||||
if old.Token != refresh.Token {
|
||||
if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == refresh.Token {
|
||||
newToken.Token = old.Token
|
||||
return old, nil
|
||||
}
|
||||
return old, errors.New("refresh token claimed twice")
|
||||
}
|
||||
|
||||
old.ObsoleteToken = old.Token
|
||||
}
|
||||
|
||||
old.Token = newToken.Token
|
||||
// Update the claims of the refresh token.
|
||||
//
|
||||
@@ -1201,6 +1220,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
||||
return
|
||||
}
|
||||
|
||||
rawNewToken, err := internal.Marshal(newToken)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
|
||||
s.writeAccessToken(w, resp)
|
||||
}
|
||||
|
@@ -177,3 +177,73 @@ func (k keyRotator) rotate() error {
|
||||
k.logger.Infof("keys rotated, next rotation: %s", nextRotation)
|
||||
return nil
|
||||
}
|
||||
|
||||
type RefreshTokenPolicy struct {
|
||||
rotateRefreshTokens bool // enable rotation
|
||||
|
||||
absoluteLifetime time.Duration // interval from token creation to the end of its life
|
||||
validIfNotUsedFor time.Duration // interval from last token update to the end of its life
|
||||
reuseInterval time.Duration // interval within which old refresh token is allowed to be reused
|
||||
|
||||
Clock func() time.Time
|
||||
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
func NewRefreshTokenPolicyFromConfig(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) {
|
||||
r := RefreshTokenPolicy{Clock: time.Now, logger: logger}
|
||||
var err error
|
||||
|
||||
if validIfNotUsedFor != "" {
|
||||
r.validIfNotUsedFor, err = time.ParseDuration(validIfNotUsedFor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err)
|
||||
}
|
||||
logger.Infof("config refresh tokens valid if not used for: %v", validIfNotUsedFor)
|
||||
}
|
||||
|
||||
if absoluteLifetime != "" {
|
||||
r.absoluteLifetime, err = time.ParseDuration(absoluteLifetime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err)
|
||||
}
|
||||
logger.Infof("config refresh tokens absolute lifetime: %v", absoluteLifetime)
|
||||
}
|
||||
|
||||
if reuseInterval != "" {
|
||||
r.reuseInterval, err = time.ParseDuration(reuseInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err)
|
||||
}
|
||||
logger.Infof("config refresh tokens reuse interval: %v", reuseInterval)
|
||||
}
|
||||
|
||||
r.rotateRefreshTokens = !rotation
|
||||
logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens)
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (r *RefreshTokenPolicy) RotationEnabled() bool {
|
||||
return r.rotateRefreshTokens
|
||||
}
|
||||
|
||||
func (r *RefreshTokenPolicy) CompletelyExpired(lastUsed time.Time) bool {
|
||||
if r.absoluteLifetime == 0 {
|
||||
return false // expiration disabled
|
||||
}
|
||||
return r.Clock().After(lastUsed.Add(r.absoluteLifetime))
|
||||
}
|
||||
|
||||
func (r *RefreshTokenPolicy) ExpiredBecauseUnused(lastUsed time.Time) bool {
|
||||
if r.validIfNotUsedFor == 0 {
|
||||
return false // expiration disabled
|
||||
}
|
||||
return r.Clock().After(lastUsed.Add(r.validIfNotUsedFor))
|
||||
}
|
||||
|
||||
func (r *RefreshTokenPolicy) AllowedToReuse(lastUsed time.Time) bool {
|
||||
if r.reuseInterval == 0 {
|
||||
return false // expiration disabled
|
||||
}
|
||||
return !r.Clock().After(lastUsed.Add(r.reuseInterval))
|
||||
}
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/memory"
|
||||
@@ -100,3 +101,30 @@ func TestKeyRotator(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenPolicy(t *testing.T) {
|
||||
lastTime := time.Now()
|
||||
l := &logrus.Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
|
||||
r, err := NewRefreshTokenPolicyFromConfig(l, true, "1m", "1m", "1m")
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Allowed", func(t *testing.T) {
|
||||
r.Clock = func() time.Time { return lastTime }
|
||||
require.Equal(t, true, r.AllowedToReuse(lastTime))
|
||||
require.Equal(t, false, r.ExpiredBecauseUnused(lastTime))
|
||||
require.Equal(t, false, r.CompletelyExpired(lastTime))
|
||||
})
|
||||
|
||||
t.Run("Expired", func(t *testing.T) {
|
||||
r.Clock = func() time.Time { return lastTime.Add(2 * time.Minute) }
|
||||
time.Sleep(1 * time.Second)
|
||||
require.Equal(t, false, r.AllowedToReuse(lastTime))
|
||||
require.Equal(t, true, r.ExpiredBecauseUnused(lastTime))
|
||||
require.Equal(t, true, r.CompletelyExpired(lastTime))
|
||||
})
|
||||
}
|
||||
|
@@ -80,6 +80,10 @@ type Config struct {
|
||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
||||
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
|
||||
|
||||
// Refresh token expiration settings
|
||||
RefreshTokenPolicy *RefreshTokenPolicy
|
||||
|
||||
// If set, the server will use this connector to handle password grants
|
||||
PasswordConnector string
|
||||
|
||||
@@ -159,6 +163,8 @@ type Server struct {
|
||||
authRequestsValidFor time.Duration
|
||||
deviceRequestsValidFor time.Duration
|
||||
|
||||
refreshTokenPolicy *RefreshTokenPolicy
|
||||
|
||||
logger log.Logger
|
||||
}
|
||||
|
||||
@@ -227,6 +233,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
||||
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
|
||||
refreshTokenPolicy: c.RefreshTokenPolicy,
|
||||
skipApproval: c.SkipApprovalScreen,
|
||||
alwaysShowLogin: c.AlwaysShowLoginScreen,
|
||||
now: now,
|
||||
|
@@ -677,6 +677,13 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to prepare rotation policy: %v", err)
|
||||
}
|
||||
policy.Clock = now
|
||||
s.refreshTokenPolicy = policy
|
||||
|
||||
mockConn := s.connectors["mock"]
|
||||
conn = mockConn.Connector.(*mock.Callback)
|
||||
|
||||
@@ -1508,6 +1515,13 @@ func TestOAuth2DeviceFlow(t *testing.T) {
|
||||
})
|
||||
defer httpServer.Close()
|
||||
|
||||
policy, err := NewRefreshTokenPolicyFromConfig(s.logger, false, "", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to prepare rotation policy: %v", err)
|
||||
}
|
||||
policy.Clock = now
|
||||
s.refreshTokenPolicy = policy
|
||||
|
||||
mockConn := s.connectors["mock"]
|
||||
conn = mockConn.Connector.(*mock.Callback)
|
||||
|
||||
|
Reference in New Issue
Block a user