diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 532b8648..ca941f7c 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "strings" "github.com/coreos/dex/storage" ) @@ -137,7 +138,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, - a.Expiry, a.ID, + a.Expiry, r.ID, ) if err != nil { return fmt.Errorf("update auth request: %v", err) @@ -462,14 +463,83 @@ func scanClient(s scanner) (cli storage.Client, err error) { return cli, nil } -func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", id) } -func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", id) } -func (c *conn) DeleteClient(id string) error { return c.delete("client", id) } -func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_token", id) } +func (c *conn) CreatePassword(p storage.Password) error { + p.Email = strings.ToLower(p.Email) + _, err := c.Exec(` + insert into password ( + email, hash, username, user_id + ) + values ( + $1, $2, $3, $4 + ); + `, + p.Email, p.Hash, p.Username, p.UserID, + ) + if err != nil { + return fmt.Errorf("insert password: %v", err) + } + return nil +} + +func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error { + return c.ExecTx(func(tx *trans) error { + p, err := getPassword(tx, email) + if err != nil { + return err + } + + np, err := updater(p) + if err != nil { + return err + } + _, err = tx.Exec(` + update password + set + hash = $1, username = $2, user_id = $3 + where email = $4; + `, + np.Hash, np.Username, np.UserID, p.Email, + ) + if err != nil { + return fmt.Errorf("update password: %v", err) + } + return nil + }) +} + +func (c *conn) GetPassword(email string) (storage.Password, error) { + return getPassword(c, email) +} + +func getPassword(q querier, email string) (p storage.Password, err error) { + email = strings.ToLower(email) + err = q.QueryRow(` + select + email, hash, username, user_id + from password where email = $1; + `, email).Scan( + &p.Email, &p.Hash, &p.Username, &p.UserID, + ) + if err != nil { + if err == sql.ErrNoRows { + return p, storage.ErrNotFound + } + return p, fmt.Errorf("select password: %v", err) + } + return p, nil +} + +func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) } +func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) } +func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) } +func (c *conn) DeleteRefresh(id string) error { return c.delete("refresh_token", "id", id) } +func (c *conn) DeletePassword(email string) error { + return c.delete("password", "email", strings.ToLower(email)) +} // Do NOT call directly. Does not escape table. -func (c *conn) delete(table, id string) error { - result, err := c.Exec(`delete from `+table+` where id = $1`, id) +func (c *conn) delete(table, field, id string) error { + result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id) if err != nil { return fmt.Errorf("delete %s: %v", table, id) } diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 8754caf5..d9c254d3 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -137,6 +137,13 @@ var migrations = []migration{ connector_id text not null, connector_data bytea ); + + create table password ( + email text not null primary key, + hash bytea not null, + username text not null, + user_id text not null + ); -- keys is a weird table because we only ever expect there to be a single row create table keys (