Merge pull request #1342 from concourse/pr/retry-on-pg-serialization-error

retry on serialization errors
This commit is contained in:
Stephan Renatus
2018-11-21 10:29:46 +01:00
committed by GitHub
20 changed files with 1355 additions and 463 deletions

View File

@@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
}
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
r, err := getAuthRequest(tx, id)
if err != nil {
return err
@@ -144,6 +144,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
if err != nil {
return err
}
_, err = tx.Exec(`
update auth_request
set
@@ -163,21 +164,31 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
a.ConnectorID, a.ConnectorData,
a.Expiry, r.ID,
)
if err != nil {
return fmt.Errorf("update auth request: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update auth request: %v", err)
}
return nil
}
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
return getAuthRequest(c, id)
req, err := getAuthRequest(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.AuthRequest{}, storage.ErrNotFound
}
return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err)
}
return req, nil
}
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
err = q.QueryRow(`
select
select
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_email, claims_email_verified,
@@ -192,10 +203,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select auth request: %v", err)
return a, err
}
return a, nil
}
@@ -269,20 +277,22 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert refresh_token: %v", err)
return fmt.Errorf("insert refresh token: %v", err)
}
return nil
}
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
r, err := getRefresh(tx, id)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update refresh_token
set
@@ -308,15 +318,25 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, id,
)
if err != nil {
return fmt.Errorf("update refresh token: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update refresh token: %v", err)
}
return nil
}
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
return getRefresh(c, id)
req, err := getRefresh(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.RefreshToken{}, storage.ErrNotFound
}
return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err)
}
return req, nil
}
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
@@ -342,14 +362,15 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
from refresh_token;
`)
if err != nil {
return nil, fmt.Errorf("query: %v", err)
return nil, fmt.Errorf("select refresh tokens: %v", err)
}
var tokens []storage.RefreshToken
for rows.Next() {
r, err := scanRefresh(rows)
if err != nil {
return nil, err
return nil, fmt.Errorf("scan refresh token: %s", err)
}
tokens = append(tokens, r)
}
if err := rows.Err(); err != nil {
@@ -367,10 +388,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
&r.Token, &r.CreatedAt, &r.LastUsed,
)
if err != nil {
if err == sql.ErrNoRows {
return r, storage.ErrNotFound
}
return r, fmt.Errorf("scan refresh_token: %v", err)
return r, err
}
return r, nil
}
@@ -381,12 +399,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
// server. Test this, and consider adding a COUNT() command beforehand.
old, err := getKeys(tx)
if err != nil {
if err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
}
if err == sql.ErrNoRows {
firstUpdate = true
old = storage.Keys{}
} else if err != nil {
return err
}
nk, err := updater(old)
@@ -405,12 +422,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
encoder(nk.SigningKeyPub), nk.NextRotation,
)
if err != nil {
return fmt.Errorf("insert: %v", err)
return err
}
} else {
_, err = tx.Exec(`
update keys
set
set
verification_keys = $1,
signing_key = $2,
signing_key_pub = $3,
@@ -421,15 +438,24 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
)
if err != nil {
return fmt.Errorf("update: %v", err)
return err
}
}
return nil
})
}
func (c *conn) GetKeys() (keys storage.Keys, err error) {
return getKeys(c)
func (c *conn) GetKeys() (storage.Keys, error) {
keys, err := getKeys(c)
if err != nil {
if err == sql.ErrNoRows {
return storage.Keys{}, storage.ErrNotFound
}
return storage.Keys{}, fmt.Errorf("select keys: %s", err)
}
return keys, nil
}
func getKeys(q querier) (keys storage.Keys, err error) {
@@ -443,20 +469,18 @@ func getKeys(q querier) (keys storage.Keys, err error) {
decoder(&keys.SigningKeyPub), &keys.NextRotation,
)
if err != nil {
if err == sql.ErrNoRows {
return keys, storage.ErrNotFound
}
return keys, fmt.Errorf("query keys: %v", err)
return keys, err
}
return keys, nil
}
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
cli, err := getClient(tx, id)
if err != nil {
return err
}
nc, err := updater(cli)
if err != nil {
return err
@@ -474,11 +498,13 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage
where id = $7;
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
)
if err != nil {
return fmt.Errorf("update client: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update client: %v", err)
}
return nil
}
func (c *conn) CreateClient(cli storage.Client) error {
@@ -509,7 +535,16 @@ func getClient(q querier, id string) (storage.Client, error) {
}
func (c *conn) GetClient(id string) (storage.Client, error) {
return getClient(c, id)
client, err := getClient(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.Client{}, storage.ErrNotFound
}
return storage.Client{}, fmt.Errorf("select client: %v", err)
}
return client, nil
}
func (c *conn) ListClients() ([]storage.Client, error) {
@@ -525,12 +560,12 @@ func (c *conn) ListClients() ([]storage.Client, error) {
for rows.Next() {
cli, err := scanClient(rows)
if err != nil {
return nil, err
return nil, fmt.Errorf("scan client: %s", err)
}
clients = append(clients, cli)
}
if err := rows.Err(); err != nil {
return nil, err
return nil, fmt.Errorf("scan: %s", err)
}
return clients, nil
}
@@ -541,10 +576,7 @@ func scanClient(s scanner) (cli storage.Client, err error) {
&cli.Public, &cli.Name, &cli.LogoURL,
)
if err != nil {
if err == sql.ErrNoRows {
return cli, storage.ErrNotFound
}
return cli, fmt.Errorf("get client: %v", err)
return cli, err
}
return cli, nil
}
@@ -571,7 +603,7 @@ func (c *conn) CreatePassword(p storage.Password) error {
}
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
p, err := getPassword(tx, email)
if err != nil {
return err
@@ -581,6 +613,7 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
if err != nil {
return err
}
_, err = tx.Exec(`
update password
set
@@ -589,15 +622,25 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
`,
np.Hash, np.Username, np.UserID, p.Email,
)
if err != nil {
return fmt.Errorf("update password: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update password: %v", err)
}
return nil
}
func (c *conn) GetPassword(email string) (storage.Password, error) {
return getPassword(c, email)
pass, err := getPassword(c, email)
if err != nil {
if err == sql.ErrNoRows {
return storage.Password{}, storage.ErrNotFound
}
return storage.Password{}, fmt.Errorf("get password: %s", err)
}
return pass, nil
}
func getPassword(q querier, email string) (p storage.Password, err error) {
@@ -622,12 +665,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) {
for rows.Next() {
p, err := scanPassword(rows)
if err != nil {
return nil, err
return nil, fmt.Errorf("scan password: %s", err)
}
passwords = append(passwords, p)
}
if err := rows.Err(); err != nil {
return nil, err
return nil, fmt.Errorf("scan: %s", err)
}
return passwords, nil
}
@@ -637,10 +680,7 @@ func scanPassword(s scanner) (p storage.Password, err error) {
&p.Email, &p.Hash, &p.Username, &p.UserID,
)
if err != nil {
if err == sql.ErrNoRows {
return p, storage.ErrNotFound
}
return p, fmt.Errorf("select password: %v", err)
return p, err
}
return p, nil
}
@@ -666,7 +706,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
}
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
s, err := getOfflineSessions(tx, userID, connID)
if err != nil {
return err
@@ -676,6 +716,7 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
if err != nil {
return err
}
_, err = tx.Exec(`
update offline_session
set
@@ -684,15 +725,26 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
`,
encoder(newSession.Refresh), s.UserID, s.ConnID,
)
if err != nil {
return fmt.Errorf("update offline session: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update offline session: %v", err)
}
return nil
}
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
return getOfflineSessions(c, userID, connID)
sessions, err := getOfflineSessions(c, userID, connID)
if err != nil {
if err == sql.ErrNoRows {
return storage.OfflineSessions{}, storage.ErrNotFound
}
return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err)
}
return sessions, nil
}
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
@@ -709,10 +761,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
&o.UserID, &o.ConnID, decoder(&o.Refresh),
)
if err != nil {
if err == sql.ErrNoRows {
return o, storage.ErrNotFound
}
return o, fmt.Errorf("select offline session: %v", err)
return o, err
}
return o, nil
}
@@ -738,7 +787,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error {
}
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
return c.ExecTx(func(tx *trans) error {
err := c.ExecTx(func(tx *trans) error {
connector, err := getConnector(tx, id)
if err != nil {
return err
@@ -748,9 +797,10 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
if err != nil {
return err
}
_, err = tx.Exec(`
update connector
set
set
type = $1,
name = $2,
resource_version = $3,
@@ -759,15 +809,26 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
`,
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
)
if err != nil {
return fmt.Errorf("update connector: %v", err)
}
return nil
return err
})
if err != nil {
return fmt.Errorf("update connector: %v", err)
}
return nil
}
func (c *conn) GetConnector(id string) (storage.Connector, error) {
return getConnector(c, id)
connector, err := getConnector(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.Connector{}, storage.ErrNotFound
}
return storage.Connector{}, fmt.Errorf("get connector: %s", err)
}
return connector, nil
}
func getConnector(q querier, id string) (storage.Connector, error) {
@@ -784,10 +845,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) {
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
)
if err != nil {
if err == sql.ErrNoRows {
return c, storage.ErrNotFound
}
return c, fmt.Errorf("select connector: %v", err)
return c, err
}
return c, nil
}
@@ -805,12 +863,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) {
for rows.Next() {
conn, err := scanConnector(rows)
if err != nil {
return nil, err
return nil, fmt.Errorf("scan connector: %s", err)
}
connectors = append(connectors, conn)
}
if err := rows.Err(); err != nil {
return nil, err
return nil, fmt.Errorf("scan: %s", err)
}
return connectors, nil
}

View File

@@ -2,14 +2,15 @@
package sql
import (
"context"
"database/sql"
"regexp"
"time"
"github.com/lib/pq"
"github.com/sirupsen/logrus"
// import third party drivers
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
)
@@ -39,31 +40,66 @@ func matchLiteral(s string) *regexp.Regexp {
return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`)
}
// Detect a serialization failure, which should trigger retrying the
// transaction according to PostgreSQL docs:
//
// https://www.postgresql.org/docs/current/transaction-iso.html#XACT-SERIALIZABLE
//
// "applications using this level must be prepared to retry transactions due to
// serialization failures"
func isRetryableSerializationFailure(err error) bool {
if pqErr, ok := err.(*pq.Error); ok {
return pqErr.Code.Name() == "serialization_failure"
}
return false
}
var (
// The "github.com/lib/pq" driver is the default flavor. All others are
// translations of this.
flavorPostgres = flavor{
// The default behavior for Postgres transactions is consistent reads, not consistent writes.
// For each transaction opened, ensure it has the correct isolation level.
// The default behavior for Postgres transactions is consistent reads, not
// consistent writes. For each transaction opened, ensure it has the
// correct isolation level.
//
// See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html
//
// NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a
// session level didn't work for some edge cases. Might be something worth exploring.
// Be careful not to wrap sql errors in the callback 'fn', otherwise
// serialization failures will not be detected and retried.
executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
if _, err := tx.Exec(`SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;`); err != nil {
return err
opts := &sql.TxOptions{
Isolation: sql.LevelSerializable,
}
if err := fn(tx); err != nil {
return err
for {
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return err
}
if err := fn(tx); err != nil {
if isRetryableSerializationFailure(err) {
continue
}
return err
}
err = tx.Commit()
if err != nil {
if isRetryableSerializationFailure(err) {
continue
}
return err
}
return nil
}
return tx.Commit()
},
supportsTimezones: true,