storage: Add OfflineSession object to backend storage.
This commit is contained in:
@@ -624,6 +624,75 @@ func scanPassword(s scanner) (p storage.Password, err error) {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||
_, err := c.Exec(`
|
||||
insert into offline_session (
|
||||
user_id, conn_id, refresh
|
||||
)
|
||||
values (
|
||||
$1, $2, $3
|
||||
);
|
||||
`,
|
||||
s.UserID, s.ConnID, encoder(s.Refresh),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert offline session: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
s, err := getOfflineSessions(tx, userID, connID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newSession, err := updater(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`
|
||||
update offline_session
|
||||
set
|
||||
refresh = $1
|
||||
where user_id = $2 AND conn_id = $3;
|
||||
`,
|
||||
encoder(newSession.Refresh), s.UserID, s.ConnID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update offline session: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
|
||||
return getOfflineSessions(c, userID, connID)
|
||||
}
|
||||
|
||||
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
|
||||
return scanOfflineSessions(q.QueryRow(`
|
||||
select
|
||||
user_id, conn_id, refresh
|
||||
from offline_session
|
||||
where user_id = $1 AND conn_id = $2;
|
||||
`, userID, connID))
|
||||
}
|
||||
|
||||
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
|
||||
err = s.Scan(
|
||||
&o.UserID, &o.ConnID, decoder(&o.Refresh),
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return o, storage.ErrNotFound
|
||||
}
|
||||
return o, fmt.Errorf("select offline session: %v", err)
|
||||
}
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) }
|
||||
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) }
|
||||
func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) }
|
||||
@@ -632,6 +701,24 @@ func (c *conn) DeletePassword(email string) error {
|
||||
return c.delete("password", "email", strings.ToLower(email))
|
||||
}
|
||||
|
||||
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
|
||||
result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID)
|
||||
}
|
||||
|
||||
// For now mandate that the driver implements RowsAffected. If we ever need to support
|
||||
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
|
||||
n, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("rows affected: %v", err)
|
||||
}
|
||||
if n < 1 {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do NOT call directly. Does not escape table.
|
||||
func (c *conn) delete(table, field, id string) error {
|
||||
result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id)
|
||||
|
Reference in New Issue
Block a user