fix: refresh token only once for all concurrent requests
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
		
							
								
								
									
										124
									
								
								storage/kubernetes/lock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								storage/kubernetes/lock.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,124 @@ | ||||
| package kubernetes | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	lockAnnotation = "dexidp.com/resource-lock" | ||||
| 	lockTimeFormat = time.RFC3339 | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	lockTimeout     = 10 * time.Second | ||||
| 	lockCheckPeriod = 100 * time.Millisecond | ||||
| ) | ||||
|  | ||||
| // refreshTokenLock is an implementation of annotation-based optimistic locking. | ||||
| // | ||||
| // Refresh token contains data to refresh identity in external authentication system. | ||||
| // There is a requirement that refresh should be called only once because of several reasons: | ||||
| // * Some of OIDC providers could use the refresh token rotation feature which requires calling refresh only once. | ||||
| // * Providers can limit the rate of requests to the token endpoint, which will lead to the error | ||||
| //   in case of many concurrent requests. | ||||
| type refreshTokenLock struct { | ||||
| 	cli          *client | ||||
| 	waitingState bool | ||||
| } | ||||
|  | ||||
| func newRefreshTokenLock(cli *client) *refreshTokenLock { | ||||
| 	return &refreshTokenLock{cli: cli} | ||||
| } | ||||
|  | ||||
| func (l *refreshTokenLock) Lock(id string) error { | ||||
| 	for i := 0; i <= 60; i++ { | ||||
| 		ok, err := l.setLockAnnotation(id) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if !ok { | ||||
| 			return nil | ||||
| 		} | ||||
| 		time.Sleep(lockCheckPeriod) | ||||
| 	} | ||||
| 	return fmt.Errorf("timeout waiting for refresh token %s lock", id) | ||||
| } | ||||
|  | ||||
| func (l *refreshTokenLock) Unlock(id string) { | ||||
| 	if l.waitingState { | ||||
| 		// Do not need to unlock for waiting goroutines, because the have not set it. | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r, err := l.cli.getRefreshToken(id) | ||||
| 	if err != nil { | ||||
| 		l.cli.logger.Debugf("failed to get resource to release lock for refresh token %s: %v", id, err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	r.Annotations = nil | ||||
| 	err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) | ||||
| 	if err != nil { | ||||
| 		l.cli.logger.Debugf("failed to release lock for refresh token %s: %v", id, err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) { | ||||
| 	r, err := l.cli.getRefreshToken(id) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| 	currentTime := time.Now() | ||||
| 	lockData := map[string]string{ | ||||
| 		lockAnnotation: currentTime.Add(lockTimeout).Format(lockTimeFormat), | ||||
| 	} | ||||
|  | ||||
| 	val, ok := r.Annotations[lockAnnotation] | ||||
| 	if !ok { | ||||
| 		if l.waitingState { | ||||
| 			return false, nil | ||||
| 		} | ||||
|  | ||||
| 		r.Annotations = lockData | ||||
| 		err := l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) | ||||
| 		if err == nil { | ||||
| 			return false, nil | ||||
| 		} | ||||
|  | ||||
| 		if isKubernetesAPIConflictError(err) { | ||||
| 			l.waitingState = true | ||||
| 			return true, nil | ||||
| 		} | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| 	until, err := time.Parse(lockTimeFormat, val) | ||||
| 	if err != nil { | ||||
| 		return false, fmt.Errorf("lock annotation value is malformed: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if !currentTime.After(until) { | ||||
| 		// waiting for the lock to be released | ||||
| 		l.waitingState = true | ||||
| 		return true, nil | ||||
| 	} | ||||
|  | ||||
| 	// Lock time is out, lets break the lock and take the advantage | ||||
| 	r.Annotations = lockData | ||||
|  | ||||
| 	err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) | ||||
| 	if err == nil { | ||||
| 		// break lock annotation | ||||
| 		return false, nil | ||||
| 	} | ||||
|  | ||||
| 	l.cli.logger.Debugf("break lock annotation error: %v", err) | ||||
| 	if isKubernetesAPIConflictError(err) { | ||||
| 		l.waitingState = true | ||||
| 		// after breaking error waiting for the lock to be released | ||||
| 		return true, nil | ||||
| 	} | ||||
| 	return false, err | ||||
| } | ||||
| @@ -451,11 +451,19 @@ func (cli *client) DeleteConnector(id string) error { | ||||
| } | ||||
|  | ||||
| func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { | ||||
| 	lock := newRefreshTokenLock(cli) | ||||
|  | ||||
| 	if err := lock.Lock(id); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer lock.Unlock(id) | ||||
|  | ||||
| 	return retryOnConflict(context.TODO(), func() error { | ||||
| 		r, err := cli.getRefreshToken(id) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		updated, err := updater(toStorageRefreshToken(r)) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| @@ -464,6 +472,7 @@ func (cli *client) UpdateRefreshToken(id string, updater func(old storage.Refres | ||||
|  | ||||
| 		newToken := cli.fromStorageRefreshToken(updated) | ||||
| 		newToken.ObjectMeta = r.ObjectMeta | ||||
|  | ||||
| 		return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken) | ||||
| 	}) | ||||
| } | ||||
|   | ||||
| @@ -11,6 +11,7 @@ import ( | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| @@ -35,19 +36,22 @@ type StorageTestSuite struct { | ||||
| 	client *client | ||||
| } | ||||
|  | ||||
| func (s *StorageTestSuite) expandDir(dir string) string { | ||||
| func expandDir(dir string) (string, error) { | ||||
| 	dir = strings.Trim(dir, `"`) | ||||
| 	if strings.HasPrefix(dir, "~/") { | ||||
| 		homedir, err := os.UserHomeDir() | ||||
| 		s.Require().NoError(err) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
|  | ||||
| 		dir = filepath.Join(homedir, strings.TrimPrefix(dir, "~/")) | ||||
| 	} | ||||
| 	return dir | ||||
| 	return dir, nil | ||||
| } | ||||
|  | ||||
| func (s *StorageTestSuite) SetupTest() { | ||||
| 	kubeconfigPath := s.expandDir(os.Getenv(kubeconfigPathVariableName)) | ||||
| 	kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName)) | ||||
| 	s.Require().NoError(err) | ||||
|  | ||||
| 	config := Config{ | ||||
| 		KubeConfigFile: kubeconfigPath, | ||||
| @@ -292,3 +296,95 @@ func TestRetryOnConflict(t *testing.T) { | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRefreshTokenLock(t *testing.T) { | ||||
| 	if os.Getenv(kubeconfigPathVariableName) == "" { | ||||
| 		t.Skipf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName) | ||||
| 	} | ||||
|  | ||||
| 	kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName)) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	config := Config{ | ||||
| 		KubeConfigFile: kubeconfigPath, | ||||
| 	} | ||||
|  | ||||
| 	logger := &logrus.Logger{ | ||||
| 		Out:       os.Stderr, | ||||
| 		Formatter: &logrus.TextFormatter{DisableColors: true}, | ||||
| 		Level:     logrus.DebugLevel, | ||||
| 	} | ||||
|  | ||||
| 	kubeClient, err := config.open(logger, true) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	lockCheckPeriod = time.Nanosecond | ||||
|  | ||||
| 	// Creating a storage with an existing refresh token and offline session for the user. | ||||
| 	id := storage.NewID() | ||||
| 	r := storage.RefreshToken{ | ||||
| 		ID:          id, | ||||
| 		Token:       "bar", | ||||
| 		Nonce:       "foo", | ||||
| 		ClientID:    "client_id", | ||||
| 		ConnectorID: "client_secret", | ||||
| 		Scopes:      []string{"openid", "email", "profile"}, | ||||
| 		CreatedAt:   time.Now().UTC().Round(time.Millisecond), | ||||
| 		LastUsed:    time.Now().UTC().Round(time.Millisecond), | ||||
| 		Claims: storage.Claims{ | ||||
| 			UserID:        "1", | ||||
| 			Username:      "jane", | ||||
| 			Email:         "jane.doe@example.com", | ||||
| 			EmailVerified: true, | ||||
| 			Groups:        []string{"a", "b"}, | ||||
| 		}, | ||||
| 		ConnectorData: []byte(`{"some":"data"}`), | ||||
| 	} | ||||
|  | ||||
| 	err = kubeClient.CreateRefresh(r) | ||||
| 	require.NoError(t, err) | ||||
|  | ||||
| 	t.Run("Timeout lock error", func(t *testing.T) { | ||||
| 		err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { | ||||
| 			r.Token = "update-result-1" | ||||
| 			err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { | ||||
| 				r.Token = "timeout-err" | ||||
| 				return r, nil | ||||
| 			}) | ||||
| 			require.Equal(t, fmt.Errorf("timeout waiting for refresh token %s lock", r.ID), err) | ||||
| 			return r, nil | ||||
| 		}) | ||||
| 		require.NoError(t, err) | ||||
|  | ||||
| 		token, err := kubeClient.GetRefresh(r.ID) | ||||
| 		require.NoError(t, err) | ||||
| 		require.Equal(t, "update-result-1", token.Token) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("Break the lock", func(t *testing.T) { | ||||
| 		var lockBroken bool | ||||
| 		lockTimeout = -time.Hour | ||||
|  | ||||
| 		err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { | ||||
| 			r.Token = "update-result-2" | ||||
| 			if lockBroken { | ||||
| 				return r, nil | ||||
| 			} | ||||
|  | ||||
| 			err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) { | ||||
| 				r.Token = "should-break-the-lock-and-finish-updating" | ||||
| 				return r, nil | ||||
| 			}) | ||||
| 			require.NoError(t, err) | ||||
|  | ||||
| 			lockBroken = true | ||||
| 			return r, nil | ||||
| 		}) | ||||
| 		require.NoError(t, err) | ||||
|  | ||||
| 		token, err := kubeClient.GetRefresh(r.ID) | ||||
| 		require.NoError(t, err) | ||||
| 		// Because concurrent update breaks the lock, the final result will be the value of the first update | ||||
| 		require.Equal(t, "update-result-2", token.Token) | ||||
| 	}) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user