fix: refresh token only once for all concurrent requests
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
124
storage/kubernetes/lock.go
Normal file
124
storage/kubernetes/lock.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package kubernetes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
lockAnnotation = "dexidp.com/resource-lock"
|
||||
lockTimeFormat = time.RFC3339
|
||||
)
|
||||
|
||||
var (
|
||||
lockTimeout = 10 * time.Second
|
||||
lockCheckPeriod = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
// refreshTokenLock is an implementation of annotation-based optimistic locking.
|
||||
//
|
||||
// Refresh token contains data to refresh identity in external authentication system.
|
||||
// There is a requirement that refresh should be called only once because of several reasons:
|
||||
// * Some of OIDC providers could use the refresh token rotation feature which requires calling refresh only once.
|
||||
// * Providers can limit the rate of requests to the token endpoint, which will lead to the error
|
||||
// in case of many concurrent requests.
|
||||
type refreshTokenLock struct {
|
||||
cli *client
|
||||
waitingState bool
|
||||
}
|
||||
|
||||
func newRefreshTokenLock(cli *client) *refreshTokenLock {
|
||||
return &refreshTokenLock{cli: cli}
|
||||
}
|
||||
|
||||
func (l *refreshTokenLock) Lock(id string) error {
|
||||
for i := 0; i <= 60; i++ {
|
||||
ok, err := l.setLockAnnotation(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(lockCheckPeriod)
|
||||
}
|
||||
return fmt.Errorf("timeout waiting for refresh token %s lock", id)
|
||||
}
|
||||
|
||||
func (l *refreshTokenLock) Unlock(id string) {
|
||||
if l.waitingState {
|
||||
// Do not need to unlock for waiting goroutines, because the have not set it.
|
||||
return
|
||||
}
|
||||
|
||||
r, err := l.cli.getRefreshToken(id)
|
||||
if err != nil {
|
||||
l.cli.logger.Debugf("failed to get resource to release lock for refresh token %s: %v", id, err)
|
||||
return
|
||||
}
|
||||
|
||||
r.Annotations = nil
|
||||
err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
|
||||
if err != nil {
|
||||
l.cli.logger.Debugf("failed to release lock for refresh token %s: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) {
|
||||
r, err := l.cli.getRefreshToken(id)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
currentTime := time.Now()
|
||||
lockData := map[string]string{
|
||||
lockAnnotation: currentTime.Add(lockTimeout).Format(lockTimeFormat),
|
||||
}
|
||||
|
||||
val, ok := r.Annotations[lockAnnotation]
|
||||
if !ok {
|
||||
if l.waitingState {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
r.Annotations = lockData
|
||||
err := l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
|
||||
if err == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if isKubernetesAPIConflictError(err) {
|
||||
l.waitingState = true
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
until, err := time.Parse(lockTimeFormat, val)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("lock annotation value is malformed: %v", err)
|
||||
}
|
||||
|
||||
if !currentTime.After(until) {
|
||||
// waiting for the lock to be released
|
||||
l.waitingState = true
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Lock time is out, lets break the lock and take the advantage
|
||||
r.Annotations = lockData
|
||||
|
||||
err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
|
||||
if err == nil {
|
||||
// break lock annotation
|
||||
return false, nil
|
||||
}
|
||||
|
||||
l.cli.logger.Debugf("break lock annotation error: %v", err)
|
||||
if isKubernetesAPIConflictError(err) {
|
||||
l.waitingState = true
|
||||
// after breaking error waiting for the lock to be released
|
||||
return true, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
@@ -451,11 +451,19 @@ func (cli *client) DeleteConnector(id string) error {
|
||||
}
|
||||
|
||||
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||
lock := newRefreshTokenLock(cli)
|
||||
|
||||
if err := lock.Lock(id); err != nil {
|
||||
return err
|
||||
}
|
||||
defer lock.Unlock(id)
|
||||
|
||||
return retryOnConflict(context.TODO(), func() error {
|
||||
r, err := cli.getRefreshToken(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updated, err := updater(toStorageRefreshToken(r))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -464,6 +472,7 @@ func (cli *client) UpdateRefreshToken(id string, updater func(old storage.Refres
|
||||
|
||||
newToken := cli.fromStorageRefreshToken(updated)
|
||||
newToken.ObjectMeta = r.ObjectMeta
|
||||
|
||||
return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken)
|
||||
})
|
||||
}
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -35,19 +36,22 @@ type StorageTestSuite struct {
|
||||
client *client
|
||||
}
|
||||
|
||||
func (s *StorageTestSuite) expandDir(dir string) string {
|
||||
func expandDir(dir string) (string, error) {
|
||||
dir = strings.Trim(dir, `"`)
|
||||
if strings.HasPrefix(dir, "~/") {
|
||||
homedir, err := os.UserHomeDir()
|
||||
s.Require().NoError(err)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
dir = filepath.Join(homedir, strings.TrimPrefix(dir, "~/"))
|
||||
}
|
||||
return dir
|
||||
return dir, nil
|
||||
}
|
||||
|
||||
func (s *StorageTestSuite) SetupTest() {
|
||||
kubeconfigPath := s.expandDir(os.Getenv(kubeconfigPathVariableName))
|
||||
kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName))
|
||||
s.Require().NoError(err)
|
||||
|
||||
config := Config{
|
||||
KubeConfigFile: kubeconfigPath,
|
||||
@@ -292,3 +296,95 @@ func TestRetryOnConflict(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshTokenLock(t *testing.T) {
|
||||
if os.Getenv(kubeconfigPathVariableName) == "" {
|
||||
t.Skipf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName)
|
||||
}
|
||||
|
||||
kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName))
|
||||
require.NoError(t, err)
|
||||
|
||||
config := Config{
|
||||
KubeConfigFile: kubeconfigPath,
|
||||
}
|
||||
|
||||
logger := &logrus.Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
|
||||
kubeClient, err := config.open(logger, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
lockCheckPeriod = time.Nanosecond
|
||||
|
||||
// Creating a storage with an existing refresh token and offline session for the user.
|
||||
id := storage.NewID()
|
||||
r := storage.RefreshToken{
|
||||
ID: id,
|
||||
Token: "bar",
|
||||
Nonce: "foo",
|
||||
ClientID: "client_id",
|
||||
ConnectorID: "client_secret",
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
CreatedAt: time.Now().UTC().Round(time.Millisecond),
|
||||
LastUsed: time.Now().UTC().Round(time.Millisecond),
|
||||
Claims: storage.Claims{
|
||||
UserID: "1",
|
||||
Username: "jane",
|
||||
Email: "jane.doe@example.com",
|
||||
EmailVerified: true,
|
||||
Groups: []string{"a", "b"},
|
||||
},
|
||||
ConnectorData: []byte(`{"some":"data"}`),
|
||||
}
|
||||
|
||||
err = kubeClient.CreateRefresh(r)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("Timeout lock error", func(t *testing.T) {
|
||||
err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||
r.Token = "update-result-1"
|
||||
err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||
r.Token = "timeout-err"
|
||||
return r, nil
|
||||
})
|
||||
require.Equal(t, fmt.Errorf("timeout waiting for refresh token %s lock", r.ID), err)
|
||||
return r, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := kubeClient.GetRefresh(r.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "update-result-1", token.Token)
|
||||
})
|
||||
|
||||
t.Run("Break the lock", func(t *testing.T) {
|
||||
var lockBroken bool
|
||||
lockTimeout = -time.Hour
|
||||
|
||||
err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||
r.Token = "update-result-2"
|
||||
if lockBroken {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
|
||||
r.Token = "should-break-the-lock-and-finish-updating"
|
||||
return r, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
lockBroken = true
|
||||
return r, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
token, err := kubeClient.GetRefresh(r.ID)
|
||||
require.NoError(t, err)
|
||||
// Because concurrent update breaks the lock, the final result will be the value of the first update
|
||||
require.Equal(t, "update-result-2", token.Token)
|
||||
})
|
||||
}
|
||||
|
Reference in New Issue
Block a user