Merge pull request #2692 from flant/refresh-once
fix: refresh token only once for all concurrent requests
This commit is contained in:
commit
e4bceef9f3
@ -28,6 +28,10 @@ type refreshError struct {
|
|||||||
desc string
|
desc string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *refreshError) Error() string {
|
||||||
|
return fmt.Sprintf("refresh token error: status %d, %q %s", r.code, r.msg, r.desc)
|
||||||
|
}
|
||||||
|
|
||||||
func newInternalServerError() *refreshError {
|
func newInternalServerError() *refreshError {
|
||||||
return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
|
return &refreshError{msg: errInvalidRequest, desc: "", code: http.StatusInternalServerError}
|
||||||
}
|
}
|
||||||
@ -60,10 +64,23 @@ func (s *Server) extractRefreshTokenFromRequest(r *http.Request) (*internal.Refr
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type refreshContext struct {
|
||||||
|
storageToken *storage.RefreshToken
|
||||||
|
requestToken *internal.RefreshToken
|
||||||
|
|
||||||
|
connector Connector
|
||||||
|
connectorData []byte
|
||||||
|
|
||||||
|
scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info
|
// getRefreshTokenFromStorage checks that refresh token is valid and exists in the storage and gets its info
|
||||||
func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*storage.RefreshToken, *refreshError) {
|
func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.RefreshToken) (*refreshContext, *refreshError) {
|
||||||
|
refreshCtx := refreshContext{requestToken: token}
|
||||||
|
|
||||||
invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.")
|
invalidErr := newBadRequestError("Refresh token is invalid or has already been claimed by another client.")
|
||||||
|
|
||||||
|
// Get RefreshToken
|
||||||
refresh, err := s.storage.GetRefresh(token.RefreshId)
|
refresh, err := s.storage.GetRefresh(token.RefreshId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != storage.ErrNotFound {
|
if err != storage.ErrNotFound {
|
||||||
@ -103,7 +120,31 @@ func (s *Server) getRefreshTokenFromStorage(clientID string, token *internal.Ref
|
|||||||
return nil, expiredErr
|
return nil, expiredErr
|
||||||
}
|
}
|
||||||
|
|
||||||
return &refresh, nil
|
refreshCtx.storageToken = &refresh
|
||||||
|
|
||||||
|
// Get Connector
|
||||||
|
refreshCtx.connector, err = s.getConnector(refresh.ConnectorID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
|
||||||
|
return nil, newInternalServerError()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Connector Data
|
||||||
|
session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
|
||||||
|
switch {
|
||||||
|
case err != nil:
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get offline session: %v", err)
|
||||||
|
return nil, newInternalServerError()
|
||||||
|
}
|
||||||
|
case len(refresh.ConnectorData) > 0:
|
||||||
|
// Use the old connector data if it exists, should be deleted once used
|
||||||
|
refreshCtx.connectorData = refresh.ConnectorData
|
||||||
|
default:
|
||||||
|
refreshCtx.connectorData = session.ConnectorData
|
||||||
|
}
|
||||||
|
|
||||||
|
return &refreshCtx, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) {
|
func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken) ([]string, *refreshError) {
|
||||||
@ -138,59 +179,23 @@ func (s *Server) getRefreshScopes(r *http.Request, refresh *storage.RefreshToken
|
|||||||
return requestedScopes, nil
|
return requestedScopes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) refreshWithConnector(ctx context.Context, token *internal.RefreshToken, refresh *storage.RefreshToken, scopes []string) (connector.Identity, *refreshError) {
|
func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext, ident connector.Identity) (connector.Identity, *refreshError) {
|
||||||
var connectorData []byte
|
|
||||||
|
|
||||||
session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID)
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
if err != storage.ErrNotFound {
|
|
||||||
s.logger.Errorf("failed to get offline session: %v", err)
|
|
||||||
return connector.Identity{}, newInternalServerError()
|
|
||||||
}
|
|
||||||
case len(refresh.ConnectorData) > 0:
|
|
||||||
// Use the old connector data if it exists, should be deleted once used
|
|
||||||
connectorData = refresh.ConnectorData
|
|
||||||
default:
|
|
||||||
connectorData = session.ConnectorData
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := s.getConnector(refresh.ConnectorID)
|
|
||||||
if err != nil {
|
|
||||||
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
|
|
||||||
return connector.Identity{}, newInternalServerError()
|
|
||||||
}
|
|
||||||
|
|
||||||
ident := connector.Identity{
|
|
||||||
UserID: refresh.Claims.UserID,
|
|
||||||
Username: refresh.Claims.Username,
|
|
||||||
PreferredUsername: refresh.Claims.PreferredUsername,
|
|
||||||
Email: refresh.Claims.Email,
|
|
||||||
EmailVerified: refresh.Claims.EmailVerified,
|
|
||||||
Groups: refresh.Claims.Groups,
|
|
||||||
ConnectorData: connectorData,
|
|
||||||
}
|
|
||||||
|
|
||||||
// user's token was previously updated by a connector and is allowed to reuse
|
|
||||||
// it is excessive to refresh identity in upstream
|
|
||||||
if s.refreshTokenPolicy.AllowedToReuse(refresh.LastUsed) && token.Token == refresh.ObsoleteToken {
|
|
||||||
return ident, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Can the connector refresh the identity? If so, attempt to refresh the data
|
// Can the connector refresh the identity? If so, attempt to refresh the data
|
||||||
// in the connector.
|
// in the connector.
|
||||||
//
|
//
|
||||||
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
|
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
|
||||||
// this interface can't perform refreshing.
|
// this interface can't perform refreshing.
|
||||||
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
|
if refreshConn, ok := rCtx.connector.Connector.(connector.RefreshConnector); ok {
|
||||||
newIdent, err := refreshConn.Refresh(ctx, parseScopes(scopes), ident)
|
s.logger.Debugf("connector data before refresh: %s", ident.ConnectorData)
|
||||||
|
|
||||||
|
newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to refresh identity: %v", err)
|
s.logger.Errorf("failed to refresh identity: %v", err)
|
||||||
return connector.Identity{}, newInternalServerError()
|
return ident, newInternalServerError()
|
||||||
}
|
|
||||||
ident = newIdent
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return newIdent, nil
|
||||||
|
}
|
||||||
return ident, nil
|
return ident, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -200,8 +205,14 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne
|
|||||||
if old.Refresh[refresh.ClientID].ID != refresh.ID {
|
if old.Refresh[refresh.ClientID].ID != refresh.ID {
|
||||||
return old, errors.New("refresh token invalid")
|
return old, errors.New("refresh token invalid")
|
||||||
}
|
}
|
||||||
|
|
||||||
old.Refresh[refresh.ClientID].LastUsed = lastUsed
|
old.Refresh[refresh.ClientID].LastUsed = lastUsed
|
||||||
|
if len(ident.ConnectorData) > 0 {
|
||||||
old.ConnectorData = ident.ConnectorData
|
old.ConnectorData = ident.ConnectorData
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Debugf("saved connector data: %s %s", ident.UserID, ident.ConnectorData)
|
||||||
|
|
||||||
return old, nil
|
return old, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -217,33 +228,74 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateRefreshToken updates refresh token and offline session in the storage
|
// updateRefreshToken updates refresh token and offline session in the storage
|
||||||
func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *storage.RefreshToken, ident connector.Identity) (*internal.RefreshToken, *refreshError) {
|
func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (*internal.RefreshToken, connector.Identity, *refreshError) {
|
||||||
newToken := token
|
var rerr *refreshError
|
||||||
if s.refreshTokenPolicy.RotationEnabled() {
|
|
||||||
newToken = &internal.RefreshToken{
|
newToken := &internal.RefreshToken{
|
||||||
RefreshId: refresh.ID,
|
Token: rCtx.requestToken.Token,
|
||||||
Token: storage.NewID(),
|
RefreshId: rCtx.requestToken.RefreshId,
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lastUsed := s.now()
|
lastUsed := s.now()
|
||||||
|
|
||||||
refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
ident := connector.Identity{
|
||||||
if s.refreshTokenPolicy.RotationEnabled() {
|
UserID: rCtx.storageToken.Claims.UserID,
|
||||||
if old.Token != token.Token {
|
Username: rCtx.storageToken.Claims.Username,
|
||||||
if s.refreshTokenPolicy.AllowedToReuse(old.LastUsed) && old.ObsoleteToken == token.Token {
|
PreferredUsername: rCtx.storageToken.Claims.PreferredUsername,
|
||||||
newToken.Token = old.Token
|
Email: rCtx.storageToken.Claims.Email,
|
||||||
// Do not update last used time for offline session if token is allowed to be reused
|
EmailVerified: rCtx.storageToken.Claims.EmailVerified,
|
||||||
lastUsed = old.LastUsed
|
Groups: rCtx.storageToken.Claims.Groups,
|
||||||
return old, nil
|
ConnectorData: rCtx.connectorData,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
refreshTokenUpdater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
||||||
|
rotationEnabled := s.refreshTokenPolicy.RotationEnabled()
|
||||||
|
reusingAllowed := s.refreshTokenPolicy.AllowedToReuse(old.LastUsed)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case !rotationEnabled && reusingAllowed:
|
||||||
|
// If rotation is disabled and the offline session was updated not so long ago - skip further actions.
|
||||||
|
return old, nil
|
||||||
|
|
||||||
|
case rotationEnabled && reusingAllowed:
|
||||||
|
if old.Token != rCtx.requestToken.Token && old.ObsoleteToken != rCtx.requestToken.Token {
|
||||||
return old, errors.New("refresh token claimed twice")
|
return old, errors.New("refresh token claimed twice")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return previously generated token for all requests with an obsolete tokens
|
||||||
|
if old.ObsoleteToken == rCtx.requestToken.Token {
|
||||||
|
newToken.Token = old.Token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not update last used time for offline session if token is allowed to be reused
|
||||||
|
lastUsed = old.LastUsed
|
||||||
|
ident.ConnectorData = nil
|
||||||
|
return old, nil
|
||||||
|
|
||||||
|
case rotationEnabled && !reusingAllowed:
|
||||||
|
if old.Token != rCtx.requestToken.Token {
|
||||||
|
return old, errors.New("refresh token claimed twice")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Issue new refresh token
|
||||||
old.ObsoleteToken = old.Token
|
old.ObsoleteToken = old.Token
|
||||||
|
newToken.Token = storage.NewID()
|
||||||
}
|
}
|
||||||
|
|
||||||
old.Token = newToken.Token
|
old.Token = newToken.Token
|
||||||
|
old.LastUsed = lastUsed
|
||||||
|
|
||||||
|
// ConnectorData has been moved to OfflineSession
|
||||||
|
old.ConnectorData = []byte{}
|
||||||
|
|
||||||
|
// Call only once if there is a request which is not in the reuse interval.
|
||||||
|
// This is required to avoid multiple calls to the external IdP for concurrent requests.
|
||||||
|
// Dex will call the connector's Refresh method only once if request is not in reuse interval.
|
||||||
|
ident, rerr = s.refreshWithConnector(ctx, rCtx, ident)
|
||||||
|
if rerr != nil {
|
||||||
|
return old, rerr
|
||||||
|
}
|
||||||
|
|
||||||
// Update the claims of the refresh token.
|
// Update the claims of the refresh token.
|
||||||
//
|
//
|
||||||
// UserID intentionally ignored for now.
|
// UserID intentionally ignored for now.
|
||||||
@ -252,26 +304,23 @@ func (s *Server) updateRefreshToken(token *internal.RefreshToken, refresh *stora
|
|||||||
old.Claims.Email = ident.Email
|
old.Claims.Email = ident.Email
|
||||||
old.Claims.EmailVerified = ident.EmailVerified
|
old.Claims.EmailVerified = ident.EmailVerified
|
||||||
old.Claims.Groups = ident.Groups
|
old.Claims.Groups = ident.Groups
|
||||||
old.LastUsed = lastUsed
|
|
||||||
|
|
||||||
// ConnectorData has been moved to OfflineSession
|
|
||||||
old.ConnectorData = []byte{}
|
|
||||||
return old, nil
|
return old, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update refresh token in the storage.
|
// Update refresh token in the storage.
|
||||||
err := s.storage.UpdateRefreshToken(refresh.ID, refreshTokenUpdater)
|
err := s.storage.UpdateRefreshToken(rCtx.storageToken.ID, refreshTokenUpdater)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to update refresh token: %v", err)
|
s.logger.Errorf("failed to update refresh token: %v", err)
|
||||||
return nil, newInternalServerError()
|
return nil, ident, newInternalServerError()
|
||||||
}
|
}
|
||||||
|
|
||||||
rerr := s.updateOfflineSession(refresh, ident, lastUsed)
|
rerr = s.updateOfflineSession(rCtx.storageToken, ident, lastUsed)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
return nil, rerr
|
return nil, ident, rerr
|
||||||
}
|
}
|
||||||
|
|
||||||
return newToken, nil
|
return newToken, ident, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
// handleRefreshToken handles a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
||||||
@ -283,19 +332,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
refresh, rerr := s.getRefreshTokenFromStorage(client.ID, token)
|
rCtx, rerr := s.getRefreshTokenFromStorage(client.ID, token)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
s.refreshTokenErrHelper(w, rerr)
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
scopes, rerr := s.getRefreshScopes(r, refresh)
|
rCtx.scopes, rerr = s.getRefreshScopes(r, rCtx.storageToken)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
s.refreshTokenErrHelper(w, rerr)
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ident, rerr := s.refreshWithConnector(r.Context(), token, refresh, scopes)
|
newToken, ident, rerr := s.updateRefreshToken(r.Context(), rCtx)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
s.refreshTokenErrHelper(w, rerr)
|
s.refreshTokenErrHelper(w, rerr)
|
||||||
return
|
return
|
||||||
@ -310,26 +359,20 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||||||
Groups: ident.Groups,
|
Groups: ident.Groups,
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
|
accessToken, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to create new access token: %v", err)
|
s.logger.Errorf("failed to create new access token: %v", err)
|
||||||
s.refreshTokenErrHelper(w, newInternalServerError())
|
s.refreshTokenErrHelper(w, newInternalServerError())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, "", refresh.ConnectorID)
|
idToken, expiry, err := s.newIDToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to create ID token: %v", err)
|
s.logger.Errorf("failed to create ID token: %v", err)
|
||||||
s.refreshTokenErrHelper(w, newInternalServerError())
|
s.refreshTokenErrHelper(w, newInternalServerError())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newToken, rerr := s.updateRefreshToken(token, refresh, ident)
|
|
||||||
if rerr != nil {
|
|
||||||
s.refreshTokenErrHelper(w, rerr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawNewToken, err := internal.Marshal(newToken)
|
rawNewToken, err := internal.Marshal(newToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||||
|
124
storage/kubernetes/lock.go
Normal file
124
storage/kubernetes/lock.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
package kubernetes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
lockAnnotation = "dexidp.com/resource-lock"
|
||||||
|
lockTimeFormat = time.RFC3339
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
lockTimeout = 10 * time.Second
|
||||||
|
lockCheckPeriod = 100 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
// refreshTokenLock is an implementation of annotation-based optimistic locking.
|
||||||
|
//
|
||||||
|
// Refresh token contains data to refresh identity in external authentication system.
|
||||||
|
// There is a requirement that refresh should be called only once because of several reasons:
|
||||||
|
// * Some of OIDC providers could use the refresh token rotation feature which requires calling refresh only once.
|
||||||
|
// * Providers can limit the rate of requests to the token endpoint, which will lead to the error
|
||||||
|
// in case of many concurrent requests.
|
||||||
|
type refreshTokenLock struct {
|
||||||
|
cli *client
|
||||||
|
waitingState bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRefreshTokenLock(cli *client) *refreshTokenLock {
|
||||||
|
return &refreshTokenLock{cli: cli}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *refreshTokenLock) Lock(id string) error {
|
||||||
|
for i := 0; i <= 60; i++ {
|
||||||
|
ok, err := l.setLockAnnotation(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
time.Sleep(lockCheckPeriod)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("timeout waiting for refresh token %s lock", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *refreshTokenLock) Unlock(id string) {
|
||||||
|
if l.waitingState {
|
||||||
|
// Do not need to unlock for waiting goroutines, because the have not set it.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := l.cli.getRefreshToken(id)
|
||||||
|
if err != nil {
|
||||||
|
l.cli.logger.Debugf("failed to get resource to release lock for refresh token %s: %v", id, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Annotations = nil
|
||||||
|
err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
|
||||||
|
if err != nil {
|
||||||
|
l.cli.logger.Debugf("failed to release lock for refresh token %s: %v", id, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) {
|
||||||
|
r, err := l.cli.getRefreshToken(id)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
currentTime := time.Now()
|
||||||
|
lockData := map[string]string{
|
||||||
|
lockAnnotation: currentTime.Add(lockTimeout).Format(lockTimeFormat),
|
||||||
|
}
|
||||||
|
|
||||||
|
val, ok := r.Annotations[lockAnnotation]
|
||||||
|
if !ok {
|
||||||
|
if l.waitingState {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Annotations = lockData
|
||||||
|
err := l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
|
||||||
|
if err == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isKubernetesAPIConflictError(err) {
|
||||||
|
l.waitingState = true
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
until, err := time.Parse(lockTimeFormat, val)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("lock annotation value is malformed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !currentTime.After(until) {
|
||||||
|
// waiting for the lock to be released
|
||||||
|
l.waitingState = true
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lock time is out, lets break the lock and take the advantage
|
||||||
|
r.Annotations = lockData
|
||||||
|
|
||||||
|
err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
|
||||||
|
if err == nil {
|
||||||
|
// break lock annotation
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
l.cli.logger.Debugf("break lock annotation error: %v", err)
|
||||||
|
if isKubernetesAPIConflictError(err) {
|
||||||
|
l.waitingState = true
|
||||||
|
// after breaking error waiting for the lock to be released
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
@ -451,11 +451,19 @@ func (cli *client) DeleteConnector(id string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||||
|
lock := newRefreshTokenLock(cli)
|
||||||
|
|
||||||
|
if err := lock.Lock(id); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer lock.Unlock(id)
|
||||||
|
|
||||||
return retryOnConflict(context.TODO(), func() error {
|
return retryOnConflict(context.TODO(), func() error {
|
||||||
r, err := cli.getRefreshToken(id)
|
r, err := cli.getRefreshToken(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := updater(toStorageRefreshToken(r))
|
updated, err := updater(toStorageRefreshToken(r))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -464,6 +472,7 @@ func (cli *client) UpdateRefreshToken(id string, updater func(old storage.Refres
|
|||||||
|
|
||||||
newToken := cli.fromStorageRefreshToken(updated)
|
newToken := cli.fromStorageRefreshToken(updated)
|
||||||
newToken.ObjectMeta = r.ObjectMeta
|
newToken.ObjectMeta = r.ObjectMeta
|
||||||
|
|
||||||
return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken)
|
return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -35,19 +36,22 @@ type StorageTestSuite struct {
|
|||||||
client *client
|
client *client
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StorageTestSuite) expandDir(dir string) string {
|
func expandDir(dir string) (string, error) {
|
||||||
dir = strings.Trim(dir, `"`)
|
dir = strings.Trim(dir, `"`)
|
||||||
if strings.HasPrefix(dir, "~/") {
|
if strings.HasPrefix(dir, "~/") {
|
||||||
homedir, err := os.UserHomeDir()
|
homedir, err := os.UserHomeDir()
|
||||||
s.Require().NoError(err)
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
dir = filepath.Join(homedir, strings.TrimPrefix(dir, "~/"))
|
dir = filepath.Join(homedir, strings.TrimPrefix(dir, "~/"))
|
||||||
}
|
}
|
||||||
return dir
|
return dir, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StorageTestSuite) SetupTest() {
|
func (s *StorageTestSuite) SetupTest() {
|
||||||
kubeconfigPath := s.expandDir(os.Getenv(kubeconfigPathVariableName))
|
kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName))
|
||||||
|
s.Require().NoError(err)
|
||||||
|
|
||||||
config := Config{
|
config := Config{
|
||||||
KubeConfigFile: kubeconfigPath,
|
KubeConfigFile: kubeconfigPath,
|
||||||
@ -292,3 +296,95 @@ func TestRetryOnConflict(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenLock(t *testing.T) {
|
||||||
|
if os.Getenv(kubeconfigPathVariableName) == "" {
|
||||||
|
t.Skipf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := Config{
|
||||||
|
KubeConfigFile: kubeconfigPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := &logrus.Logger{
|
||||||
|
Out: os.Stderr,
|
||||||
|
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||||
|
Level: logrus.DebugLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
kubeClient, err := config.open(logger, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
lockCheckPeriod = time.Nanosecond
|
||||||
|
|
||||||
|
// Creating a storage with an existing refresh token and offline session for the user.
|
||||||
|
id := storage.NewID()
|
||||||
|
r := storage.RefreshToken{
|
||||||
|
ID: id,
|
||||||
|
Token: "bar",
|
||||||
|
Nonce: "foo",
|
||||||
|
ClientID: "client_id",
|
||||||
|
ConnectorID: "client_secret",
|
||||||
|
Scopes: []string{"openid", "email", "profile"},
|
||||||
|
CreatedAt: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
LastUsed: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
Claims: storage.Claims{
|
||||||
|
UserID: "1",
|
||||||
|
Username: "jane",
|
||||||
|
Email: "jane.doe@example.com",
|
||||||
|
EmailVerified: true,
|
||||||
|
Groups: []string{"a", "b"},
|
||||||
|
},
|
||||||
|
ConnectorData: []byte(`{"some":"data"}`),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = kubeClient.CreateRefresh(r)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("Timeout lock error", func(t *testing.T) {
|
||||||
|
err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||||
|
r.Token = "update-result-1"
|
||||||
|
err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||||
|
r.Token = "timeout-err"
|
||||||
|
return r, nil
|
||||||
|
})
|
||||||
|
require.Equal(t, fmt.Errorf("timeout waiting for refresh token %s lock", r.ID), err)
|
||||||
|
return r, nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := kubeClient.GetRefresh(r.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "update-result-1", token.Token)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Break the lock", func(t *testing.T) {
|
||||||
|
var lockBroken bool
|
||||||
|
lockTimeout = -time.Hour
|
||||||
|
|
||||||
|
err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||||
|
r.Token = "update-result-2"
|
||||||
|
if lockBroken {
|
||||||
|
return r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||||
|
r.Token = "should-break-the-lock-and-finish-updating"
|
||||||
|
return r, nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
lockBroken = true
|
||||||
|
return r, nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token, err := kubeClient.GetRefresh(r.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Because concurrent update breaks the lock, the final result will be the value of the first update
|
||||||
|
require.Equal(t, "update-result-2", token.Token)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user