storage: add extra fields to refresh token and update method
This commit is contained in:
		| @@ -208,10 +208,14 @@ func testClientCRUD(t *testing.T, s storage.Storage) { | ||||
| func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { | ||||
| 	id := storage.NewID() | ||||
| 	refresh := storage.RefreshToken{ | ||||
| 		RefreshToken: id, | ||||
| 		ClientID:     "client_id", | ||||
| 		ConnectorID:  "client_secret", | ||||
| 		Scopes:       []string{"openid", "email", "profile"}, | ||||
| 		ID:          id, | ||||
| 		Token:       "bar", | ||||
| 		Nonce:       "foo", | ||||
| 		ClientID:    "client_id", | ||||
| 		ConnectorID: "client_secret", | ||||
| 		Scopes:      []string{"openid", "email", "profile"}, | ||||
| 		CreatedAt:   time.Now().UTC().Round(time.Millisecond), | ||||
| 		LastUsed:    time.Now().UTC().Round(time.Millisecond), | ||||
| 		Claims: storage.Claims{ | ||||
| 			UserID:        "1", | ||||
| 			Username:      "jane", | ||||
| @@ -238,6 +242,20 @@ func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { | ||||
|  | ||||
| 	getAndCompare(id, refresh) | ||||
|  | ||||
| 	updatedAt := time.Now().UTC().Round(time.Millisecond) | ||||
|  | ||||
| 	updater := func(r storage.RefreshToken) (storage.RefreshToken, error) { | ||||
| 		r.Token = "spam" | ||||
| 		r.LastUsed = updatedAt | ||||
| 		return r, nil | ||||
| 	} | ||||
| 	if err := s.UpdateRefreshToken(id, updater); err != nil { | ||||
| 		t.Errorf("failed to udpate refresh token: %v", err) | ||||
| 	} | ||||
| 	refresh.Token = "spam" | ||||
| 	refresh.LastUsed = updatedAt | ||||
| 	getAndCompare(id, refresh) | ||||
|  | ||||
| 	if err := s.DeleteRefresh(id); err != nil { | ||||
| 		t.Fatalf("failed to delete refresh request: %v", err) | ||||
| 	} | ||||
|   | ||||
| @@ -153,23 +153,7 @@ func (cli *client) CreatePassword(p storage.Password) error { | ||||
| } | ||||
|  | ||||
| func (cli *client) CreateRefresh(r storage.RefreshToken) error { | ||||
| 	refresh := RefreshToken{ | ||||
| 		TypeMeta: k8sapi.TypeMeta{ | ||||
| 			Kind:       kindRefreshToken, | ||||
| 			APIVersion: cli.apiVersion, | ||||
| 		}, | ||||
| 		ObjectMeta: k8sapi.ObjectMeta{ | ||||
| 			Name:      r.RefreshToken, | ||||
| 			Namespace: cli.namespace, | ||||
| 		}, | ||||
| 		ClientID:      r.ClientID, | ||||
| 		ConnectorID:   r.ConnectorID, | ||||
| 		Scopes:        r.Scopes, | ||||
| 		Nonce:         r.Nonce, | ||||
| 		Claims:        fromStorageClaims(r.Claims), | ||||
| 		ConnectorData: r.ConnectorData, | ||||
| 	} | ||||
| 	return cli.post(resourceRefreshToken, refresh) | ||||
| 	return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r)) | ||||
| } | ||||
|  | ||||
| func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) { | ||||
| @@ -239,19 +223,16 @@ func (cli *client) GetKeys() (storage.Keys, error) { | ||||
| } | ||||
|  | ||||
| func (cli *client) GetRefresh(id string) (storage.RefreshToken, error) { | ||||
| 	var r RefreshToken | ||||
| 	if err := cli.get(resourceRefreshToken, id, &r); err != nil { | ||||
| 	r, err := cli.getRefreshToken(id) | ||||
| 	if err != nil { | ||||
| 		return storage.RefreshToken{}, err | ||||
| 	} | ||||
| 	return storage.RefreshToken{ | ||||
| 		RefreshToken:  r.ObjectMeta.Name, | ||||
| 		ClientID:      r.ClientID, | ||||
| 		ConnectorID:   r.ConnectorID, | ||||
| 		Scopes:        r.Scopes, | ||||
| 		Nonce:         r.Nonce, | ||||
| 		Claims:        toStorageClaims(r.Claims), | ||||
| 		ConnectorData: r.ConnectorData, | ||||
| 	}, nil | ||||
| 	return toStorageRefreshToken(r), nil | ||||
| } | ||||
|  | ||||
| func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) { | ||||
| 	err = cli.get(resourceRefreshToken, id, &r) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (cli *client) ListClients() ([]storage.Client, error) { | ||||
| @@ -311,6 +292,22 @@ func (cli *client) DeletePassword(email string) error { | ||||
| 	return cli.delete(resourcePassword, p.ObjectMeta.Name) | ||||
| } | ||||
|  | ||||
| func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { | ||||
| 	r, err := cli.getRefreshToken(id) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	updated, err := updater(toStorageRefreshToken(r)) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	updated.ID = id | ||||
|  | ||||
| 	newToken := cli.fromStorageRefreshToken(updated) | ||||
| 	newToken.ObjectMeta = r.ObjectMeta | ||||
| 	return cli.put(resourceRefreshToken, r.ObjectMeta.Name, newToken) | ||||
| } | ||||
|  | ||||
| func (cli *client) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { | ||||
| 	c, err := cli.getClient(id) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -362,9 +362,14 @@ type RefreshToken struct { | ||||
| 	k8sapi.TypeMeta   `json:",inline"` | ||||
| 	k8sapi.ObjectMeta `json:"metadata,omitempty"` | ||||
|  | ||||
| 	CreatedAt time.Time | ||||
| 	LastUsed  time.Time | ||||
|  | ||||
| 	ClientID string   `json:"clientID"` | ||||
| 	Scopes   []string `json:"scopes,omitempty"` | ||||
|  | ||||
| 	Token string `json:"token,omitempty"` | ||||
|  | ||||
| 	Nonce string `json:"nonce,omitempty"` | ||||
|  | ||||
| 	Claims        Claims `json:"claims,omitempty"` | ||||
| @@ -379,6 +384,43 @@ type RefreshList struct { | ||||
| 	RefreshTokens   []RefreshToken `json:"items"` | ||||
| } | ||||
|  | ||||
| func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { | ||||
| 	return storage.RefreshToken{ | ||||
| 		ID:            r.ObjectMeta.Name, | ||||
| 		Token:         r.Token, | ||||
| 		CreatedAt:     r.CreatedAt, | ||||
| 		LastUsed:      r.LastUsed, | ||||
| 		ClientID:      r.ClientID, | ||||
| 		ConnectorID:   r.ConnectorID, | ||||
| 		ConnectorData: r.ConnectorData, | ||||
| 		Scopes:        r.Scopes, | ||||
| 		Nonce:         r.Nonce, | ||||
| 		Claims:        toStorageClaims(r.Claims), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (cli *client) fromStorageRefreshToken(r storage.RefreshToken) RefreshToken { | ||||
| 	return RefreshToken{ | ||||
| 		TypeMeta: k8sapi.TypeMeta{ | ||||
| 			Kind:       kindRefreshToken, | ||||
| 			APIVersion: cli.apiVersion, | ||||
| 		}, | ||||
| 		ObjectMeta: k8sapi.ObjectMeta{ | ||||
| 			Name:      r.ID, | ||||
| 			Namespace: cli.namespace, | ||||
| 		}, | ||||
| 		Token:         r.Token, | ||||
| 		CreatedAt:     r.CreatedAt, | ||||
| 		LastUsed:      r.LastUsed, | ||||
| 		ClientID:      r.ClientID, | ||||
| 		ConnectorID:   r.ConnectorID, | ||||
| 		ConnectorData: r.ConnectorData, | ||||
| 		Scopes:        r.Scopes, | ||||
| 		Nonce:         r.Nonce, | ||||
| 		Claims:        fromStorageClaims(r.Claims), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Keys is a mirrored struct from storage with JSON struct tags and Kubernetes | ||||
| // type metadata. | ||||
| type Keys struct { | ||||
|   | ||||
| @@ -98,10 +98,10 @@ func (s *memStorage) CreateAuthCode(c storage.AuthCode) (err error) { | ||||
|  | ||||
| func (s *memStorage) CreateRefresh(r storage.RefreshToken) (err error) { | ||||
| 	s.tx(func() { | ||||
| 		if _, ok := s.refreshTokens[r.RefreshToken]; ok { | ||||
| 		if _, ok := s.refreshTokens[r.ID]; ok { | ||||
| 			err = storage.ErrAlreadyExists | ||||
| 		} else { | ||||
| 			s.refreshTokens[r.RefreshToken] = r | ||||
| 			s.refreshTokens[r.ID] = r | ||||
| 		} | ||||
| 	}) | ||||
| 	return | ||||
| @@ -324,3 +324,17 @@ func (s *memStorage) UpdatePassword(email string, updater func(p storage.Passwor | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.RefreshToken) (storage.RefreshToken, error)) (err error) { | ||||
| 	s.tx(func() { | ||||
| 		r, ok := s.refreshTokens[id] | ||||
| 		if !ok { | ||||
| 			err = storage.ErrNotFound | ||||
| 			return | ||||
| 		} | ||||
| 		if r, err = updater(r); err == nil { | ||||
| 			s.refreshTokens[id] = r | ||||
| 		} | ||||
| 	}) | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -244,14 +244,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { | ||||
| 			id, client_id, scopes, nonce, | ||||
| 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||
| 			claims_groups, | ||||
| 			connector_id, connector_data | ||||
| 			connector_id, connector_data, | ||||
| 			token, created_at, last_used | ||||
| 		) | ||||
| 		values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); | ||||
| 		values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); | ||||
| 	`, | ||||
| 		r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce, | ||||
| 		r.ID, r.ClientID, encoder(r.Scopes), r.Nonce, | ||||
| 		r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, | ||||
| 		encoder(r.Claims.Groups), | ||||
| 		r.ConnectorID, r.ConnectorData, | ||||
| 		r.Token, r.CreatedAt, r.LastUsed, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("insert refresh_token: %v", err) | ||||
| @@ -259,13 +261,57 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { | ||||
| 	return c.ExecTx(func(tx *trans) error { | ||||
| 		r, err := getRefresh(tx, id) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if r, err = updater(r); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		_, err = tx.Exec(` | ||||
| 			update refresh_token | ||||
| 			set | ||||
| 				client_id = $1, | ||||
| 				scopes = $2, | ||||
| 				nonce = $3, | ||||
| 				claims_user_id = $4, | ||||
| 				claims_username = $5, | ||||
| 				claims_email = $6, | ||||
| 				claims_email_verified = $7, | ||||
| 				claims_groups = $8, | ||||
| 				connector_id = $9, | ||||
| 				connector_data = $10, | ||||
| 				token = $11, | ||||
| 				created_at = $12, | ||||
| 				last_used = $13 | ||||
| 		`, | ||||
| 			r.ClientID, encoder(r.Scopes), r.Nonce, | ||||
| 			r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, | ||||
| 			encoder(r.Claims.Groups), | ||||
| 			r.ConnectorID, r.ConnectorData, | ||||
| 			r.Token, r.CreatedAt, r.LastUsed, | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("update refresh token: %v", err) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { | ||||
| 	return scanRefresh(c.QueryRow(` | ||||
| 	return getRefresh(c, id) | ||||
| } | ||||
|  | ||||
| func getRefresh(q querier, id string) (storage.RefreshToken, error) { | ||||
| 	return scanRefresh(q.QueryRow(` | ||||
| 		select | ||||
| 			id, client_id, scopes, nonce, | ||||
| 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||
| 			claims_groups, | ||||
| 			connector_id, connector_data | ||||
| 			connector_id, connector_data, | ||||
| 			token, created_at, last_used | ||||
| 		from refresh_token where id = $1; | ||||
| 	`, id)) | ||||
| } | ||||
| @@ -276,7 +322,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { | ||||
| 			id, client_id, scopes, nonce, | ||||
| 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||
| 			claims_groups, | ||||
| 			connector_id, connector_data | ||||
| 			connector_id, connector_data, | ||||
| 			token, created_at, last_used | ||||
| 		from refresh_token; | ||||
| 	`) | ||||
| 	if err != nil { | ||||
| @@ -298,10 +345,11 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { | ||||
|  | ||||
| func scanRefresh(s scanner) (r storage.RefreshToken, err error) { | ||||
| 	err = s.Scan( | ||||
| 		&r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce, | ||||
| 		&r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce, | ||||
| 		&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified, | ||||
| 		decoder(&r.Claims.Groups), | ||||
| 		&r.ConnectorID, &r.ConnectorData, | ||||
| 		&r.Token, &r.CreatedAt, &r.LastUsed, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		if err == sql.ErrNoRows { | ||||
|   | ||||
| @@ -155,4 +155,14 @@ var migrations = []migration{ | ||||
| 			); | ||||
| 		`, | ||||
| 	}, | ||||
| 	{ | ||||
| 		stmt: ` | ||||
| 			alter table refresh_token | ||||
| 				add column token text not null default ''; | ||||
| 			alter table refresh_token | ||||
| 				add column created_at timestamptz not null default '0001-01-01 00:00:00 UTC'; | ||||
| 			alter table refresh_token | ||||
| 				add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC'; | ||||
| 		`, | ||||
| 	}, | ||||
| } | ||||
|   | ||||
| @@ -94,6 +94,7 @@ type Storage interface { | ||||
| 	UpdateClient(id string, updater func(old Client) (Client, error)) error | ||||
| 	UpdateKeys(updater func(old Keys) (Keys, error)) error | ||||
| 	UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error | ||||
| 	UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error | ||||
| 	UpdatePassword(email string, updater func(p Password) (Password, error)) error | ||||
|  | ||||
| 	// GarbageCollect deletes all expired AuthCodes and AuthRequests. | ||||
| @@ -216,8 +217,15 @@ type AuthCode struct { | ||||
| // RefreshToken is an OAuth2 refresh token which allows a client to request new | ||||
| // tokens on the end user's behalf. | ||||
| type RefreshToken struct { | ||||
| 	// The actual refresh token. | ||||
| 	RefreshToken string | ||||
| 	ID string | ||||
|  | ||||
| 	// A single token that's rotated every time the refresh token is refreshed. | ||||
| 	// | ||||
| 	// May be empty. | ||||
| 	Token string | ||||
|  | ||||
| 	CreatedAt time.Time | ||||
| 	LastUsed  time.Time | ||||
|  | ||||
| 	// Client this refresh token is valid for. | ||||
| 	ClientID string | ||||
|   | ||||
		Reference in New Issue
	
	Block a user