postgres: refactor error handling to fix retrying

prior to this change, many of the functions in the ExecTx callback would
wrap the error before returning it. this made it impossible to check
for the error code.

instead, the error wrapping has been moved to be external to the
`ExecTx` callback, so that the error code can be checked and
serialization failures can be retried.
This commit is contained in:
Alex Suraci 2018-11-19 11:34:45 -05:00
parent 5d67da1472
commit 587081a643
2 changed files with 152 additions and 88 deletions

View File

@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
} }
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
return c.ExecTx(func(tx *trans) error { err := c.ExecTx(func(tx *trans) error {
r, err := getAuthRequest(tx, id) r, err := getAuthRequest(tx, id)
if err != nil { if err != nil {
return err return err
@ -144,6 +144,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec(` _, err = tx.Exec(`
update auth_request update auth_request
set set
@ -163,16 +164,26 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
a.ConnectorID, a.ConnectorData, a.ConnectorID, a.ConnectorData,
a.Expiry, r.ID, a.Expiry, r.ID,
) )
return err
})
if err != nil { if err != nil {
return fmt.Errorf("update auth request: %v", err) return fmt.Errorf("update auth request: %v", err)
} }
return nil
})
return nil
} }
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) { func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
return getAuthRequest(c, id) req, err := getAuthRequest(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.AuthRequest{}, storage.ErrNotFound
}
return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err)
}
return req, nil
} }
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
@ -192,10 +203,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
&a.ConnectorID, &a.ConnectorData, &a.Expiry, &a.ConnectorID, &a.ConnectorData, &a.Expiry,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { return a, err
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select auth request: %v", err)
} }
return a, nil return a, nil
} }
@ -269,20 +277,22 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
if c.alreadyExistsCheck(err) { if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists return storage.ErrAlreadyExists
} }
return fmt.Errorf("insert refresh_token: %v", err) return fmt.Errorf("insert refresh token: %v", err)
} }
return nil return nil
} }
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
return c.ExecTx(func(tx *trans) error { err := c.ExecTx(func(tx *trans) error {
r, err := getRefresh(tx, id) r, err := getRefresh(tx, id)
if err != nil { if err != nil {
return err return err
} }
if r, err = updater(r); err != nil { if r, err = updater(r); err != nil {
return err return err
} }
_, err = tx.Exec(` _, err = tx.Exec(`
update refresh_token update refresh_token
set set
@ -308,15 +318,25 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
r.ConnectorID, r.ConnectorData, r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, id, r.Token, r.CreatedAt, r.LastUsed, id,
) )
return err
})
if err != nil { if err != nil {
return fmt.Errorf("update refresh token: %v", err) return fmt.Errorf("update refresh token: %v", err)
} }
return nil return nil
})
} }
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
return getRefresh(c, id) req, err := getRefresh(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.RefreshToken{}, storage.ErrNotFound
}
return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err)
}
return req, nil
} }
func getRefresh(q querier, id string) (storage.RefreshToken, error) { func getRefresh(q querier, id string) (storage.RefreshToken, error) {
@ -342,14 +362,15 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
from refresh_token; from refresh_token;
`) `)
if err != nil { if err != nil {
return nil, fmt.Errorf("query: %v", err) return nil, fmt.Errorf("select refresh tokens: %v", err)
} }
var tokens []storage.RefreshToken var tokens []storage.RefreshToken
for rows.Next() { for rows.Next() {
r, err := scanRefresh(rows) r, err := scanRefresh(rows)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("scan refresh token: %s", err)
} }
tokens = append(tokens, r) tokens = append(tokens, r)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
@ -367,10 +388,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
&r.Token, &r.CreatedAt, &r.LastUsed, &r.Token, &r.CreatedAt, &r.LastUsed,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { return r, err
return r, storage.ErrNotFound
}
return r, fmt.Errorf("scan refresh_token: %v", err)
} }
return r, nil return r, nil
} }
@ -381,12 +399,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL // TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
// server. Test this, and consider adding a COUNT() command beforehand. // server. Test this, and consider adding a COUNT() command beforehand.
old, err := getKeys(tx) old, err := getKeys(tx)
if err != nil { if err == sql.ErrNoRows {
if err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
}
firstUpdate = true firstUpdate = true
old = storage.Keys{} old = storage.Keys{}
} else if err != nil {
return err
} }
nk, err := updater(old) nk, err := updater(old)
@ -405,7 +422,7 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
encoder(nk.SigningKeyPub), nk.NextRotation, encoder(nk.SigningKeyPub), nk.NextRotation,
) )
if err != nil { if err != nil {
return fmt.Errorf("insert: %v", err) return err
} }
} else { } else {
_, err = tx.Exec(` _, err = tx.Exec(`
@ -421,15 +438,24 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID, encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
) )
if err != nil { if err != nil {
return fmt.Errorf("update: %v", err) return err
} }
} }
return nil return nil
}) })
} }
func (c *conn) GetKeys() (keys storage.Keys, err error) { func (c *conn) GetKeys() (storage.Keys, error) {
return getKeys(c) keys, err := getKeys(c)
if err != nil {
if err == sql.ErrNoRows {
return storage.Keys{}, storage.ErrNotFound
}
return storage.Keys{}, fmt.Errorf("select keys: %s", err)
}
return keys, nil
} }
func getKeys(q querier) (keys storage.Keys, err error) { func getKeys(q querier) (keys storage.Keys, err error) {
@ -443,20 +469,18 @@ func getKeys(q querier) (keys storage.Keys, err error) {
decoder(&keys.SigningKeyPub), &keys.NextRotation, decoder(&keys.SigningKeyPub), &keys.NextRotation,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { return keys, err
return keys, storage.ErrNotFound
}
return keys, fmt.Errorf("query keys: %v", err)
} }
return keys, nil return keys, nil
} }
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
return c.ExecTx(func(tx *trans) error { err := c.ExecTx(func(tx *trans) error {
cli, err := getClient(tx, id) cli, err := getClient(tx, id)
if err != nil { if err != nil {
return err return err
} }
nc, err := updater(cli) nc, err := updater(cli)
if err != nil { if err != nil {
return err return err
@ -474,11 +498,13 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage
where id = $7; where id = $7;
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id, `, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
) )
return err
})
if err != nil { if err != nil {
return fmt.Errorf("update client: %v", err) return fmt.Errorf("update client: %v", err)
} }
return nil return nil
})
} }
func (c *conn) CreateClient(cli storage.Client) error { func (c *conn) CreateClient(cli storage.Client) error {
@ -509,7 +535,16 @@ func getClient(q querier, id string) (storage.Client, error) {
} }
func (c *conn) GetClient(id string) (storage.Client, error) { func (c *conn) GetClient(id string) (storage.Client, error) {
return getClient(c, id) client, err := getClient(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.Client{}, storage.ErrNotFound
}
return storage.Client{}, fmt.Errorf("select client: %v", err)
}
return client, nil
} }
func (c *conn) ListClients() ([]storage.Client, error) { func (c *conn) ListClients() ([]storage.Client, error) {
@ -525,12 +560,12 @@ func (c *conn) ListClients() ([]storage.Client, error) {
for rows.Next() { for rows.Next() {
cli, err := scanClient(rows) cli, err := scanClient(rows)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("scan client: %s", err)
} }
clients = append(clients, cli) clients = append(clients, cli)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return nil, fmt.Errorf("scan: %s", err)
} }
return clients, nil return clients, nil
} }
@ -541,10 +576,7 @@ func scanClient(s scanner) (cli storage.Client, err error) {
&cli.Public, &cli.Name, &cli.LogoURL, &cli.Public, &cli.Name, &cli.LogoURL,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { return cli, err
return cli, storage.ErrNotFound
}
return cli, fmt.Errorf("get client: %v", err)
} }
return cli, nil return cli, nil
} }
@ -571,7 +603,7 @@ func (c *conn) CreatePassword(p storage.Password) error {
} }
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error { func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
return c.ExecTx(func(tx *trans) error { err := c.ExecTx(func(tx *trans) error {
p, err := getPassword(tx, email) p, err := getPassword(tx, email)
if err != nil { if err != nil {
return err return err
@ -581,6 +613,7 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec(` _, err = tx.Exec(`
update password update password
set set
@ -589,15 +622,25 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
`, `,
np.Hash, np.Username, np.UserID, p.Email, np.Hash, np.Username, np.UserID, p.Email,
) )
return err
})
if err != nil { if err != nil {
return fmt.Errorf("update password: %v", err) return fmt.Errorf("update password: %v", err)
} }
return nil return nil
})
} }
func (c *conn) GetPassword(email string) (storage.Password, error) { func (c *conn) GetPassword(email string) (storage.Password, error) {
return getPassword(c, email) pass, err := getPassword(c, email)
if err != nil {
if err == sql.ErrNoRows {
return storage.Password{}, storage.ErrNotFound
}
return storage.Password{}, fmt.Errorf("get password: %s", err)
}
return pass, nil
} }
func getPassword(q querier, email string) (p storage.Password, err error) { func getPassword(q querier, email string) (p storage.Password, err error) {
@ -622,12 +665,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) {
for rows.Next() { for rows.Next() {
p, err := scanPassword(rows) p, err := scanPassword(rows)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("scan password: %s", err)
} }
passwords = append(passwords, p) passwords = append(passwords, p)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return nil, fmt.Errorf("scan: %s", err)
} }
return passwords, nil return passwords, nil
} }
@ -637,10 +680,7 @@ func scanPassword(s scanner) (p storage.Password, err error) {
&p.Email, &p.Hash, &p.Username, &p.UserID, &p.Email, &p.Hash, &p.Username, &p.UserID,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { return p, err
return p, storage.ErrNotFound
}
return p, fmt.Errorf("select password: %v", err)
} }
return p, nil return p, nil
} }
@ -666,7 +706,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
} }
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
return c.ExecTx(func(tx *trans) error { err := c.ExecTx(func(tx *trans) error {
s, err := getOfflineSessions(tx, userID, connID) s, err := getOfflineSessions(tx, userID, connID)
if err != nil { if err != nil {
return err return err
@ -676,6 +716,7 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec(` _, err = tx.Exec(`
update offline_session update offline_session
set set
@ -684,15 +725,26 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
`, `,
encoder(newSession.Refresh), s.UserID, s.ConnID, encoder(newSession.Refresh), s.UserID, s.ConnID,
) )
return err
})
if err != nil { if err != nil {
return fmt.Errorf("update offline session: %v", err) return fmt.Errorf("update offline session: %v", err)
} }
return nil return nil
})
} }
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) { func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
return getOfflineSessions(c, userID, connID) sessions, err := getOfflineSessions(c, userID, connID)
if err != nil {
if err == sql.ErrNoRows {
return storage.OfflineSessions{}, storage.ErrNotFound
}
return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err)
}
return sessions, nil
} }
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
@ -709,10 +761,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
&o.UserID, &o.ConnID, decoder(&o.Refresh), &o.UserID, &o.ConnID, decoder(&o.Refresh),
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { return o, err
return o, storage.ErrNotFound
}
return o, fmt.Errorf("select offline session: %v", err)
} }
return o, nil return o, nil
} }
@ -738,7 +787,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error {
} }
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error { func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
return c.ExecTx(func(tx *trans) error { err := c.ExecTx(func(tx *trans) error {
connector, err := getConnector(tx, id) connector, err := getConnector(tx, id)
if err != nil { if err != nil {
return err return err
@ -748,6 +797,7 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec(` _, err = tx.Exec(`
update connector update connector
set set
@ -759,15 +809,26 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
`, `,
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID, newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
) )
return err
})
if err != nil { if err != nil {
return fmt.Errorf("update connector: %v", err) return fmt.Errorf("update connector: %v", err)
} }
return nil return nil
})
} }
func (c *conn) GetConnector(id string) (storage.Connector, error) { func (c *conn) GetConnector(id string) (storage.Connector, error) {
return getConnector(c, id) connector, err := getConnector(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.Connector{}, storage.ErrNotFound
}
return storage.Connector{}, fmt.Errorf("get connector: %s", err)
}
return connector, nil
} }
func getConnector(q querier, id string) (storage.Connector, error) { func getConnector(q querier, id string) (storage.Connector, error) {
@ -784,10 +845,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) {
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config, &c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { return c, err
return c, storage.ErrNotFound
}
return c, fmt.Errorf("select connector: %v", err)
} }
return c, nil return c, nil
} }
@ -805,12 +863,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) {
for rows.Next() { for rows.Next() {
conn, err := scanConnector(rows) conn, err := scanConnector(rows)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("scan connector: %s", err)
} }
connectors = append(connectors, conn) connectors = append(connectors, conn)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, err return nil, fmt.Errorf("scan: %s", err)
} }
return connectors, nil return connectors, nil
} }

View File

@ -44,13 +44,14 @@ var (
// The "github.com/lib/pq" driver is the default flavor. All others are // The "github.com/lib/pq" driver is the default flavor. All others are
// translations of this. // translations of this.
flavorPostgres = flavor{ flavorPostgres = flavor{
// The default behavior for Postgres transactions is consistent reads, not consistent writes. // The default behavior for Postgres transactions is consistent reads, not
// For each transaction opened, ensure it has the correct isolation level. // consistent writes. For each transaction opened, ensure it has the
// correct isolation level.
// //
// See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html // See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html
// //
// NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a // Be careful not to wrap sql errors in the callback 'fn', otherwise
// session level didn't work for some edge cases. Might be something worth exploring. // serialization failures will not be detected and retried.
executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error { executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
defer cancel() defer cancel()
@ -66,6 +67,11 @@ var (
} }
if err := fn(tx); err != nil { if err := fn(tx); err != nil {
if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "serialization_failure" {
// serialization error; retry
continue
}
return err return err
} }