Revert "retry on serialization errors"

This commit is contained in:
Stephan Renatus
2018-11-29 08:24:13 +01:00
committed by GitHub
parent f3acec0b1b
commit 8f3cca7ba4
20 changed files with 463 additions and 1355 deletions

View File

@@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
}
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
err := c.ExecTx(func(tx *trans) error {
return c.ExecTx(func(tx *trans) error {
r, err := getAuthRequest(tx, id)
if err != nil {
return err
@@ -144,7 +144,6 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
if err != nil {
return err
}
_, err = tx.Exec(`
update auth_request
set
@@ -164,31 +163,21 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
a.ConnectorID, a.ConnectorData,
a.Expiry, r.ID,
)
return err
if err != nil {
return fmt.Errorf("update auth request: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("update auth request: %v", err)
}
return nil
}
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
req, err := getAuthRequest(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.AuthRequest{}, storage.ErrNotFound
}
return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err)
}
return req, nil
return getAuthRequest(c, id)
}
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
err = q.QueryRow(`
select
select
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_email, claims_email_verified,
@@ -203,7 +192,10 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
if err != nil {
return a, err
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select auth request: %v", err)
}
return a, nil
}
@@ -277,22 +269,20 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert refresh token: %v", err)
return fmt.Errorf("insert refresh_token: %v", err)
}
return nil
}
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
err := c.ExecTx(func(tx *trans) error {
return c.ExecTx(func(tx *trans) error {
r, err := getRefresh(tx, id)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update refresh_token
set
@@ -318,25 +308,15 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, id,
)
return err
if err != nil {
return fmt.Errorf("update refresh token: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("update refresh token: %v", err)
}
return nil
}
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
req, err := getRefresh(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.RefreshToken{}, storage.ErrNotFound
}
return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err)
}
return req, nil
return getRefresh(c, id)
}
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
@@ -362,15 +342,14 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
from refresh_token;
`)
if err != nil {
return nil, fmt.Errorf("select refresh tokens: %v", err)
return nil, fmt.Errorf("query: %v", err)
}
var tokens []storage.RefreshToken
for rows.Next() {
r, err := scanRefresh(rows)
if err != nil {
return nil, fmt.Errorf("scan refresh token: %s", err)
return nil, err
}
tokens = append(tokens, r)
}
if err := rows.Err(); err != nil {
@@ -388,7 +367,10 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
&r.Token, &r.CreatedAt, &r.LastUsed,
)
if err != nil {
return r, err
if err == sql.ErrNoRows {
return r, storage.ErrNotFound
}
return r, fmt.Errorf("scan refresh_token: %v", err)
}
return r, nil
}
@@ -399,11 +381,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
// server. Test this, and consider adding a COUNT() command beforehand.
old, err := getKeys(tx)
if err == sql.ErrNoRows {
if err != nil {
if err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
}
firstUpdate = true
old = storage.Keys{}
} else if err != nil {
return err
}
nk, err := updater(old)
@@ -422,12 +405,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
encoder(nk.SigningKeyPub), nk.NextRotation,
)
if err != nil {
return err
return fmt.Errorf("insert: %v", err)
}
} else {
_, err = tx.Exec(`
update keys
set
set
verification_keys = $1,
signing_key = $2,
signing_key_pub = $3,
@@ -438,24 +421,15 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
)
if err != nil {
return err
return fmt.Errorf("update: %v", err)
}
}
return nil
})
}
func (c *conn) GetKeys() (storage.Keys, error) {
keys, err := getKeys(c)
if err != nil {
if err == sql.ErrNoRows {
return storage.Keys{}, storage.ErrNotFound
}
return storage.Keys{}, fmt.Errorf("select keys: %s", err)
}
return keys, nil
func (c *conn) GetKeys() (keys storage.Keys, err error) {
return getKeys(c)
}
func getKeys(q querier) (keys storage.Keys, err error) {
@@ -469,18 +443,20 @@ func getKeys(q querier) (keys storage.Keys, err error) {
decoder(&keys.SigningKeyPub), &keys.NextRotation,
)
if err != nil {
return keys, err
if err == sql.ErrNoRows {
return keys, storage.ErrNotFound
}
return keys, fmt.Errorf("query keys: %v", err)
}
return keys, nil
}
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
err := c.ExecTx(func(tx *trans) error {
return c.ExecTx(func(tx *trans) error {
cli, err := getClient(tx, id)
if err != nil {
return err
}
nc, err := updater(cli)
if err != nil {
return err
@@ -498,13 +474,11 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage
where id = $7;
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
)
return err
if err != nil {
return fmt.Errorf("update client: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("update client: %v", err)
}
return nil
}
func (c *conn) CreateClient(cli storage.Client) error {
@@ -535,16 +509,7 @@ func getClient(q querier, id string) (storage.Client, error) {
}
func (c *conn) GetClient(id string) (storage.Client, error) {
client, err := getClient(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.Client{}, storage.ErrNotFound
}
return storage.Client{}, fmt.Errorf("select client: %v", err)
}
return client, nil
return getClient(c, id)
}
func (c *conn) ListClients() ([]storage.Client, error) {
@@ -560,12 +525,12 @@ func (c *conn) ListClients() ([]storage.Client, error) {
for rows.Next() {
cli, err := scanClient(rows)
if err != nil {
return nil, fmt.Errorf("scan client: %s", err)
return nil, err
}
clients = append(clients, cli)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scan: %s", err)
return nil, err
}
return clients, nil
}
@@ -576,7 +541,10 @@ func scanClient(s scanner) (cli storage.Client, err error) {
&cli.Public, &cli.Name, &cli.LogoURL,
)
if err != nil {
return cli, err
if err == sql.ErrNoRows {
return cli, storage.ErrNotFound
}
return cli, fmt.Errorf("get client: %v", err)
}
return cli, nil
}
@@ -603,7 +571,7 @@ func (c *conn) CreatePassword(p storage.Password) error {
}
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
err := c.ExecTx(func(tx *trans) error {
return c.ExecTx(func(tx *trans) error {
p, err := getPassword(tx, email)
if err != nil {
return err
@@ -613,7 +581,6 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
if err != nil {
return err
}
_, err = tx.Exec(`
update password
set
@@ -622,25 +589,15 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
`,
np.Hash, np.Username, np.UserID, p.Email,
)
return err
if err != nil {
return fmt.Errorf("update password: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("update password: %v", err)
}
return nil
}
func (c *conn) GetPassword(email string) (storage.Password, error) {
pass, err := getPassword(c, email)
if err != nil {
if err == sql.ErrNoRows {
return storage.Password{}, storage.ErrNotFound
}
return storage.Password{}, fmt.Errorf("get password: %s", err)
}
return pass, nil
return getPassword(c, email)
}
func getPassword(q querier, email string) (p storage.Password, err error) {
@@ -665,12 +622,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) {
for rows.Next() {
p, err := scanPassword(rows)
if err != nil {
return nil, fmt.Errorf("scan password: %s", err)
return nil, err
}
passwords = append(passwords, p)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scan: %s", err)
return nil, err
}
return passwords, nil
}
@@ -680,7 +637,10 @@ func scanPassword(s scanner) (p storage.Password, err error) {
&p.Email, &p.Hash, &p.Username, &p.UserID,
)
if err != nil {
return p, err
if err == sql.ErrNoRows {
return p, storage.ErrNotFound
}
return p, fmt.Errorf("select password: %v", err)
}
return p, nil
}
@@ -706,7 +666,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
}
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
err := c.ExecTx(func(tx *trans) error {
return c.ExecTx(func(tx *trans) error {
s, err := getOfflineSessions(tx, userID, connID)
if err != nil {
return err
@@ -716,7 +676,6 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
if err != nil {
return err
}
_, err = tx.Exec(`
update offline_session
set
@@ -725,26 +684,15 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
`,
encoder(newSession.Refresh), s.UserID, s.ConnID,
)
return err
if err != nil {
return fmt.Errorf("update offline session: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("update offline session: %v", err)
}
return nil
}
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
sessions, err := getOfflineSessions(c, userID, connID)
if err != nil {
if err == sql.ErrNoRows {
return storage.OfflineSessions{}, storage.ErrNotFound
}
return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err)
}
return sessions, nil
return getOfflineSessions(c, userID, connID)
}
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
@@ -761,7 +709,10 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
&o.UserID, &o.ConnID, decoder(&o.Refresh),
)
if err != nil {
return o, err
if err == sql.ErrNoRows {
return o, storage.ErrNotFound
}
return o, fmt.Errorf("select offline session: %v", err)
}
return o, nil
}
@@ -787,7 +738,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error {
}
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
err := c.ExecTx(func(tx *trans) error {
return c.ExecTx(func(tx *trans) error {
connector, err := getConnector(tx, id)
if err != nil {
return err
@@ -797,10 +748,9 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
if err != nil {
return err
}
_, err = tx.Exec(`
update connector
set
set
type = $1,
name = $2,
resource_version = $3,
@@ -809,26 +759,15 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
`,
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
)
return err
if err != nil {
return fmt.Errorf("update connector: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("update connector: %v", err)
}
return nil
}
func (c *conn) GetConnector(id string) (storage.Connector, error) {
connector, err := getConnector(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.Connector{}, storage.ErrNotFound
}
return storage.Connector{}, fmt.Errorf("get connector: %s", err)
}
return connector, nil
return getConnector(c, id)
}
func getConnector(q querier, id string) (storage.Connector, error) {
@@ -845,7 +784,10 @@ func scanConnector(s scanner) (c storage.Connector, err error) {
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
)
if err != nil {
return c, err
if err == sql.ErrNoRows {
return c, storage.ErrNotFound
}
return c, fmt.Errorf("select connector: %v", err)
}
return c, nil
}
@@ -863,12 +805,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) {
for rows.Next() {
conn, err := scanConnector(rows)
if err != nil {
return nil, fmt.Errorf("scan connector: %s", err)
return nil, err
}
connectors = append(connectors, conn)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scan: %s", err)
return nil, err
}
return connectors, nil
}