diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index c0d6eb91..04c59171 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/rand" "net/http" "strings" "time" @@ -439,19 +440,21 @@ func (cli *client) DeleteConnector(id string) error { } func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { - r, err := cli.getRefreshToken(id) - if err != nil { - return err - } - updated, err := updater(toStorageRefreshToken(r)) - if err != nil { - return err - } - updated.ID = id + return retryOnConflict(context.TODO(), func() error { + r, err := cli.getRefreshToken(id) + if err != nil { + return err + } + updated, err := updater(toStorageRefreshToken(r)) + if err != nil { + return err + } + updated.ID = id - newToken := cli.fromStorageRefreshToken(updated) - newToken.ObjectMeta = r.ObjectMeta - return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken) + newToken := cli.fromStorageRefreshToken(updated) + newToken.ObjectMeta = r.ObjectMeta + return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken) + }) } func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { @@ -489,19 +492,21 @@ func (cli *client) UpdatePassword(email string, updater func(old storage.Passwor } func (cli *client) UpdateOfflineSessions(userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error { - o, err := cli.getOfflineSessions(userID, connID) - if err != nil { - return err - } + return retryOnConflict(context.TODO(), func() error { + o, err := cli.getOfflineSessions(userID, connID) + if err != nil { + return err + } - updated, err := updater(toStorageOfflineSessions(o)) - if err != nil { - return err - } + updated, err := updater(toStorageOfflineSessions(o)) + if err != nil { + return err + } - newOfflineSessions := cli.fromStorageOfflineSessions(updated) - newOfflineSessions.ObjectMeta = o.ObjectMeta - return cli.put(resourceOfflineSessions, o.ObjectMeta.Name, newOfflineSessions) + newOfflineSessions := cli.fromStorageOfflineSessions(updated) + newOfflineSessions.ObjectMeta = o.ObjectMeta + return cli.put(resourceOfflineSessions, o.ObjectMeta.Name, newOfflineSessions) + }) } func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { @@ -539,13 +544,11 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro newKeys.ObjectMeta = keys.ObjectMeta err = cli.put(resourceKeys, keysName, newKeys) - if httpErr, ok := err.(httpError); ok { + if isKubernetesAPIConflictError(err) { // We need to tolerate conflicts here in case of HA mode. // Dex instances run keys rotation at the same time because they use SigningKey.nextRotation CR field as a trigger. - if httpErr.StatusCode() == http.StatusConflict { - cli.logger.Debugf("Keys rotation failed: %v. It is possible that keys have already been rotated by another dex instance.", err) - return errors.New("keys already rotated by another server instance") - } + cli.logger.Debugf("Keys rotation failed: %v. It is possible that keys have already been rotated by another dex instance.", err) + return errors.New("keys already rotated by another server instance") } return err @@ -569,20 +572,22 @@ func (cli *client) UpdateAuthRequest(id string, updater func(a storage.AuthReque } func (cli *client) UpdateConnector(id string, updater func(a storage.Connector) (storage.Connector, error)) error { - var c Connector - err := cli.get(resourceConnector, id, &c) - if err != nil { - return err - } + return retryOnConflict(context.TODO(), func() error { + var c Connector + err := cli.get(resourceConnector, id, &c) + if err != nil { + return err + } - updated, err := updater(toStorageConnector(c)) - if err != nil { - return err - } + updated, err := updater(toStorageConnector(c)) + if err != nil { + return err + } - newConn := cli.fromStorageConnector(updated) - newConn.ObjectMeta = c.ObjectMeta - return cli.put(resourceConnector, id, newConn) + newConn := cli.fromStorageConnector(updated) + newConn.ObjectMeta = c.ObjectMeta + return cli.put(resourceConnector, id, newConn) + }) } func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err error) { @@ -686,17 +691,58 @@ func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) } func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { - r, err := cli.getDeviceToken(deviceCode) - if err != nil { - return err - } - updated, err := updater(toStorageDeviceToken(r)) - if err != nil { - return err - } - updated.DeviceCode = deviceCode + return retryOnConflict(context.TODO(), func() error { + r, err := cli.getDeviceToken(deviceCode) + if err != nil { + return err + } + updated, err := updater(toStorageDeviceToken(r)) + if err != nil { + return err + } + updated.DeviceCode = deviceCode - newToken := cli.fromStorageDeviceToken(updated) - newToken.ObjectMeta = r.ObjectMeta - return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken) + newToken := cli.fromStorageDeviceToken(updated) + newToken.ObjectMeta = r.ObjectMeta + return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken) + }) +} + +func isKubernetesAPIConflictError(err error) bool { + if httpErr, ok := err.(httpError); ok { + if httpErr.StatusCode() == http.StatusConflict { + return true + } + } + return false +} + +func retryOnConflict(ctx context.Context, action func() error) error { + policy := []int{10, 20, 100, 300, 600} + + attempts := 0 + getNextStep := func() time.Duration { + step := policy[attempts] + return time.Duration(step*5+rand.Intn(step)) * time.Microsecond + } + + if err := action(); err == nil || !isKubernetesAPIConflictError(err) { + return err + } + + for { + select { + case <-time.After(getNextStep()): + if err := action(); err == nil || !isKubernetesAPIConflictError(err) { + return err + } + + attempts++ + if attempts >= 4 { + return errors.New("maximum timeout reached while retrying a conflicted request") + } + case <-ctx.Done(): + return errors.New("canceled") + } + } } diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index e2c77a62..42ba19a4 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -1,6 +1,7 @@ package kubernetes import ( + "context" "crypto/tls" "errors" "fmt" @@ -12,6 +13,7 @@ import ( "testing" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "sigs.k8s.io/testing_frameworks/integration" @@ -272,3 +274,43 @@ func newStatusCodesResponseTestClient(getResponseCode, actionResponseCode int) * }, } } + +func TestRetryOnConflict(t *testing.T) { + tests := []struct { + name string + action func() error + exactErr string + }{ + { + "Timeout reached", + func() error { err := httpErr{status: 409}; return error(&err) }, + "maximum timeout reached while retrying a conflicted request", + }, + { + "HTTP Error", + func() error { err := httpErr{status: 500}; return error(&err) }, + " Internal Server Error: response from server \"\"", + }, + { + "Error", + func() error { return errors.New("test") }, + "test", + }, + { + "OK", + func() error { return nil }, + "", + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + err := retryOnConflict(context.TODO(), testCase.action) + if testCase.exactErr != "" { + require.EqualError(t, err, testCase.exactErr) + } else { + require.NoError(t, err) + } + }) + } +}