Device flow token code exchange (#2)

* Added /device/token handler with associated business logic and storage tests.

Perform user code exchange, flag the device code as complete.

Moved device handler code into its own file for cleanliness.  Cleanup

* Removed PKCE code

* Rate limiting for /device/token endpoint based on ietf standards

* Configurable Device expiry

Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
This commit is contained in:
Justin Slowik
2020-01-28 14:14:30 -05:00
committed by justin-slowik
parent 0d1a0e4129
commit 9bbdc721d5
20 changed files with 777 additions and 274 deletions

View File

@@ -843,12 +843,11 @@ func testGC(t *testing.T, s storage.Storage) {
}
d := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
PkceVerifier: storage.NewID(),
Expiry: expiry,
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
Expiry: expiry,
}
if err := s.CreateDeviceRequest(d); err != nil {
@@ -970,12 +969,11 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
panic(err)
}
d1 := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
PkceVerifier: storage.NewID(),
Expiry: neverExpire,
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
Expiry: neverExpire,
}
if err := s.CreateDeviceRequest(d1); err != nil {
@@ -991,20 +989,44 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
}
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
//Create a Token
d1 := storage.DeviceToken{
DeviceCode: storage.NewID(),
Status: "pending",
Token: storage.NewID(),
Expiry: neverExpire,
DeviceCode: storage.NewID(),
Status: "pending",
Token: storage.NewID(),
Expiry: neverExpire,
LastRequestTime: time.Now(),
PollIntervalSeconds: 0,
}
if err := s.CreateDeviceToken(d1); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
// Attempt to create same DeviceRequest twice.
// Attempt to create same Device Token twice.
err := s.CreateDeviceToken(d1)
mustBeErrAlreadyExists(t, "device token", err)
//TODO Add update / delete tests as functionality is put into main code
//Update the device token, simulate a redemption
if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) {
old.Token = "token data"
old.Status = "complete"
return old, nil
}); err != nil {
t.Fatalf("failed to update device token: %v", err)
}
//Retrieve the device token
got, err := s.GetDeviceToken(d1.DeviceCode)
if err != nil {
t.Fatalf("failed to get device token: %v", err)
}
//Validate expected result set
if got.Status != "complete" {
t.Fatalf("update failed, wanted token status=%#v got %#v", "complete", got.Status)
}
if got.Token != "token data" {
t.Fatalf("update failed, wanted token =%#v got %#v", "token data", got.Token)
}
}

View File

@@ -570,6 +570,13 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
}
func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r)
return r, err
}
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
if err != nil {
@@ -612,3 +619,21 @@ func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken
}
return deviceTokens, nil
}
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) {
var current DeviceToken
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(toStorageDeviceToken(current))
if err != nil {
return nil, err
}
return json.Marshal(fromStorageDeviceToken(updated))
})
}

View File

@@ -219,38 +219,51 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
// DeviceRequest is a mirrored struct from storage with JSON struct tags
type DeviceRequest struct {
UserCode string `json:"user_code"`
DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"`
Scopes []string `json:"scopes"`
PkceVerifier string `json:"pkce_verifier"`
Expiry time.Time `json:"expiry"`
UserCode string `json:"user_code"`
DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"`
Scopes []string `json:"scopes"`
Expiry time.Time `json:"expiry"`
}
func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest {
return DeviceRequest{
UserCode: d.UserCode,
DeviceCode: d.DeviceCode,
ClientID: d.ClientID,
Scopes: d.Scopes,
PkceVerifier: d.PkceVerifier,
Expiry: d.Expiry,
UserCode: d.UserCode,
DeviceCode: d.DeviceCode,
ClientID: d.ClientID,
Scopes: d.Scopes,
Expiry: d.Expiry,
}
}
// DeviceToken is a mirrored struct from storage with JSON struct tags
type DeviceToken struct {
DeviceCode string `json:"device_code"`
Status string `json:"status"`
Token string `json:"token"`
Expiry time.Time `json:"expiry"`
DeviceCode string `json:"device_code"`
Status string `json:"status"`
Token string `json:"token"`
Expiry time.Time `json:"expiry"`
LastRequestTime time.Time `json:"last_request"`
PollIntervalSeconds int `json:"poll_interval"`
}
func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
return DeviceToken{
DeviceCode: t.DeviceCode,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
DeviceCode: t.DeviceCode,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
}
}
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
return storage.DeviceToken{
DeviceCode: t.DeviceCode,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
}
}

View File

@@ -638,6 +638,14 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
}
func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
var req DeviceRequest
if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil {
return storage.DeviceRequest{}, err
}
return toStorageDeviceRequest(req), nil
}
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
}
@@ -649,3 +657,24 @@ func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error
}
return toStorageDeviceToken(token), nil
}
func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) {
err = cli.get(resourceDeviceToken, deviceCode, &t)
return
}
func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
r, err := cli.getDeviceToken(deviceCode)
if err != nil {
return err
}
updated, err := updater(toStorageDeviceToken(r))
if err != nil {
return err
}
updated.DeviceCode = deviceCode
newToken := cli.fromStorageDeviceToken(updated)
newToken.ObjectMeta = r.ObjectMeta
return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken)
}

View File

@@ -672,11 +672,10 @@ type DeviceRequest struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
DeviceCode string `json:"device_code,omitempty"`
CLientID string `json:"client_id,omitempty"`
Scopes []string `json:"scopes,omitempty"`
PkceVerifier string `json:"pkce_verifier,omitempty"`
Expiry time.Time `json:"expiry"`
DeviceCode string `json:"device_code,omitempty"`
CLientID string `json:"client_id,omitempty"`
Scopes []string `json:"scopes,omitempty"`
Expiry time.Time `json:"expiry"`
}
// AuthRequestList is a list of AuthRequests.
@@ -696,24 +695,35 @@ func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceReque
Name: strings.ToLower(a.UserCode),
Namespace: cli.namespace,
},
DeviceCode: a.DeviceCode,
CLientID: a.ClientID,
Scopes: a.Scopes,
PkceVerifier: a.PkceVerifier,
Expiry: a.Expiry,
DeviceCode: a.DeviceCode,
CLientID: a.ClientID,
Scopes: a.Scopes,
Expiry: a.Expiry,
}
return req
}
func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest {
return storage.DeviceRequest{
UserCode: strings.ToUpper(req.ObjectMeta.Name),
DeviceCode: req.DeviceCode,
ClientID: req.CLientID,
Scopes: req.Scopes,
Expiry: req.Expiry,
}
}
// DeviceToken is a mirrored struct from storage with JSON struct tags and
// Kubernetes type metadata.
type DeviceToken struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
Status string `json:"status,omitempty"`
Token string `json:"token,omitempty"`
Expiry time.Time `json:"expiry"`
Status string `json:"status,omitempty"`
Token string `json:"token,omitempty"`
Expiry time.Time `json:"expiry"`
LastRequestTime time.Time `json:"last_request"`
PollIntervalSeconds int `json:"poll_interval"`
}
// DeviceTokenList is a list of DeviceTokens.
@@ -733,18 +743,22 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
Name: t.DeviceCode,
Namespace: cli.namespace,
},
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
}
return req
}
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
return storage.DeviceToken{
DeviceCode: t.ObjectMeta.Name,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
DeviceCode: t.ObjectMeta.Name,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
}
}

View File

@@ -493,6 +493,17 @@ func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) {
return
}
func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) {
s.tx(func() {
var ok bool
if req, ok = s.deviceRequests[userCode]; !ok {
err = storage.ErrNotFound
return
}
})
return
}
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
s.tx(func() {
if _, ok := s.deviceTokens[t.DeviceCode]; ok {
@@ -514,3 +525,17 @@ func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, e
})
return
}
func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) {
s.tx(func() {
r, ok := s.deviceTokens[deviceCode]
if !ok {
err = storage.ErrNotFound
return
}
if r, err = updater(r); err == nil {
s.deviceTokens[deviceCode] = r
}
})
return
}

View File

@@ -888,12 +888,12 @@ func (c *conn) delete(table, field, id string) error {
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
_, err := c.Exec(`
insert into device_request (
user_code, device_code, client_id, scopes, pkce_verifier, expiry
user_code, device_code, client_id, scopes, expiry
)
values (
$1, $2, $3, $4, $5, $6
$1, $2, $3, $4, $5
);`,
d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.PkceVerifier, d.Expiry,
d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
@@ -907,12 +907,12 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
_, err := c.Exec(`
insert into device_token (
device_code, status, token, expiry
device_code, status, token, expiry, last_request, poll_interval
)
values (
$1, $2, $3, $4
$1, $2, $3, $4, $5, $6
);`,
t.DeviceCode, t.Status, t.Token, t.Expiry,
t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds,
)
if err != nil {
if c.alreadyExistsCheck(err) {
@@ -923,6 +923,28 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
return nil
}
func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
return getDeviceRequest(c, userCode)
}
func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
err = q.QueryRow(`
select
device_code, client_id, scopes, expiry
from device_request where user_code = $1;
`, userCode).Scan(
&d.DeviceCode, &d.ClientID, decoder(&d.Scopes), &d.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return d, storage.ErrNotFound
}
return d, fmt.Errorf("select device token: %v", err)
}
d.UserCode = userCode
return d, nil
}
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
return getDeviceToken(c, deviceCode)
}
@@ -930,10 +952,10 @@ func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
err = q.QueryRow(`
select
status, token, expiry
status, token, expiry, last_request, poll_interval
from device_token where device_code = $1;
`, deviceCode).Scan(
&a.Status, &a.Token, &a.Expiry,
&a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds,
)
if err != nil {
if err == sql.ErrNoRows {
@@ -944,3 +966,31 @@ func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err er
a.DeviceCode = deviceCode
return a, nil
}
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getDeviceToken(tx, deviceCode)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update device_token
set
status = $1,
token = $2,
last_request = $3,
poll_interval = $4
where
device_code = $5
`,
r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode,
)
if err != nil {
return fmt.Errorf("update device token: %v", err)
}
return nil
})
}

View File

@@ -236,7 +236,6 @@ var migrations = []migration{
device_code text not null,
client_id text not null,
scopes bytea not null, -- JSON array of strings
pkce_verifier text not null,
expiry timestamptz not null
);`,
`
@@ -244,7 +243,9 @@ var migrations = []migration{
device_code text not null primary key,
status text not null,
token text,
expiry timestamptz not null
expiry timestamptz not null,
last_request timestamptz not null,
poll_interval integer not null
);`,
},
},

View File

@@ -82,6 +82,7 @@ type Storage interface {
GetPassword(email string) (Password, error)
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
GetConnector(id string) (Connector, error)
GetDeviceRequest(userCode string) (DeviceRequest, error)
GetDeviceToken(deviceCode string) (DeviceToken, error)
ListClients() ([]Client, error)
@@ -119,6 +120,7 @@ type Storage interface {
UpdatePassword(email string, updater func(p Password) (Password, error)) error
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error
// GarbageCollect deletes all expired AuthCodes,AuthRequests, DeviceRequests, and DeviceTokens.
GarbageCollect(now time.Time) (GCResult, error)
@@ -392,15 +394,15 @@ type DeviceRequest struct {
ClientID string
//The scopes the device requests
Scopes []string
//PKCE Verification
PkceVerifier string
//The expire time
Expiry time.Time
}
type DeviceToken struct {
DeviceCode string
Status string
Token string
Expiry time.Time
DeviceCode string
Status string
Token string
Expiry time.Time
LastRequestTime time.Time
PollIntervalSeconds int
}