fix: refresh token only once for all concurrent requests

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
m.nabokikh
2022-10-02 22:40:22 +02:00
parent ffeb4d5e16
commit 4b5f1d5289
4 changed files with 357 additions and 85 deletions

124
storage/kubernetes/lock.go Normal file
View File

@@ -0,0 +1,124 @@
package kubernetes
import (
"fmt"
"time"
)
const (
lockAnnotation = "dexidp.com/resource-lock"
lockTimeFormat = time.RFC3339
)
var (
lockTimeout = 10 * time.Second
lockCheckPeriod = 100 * time.Millisecond
)
// refreshTokenLock is an implementation of annotation-based optimistic locking.
//
// Refresh token contains data to refresh identity in external authentication system.
// There is a requirement that refresh should be called only once because of several reasons:
// * Some of OIDC providers could use the refresh token rotation feature which requires calling refresh only once.
// * Providers can limit the rate of requests to the token endpoint, which will lead to the error
// in case of many concurrent requests.
type refreshTokenLock struct {
cli *client
waitingState bool
}
func newRefreshTokenLock(cli *client) *refreshTokenLock {
return &refreshTokenLock{cli: cli}
}
func (l *refreshTokenLock) Lock(id string) error {
for i := 0; i <= 60; i++ {
ok, err := l.setLockAnnotation(id)
if err != nil {
return err
}
if !ok {
return nil
}
time.Sleep(lockCheckPeriod)
}
return fmt.Errorf("timeout waiting for refresh token %s lock", id)
}
func (l *refreshTokenLock) Unlock(id string) {
if l.waitingState {
// Do not need to unlock for waiting goroutines, because the have not set it.
return
}
r, err := l.cli.getRefreshToken(id)
if err != nil {
l.cli.logger.Debugf("failed to get resource to release lock for refresh token %s: %v", id, err)
return
}
r.Annotations = nil
err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
if err != nil {
l.cli.logger.Debugf("failed to release lock for refresh token %s: %v", id, err)
}
}
func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) {
r, err := l.cli.getRefreshToken(id)
if err != nil {
return false, err
}
currentTime := time.Now()
lockData := map[string]string{
lockAnnotation: currentTime.Add(lockTimeout).Format(lockTimeFormat),
}
val, ok := r.Annotations[lockAnnotation]
if !ok {
if l.waitingState {
return false, nil
}
r.Annotations = lockData
err := l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
if err == nil {
return false, nil
}
if isKubernetesAPIConflictError(err) {
l.waitingState = true
return true, nil
}
return false, err
}
until, err := time.Parse(lockTimeFormat, val)
if err != nil {
return false, fmt.Errorf("lock annotation value is malformed: %v", err)
}
if !currentTime.After(until) {
// waiting for the lock to be released
l.waitingState = true
return true, nil
}
// Lock time is out, lets break the lock and take the advantage
r.Annotations = lockData
err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
if err == nil {
// break lock annotation
return false, nil
}
l.cli.logger.Debugf("break lock annotation error: %v", err)
if isKubernetesAPIConflictError(err) {
l.waitingState = true
// after breaking error waiting for the lock to be released
return true, nil
}
return false, err
}

View File

@@ -451,11 +451,19 @@ func (cli *client) DeleteConnector(id string) error {
}
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
lock := newRefreshTokenLock(cli)
if err := lock.Lock(id); err != nil {
return err
}
defer lock.Unlock(id)
return retryOnConflict(context.TODO(), func() error {
r, err := cli.getRefreshToken(id)
if err != nil {
return err
}
updated, err := updater(toStorageRefreshToken(r))
if err != nil {
return err
@@ -464,6 +472,7 @@ func (cli *client) UpdateRefreshToken(id string, updater func(old storage.Refres
newToken := cli.fromStorageRefreshToken(updated)
newToken.ObjectMeta = r.ObjectMeta
return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken)
})
}

View File

@@ -11,6 +11,7 @@ import (
"path/filepath"
"strings"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
@@ -35,19 +36,22 @@ type StorageTestSuite struct {
client *client
}
func (s *StorageTestSuite) expandDir(dir string) string {
func expandDir(dir string) (string, error) {
dir = strings.Trim(dir, `"`)
if strings.HasPrefix(dir, "~/") {
homedir, err := os.UserHomeDir()
s.Require().NoError(err)
if err != nil {
return "", err
}
dir = filepath.Join(homedir, strings.TrimPrefix(dir, "~/"))
}
return dir
return dir, nil
}
func (s *StorageTestSuite) SetupTest() {
kubeconfigPath := s.expandDir(os.Getenv(kubeconfigPathVariableName))
kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName))
s.Require().NoError(err)
config := Config{
KubeConfigFile: kubeconfigPath,
@@ -292,3 +296,95 @@ func TestRetryOnConflict(t *testing.T) {
})
}
}
func TestRefreshTokenLock(t *testing.T) {
if os.Getenv(kubeconfigPathVariableName) == "" {
t.Skipf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName)
}
kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName))
require.NoError(t, err)
config := Config{
KubeConfigFile: kubeconfigPath,
}
logger := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
kubeClient, err := config.open(logger, true)
require.NoError(t, err)
lockCheckPeriod = time.Nanosecond
// Creating a storage with an existing refresh token and offline session for the user.
id := storage.NewID()
r := storage.RefreshToken{
ID: id,
Token: "bar",
Nonce: "foo",
ClientID: "client_id",
ConnectorID: "client_secret",
Scopes: []string{"openid", "email", "profile"},
CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
ConnectorData: []byte(`{"some":"data"}`),
}
err = kubeClient.CreateRefresh(r)
require.NoError(t, err)
t.Run("Timeout lock error", func(t *testing.T) {
err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "update-result-1"
err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "timeout-err"
return r, nil
})
require.Equal(t, fmt.Errorf("timeout waiting for refresh token %s lock", r.ID), err)
return r, nil
})
require.NoError(t, err)
token, err := kubeClient.GetRefresh(r.ID)
require.NoError(t, err)
require.Equal(t, "update-result-1", token.Token)
})
t.Run("Break the lock", func(t *testing.T) {
var lockBroken bool
lockTimeout = -time.Hour
err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "update-result-2"
if lockBroken {
return r, nil
}
err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "should-break-the-lock-and-finish-updating"
return r, nil
})
require.NoError(t, err)
lockBroken = true
return r, nil
})
require.NoError(t, err)
token, err := kubeClient.GetRefresh(r.ID)
require.NoError(t, err)
// Because concurrent update breaks the lock, the final result will be the value of the first update
require.Equal(t, "update-result-2", token.Token)
})
}