Revert "retry on serialization errors"
This commit is contained in:
@@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
||||
}
|
||||
|
||||
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
r, err := getAuthRequest(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -144,7 +144,6 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update auth_request
|
||||
set
|
||||
@@ -164,31 +163,21 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
|
||||
a.ConnectorID, a.ConnectorData,
|
||||
a.Expiry, r.ID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("update auth request: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update auth request: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||
req, err := getAuthRequest(c, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.AuthRequest{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
return getAuthRequest(c, id)
|
||||
}
|
||||
|
||||
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
|
||||
err = q.QueryRow(`
|
||||
select
|
||||
select
|
||||
id, client_id, response_types, scopes, redirect_uri, nonce, state,
|
||||
force_approval_prompt, logged_in,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
@@ -203,7 +192,10 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
|
||||
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
return a, err
|
||||
if err == sql.ErrNoRows {
|
||||
return a, storage.ErrNotFound
|
||||
}
|
||||
return a, fmt.Errorf("select auth request: %v", err)
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
@@ -277,22 +269,20 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
return fmt.Errorf("insert refresh token: %v", err)
|
||||
return fmt.Errorf("insert refresh_token: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
r, err := getRefresh(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r, err = updater(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update refresh_token
|
||||
set
|
||||
@@ -318,25 +308,15 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
|
||||
r.ConnectorID, r.ConnectorData,
|
||||
r.Token, r.CreatedAt, r.LastUsed, id,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("update refresh token: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update refresh token: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||
req, err := getRefresh(c, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.RefreshToken{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err)
|
||||
}
|
||||
|
||||
return req, nil
|
||||
return getRefresh(c, id)
|
||||
}
|
||||
|
||||
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
|
||||
@@ -362,15 +342,14 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||
from refresh_token;
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("select refresh tokens: %v", err)
|
||||
return nil, fmt.Errorf("query: %v", err)
|
||||
}
|
||||
var tokens []storage.RefreshToken
|
||||
for rows.Next() {
|
||||
r, err := scanRefresh(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan refresh token: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokens = append(tokens, r)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
@@ -388,7 +367,10 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
|
||||
&r.Token, &r.CreatedAt, &r.LastUsed,
|
||||
)
|
||||
if err != nil {
|
||||
return r, err
|
||||
if err == sql.ErrNoRows {
|
||||
return r, storage.ErrNotFound
|
||||
}
|
||||
return r, fmt.Errorf("scan refresh_token: %v", err)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
@@ -399,11 +381,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
|
||||
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
|
||||
// server. Test this, and consider adding a COUNT() command beforehand.
|
||||
old, err := getKeys(tx)
|
||||
if err == sql.ErrNoRows {
|
||||
if err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
return fmt.Errorf("get keys: %v", err)
|
||||
}
|
||||
firstUpdate = true
|
||||
old = storage.Keys{}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nk, err := updater(old)
|
||||
@@ -422,12 +405,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
|
||||
encoder(nk.SigningKeyPub), nk.NextRotation,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("insert: %v", err)
|
||||
}
|
||||
} else {
|
||||
_, err = tx.Exec(`
|
||||
update keys
|
||||
set
|
||||
set
|
||||
verification_keys = $1,
|
||||
signing_key = $2,
|
||||
signing_key_pub = $3,
|
||||
@@ -438,24 +421,15 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
|
||||
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("update: %v", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetKeys() (storage.Keys, error) {
|
||||
keys, err := getKeys(c)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.Keys{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.Keys{}, fmt.Errorf("select keys: %s", err)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
func (c *conn) GetKeys() (keys storage.Keys, err error) {
|
||||
return getKeys(c)
|
||||
}
|
||||
|
||||
func getKeys(q querier) (keys storage.Keys, err error) {
|
||||
@@ -469,18 +443,20 @@ func getKeys(q querier) (keys storage.Keys, err error) {
|
||||
decoder(&keys.SigningKeyPub), &keys.NextRotation,
|
||||
)
|
||||
if err != nil {
|
||||
return keys, err
|
||||
if err == sql.ErrNoRows {
|
||||
return keys, storage.ErrNotFound
|
||||
}
|
||||
return keys, fmt.Errorf("query keys: %v", err)
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
cli, err := getClient(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nc, err := updater(cli)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -498,13 +474,11 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage
|
||||
where id = $7;
|
||||
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("update client: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update client: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateClient(cli storage.Client) error {
|
||||
@@ -535,16 +509,7 @@ func getClient(q querier, id string) (storage.Client, error) {
|
||||
}
|
||||
|
||||
func (c *conn) GetClient(id string) (storage.Client, error) {
|
||||
client, err := getClient(c, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.Client{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.Client{}, fmt.Errorf("select client: %v", err)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
return getClient(c, id)
|
||||
}
|
||||
|
||||
func (c *conn) ListClients() ([]storage.Client, error) {
|
||||
@@ -560,12 +525,12 @@ func (c *conn) ListClients() ([]storage.Client, error) {
|
||||
for rows.Next() {
|
||||
cli, err := scanClient(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan client: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
clients = append(clients, cli)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("scan: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
return clients, nil
|
||||
}
|
||||
@@ -576,7 +541,10 @@ func scanClient(s scanner) (cli storage.Client, err error) {
|
||||
&cli.Public, &cli.Name, &cli.LogoURL,
|
||||
)
|
||||
if err != nil {
|
||||
return cli, err
|
||||
if err == sql.ErrNoRows {
|
||||
return cli, storage.ErrNotFound
|
||||
}
|
||||
return cli, fmt.Errorf("get client: %v", err)
|
||||
}
|
||||
return cli, nil
|
||||
}
|
||||
@@ -603,7 +571,7 @@ func (c *conn) CreatePassword(p storage.Password) error {
|
||||
}
|
||||
|
||||
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
p, err := getPassword(tx, email)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -613,7 +581,6 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update password
|
||||
set
|
||||
@@ -622,25 +589,15 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st
|
||||
`,
|
||||
np.Hash, np.Username, np.UserID, p.Email,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("update password: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update password: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetPassword(email string) (storage.Password, error) {
|
||||
pass, err := getPassword(c, email)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.Password{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.Password{}, fmt.Errorf("get password: %s", err)
|
||||
}
|
||||
|
||||
return pass, nil
|
||||
return getPassword(c, email)
|
||||
}
|
||||
|
||||
func getPassword(q querier, email string) (p storage.Password, err error) {
|
||||
@@ -665,12 +622,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) {
|
||||
for rows.Next() {
|
||||
p, err := scanPassword(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan password: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
passwords = append(passwords, p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("scan: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
return passwords, nil
|
||||
}
|
||||
@@ -680,7 +637,10 @@ func scanPassword(s scanner) (p storage.Password, err error) {
|
||||
&p.Email, &p.Hash, &p.Username, &p.UserID,
|
||||
)
|
||||
if err != nil {
|
||||
return p, err
|
||||
if err == sql.ErrNoRows {
|
||||
return p, storage.ErrNotFound
|
||||
}
|
||||
return p, fmt.Errorf("select password: %v", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
@@ -706,7 +666,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||
}
|
||||
|
||||
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
s, err := getOfflineSessions(tx, userID, connID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -716,7 +676,6 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update offline_session
|
||||
set
|
||||
@@ -725,26 +684,15 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
|
||||
`,
|
||||
encoder(newSession.Refresh), s.UserID, s.ConnID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("update offline session: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update offline session: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
|
||||
sessions, err := getOfflineSessions(c, userID, connID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.OfflineSessions{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err)
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
return getOfflineSessions(c, userID, connID)
|
||||
}
|
||||
|
||||
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
|
||||
@@ -761,7 +709,10 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
|
||||
&o.UserID, &o.ConnID, decoder(&o.Refresh),
|
||||
)
|
||||
if err != nil {
|
||||
return o, err
|
||||
if err == sql.ErrNoRows {
|
||||
return o, storage.ErrNotFound
|
||||
}
|
||||
return o, fmt.Errorf("select offline session: %v", err)
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
@@ -787,7 +738,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error {
|
||||
}
|
||||
|
||||
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
|
||||
err := c.ExecTx(func(tx *trans) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
connector, err := getConnector(tx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -797,10 +748,9 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(`
|
||||
update connector
|
||||
set
|
||||
set
|
||||
type = $1,
|
||||
name = $2,
|
||||
resource_version = $3,
|
||||
@@ -809,26 +759,15 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
|
||||
`,
|
||||
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
|
||||
)
|
||||
return err
|
||||
if err != nil {
|
||||
return fmt.Errorf("update connector: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update connector: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetConnector(id string) (storage.Connector, error) {
|
||||
connector, err := getConnector(c, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return storage.Connector{}, storage.ErrNotFound
|
||||
}
|
||||
|
||||
return storage.Connector{}, fmt.Errorf("get connector: %s", err)
|
||||
}
|
||||
|
||||
return connector, nil
|
||||
return getConnector(c, id)
|
||||
}
|
||||
|
||||
func getConnector(q querier, id string) (storage.Connector, error) {
|
||||
@@ -845,7 +784,10 @@ func scanConnector(s scanner) (c storage.Connector, err error) {
|
||||
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
|
||||
)
|
||||
if err != nil {
|
||||
return c, err
|
||||
if err == sql.ErrNoRows {
|
||||
return c, storage.ErrNotFound
|
||||
}
|
||||
return c, fmt.Errorf("select connector: %v", err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
@@ -863,12 +805,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) {
|
||||
for rows.Next() {
|
||||
conn, err := scanConnector(rows)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scan connector: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
connectors = append(connectors, conn)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("scan: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
return connectors, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user