This repository has been archived on 2023-08-14. You can view files and clone it, but cannot push or open issues or pull requests.
dex/storage/sql/crud.go
Alex Suraci 587081a643 postgres: refactor error handling to fix retrying
prior to this change, many of the functions in the ExecTx callback would
wrap the error before returning it. this made it impossible to check
for the error code.

instead, the error wrapping has been moved to be external to the
`ExecTx` callback, so that the error code can be checked and
serialization failures can be retried.
2018-11-20 10:50:55 -05:00

921 lines
22 KiB
Go

package sql
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/dexidp/dex/storage"
)
// TODO(ericchiang): The update, insert, and select methods queries are all
// very repetitive. Consider creating them programmatically.
// keysRowID is the ID of the only row we expect to populate the "keys" table.
const keysRowID = "keys"
// encoder wraps the underlying value in a JSON marshaler which is automatically
// called by the database/sql package.
//
// s := []string{"planes", "bears"}
// err := db.Exec(`insert into t1 (id, things) values (1, $1)`, encoder(s))
// if err != nil {
// // handle error
// }
//
// var r []byte
// err = db.QueryRow(`select things from t1 where id = 1;`).Scan(&r)
// if err != nil {
// // handle error
// }
// fmt.Printf("%s\n", r) // ["planes","bears"]
//
func encoder(i interface{}) driver.Valuer {
return jsonEncoder{i}
}
// decoder wraps the underlying value in a JSON unmarshaler which can then be passed
// to a database Scan() method.
func decoder(i interface{}) sql.Scanner {
return jsonDecoder{i}
}
type jsonEncoder struct {
i interface{}
}
func (j jsonEncoder) Value() (driver.Value, error) {
b, err := json.Marshal(j.i)
if err != nil {
return nil, fmt.Errorf("marshal: %v", err)
}
return b, nil
}
type jsonDecoder struct {
i interface{}
}
func (j jsonDecoder) Scan(dest interface{}) error {
if dest == nil {
return errors.New("nil value")
}
b, ok := dest.([]byte)
if !ok {
return fmt.Errorf("expected []byte got %T", dest)
}
if err := json.Unmarshal(b, &j.i); err != nil {
return fmt.Errorf("unmarshal: %v", err)
}
return nil
}
// Abstract conn vs trans.
type querier interface {
QueryRow(query string, args ...interface{}) *sql.Row
}
// Abstract row vs rows.
type scanner interface {
Scan(dest ...interface{}) error
}
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
r, err := c.Exec(`delete from auth_request where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc auth_request: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.AuthRequests = n
}
r, err = c.Exec(`delete from auth_code where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc auth_code: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.AuthCodes = n
}
return
}
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
_, err := c.Exec(`
insert into auth_request (
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data,
expiry
)
values (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
);
`,
a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
a.ForceApprovalPrompt, a.LoggedIn,
a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData,
a.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert auth request: %v", err)
}
return nil
}
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error {
err := c.ExecTx(func(tx *trans) error {
r, err := getAuthRequest(tx, id)
if err != nil {
return err
}
a, err := updater(r)
if err != nil {
return err
}
_, err = tx.Exec(`
update auth_request
set
client_id = $1, response_types = $2, scopes = $3, redirect_uri = $4,
nonce = $5, state = $6, force_approval_prompt = $7, logged_in = $8,
claims_user_id = $9, claims_username = $10, claims_email = $11,
claims_email_verified = $12,
claims_groups = $13,
connector_id = $14, connector_data = $15,
expiry = $16
where id = $17;
`,
a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
a.ForceApprovalPrompt, a.LoggedIn,
a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData,
a.Expiry, r.ID,
)
return err
})
if err != nil {
return fmt.Errorf("update auth request: %v", err)
}
return nil
}
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
req, err := getAuthRequest(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.AuthRequest{}, storage.ErrNotFound
}
return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err)
}
return req, nil
}
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
err = q.QueryRow(`
select
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data, expiry
from auth_request where id = $1;
`, id).Scan(
&a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State,
&a.ForceApprovalPrompt, &a.LoggedIn,
&a.Claims.UserID, &a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified,
decoder(&a.Claims.Groups),
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
if err != nil {
return a, err
}
return a, nil
}
func (c *conn) CreateAuthCode(a storage.AuthCode) error {
_, err := c.Exec(`
insert into auth_code (
id, client_id, scopes, nonce, redirect_uri,
claims_user_id, claims_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
expiry
)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13);
`,
a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID,
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData, a.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert auth code: %v", err)
}
return nil
}
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
err = c.QueryRow(`
select
id, client_id, scopes, nonce, redirect_uri,
claims_user_id, claims_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
expiry
from auth_code where id = $1;
`, id).Scan(
&a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID,
&a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups),
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select auth code: %v", err)
}
return a, nil
}
func (c *conn) CreateRefresh(r storage.RefreshToken) error {
_, err := c.Exec(`
insert into refresh_token (
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data,
token, created_at, last_used
)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
`,
r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert refresh token: %v", err)
}
return nil
}
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
err := c.ExecTx(func(tx *trans) error {
r, err := getRefresh(tx, id)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update refresh_token
set
client_id = $1,
scopes = $2,
nonce = $3,
claims_user_id = $4,
claims_username = $5,
claims_email = $6,
claims_email_verified = $7,
claims_groups = $8,
connector_id = $9,
connector_data = $10,
token = $11,
created_at = $12,
last_used = $13
where
id = $14
`,
r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, id,
)
return err
})
if err != nil {
return fmt.Errorf("update refresh token: %v", err)
}
return nil
}
func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) {
req, err := getRefresh(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.RefreshToken{}, storage.ErrNotFound
}
return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err)
}
return req, nil
}
func getRefresh(q querier, id string) (storage.RefreshToken, error) {
return scanRefresh(q.QueryRow(`
select
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data,
token, created_at, last_used
from refresh_token where id = $1;
`, id))
}
func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
rows, err := c.Query(`
select
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data,
token, created_at, last_used
from refresh_token;
`)
if err != nil {
return nil, fmt.Errorf("select refresh tokens: %v", err)
}
var tokens []storage.RefreshToken
for rows.Next() {
r, err := scanRefresh(rows)
if err != nil {
return nil, fmt.Errorf("scan refresh token: %s", err)
}
tokens = append(tokens, r)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scan: %v", err)
}
return tokens, nil
}
func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
err = s.Scan(
&r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
decoder(&r.Claims.Groups),
&r.ConnectorID, &r.ConnectorData,
&r.Token, &r.CreatedAt, &r.LastUsed,
)
if err != nil {
return r, err
}
return r, nil
}
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
return c.ExecTx(func(tx *trans) error {
firstUpdate := false
// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL
// server. Test this, and consider adding a COUNT() command beforehand.
old, err := getKeys(tx)
if err == sql.ErrNoRows {
firstUpdate = true
old = storage.Keys{}
} else if err != nil {
return err
}
nk, err := updater(old)
if err != nil {
return err
}
if firstUpdate {
_, err = tx.Exec(`
insert into keys (
id, verification_keys, signing_key, signing_key_pub, next_rotation
)
values ($1, $2, $3, $4, $5);
`,
keysRowID, encoder(nk.VerificationKeys), encoder(nk.SigningKey),
encoder(nk.SigningKeyPub), nk.NextRotation,
)
if err != nil {
return err
}
} else {
_, err = tx.Exec(`
update keys
set
verification_keys = $1,
signing_key = $2,
signing_key_pub = $3,
next_rotation = $4
where id = $5;
`,
encoder(nk.VerificationKeys), encoder(nk.SigningKey),
encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID,
)
if err != nil {
return err
}
}
return nil
})
}
func (c *conn) GetKeys() (storage.Keys, error) {
keys, err := getKeys(c)
if err != nil {
if err == sql.ErrNoRows {
return storage.Keys{}, storage.ErrNotFound
}
return storage.Keys{}, fmt.Errorf("select keys: %s", err)
}
return keys, nil
}
func getKeys(q querier) (keys storage.Keys, err error) {
err = q.QueryRow(`
select
verification_keys, signing_key, signing_key_pub, next_rotation
from keys
where id=$1
`, keysRowID).Scan(
decoder(&keys.VerificationKeys), decoder(&keys.SigningKey),
decoder(&keys.SigningKeyPub), &keys.NextRotation,
)
if err != nil {
return keys, err
}
return keys, nil
}
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
err := c.ExecTx(func(tx *trans) error {
cli, err := getClient(tx, id)
if err != nil {
return err
}
nc, err := updater(cli)
if err != nil {
return err
}
_, err = tx.Exec(`
update client
set
secret = $1,
redirect_uris = $2,
trusted_peers = $3,
public = $4,
name = $5,
logo_url = $6
where id = $7;
`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id,
)
return err
})
if err != nil {
return fmt.Errorf("update client: %v", err)
}
return nil
}
func (c *conn) CreateClient(cli storage.Client) error {
_, err := c.Exec(`
insert into client (
id, secret, redirect_uris, trusted_peers, public, name, logo_url
)
values ($1, $2, $3, $4, $5, $6, $7);
`,
cli.ID, cli.Secret, encoder(cli.RedirectURIs), encoder(cli.TrustedPeers),
cli.Public, cli.Name, cli.LogoURL,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert client: %v", err)
}
return nil
}
func getClient(q querier, id string) (storage.Client, error) {
return scanClient(q.QueryRow(`
select
id, secret, redirect_uris, trusted_peers, public, name, logo_url
from client where id = $1;
`, id))
}
func (c *conn) GetClient(id string) (storage.Client, error) {
client, err := getClient(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.Client{}, storage.ErrNotFound
}
return storage.Client{}, fmt.Errorf("select client: %v", err)
}
return client, nil
}
func (c *conn) ListClients() ([]storage.Client, error) {
rows, err := c.Query(`
select
id, secret, redirect_uris, trusted_peers, public, name, logo_url
from client;
`)
if err != nil {
return nil, err
}
var clients []storage.Client
for rows.Next() {
cli, err := scanClient(rows)
if err != nil {
return nil, fmt.Errorf("scan client: %s", err)
}
clients = append(clients, cli)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scan: %s", err)
}
return clients, nil
}
func scanClient(s scanner) (cli storage.Client, err error) {
err = s.Scan(
&cli.ID, &cli.Secret, decoder(&cli.RedirectURIs), decoder(&cli.TrustedPeers),
&cli.Public, &cli.Name, &cli.LogoURL,
)
if err != nil {
return cli, err
}
return cli, nil
}
func (c *conn) CreatePassword(p storage.Password) error {
p.Email = strings.ToLower(p.Email)
_, err := c.Exec(`
insert into password (
email, hash, username, user_id
)
values (
$1, $2, $3, $4
);
`,
p.Email, p.Hash, p.Username, p.UserID,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert password: %v", err)
}
return nil
}
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error {
err := c.ExecTx(func(tx *trans) error {
p, err := getPassword(tx, email)
if err != nil {
return err
}
np, err := updater(p)
if err != nil {
return err
}
_, err = tx.Exec(`
update password
set
hash = $1, username = $2, user_id = $3
where email = $4;
`,
np.Hash, np.Username, np.UserID, p.Email,
)
return err
})
if err != nil {
return fmt.Errorf("update password: %v", err)
}
return nil
}
func (c *conn) GetPassword(email string) (storage.Password, error) {
pass, err := getPassword(c, email)
if err != nil {
if err == sql.ErrNoRows {
return storage.Password{}, storage.ErrNotFound
}
return storage.Password{}, fmt.Errorf("get password: %s", err)
}
return pass, nil
}
func getPassword(q querier, email string) (p storage.Password, err error) {
return scanPassword(q.QueryRow(`
select
email, hash, username, user_id
from password where email = $1;
`, strings.ToLower(email)))
}
func (c *conn) ListPasswords() ([]storage.Password, error) {
rows, err := c.Query(`
select
email, hash, username, user_id
from password;
`)
if err != nil {
return nil, err
}
var passwords []storage.Password
for rows.Next() {
p, err := scanPassword(rows)
if err != nil {
return nil, fmt.Errorf("scan password: %s", err)
}
passwords = append(passwords, p)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scan: %s", err)
}
return passwords, nil
}
func scanPassword(s scanner) (p storage.Password, err error) {
err = s.Scan(
&p.Email, &p.Hash, &p.Username, &p.UserID,
)
if err != nil {
return p, err
}
return p, nil
}
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
_, err := c.Exec(`
insert into offline_session (
user_id, conn_id, refresh
)
values (
$1, $2, $3
);
`,
s.UserID, s.ConnID, encoder(s.Refresh),
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert offline session: %v", err)
}
return nil
}
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
err := c.ExecTx(func(tx *trans) error {
s, err := getOfflineSessions(tx, userID, connID)
if err != nil {
return err
}
newSession, err := updater(s)
if err != nil {
return err
}
_, err = tx.Exec(`
update offline_session
set
refresh = $1
where user_id = $2 AND conn_id = $3;
`,
encoder(newSession.Refresh), s.UserID, s.ConnID,
)
return err
})
if err != nil {
return fmt.Errorf("update offline session: %v", err)
}
return nil
}
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
sessions, err := getOfflineSessions(c, userID, connID)
if err != nil {
if err == sql.ErrNoRows {
return storage.OfflineSessions{}, storage.ErrNotFound
}
return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err)
}
return sessions, nil
}
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
return scanOfflineSessions(q.QueryRow(`
select
user_id, conn_id, refresh
from offline_session
where user_id = $1 AND conn_id = $2;
`, userID, connID))
}
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
err = s.Scan(
&o.UserID, &o.ConnID, decoder(&o.Refresh),
)
if err != nil {
return o, err
}
return o, nil
}
func (c *conn) CreateConnector(connector storage.Connector) error {
_, err := c.Exec(`
insert into connector (
id, type, name, resource_version, config
)
values (
$1, $2, $3, $4, $5
);
`,
connector.ID, connector.Type, connector.Name, connector.ResourceVersion, connector.Config,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert connector: %v", err)
}
return nil
}
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error {
err := c.ExecTx(func(tx *trans) error {
connector, err := getConnector(tx, id)
if err != nil {
return err
}
newConn, err := updater(connector)
if err != nil {
return err
}
_, err = tx.Exec(`
update connector
set
type = $1,
name = $2,
resource_version = $3,
config = $4
where id = $5;
`,
newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID,
)
return err
})
if err != nil {
return fmt.Errorf("update connector: %v", err)
}
return nil
}
func (c *conn) GetConnector(id string) (storage.Connector, error) {
connector, err := getConnector(c, id)
if err != nil {
if err == sql.ErrNoRows {
return storage.Connector{}, storage.ErrNotFound
}
return storage.Connector{}, fmt.Errorf("get connector: %s", err)
}
return connector, nil
}
func getConnector(q querier, id string) (storage.Connector, error) {
return scanConnector(q.QueryRow(`
select
id, type, name, resource_version, config
from connector
where id = $1;
`, id))
}
func scanConnector(s scanner) (c storage.Connector, err error) {
err = s.Scan(
&c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config,
)
if err != nil {
return c, err
}
return c, nil
}
func (c *conn) ListConnectors() ([]storage.Connector, error) {
rows, err := c.Query(`
select
id, type, name, resource_version, config
from connector;
`)
if err != nil {
return nil, err
}
var connectors []storage.Connector
for rows.Next() {
conn, err := scanConnector(rows)
if err != nil {
return nil, fmt.Errorf("scan connector: %s", err)
}
connectors = append(connectors, conn)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scan: %s", err)
}
return connectors, nil
}
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) }
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) }
func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) }
func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_token", "id", id) }
func (c *conn) DeletePassword(email string) error {
return c.delete("password", "email", strings.ToLower(email))
}
func (c *conn) DeleteConnector(id string) error { return c.delete("connector", "id", id) }
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID)
if err != nil {
return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID)
}
// For now mandate that the driver implements RowsAffected. If we ever need to support
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
n, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %v", err)
}
if n < 1 {
return storage.ErrNotFound
}
return nil
}
// Do NOT call directly. Does not escape table.
func (c *conn) delete(table, field, id string) error {
result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id)
if err != nil {
return fmt.Errorf("delete %s: %v", table, id)
}
// For now mandate that the driver implements RowsAffected. If we ever need to support
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
n, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %v", err)
}
if n < 1 {
return storage.ErrNotFound
}
return nil
}