Generates/Stores the device request and returns the device and user codes.

Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
This commit is contained in:
Justin Slowik 2020-01-16 10:55:07 -05:00 committed by justin-slowik
parent 11fc8568cb
commit 6d343e059b
14 changed files with 690 additions and 8 deletions

View File

@ -0,0 +1,12 @@
apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
name: devicerequests.dex.coreos.com
spec:
group: dex.coreos.com
names:
kind: DeviceRequest
listKind: DeviceRequestList
plural: devicerequests
singular: devicerequest
version: v1

View File

@ -0,0 +1,12 @@
apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
name: devicetokens.dex.coreos.com
spec:
group: dex.coreos.com
names:
kind: DeviceToken
listKind: DeviceTokenList
plural: devicetokens
singular: devicetoken
version: v1

View File

@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"path"
@ -15,12 +16,11 @@ import (
"time"
oidc "github.com/coreos/go-oidc"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"
"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"
)
// newHealthChecker returns the healthz handler. The handler runs until the
@ -1415,3 +1415,112 @@ func usernamePrompt(conn connector.PasswordConnector) string {
}
return "Username"
}
type deviceCodeResponse struct {
//The unique device code for device authentication
DeviceCode string `json:"device_code"`
//The code the user will exchange via a browser and log in
UserCode string `json:"user_code"`
//The url to verify the user code.
VerificationURI string `json:"verification_uri"`
//The lifetime of the device code
ExpireTime int `json:"expires_in"`
//How often the device is allowed to poll to verify that the user login occurred
PollInterval int `json:"interval"`
}
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
//TODO replace with configurable values
expireIntervalSeconds := 300
requestsPerMinute := 5
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
message := "Could not parse Device Request body"
s.logger.Errorf("%s : %v", message, err)
respondWithError(w, message, err)
return
}
//Get the client id and scopes from the post
clientID := r.Form.Get("client_id")
scopes := r.Form["scope"]
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
//Make device code
deviceCode := storage.NewDeviceCode()
//make user code
userCode := storage.NewUserCode()
//make a pkce verification code
pkceCode := storage.NewID()
//Generate the expire time
expireTime := time.Now().Add(time.Second * time.Duration(expireIntervalSeconds))
//Store the Device Request
deviceReq := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: deviceCode,
ClientID: clientID,
Scopes: scopes,
PkceVerifier: pkceCode,
Expiry: expireTime,
}
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
message := fmt.Sprintf("Failed to store device request %v", err)
s.logger.Errorf(message)
respondWithError(w, message, err)
return
}
//Store the device token
deviceToken := storage.DeviceToken{
DeviceCode: deviceCode,
Status: "pending",
Token: "",
Expiry: expireTime,
}
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
message := fmt.Sprintf("Failed to store device token %v", err)
s.logger.Errorf(message)
respondWithError(w, message, err)
return
}
code := deviceCodeResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: path.Join(s.issuerURL.String(), "/device"),
ExpireTime: expireIntervalSeconds,
PollInterval: requestsPerMinute,
}
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
enc.Encode(code)
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}
func respondWithError(w io.Writer, errorMessage string, err error) {
resp := struct {
Error string `json:"error"`
ErrorMessage string `json:"message"`
}{
Error: err.Error(),
ErrorMessage: errorMessage,
}
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
enc.Encode(resp)
}

View File

@ -302,6 +302,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleWithCORS("/userinfo", s.handleUserInfo)
handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/device/code", s.handleDeviceCode)
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
// Strip the X-Remote-* headers to prevent security issues on
// misconfigured authproxy connector setups.
@ -450,7 +451,8 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura
if r, err := s.storage.GarbageCollect(now()); err != nil {
s.logger.Errorf("garbage collection failed: %v", err)
} else if r.AuthRequests > 0 || r.AuthCodes > 0 {
s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d, device requests =%d, device tokens=%d",
r.AuthRequests, r.AuthCodes, r.DeviceRequests, r.DeviceTokens)
}
}
}

View File

@ -49,6 +49,8 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
{"ConnectorCRUD", testConnectorCRUD},
{"GarbageCollection", testGC},
{"TimezoneSupport", testTimezones},
{"DeviceRequestCRUD", testDeviceRequestCRUD},
{"DeviceTokenCRUD", testDeviceTokenCRUD},
})
}
@ -834,6 +836,82 @@ func testGC(t *testing.T, s storage.Storage) {
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
d := storage.DeviceRequest{
UserCode: storage.NewUserCode(),
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
PkceVerifier: storage.NewID(),
Expiry: expiry,
}
if err := s.CreateDeviceRequest(d); err != nil {
t.Fatalf("failed creating device request: %v", err)
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.DeviceRequests != 0 {
t.Errorf("expected no device garbage collection results, got %#v", result)
}
}
//if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
// t.Errorf("expected to be able to get auth request after GC: %v", err)
//}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceRequests != 1 {
t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests)
}
//TODO add this code back once Getters are written for device requests
//if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
// t.Errorf("expected device request to be GC'd")
//} else if err != storage.ErrNotFound {
// t.Errorf("expected storage.ErrNotFound, got %v", err)
//}
dt := storage.DeviceToken{
DeviceCode: storage.NewID(),
Status: "pending",
Token: "foo",
Expiry: expiry,
}
if err := s.CreateDeviceToken(dt); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.DeviceTokens != 0 {
t.Errorf("expected no device token garbage collection results, got %#v", result)
}
}
//if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
// t.Errorf("expected to be able to get auth request after GC: %v", err)
//}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceTokens != 1 {
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
}
//TODO add this code back once Getters are written for device tokens
//if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
// t.Errorf("expected device request to be GC'd")
//} else if err != storage.ErrNotFound {
// t.Errorf("expected storage.ErrNotFound, got %v", err)
//}
}
// testTimezones tests that backends either fully support timezones or
@ -881,3 +959,44 @@ func testTimezones(t *testing.T, s storage.Storage) {
t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
}
}
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
d1 := storage.DeviceRequest{
UserCode: storage.NewUserCode(),
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
PkceVerifier: storage.NewID(),
Expiry: neverExpire,
}
if err := s.CreateDeviceRequest(d1); err != nil {
t.Fatalf("failed creating device request: %v", err)
}
// Attempt to create same DeviceRequest twice.
err := s.CreateDeviceRequest(d1)
mustBeErrAlreadyExists(t, "device request", err)
//No manual deletes for device requests, will be handled by garbage collection routines
//see testGC
}
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
d1 := storage.DeviceToken{
DeviceCode: storage.NewID(),
Status: "pending",
Token: storage.NewID(),
Expiry: neverExpire,
}
if err := s.CreateDeviceToken(d1); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
// Attempt to create same DeviceRequest twice.
err := s.CreateDeviceToken(d1)
mustBeErrAlreadyExists(t, "device token", err)
//TODO Add update / delete tests as functionality is put into main code
}

View File

@ -22,6 +22,8 @@ const (
offlineSessionPrefix = "offline_session/"
connectorPrefix = "connector/"
keysName = "openid-connect-keys"
deviceRequestPrefix = "device_req/"
deviceTokenPrefix = "device_token/"
// defaultStorageTimeout will be applied to all storage's operations.
defaultStorageTimeout = 5 * time.Second
@ -72,6 +74,36 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
result.AuthCodes++
}
}
deviceRequests, err := c.listDeviceRequests(ctx)
if err != nil {
return result, err
}
for _, deviceRequest := range deviceRequests {
if now.After(deviceRequest.Expiry) {
if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil {
c.logger.Errorf("failed to delete device request %v", err)
delErr = fmt.Errorf("failed to delete device request: %v", err)
}
result.DeviceRequests++
}
}
deviceTokens, err := c.listDeviceTokens(ctx)
if err != nil {
return result, err
}
for _, deviceToken := range deviceTokens {
if now.After(deviceToken.Expiry) {
if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil {
c.logger.Errorf("failed to delete device token %v", err)
delErr = fmt.Errorf("failed to delete device token: %v", err)
}
result.DeviceTokens++
}
}
return result, delErr
}
@ -531,3 +563,45 @@ func keyEmail(prefix, email string) string { return prefix + strings.ToLower(ema
func keySession(prefix, userID, connID string) string {
return prefix + strings.ToLower(userID+"|"+connID)
}
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
}
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
if err != nil {
return requests, err
}
for _, v := range res.Kvs {
var r DeviceRequest
if err = json.Unmarshal(v.Value, &r); err != nil {
return requests, err
}
requests = append(requests, r)
}
return requests, nil
}
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t))
}
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
if err != nil {
return deviceTokens, err
}
for _, v := range res.Kvs {
var dt DeviceToken
if err = json.Unmarshal(v.Value, &dt); err != nil {
return deviceTokens, err
}
deviceTokens = append(deviceTokens, dt)
}
return deviceTokens, nil
}

View File

@ -216,3 +216,41 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
}
return s
}
// DeviceRequest is a mirrored struct from storage with JSON struct tags
type DeviceRequest struct {
UserCode string `json:"user_code"`
DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"`
Scopes []string `json:"scopes"`
PkceVerifier string `json:"pkce_verifier"`
Expiry time.Time `json:"expiry"`
}
func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest {
return DeviceRequest{
UserCode: d.UserCode,
DeviceCode: d.DeviceCode,
ClientID: d.ClientID,
Scopes: d.Scopes,
PkceVerifier: d.PkceVerifier,
Expiry: d.Expiry,
}
}
// DeviceToken is a mirrored struct from storage with JSON struct tags
type DeviceToken struct {
DeviceCode string `json:"device_code"`
Status string `json:"status"`
Token string `json:"token"`
Expiry time.Time `json:"expiry"`
}
func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
return DeviceToken{
DeviceCode: t.DeviceCode,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
}
}

View File

@ -21,6 +21,8 @@ const (
kindPassword = "Password"
kindOfflineSessions = "OfflineSessions"
kindConnector = "Connector"
kindDeviceRequest = "DeviceRequest"
kindDeviceToken = "DeviceToken"
)
const (
@ -32,6 +34,8 @@ const (
resourcePassword = "passwords"
resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize.
resourceConnector = "connectors"
resourceDeviceRequest = "devicerequests"
resourceDeviceToken = "devicetokens"
)
// Config values for the Kubernetes storage type.
@ -593,5 +597,47 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e
result.AuthCodes++
}
}
var deviceRequests DeviceRequestList
if err := cli.list(resourceDeviceRequest, &deviceRequests); err != nil {
return result, fmt.Errorf("failed to list device requests: %v", err)
}
for _, deviceRequest := range deviceRequests.DeviceRequests {
if now.After(deviceRequest.Expiry) {
if err := cli.delete(resourceDeviceRequest, deviceRequest.ObjectMeta.Name); err != nil {
cli.logger.Errorf("failed to delete device request: %v", err)
delErr = fmt.Errorf("failed to delete device request: %v", err)
}
result.DeviceRequests++
}
}
var deviceTokens DeviceTokenList
if err := cli.list(resourceDeviceToken, &deviceTokens); err != nil {
return result, fmt.Errorf("failed to list device tokens: %v", err)
}
for _, deviceToken := range deviceTokens.DeviceTokens {
if now.After(deviceToken.Expiry) {
if err := cli.delete(resourceDeviceToken, deviceToken.ObjectMeta.Name); err != nil {
cli.logger.Errorf("failed to delete device token: %v", err)
delErr = fmt.Errorf("failed to delete device token: %v", err)
}
result.DeviceTokens++
}
}
if delErr != nil {
return result, delErr
}
return result, delErr
}
func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
}
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
}

View File

@ -85,6 +85,8 @@ func (s *StorageTestSuite) TestStorage() {
for _, resource := range []string{
resourceAuthCode,
resourceAuthRequest,
resourceDeviceRequest,
resourceDeviceToken,
resourceClient,
resourceRefreshToken,
resourceKeys,

View File

@ -143,6 +143,36 @@ var customResourceDefinitions = []k8sapi.CustomResourceDefinition{
},
},
},
{
ObjectMeta: k8sapi.ObjectMeta{
Name: "devicerequests.dex.coreos.com",
},
TypeMeta: crdMeta,
Spec: k8sapi.CustomResourceDefinitionSpec{
Group: apiGroup,
Version: "v1",
Names: k8sapi.CustomResourceDefinitionNames{
Plural: "devicerequests",
Singular: "devicerequest",
Kind: "DeviceRequest",
},
},
},
{
ObjectMeta: k8sapi.ObjectMeta{
Name: "devicetokens.dex.coreos.com",
},
TypeMeta: crdMeta,
Spec: k8sapi.CustomResourceDefinitionSpec{
Group: apiGroup,
Version: "v1",
Names: k8sapi.CustomResourceDefinitionNames{
Plural: "devicetokens",
Singular: "devicetoken",
Kind: "DeviceToken",
},
},
},
}
// There will only ever be a single keys resource. Maintain this by setting a
@ -635,3 +665,77 @@ type ConnectorList struct {
k8sapi.ListMeta `json:"metadata,omitempty"`
Connectors []Connector `json:"items"`
}
// DeviceRequest is a mirrored struct from storage with JSON struct tags and
// Kubernetes type metadata.
type DeviceRequest struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
DeviceCode string `json:"device_code,omitempty"`
CLientID string `json:"client_id,omitempty"`
Scopes []string `json:"scopes,omitempty"`
PkceVerifier string `json:"pkce_verifier,omitempty"`
Expiry time.Time `json:"expiry"`
}
// AuthRequestList is a list of AuthRequests.
type DeviceRequestList struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ListMeta `json:"metadata,omitempty"`
DeviceRequests []DeviceRequest `json:"items"`
}
func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceRequest {
req := DeviceRequest{
TypeMeta: k8sapi.TypeMeta{
Kind: kindDeviceRequest,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: strings.ToLower(a.UserCode),
Namespace: cli.namespace,
},
DeviceCode: a.DeviceCode,
CLientID: a.ClientID,
Scopes: a.Scopes,
PkceVerifier: a.PkceVerifier,
Expiry: a.Expiry,
}
return req
}
// DeviceToken is a mirrored struct from storage with JSON struct tags and
// Kubernetes type metadata.
type DeviceToken struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
Status string `json:"status,omitempty"`
Token string `json:"token,omitempty"`
Expiry time.Time `json:"expiry"`
}
// DeviceTokenList is a list of DeviceTokens.
type DeviceTokenList struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ListMeta `json:"metadata,omitempty"`
DeviceTokens []DeviceToken `json:"items"`
}
func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
req := DeviceToken{
TypeMeta: k8sapi.TypeMeta{
Kind: kindDeviceToken,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: t.DeviceCode,
Namespace: cli.namespace,
},
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
}
return req
}

View File

@ -20,6 +20,8 @@ func New(logger log.Logger) storage.Storage {
passwords: make(map[string]storage.Password),
offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
connectors: make(map[string]storage.Connector),
deviceRequests: make(map[string]storage.DeviceRequest),
deviceTokens: make(map[string]storage.DeviceToken),
logger: logger,
}
}
@ -46,6 +48,8 @@ type memStorage struct {
passwords map[string]storage.Password
offlineSessions map[offlineSessionID]storage.OfflineSessions
connectors map[string]storage.Connector
deviceRequests map[string]storage.DeviceRequest
deviceTokens map[string]storage.DeviceToken
keys storage.Keys
@ -79,6 +83,18 @@ func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err
result.AuthRequests++
}
}
for id, a := range s.deviceRequests {
if now.After(a.Expiry) {
delete(s.deviceRequests, id)
result.DeviceRequests++
}
}
for id, a := range s.deviceTokens {
if now.After(a.Expiry) {
delete(s.deviceTokens, id)
result.DeviceTokens++
}
}
})
return result, nil
}
@ -465,3 +481,25 @@ func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector
})
return
}
func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) {
s.tx(func() {
if _, ok := s.deviceRequests[d.UserCode]; ok {
err = storage.ErrAlreadyExists
} else {
s.deviceRequests[d.UserCode] = d
}
})
return
}
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
s.tx(func() {
if _, ok := s.deviceTokens[t.DeviceCode]; ok {
err = storage.ErrAlreadyExists
} else {
s.deviceTokens[t.DeviceCode] = t
}
})
return
}

View File

@ -100,6 +100,23 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
if n, err := r.RowsAffected(); err == nil {
result.AuthCodes = n
}
r, err = c.Exec(`delete from device_request where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc device_request: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.DeviceRequests = n
}
r, err = c.Exec(`delete from device_token where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc device_token: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.DeviceTokens = n
}
return
}
@ -867,3 +884,41 @@ func (c *conn) delete(table, field, id string) error {
}
return nil
}
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
_, err := c.Exec(`
insert into device_request (
user_code, device_code, client_id, scopes, pkce_verifier, expiry
)
values (
$1, $2, $3, $4, $5, $6
);`,
d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.PkceVerifier, d.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert device request: %v", err)
}
return nil
}
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
_, err := c.Exec(`
insert into device_token (
device_code, status, token, expiry
)
values (
$1, $2, $3, $4
);`,
t.DeviceCode, t.Status, t.Token, t.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert device token: %v", err)
}
return nil
}

View File

@ -229,4 +229,23 @@ var migrations = []migration{
},
flavor: &flavorMySQL,
},
{
stmts: []string{`
create table device_request (
user_code text not null primary key,
device_code text not null,
client_id text not null,
scopes bytea not null, -- JSON array of strings
pkce_verifier text not null,
expiry timestamptz not null
);`,
`
create table device_token (
device_code text not null primary key,
status text not null,
token text,
expiry timestamptz not null
);`,
},
},
}

View File

@ -5,6 +5,7 @@ import (
"encoding/base32"
"errors"
"io"
mrand "math/rand"
"strings"
"time"
@ -24,9 +25,18 @@ var (
// TODO(ericchiang): refactor ID creation onto the storage.
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
func NewDeviceCode() string {
return newSecureID(32)
}
// NewID returns a random string which can be used as an ID for objects.
func NewID() string {
buff := make([]byte, 16) // 128 bit random ID.
return newSecureID(16)
}
func newSecureID(len int) string {
buff := make([]byte, len) // 128 bit random ID.
if _, err := io.ReadFull(rand.Reader, buff); err != nil {
panic(err)
}
@ -36,8 +46,10 @@ func NewID() string {
// GCResult returns the number of objects deleted by garbage collection.
type GCResult struct {
AuthRequests int64
AuthCodes int64
AuthRequests int64
AuthCodes int64
DeviceRequests int64
DeviceTokens int64
}
// Storage is the storage interface used by the server. Implementations are
@ -54,6 +66,8 @@ type Storage interface {
CreatePassword(p Password) error
CreateOfflineSessions(s OfflineSessions) error
CreateConnector(c Connector) error
CreateDeviceRequest(d DeviceRequest) error
CreateDeviceToken(d DeviceToken) error
// TODO(ericchiang): return (T, bool, error) so we can indicate not found
// requests that way instead of using ErrNotFound.
@ -102,7 +116,7 @@ type Storage interface {
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
// GarbageCollect deletes all expired AuthCodes,AuthRequests, DeviceRequests, and DeviceTokens.
GarbageCollect(now time.Time) (GCResult, error)
}
@ -342,3 +356,41 @@ type Keys struct {
// For caching purposes, implementations MUST NOT update keys before this time.
NextRotation time.Time
}
func NewUserCode() string {
mrand.Seed(time.Now().UnixNano())
return randomString(4) + "-" + randomString(4)
}
func randomString(n int) string {
var letter = []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
b := make([]rune, n)
for i := range b {
b[i] = letter[mrand.Intn(len(letter))]
}
return string(b)
}
//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user
//authenticates using their user code or the expiry time passes.
type DeviceRequest struct {
//The code the user will enter in a browser
UserCode string
//The unique device code for device authentication
DeviceCode string
//The client ID the code is for
ClientID string
//The scopes the device requests
Scopes []string
//PKCE Verification
PkceVerifier string
//The expire time
Expiry time.Time
}
type DeviceToken struct {
DeviceCode string
Status string
Token string
Expiry time.Time
}