Merge pull request #1706 from justin-slowik/device_flow

Implementing the OAuth2 Device Authorization Grant
This commit is contained in:
Joel Speed 2020-08-28 11:35:46 +01:00 committed by GitHub
commit 336c73c0a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 2529 additions and 322 deletions

View File

@ -279,6 +279,9 @@ type Expiry struct {
// AuthRequests defines the duration of time for which the AuthRequests will be valid. // AuthRequests defines the duration of time for which the AuthRequests will be valid.
AuthRequests string `json:"authRequests"` AuthRequests string `json:"authRequests"`
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
DeviceRequests string `json:"deviceRequests"`
} }
// Logger holds configuration required to customize logging for dex. // Logger holds configuration required to customize logging for dex.

View File

@ -119,6 +119,7 @@ expiry:
signingKeys: "7h" signingKeys: "7h"
idTokens: "25h" idTokens: "25h"
authRequests: "25h" authRequests: "25h"
deviceRequests: "10m"
logger: logger:
level: "debug" level: "debug"
@ -200,6 +201,7 @@ logger:
SigningKeys: "7h", SigningKeys: "7h",
IDTokens: "25h", IDTokens: "25h",
AuthRequests: "25h", AuthRequests: "25h",
DeviceRequests: "10m",
}, },
Logger: Logger{ Logger: Logger{
Level: "debug", Level: "debug",

View File

@ -269,7 +269,14 @@ func serve(cmd *cobra.Command, args []string) error {
logger.Infof("config auth requests valid for: %v", authRequests) logger.Infof("config auth requests valid for: %v", authRequests)
serverConfig.AuthRequestsValidFor = authRequests serverConfig.AuthRequestsValidFor = authRequests
} }
if c.Expiry.DeviceRequests != "" {
deviceRequests, err := time.ParseDuration(c.Expiry.DeviceRequests)
if err != nil {
return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err)
}
logger.Infof("config device requests valid for: %v", deviceRequests)
serverConfig.DeviceRequestsValidFor = deviceRequests
}
serv, err := server.NewServer(context.Background(), serverConfig) serv, err := server.NewServer(context.Background(), serverConfig)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize server: %v", err) return fmt.Errorf("failed to initialize server: %v", err)

View File

@ -64,6 +64,7 @@ telemetry:
# Uncomment this block to enable configuration for the expiration time durations. # Uncomment this block to enable configuration for the expiration time durations.
# expiry: # expiry:
# deviceRequests: "5m"
# signingKeys: "6h" # signingKeys: "6h"
# idTokens: "24h" # idTokens: "24h"
@ -95,7 +96,11 @@ staticClients:
- 'http://127.0.0.1:5555/callback' - 'http://127.0.0.1:5555/callback'
name: 'Example App' name: 'Example App'
secret: ZXhhbXBsZS1hcHAtc2VjcmV0 secret: ZXhhbXBsZS1hcHAtc2VjcmV0
# - id: example-device-client
# redirectURIs:
# - /device/callback
# name: 'Static Client for Device Flow'
# public: true
connectors: connectors:
- type: mockCallback - type: mockCallback
id: mock id: mock

View File

@ -0,0 +1,12 @@
apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
name: devicerequests.dex.coreos.com
spec:
group: dex.coreos.com
names:
kind: DeviceRequest
listKind: DeviceRequestList
plural: devicerequests
singular: devicerequest
version: v1

View File

@ -0,0 +1,12 @@
apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
name: devicetokens.dex.coreos.com
spec:
group: dex.coreos.com
names:
kind: DeviceToken
listKind: DeviceTokenList
plural: devicetokens
singular: devicetoken
version: v1

View File

@ -0,0 +1,390 @@
package server
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/dexidp/dex/storage"
)
type deviceCodeResponse struct {
//The unique device code for device authentication
DeviceCode string `json:"device_code"`
//The code the user will exchange via a browser and log in
UserCode string `json:"user_code"`
//The url to verify the user code.
VerificationURI string `json:"verification_uri"`
//The verification uri with the user code appended for pre-filling form
VerificationURIComplete string `json:"verification_uri_complete"`
//The lifetime of the device code
ExpireTime int `json:"expires_in"`
//How often the device is allowed to poll to verify that the user login occurred
PollInterval int `json:"interval"`
}
func (s *Server) getDeviceVerificationURI() string {
return path.Join(s.issuerURL.Path, "/device/auth/verify_code")
}
func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
// Grab the parameter(s) from the query.
// If "user_code" is set, pre-populate the user code text field.
// If "invalid" is set, set the invalidAttempt boolean, which will display a message to the user that they
// attempted to redeem an invalid or expired user code.
userCode := r.URL.Query().Get("user_code")
invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid"))
if err != nil {
invalidAttempt = false
}
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil {
s.logger.Errorf("Server template error: %v", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
pollIntervalSeconds := 5
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
s.logger.Errorf("Could not parse Device Request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
return
}
//Get the client id and scopes from the post
clientID := r.Form.Get("client_id")
clientSecret := r.Form.Get("client_secret")
scopes := strings.Fields(r.Form.Get("scope"))
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
//Make device code
deviceCode := storage.NewDeviceCode()
//make user code
userCode, err := storage.NewUserCode()
if err != nil {
s.logger.Errorf("Error generating user code: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
}
//Generate the expire time
expireTime := time.Now().Add(s.deviceRequestsValidFor)
//Store the Device Request
deviceReq := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: deviceCode,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: scopes,
Expiry: expireTime,
}
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
s.logger.Errorf("Failed to store device request; %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
//Store the device token
deviceToken := storage.DeviceToken{
DeviceCode: deviceCode,
Status: deviceTokenPending,
Expiry: expireTime,
LastRequestTime: s.now(),
PollIntervalSeconds: 0,
}
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
s.logger.Errorf("Failed to store device token %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
u, err := url.Parse(s.issuerURL.String())
if err != nil {
s.logger.Errorf("Could not parse issuer URL %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
u.Path = path.Join(u.Path, "device")
vURI := u.String()
q := u.Query()
q.Set("user_code", userCode)
u.RawQuery = q.Encode()
vURIComplete := u.String()
code := deviceCodeResponse{
DeviceCode: deviceCode,
UserCode: userCode,
VerificationURI: vURI,
VerificationURIComplete: vURIComplete,
ExpireTime: int(s.deviceRequestsValidFor.Seconds()),
PollInterval: pollIntervalSeconds,
}
enc := json.NewEncoder(w)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
enc.Encode(code)
default:
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
}
}
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
s.logger.Warnf("Could not parse Device Token Request body: %v", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return
}
deviceCode := r.Form.Get("device_code")
if deviceCode == "" {
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
return
}
grantType := r.PostFormValue("grant_type")
if grantType != grantTypeDeviceCode {
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
return
}
now := s.now()
//Grab the device token, check validity
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
if err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err)
}
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
return
} else if now.After(deviceToken.Expiry) {
s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest)
return
}
//Rate Limiting check
slowDown := false
pollInterval := deviceToken.PollIntervalSeconds
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
if now.Before(minRequestTime) {
slowDown = true
//Continually increase the poll interval until the user waits the proper time
pollInterval += 5
} else {
pollInterval = 5
}
switch deviceToken.Status {
case deviceTokenPending:
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
old.PollIntervalSeconds = pollInterval
old.LastRequestTime = now
return old, nil
}
// Update device token last request time in storage
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
s.logger.Errorf("failed to update device token: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return
}
if slowDown {
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
} else {
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
}
case deviceTokenComplete:
w.Write([]byte(deviceToken.Token))
}
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
userCode := r.FormValue("state")
code := r.FormValue("code")
if userCode == "" || code == "" {
s.renderError(r, w, http.StatusBadRequest, "Request was missing parameters")
return
}
// Authorization redirect callback from OAuth2 auth flow.
if errMsg := r.FormValue("error"); errMsg != "" {
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
return
}
authCode, err := s.storage.GetAuthCode(code)
if err != nil || s.now().After(authCode.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get auth code: %v", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired auth code.")
return
}
//Grab the device request from storage
deviceReq, err := s.storage.GetDeviceRequest(userCode)
if err != nil || s.now().After(deviceReq.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired user code.")
return
}
client, err := s.storage.GetClient(deviceReq.ClientID)
if err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get client: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} else {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
}
return
}
if client.Secret != deviceReq.ClientSecret {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return
}
resp, err := s.exchangeAuthCode(w, authCode, client)
if err != nil {
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
return
}
//Grab the device token from storage
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
if err != nil || s.now().After(old.Expiry) {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get device token: %v", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, errCode, "Invalid or expired device code.")
return
}
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
if old.Status == deviceTokenComplete {
return old, errors.New("device token already complete")
}
respStr, err := json.MarshalIndent(resp, "", " ")
if err != nil {
s.logger.Errorf("failed to marshal device token response: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "")
return old, err
}
old.Token = string(respStr)
old.Status = deviceTokenComplete
return old, nil
}
// Update refresh token in the storage, store the token and mark as complete
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
s.logger.Errorf("failed to update device token: %v", err)
s.renderError(r, w, http.StatusBadRequest, "")
return
}
if err := s.templates.deviceSuccess(r, w, client.Name); err != nil {
s.logger.Errorf("Server template error: %v", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
default:
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
return
}
}
func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
s.logger.Warnf("Could not parse user code verification request body : %v", err)
s.renderError(r, w, http.StatusBadRequest, "")
return
}
userCode := r.Form.Get("user_code")
if userCode == "" {
s.renderError(r, w, http.StatusBadRequest, "No user code received")
return
}
userCode = strings.ToUpper(userCode)
//Find the user code in the available requests
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
if err != nil || s.now().After(deviceRequest.Expiry) {
if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get device request: %v", err)
}
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil {
s.logger.Errorf("Server template error: %v", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
return
}
//Redirect to Dex Auth Endpoint
authURL := path.Join(s.issuerURL.Path, "/auth")
u, err := url.Parse(authURL)
if err != nil {
s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.")
return
}
q := u.Query()
q.Set("client_id", deviceRequest.ClientID)
q.Set("client_secret", deviceRequest.ClientSecret)
q.Set("state", deviceRequest.UserCode)
q.Set("response_type", "code")
q.Set("redirect_uri", "/device/callback")
q.Set("scope", strings.Join(deviceRequest.Scopes, " "))
u.RawQuery = q.Encode()
http.Redirect(w, r, u.String(), http.StatusFound)
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
}
}

View File

@ -0,0 +1,678 @@
package server
import (
"bytes"
"context"
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"path"
"strings"
"testing"
"time"
"github.com/dexidp/dex/storage"
)
func TestDeviceVerificationURI(t *testing.T) {
t0 := time.Now()
now := func() time.Time { return t0 }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
c.Now = now
})
defer httpServer.Close()
u, err := url.Parse(s.issuerURL.String())
if err != nil {
t.Fatalf("Could not parse issuer URL %v", err)
}
u.Path = path.Join(u.Path, "/device/auth/verify_code")
uri := s.getDeviceVerificationURI()
if uri != u.Path {
t.Errorf("Invalid verification URI. Expected %v got %v", u.Path, uri)
}
}
func TestHandleDeviceCode(t *testing.T) {
t0 := time.Now()
now := func() time.Time { return t0 }
tests := []struct {
testName string
clientID string
requestType string
scopes []string
expectedResponseCode int
expectedServerResponse string
}{
{
testName: "New Code",
clientID: "test",
requestType: "POST",
scopes: []string{"openid", "profile", "email"},
expectedResponseCode: http.StatusOK,
},
{
testName: "Invalid request Type (GET)",
clientID: "test",
requestType: "GET",
scopes: []string{"openid", "profile", "email"},
expectedResponseCode: http.StatusBadRequest,
},
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
c.Now = now
})
defer httpServer.Close()
u, err := url.Parse(s.issuerURL.String())
if err != nil {
t.Fatalf("Could not parse issuer URL %v", err)
}
u.Path = path.Join(u.Path, "device/code")
data := url.Values{}
data.Set("client_id", tc.clientID)
for _, scope := range tc.scopes {
data.Add("scope", scope)
}
req, _ := http.NewRequest(tc.requestType, u.String(), bytes.NewBufferString(data.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
rr := httptest.NewRecorder()
s.ServeHTTP(rr, req)
if rr.Code != tc.expectedResponseCode {
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
}
body, err := ioutil.ReadAll(rr.Body)
if err != nil {
t.Errorf("Could read token response %v", err)
}
if tc.expectedResponseCode == http.StatusOK {
var resp deviceCodeResponse
if err := json.Unmarshal(body, &resp); err != nil {
t.Errorf("Unexpected Device Code Response Format %v", string(body))
}
}
})
}
}
func TestDeviceCallback(t *testing.T) {
t0 := time.Now()
now := func() time.Time { return t0 }
type formValues struct {
state string
code string
error string
}
// Base "Control" test values
baseFormValues := formValues{
state: "XXXX-XXXX",
code: "somecode",
}
baseAuthCode := storage.AuthCode{
ID: "somecode",
ClientID: "testclient",
RedirectURI: deviceCallbackURI,
Nonce: "",
Scopes: []string{"openid", "profile", "email"},
ConnectorID: "mock",
ConnectorData: nil,
Claims: storage.Claims{},
Expiry: now().Add(5 * time.Minute),
}
baseDeviceRequest := storage.DeviceRequest{
UserCode: "XXXX-XXXX",
DeviceCode: "devicecode",
ClientID: "testclient",
ClientSecret: "",
Scopes: []string{"openid", "profile", "email"},
Expiry: now().Add(5 * time.Minute),
}
baseDeviceToken := storage.DeviceToken{
DeviceCode: "devicecode",
Status: deviceTokenPending,
Token: "",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
}
tests := []struct {
testName string
expectedResponseCode int
values formValues
testAuthCode storage.AuthCode
testDeviceRequest storage.DeviceRequest
testDeviceToken storage.DeviceToken
}{
{
testName: "Missing State",
values: formValues{
state: "",
code: "somecode",
error: "",
},
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Missing Code",
values: formValues{
state: "XXXX-XXXX",
code: "",
error: "",
},
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Error During Authorization",
values: formValues{
state: "XXXX-XXXX",
code: "somecode",
error: "Error Condition",
},
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Expired Auth Code",
values: baseFormValues,
testAuthCode: storage.AuthCode{
ID: "somecode",
ClientID: "testclient",
RedirectURI: deviceCallbackURI,
Nonce: "",
Scopes: []string{"openid", "profile", "email"},
ConnectorID: "pic",
ConnectorData: nil,
Claims: storage.Claims{},
Expiry: now().Add(-5 * time.Minute),
},
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Invalid Auth Code",
values: baseFormValues,
testAuthCode: storage.AuthCode{
ID: "somecode",
ClientID: "testclient",
RedirectURI: deviceCallbackURI,
Nonce: "",
Scopes: []string{"openid", "profile", "email"},
ConnectorID: "pic",
ConnectorData: nil,
Claims: storage.Claims{},
Expiry: now().Add(5 * time.Minute),
},
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Expired Device Request",
values: baseFormValues,
testAuthCode: baseAuthCode,
testDeviceRequest: storage.DeviceRequest{
UserCode: "XXXX-XXXX",
DeviceCode: "devicecode",
ClientID: "testclient",
Scopes: []string{"openid", "profile", "email"},
Expiry: now().Add(-5 * time.Minute),
},
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Non-Existent User Code",
values: baseFormValues,
testAuthCode: baseAuthCode,
testDeviceRequest: storage.DeviceRequest{
UserCode: "ZZZZ-ZZZZ",
DeviceCode: "devicecode",
Scopes: []string{"openid", "profile", "email"},
Expiry: now().Add(5 * time.Minute),
},
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Bad Device Request Client",
values: baseFormValues,
testAuthCode: baseAuthCode,
testDeviceRequest: storage.DeviceRequest{
UserCode: "XXXX-XXXX",
DeviceCode: "devicecode",
Scopes: []string{"openid", "profile", "email"},
Expiry: now().Add(5 * time.Minute),
},
expectedResponseCode: http.StatusUnauthorized,
},
{
testName: "Bad Device Request Secret",
values: baseFormValues,
testAuthCode: baseAuthCode,
testDeviceRequest: storage.DeviceRequest{
UserCode: "XXXX-XXXX",
DeviceCode: "devicecode",
ClientSecret: "foobar",
Scopes: []string{"openid", "profile", "email"},
Expiry: now().Add(5 * time.Minute),
},
expectedResponseCode: http.StatusUnauthorized,
},
{
testName: "Expired Device Token",
values: baseFormValues,
testAuthCode: baseAuthCode,
testDeviceRequest: baseDeviceRequest,
testDeviceToken: storage.DeviceToken{
DeviceCode: "devicecode",
Status: deviceTokenPending,
Token: "",
Expiry: now().Add(-5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
},
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Device Code Already Redeemed",
values: baseFormValues,
testAuthCode: baseAuthCode,
testDeviceRequest: baseDeviceRequest,
testDeviceToken: storage.DeviceToken{
DeviceCode: "devicecode",
Status: deviceTokenComplete,
Token: "",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
},
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Successful Exchange",
values: baseFormValues,
testAuthCode: baseAuthCode,
testDeviceRequest: baseDeviceRequest,
testDeviceToken: baseDeviceToken,
expectedResponseCode: http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
//c.Issuer = c.Issuer + "/non-root-path"
c.Now = now
})
defer httpServer.Close()
if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil {
t.Fatalf("failed to create auth code: %v", err)
}
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
t.Fatalf("failed to create device request: %v", err)
}
if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
t.Fatalf("failed to create device token: %v", err)
}
client := storage.Client{
ID: "testclient",
Secret: "",
RedirectURIs: []string{deviceCallbackURI},
}
if err := s.storage.CreateClient(client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
u, err := url.Parse(s.issuerURL.String())
if err != nil {
t.Fatalf("Could not parse issuer URL %v", err)
}
u.Path = path.Join(u.Path, "device/callback")
q := u.Query()
q.Set("state", tc.values.state)
q.Set("code", tc.values.code)
q.Set("error", tc.values.error)
u.RawQuery = q.Encode()
req, _ := http.NewRequest("GET", u.String(), nil)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
rr := httptest.NewRecorder()
s.ServeHTTP(rr, req)
if rr.Code != tc.expectedResponseCode {
t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code)
}
})
}
}
func TestDeviceTokenResponse(t *testing.T) {
t0 := time.Now()
now := func() time.Time { return t0 }
baseDeviceRequest := storage.DeviceRequest{
UserCode: "ABCD-WXYZ",
DeviceCode: "foo",
ClientID: "testclient",
Scopes: []string{"openid", "profile", "offline_access"},
Expiry: now().Add(5 * time.Minute),
}
tests := []struct {
testName string
testDeviceRequest storage.DeviceRequest
testDeviceToken storage.DeviceToken
testGrantType string
testDeviceCode string
expectedServerResponse string
expectedResponseCode int
}{
{
testName: "Valid but pending token",
testDeviceRequest: baseDeviceRequest,
testDeviceToken: storage.DeviceToken{
DeviceCode: "f00bar",
Status: deviceTokenPending,
Token: "",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
},
testDeviceCode: "f00bar",
expectedServerResponse: deviceTokenPending,
expectedResponseCode: http.StatusUnauthorized,
},
{
testName: "Invalid Grant Type",
testDeviceRequest: baseDeviceRequest,
testDeviceToken: storage.DeviceToken{
DeviceCode: "f00bar",
Status: deviceTokenPending,
Token: "",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
},
testDeviceCode: "f00bar",
testGrantType: grantTypeAuthorizationCode,
expectedServerResponse: errInvalidGrant,
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Test Slow Down State",
testDeviceRequest: baseDeviceRequest,
testDeviceToken: storage.DeviceToken{
DeviceCode: "f00bar",
Status: deviceTokenPending,
Token: "",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: now(),
PollIntervalSeconds: 10,
},
testDeviceCode: "f00bar",
expectedServerResponse: deviceTokenSlowDown,
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Test Expired Device Token",
testDeviceRequest: baseDeviceRequest,
testDeviceToken: storage.DeviceToken{
DeviceCode: "f00bar",
Status: deviceTokenPending,
Token: "",
Expiry: now().Add(-5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
},
testDeviceCode: "f00bar",
expectedServerResponse: deviceTokenExpired,
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Test Non-existent Device Code",
testDeviceRequest: baseDeviceRequest,
testDeviceToken: storage.DeviceToken{
DeviceCode: "foo",
Status: deviceTokenPending,
Token: "",
Expiry: now().Add(-5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
},
testDeviceCode: "bar",
expectedServerResponse: errInvalidRequest,
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Empty Device Code in Request",
testDeviceRequest: baseDeviceRequest,
testDeviceToken: storage.DeviceToken{
DeviceCode: "bar",
Status: deviceTokenPending,
Token: "",
Expiry: now().Add(-5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
},
testDeviceCode: "",
expectedServerResponse: errInvalidRequest,
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Claim validated token from Device Code",
testDeviceRequest: baseDeviceRequest,
testDeviceToken: storage.DeviceToken{
DeviceCode: "foo",
Status: deviceTokenComplete,
Token: "{\"access_token\": \"foobar\"}",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
},
testDeviceCode: "foo",
expectedServerResponse: "{\"access_token\": \"foobar\"}",
expectedResponseCode: http.StatusOK,
},
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
c.Now = now
})
defer httpServer.Close()
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
t.Fatalf("Failed to store device token %v", err)
}
if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
t.Fatalf("Failed to store device token %v", err)
}
u, err := url.Parse(s.issuerURL.String())
if err != nil {
t.Fatalf("Could not parse issuer URL %v", err)
}
u.Path = path.Join(u.Path, "device/token")
data := url.Values{}
grantType := grantTypeDeviceCode
if tc.testGrantType != "" {
grantType = tc.testGrantType
}
data.Set("grant_type", grantType)
data.Set("device_code", tc.testDeviceCode)
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
rr := httptest.NewRecorder()
s.ServeHTTP(rr, req)
if rr.Code != tc.expectedResponseCode {
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
}
body, err := ioutil.ReadAll(rr.Body)
if err != nil {
t.Errorf("Could read token response %v", err)
}
if tc.expectedResponseCode == http.StatusBadRequest || tc.expectedResponseCode == http.StatusUnauthorized {
expectJsonErrorResponse(tc.testName, body, tc.expectedServerResponse, t)
} else if string(body) != tc.expectedServerResponse {
t.Errorf("Unexpected Server Response. Expected %v got %v", tc.expectedServerResponse, string(body))
}
})
}
}
func expectJsonErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) {
jsonMap := make(map[string]interface{})
err := json.Unmarshal(body, &jsonMap)
if err != nil {
t.Errorf("Unexpected error unmarshalling response: %v", err)
}
if jsonMap["error"] != expectedError {
t.Errorf("Test Case %s expected error %v, received %v", testCase, expectedError, jsonMap["error"])
}
}
func TestVerifyCodeResponse(t *testing.T) {
t0 := time.Now()
now := func() time.Time { return t0 }
tests := []struct {
testName string
testDeviceRequest storage.DeviceRequest
userCode string
expectedResponseCode int
expectedRedirectPath string
}{
{
testName: "Unknown user code",
testDeviceRequest: storage.DeviceRequest{
UserCode: "ABCD-WXYZ",
DeviceCode: "f00bar",
ClientID: "testclient",
Scopes: []string{"openid", "profile", "offline_access"},
Expiry: now().Add(5 * time.Minute),
},
userCode: "CODE-TEST",
expectedResponseCode: http.StatusBadRequest,
expectedRedirectPath: "",
},
{
testName: "Expired user code",
testDeviceRequest: storage.DeviceRequest{
UserCode: "ABCD-WXYZ",
DeviceCode: "f00bar",
ClientID: "testclient",
Scopes: []string{"openid", "profile", "offline_access"},
Expiry: now().Add(-5 * time.Minute),
},
userCode: "ABCD-WXYZ",
expectedResponseCode: http.StatusBadRequest,
expectedRedirectPath: "",
},
{
testName: "No user code",
testDeviceRequest: storage.DeviceRequest{
UserCode: "ABCD-WXYZ",
DeviceCode: "f00bar",
ClientID: "testclient",
Scopes: []string{"openid", "profile", "offline_access"},
Expiry: now().Add(-5 * time.Minute),
},
userCode: "",
expectedResponseCode: http.StatusBadRequest,
expectedRedirectPath: "",
},
{
testName: "Valid user code, expect redirect to auth endpoint",
testDeviceRequest: storage.DeviceRequest{
UserCode: "ABCD-WXYZ",
DeviceCode: "f00bar",
ClientID: "testclient",
Scopes: []string{"openid", "profile", "offline_access"},
Expiry: now().Add(5 * time.Minute),
},
userCode: "ABCD-WXYZ",
expectedResponseCode: http.StatusFound,
expectedRedirectPath: "/auth",
},
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
c.Now = now
})
defer httpServer.Close()
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
t.Fatalf("Failed to store device token %v", err)
}
u, err := url.Parse(s.issuerURL.String())
if err != nil {
t.Fatalf("Could not parse issuer URL %v", err)
}
u.Path = path.Join(u.Path, "device/auth/verify_code")
data := url.Values{}
data.Set("user_code", tc.userCode)
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
rr := httptest.NewRecorder()
s.ServeHTTP(rr, req)
if rr.Code != tc.expectedResponseCode {
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
}
u, err = url.Parse(s.issuerURL.String())
if err != nil {
t.Errorf("Could not parse issuer URL %v", err)
}
u.Path = path.Join(u.Path, tc.expectedRedirectPath)
location := rr.Header().Get("Location")
if rr.Code == http.StatusFound && !strings.HasPrefix(location, u.Path) {
t.Errorf("Invalid Redirect. Expected %v got %v", u.Path, location)
}
})
}
}

View File

@ -153,6 +153,8 @@ type discovery struct {
Token string `json:"token_endpoint"` Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"` Keys string `json:"jwks_uri"`
UserInfo string `json:"userinfo_endpoint"` UserInfo string `json:"userinfo_endpoint"`
DeviceEndpoint string `json:"device_authorization_endpoint"`
GrantTypes []string `json:"grant_types_supported"`
ResponseTypes []string `json:"response_types_supported"` ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"` Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
@ -168,7 +170,9 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
Token: s.absURL("/token"), Token: s.absURL("/token"),
Keys: s.absURL("/keys"), Keys: s.absURL("/keys"),
UserInfo: s.absURL("/userinfo"), UserInfo: s.absURL("/userinfo"),
DeviceEndpoint: s.absURL("/device/code"),
Subjects: []string{"public"}, Subjects: []string{"public"},
GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
IDTokenAlgs: []string{string(jose.RS256)}, IDTokenAlgs: []string{string(jose.RS256)},
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
AuthMethods: []string{"client_secret_basic"}, AuthMethods: []string{"client_secret_basic"},
@ -784,24 +788,33 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return return
} }
tokenResponse, err := s.exchangeAuthCode(w, authCode, client)
if err != nil {
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
s.writeAccessToken(w, tokenResponse)
}
func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenReponse, error) {
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create new access token: %v", err) s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return nil, err
} }
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID) idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create ID token: %v", err) s.logger.Errorf("failed to create ID token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return nil, err
} }
if err := s.storage.DeleteAuthCode(code); err != nil { if err := s.storage.DeleteAuthCode(authCode.ID); err != nil {
s.logger.Errorf("failed to delete auth code: %v", err) s.logger.Errorf("failed to delete auth code: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return nil, err
} }
reqRefresh := func() bool { reqRefresh := func() bool {
@ -848,13 +861,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
if refreshToken, err = internal.Marshal(token); err != nil { if refreshToken, err = internal.Marshal(token); err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err) s.logger.Errorf("failed to marshal refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return nil, err
} }
if err := s.storage.CreateRefresh(refresh); err != nil { if err := s.storage.CreateRefresh(refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err) s.logger.Errorf("failed to create refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return nil, err
} }
// deleteToken determines if we need to delete the newly created refresh token // deleteToken determines if we need to delete the newly created refresh token
@ -885,7 +898,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.logger.Errorf("failed to get offline session: %v", err) s.logger.Errorf("failed to get offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return nil, err
} }
offlineSessions := storage.OfflineSessions{ offlineSessions := storage.OfflineSessions{
UserID: refresh.Claims.UserID, UserID: refresh.Claims.UserID,
@ -900,7 +913,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.logger.Errorf("failed to create offline session: %v", err) s.logger.Errorf("failed to create offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return nil, err
} }
} else { } else {
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
@ -909,7 +922,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.logger.Errorf("failed to delete refresh token: %v", err) s.logger.Errorf("failed to delete refresh token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return nil, err
} }
} }
@ -921,11 +934,11 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
s.logger.Errorf("failed to update offline session: %v", err) s.logger.Errorf("failed to update offline session: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return nil, err
} }
} }
} }
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
} }
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6 // handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
@ -1121,7 +1134,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return return
} }
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry) resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
s.writeAccessToken(w, resp)
} }
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
@ -1368,23 +1382,29 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
} }
} }
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry) resp := s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry)
s.writeAccessToken(w, resp)
} }
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) { type accessTokenReponse struct {
resp := struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"` RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token"` IDToken string `json:"id_token"`
}{ }
func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenReponse {
return &accessTokenReponse{
accessToken, accessToken,
"bearer", "bearer",
int(expiry.Sub(s.now()).Seconds()), int(expiry.Sub(s.now()).Seconds()),
refreshToken, refreshToken,
idToken, idToken,
} }
}
func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenReponse) {
data, err := json.Marshal(resp) data, err := json.Marshal(resp)
if err != nil { if err != nil {
s.logger.Errorf("failed to marshal access token response: %v", err) s.logger.Errorf("failed to marshal access token response: %v", err)

View File

@ -114,6 +114,10 @@ const (
scopeCrossClientPrefix = "audience:server:client_id:" scopeCrossClientPrefix = "audience:server:client_id:"
) )
const (
deviceCallbackURI = "/device/callback"
)
const ( const (
redirectURIOOB = "urn:ietf:wg:oauth:2.0:oob" redirectURIOOB = "urn:ietf:wg:oauth:2.0:oob"
) )
@ -122,6 +126,7 @@ const (
grantTypeAuthorizationCode = "authorization_code" grantTypeAuthorizationCode = "authorization_code"
grantTypeRefreshToken = "refresh_token" grantTypeRefreshToken = "refresh_token"
grantTypePassword = "password" grantTypePassword = "password"
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
) )
const ( const (
@ -130,6 +135,13 @@ const (
responseTypeIDToken = "id_token" // ID Token in url fragment responseTypeIDToken = "id_token" // ID Token in url fragment
) )
const (
deviceTokenPending = "authorization_pending"
deviceTokenComplete = "complete"
deviceTokenSlowDown = "slow_down"
deviceTokenExpired = "expired_token"
)
func parseScopes(scopes []string) connector.Scopes { func parseScopes(scopes []string) connector.Scopes {
var s connector.Scopes var s connector.Scopes
for _, scope := range scopes { for _, scope := range scopes {
@ -425,6 +437,9 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
return nil, &authErr{"", "", errInvalidRequest, description} return nil, &authErr{"", "", errInvalidRequest, description}
} }
if redirectURI == deviceCallbackURI && client.Public {
redirectURI = s.issuerURL.Path + deviceCallbackURI
}
// From here on out, we want to redirect back to the client with an error. // From here on out, we want to redirect back to the client with an error.
newErr := func(typ, format string, a ...interface{}) *authErr { newErr := func(typ, format string, a ...interface{}) *authErr {
@ -566,7 +581,7 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
return false return false
} }
if redirectURI == redirectURIOOB { if redirectURI == redirectURIOOB || redirectURI == deviceCallbackURI {
return true return true
} }

View File

@ -78,6 +78,7 @@ type Config struct {
RotateKeysAfter time.Duration // Defaults to 6 hours. RotateKeysAfter time.Duration // Defaults to 6 hours.
IDTokensValidFor time.Duration // Defaults to 24 hours IDTokensValidFor time.Duration // Defaults to 24 hours
AuthRequestsValidFor time.Duration // Defaults to 24 hours AuthRequestsValidFor time.Duration // Defaults to 24 hours
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
// If set, the server will use this connector to handle password grants // If set, the server will use this connector to handle password grants
PasswordConnector string PasswordConnector string
@ -158,6 +159,7 @@ type Server struct {
idTokensValidFor time.Duration idTokensValidFor time.Duration
authRequestsValidFor time.Duration authRequestsValidFor time.Duration
deviceRequestsValidFor time.Duration
logger log.Logger logger log.Logger
} }
@ -219,6 +221,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
supportedResponseTypes: supported, supportedResponseTypes: supported,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour), authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
skipApproval: c.SkipApprovalScreen, skipApproval: c.SkipApprovalScreen,
alwaysShowLogin: c.AlwaysShowLoginScreen, alwaysShowLogin: c.AlwaysShowLoginScreen,
now: now, now: now,
@ -302,6 +305,11 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleWithCORS("/userinfo", s.handleUserInfo) handleWithCORS("/userinfo", s.handleUserInfo)
handleFunc("/auth", s.handleAuthorization) handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/device", s.handleDeviceExchange)
handleFunc("/device/auth/verify_code", s.verifyUserCode)
handleFunc("/device/code", s.handleDeviceCode)
handleFunc("/device/token", s.handleDeviceToken)
handleFunc(deviceCallbackURI, s.handleDeviceCallback)
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
// Strip the X-Remote-* headers to prevent security issues on // Strip the X-Remote-* headers to prevent security issues on
// misconfigured authproxy connector setups. // misconfigured authproxy connector setups.
@ -450,7 +458,8 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura
if r, err := s.storage.GarbageCollect(now()); err != nil { if r, err := s.storage.GarbageCollect(now()); err != nil {
s.logger.Errorf("garbage collection failed: %v", err) s.logger.Errorf("garbage collection failed: %v", err)
} else if r.AuthRequests > 0 || r.AuthCodes > 0 { } else if r.AuthRequests > 0 || r.AuthCodes > 0 {
s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes) s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d, device requests =%d, device tokens=%d",
r.AuthRequests, r.AuthCodes, r.DeviceRequests, r.DeviceTokens)
} }
} }
} }

View File

@ -8,11 +8,13 @@ import (
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"os" "os"
"path"
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
@ -203,41 +205,36 @@ func TestDiscovery(t *testing.T) {
} }
} }
// TestOAuth2CodeFlow runs integration tests against a test server. The tests stand up a server type oauth2Tests struct {
// which requires no interaction to login, logs in through a test client, then passes the client clientID string
// and returned token to the test. tests []test
func TestOAuth2CodeFlow(t *testing.T) { }
clientID := "testclient"
clientSecret := "testclientsecret" type test struct {
name string
// If specified these set of scopes will be used during the test case.
scopes []string
// handleToken provides the OAuth2 token response for the integration test.
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token, *mock.Callback) error
}
func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) oauth2Tests {
requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"} requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
t0 := time.Now()
// Always have the time function used by the server return the same time so
// we can predict expected values of "expires_in" fields exactly.
now := func() time.Time { return t0 }
// Used later when configuring test servers to set how long id_tokens will be valid for. // Used later when configuring test servers to set how long id_tokens will be valid for.
// //
// The actual value of 30s is completely arbitrary. We just need to set a value // The actual value of 30s is completely arbitrary. We just need to set a value
// so tests can compute the expected "expires_in" field. // so tests can compute the expected "expires_in" field.
idTokensValidFor := time.Second * 30 idTokensValidFor := time.Second * 30
// Connector used by the tests.
var conn *mock.Callback
oidcConfig := &oidc.Config{SkipClientIDCheck: true} oidcConfig := &oidc.Config{SkipClientIDCheck: true}
tests := []struct { return oauth2Tests{
name string clientID: clientID,
// If specified these set of scopes will be used during the test case. tests: []test{
scopes []string
// handleToken provides the OAuth2 token response for the integration test.
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error
}{
{ {
name: "verify ID Token", name: "verify ID Token",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
idToken, ok := token.Extra("id_token").(string) idToken, ok := token.Extra("id_token").(string)
if !ok { if !ok {
return fmt.Errorf("no id token found") return fmt.Errorf("no id token found")
@ -250,7 +247,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
}, },
{ {
name: "fetch userinfo", name: "fetch userinfo",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
ui, err := p.UserInfo(ctx, config.TokenSource(ctx, token)) ui, err := p.UserInfo(ctx, config.TokenSource(ctx, token))
if err != nil { if err != nil {
return fmt.Errorf("failed to fetch userinfo: %v", err) return fmt.Errorf("failed to fetch userinfo: %v", err)
@ -263,7 +260,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
}, },
{ {
name: "verify id token and oauth2 token expiry", name: "verify id token and oauth2 token expiry",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
expectedExpiry := now().Add(idTokensValidFor) expectedExpiry := now().Add(idTokensValidFor)
timeEq := func(t1, t2 time.Time, within time.Duration) bool { timeEq := func(t1, t2 time.Time, within time.Duration) bool {
@ -290,7 +287,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
}, },
{ {
name: "verify at_hash", name: "verify at_hash",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
rawIDToken, ok := token.Extra("id_token").(string) rawIDToken, ok := token.Extra("id_token").(string)
if !ok { if !ok {
return fmt.Errorf("no id token found") return fmt.Errorf("no id token found")
@ -322,7 +319,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
}, },
{ {
name: "refresh token", name: "refresh token",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
// have to use time.Now because the OAuth2 package uses it. // have to use time.Now because the OAuth2 package uses it.
token.Expiry = time.Now().Add(time.Second * -10) token.Expiry = time.Now().Add(time.Second * -10)
if token.Valid() { if token.Valid() {
@ -345,7 +342,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
}, },
{ {
name: "refresh with explicit scopes", name: "refresh with explicit scopes",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
v := url.Values{} v := url.Values{}
v.Add("client_id", clientID) v.Add("client_id", clientID)
v.Add("client_secret", clientSecret) v.Add("client_secret", clientSecret)
@ -369,7 +366,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
}, },
{ {
name: "refresh with extra spaces", name: "refresh with extra spaces",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
v := url.Values{} v := url.Values{}
v.Add("client_id", clientID) v.Add("client_id", clientID)
v.Add("client_secret", clientSecret) v.Add("client_secret", clientSecret)
@ -398,7 +395,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
{ {
name: "refresh with unauthorized scopes", name: "refresh with unauthorized scopes",
scopes: []string{"openid", "email"}, scopes: []string{"openid", "email"},
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
v := url.Values{} v := url.Values{}
v.Add("client_id", clientID) v.Add("client_id", clientID)
v.Add("client_secret", clientSecret) v.Add("client_secret", clientSecret)
@ -425,7 +422,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
// This test ensures that the connector.RefreshConnector interface is being // This test ensures that the connector.RefreshConnector interface is being
// used when clients request a refresh token. // used when clients request a refresh token.
name: "refresh with identity changes", name: "refresh with identity changes",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
// have to use time.Now because the OAuth2 package uses it. // have to use time.Now because the OAuth2 package uses it.
token.Expiry = time.Now().Add(time.Second * -10) token.Expiry = time.Now().Add(time.Second * -10)
if token.Valid() { if token.Valid() {
@ -472,9 +469,35 @@ func TestOAuth2CodeFlow(t *testing.T) {
return nil return nil
}, },
}, },
},
} }
}
for _, tc := range tests { // TestOAuth2CodeFlow runs integration tests against a test server. The tests stand up a server
// which requires no interaction to login, logs in through a test client, then passes the client
// and returned token to the test.
func TestOAuth2CodeFlow(t *testing.T) {
clientID := "testclient"
clientSecret := "testclientsecret"
requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
t0 := time.Now()
// Always have the time function used by the server return the same time so
// we can predict expected values of "expires_in" fields exactly.
now := func() time.Time { return t0 }
// Used later when configuring test servers to set how long id_tokens will be valid for.
//
// The actual value of 30s is completely arbitrary. We just need to set a value
// so tests can compute the expected "expires_in" field.
idTokensValidFor := time.Second * 30
// Connector used by the tests.
var conn *mock.Callback
tests := makeOAuth2Tests(clientID, clientSecret, now)
for _, tc := range tests.tests {
func() { func() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -540,7 +563,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
t.Errorf("failed to exchange code for token: %v", err) t.Errorf("failed to exchange code for token: %v", err)
return return
} }
err = tc.handleToken(ctx, p, oauth2Config, token) err = tc.handleToken(ctx, p, oauth2Config, token, conn)
if err != nil { if err != nil {
t.Errorf("%s: %v", tc.name, err) t.Errorf("%s: %v", tc.name, err)
} }
@ -1253,3 +1276,157 @@ func TestRefreshTokenFlow(t *testing.T) {
t.Errorf("Token refreshed with invalid refresh token, error expected.") t.Errorf("Token refreshed with invalid refresh token, error expected.")
} }
} }
// TestOAuth2DeviceFlow runs device flow integration tests against a test server
func TestOAuth2DeviceFlow(t *testing.T) {
clientID := "testclient"
clientSecret := ""
requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
t0 := time.Now()
// Always have the time function used by the server return the same time so
// we can predict expected values of "expires_in" fields exactly.
now := func() time.Time { return t0 }
// Connector used by the tests.
var conn *mock.Callback
idTokensValidFor := time.Second * 30
for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path"
c.Now = now
c.IDTokensValidFor = idTokensValidFor
})
defer httpServer.Close()
mockConn := s.connectors["mock"]
conn = mockConn.Connector.(*mock.Callback)
p, err := oidc.NewProvider(ctx, httpServer.URL)
if err != nil {
t.Fatalf("failed to get provider: %v", err)
}
//Add the Clients to the test server
client := storage.Client{
ID: clientID,
RedirectURIs: []string{deviceCallbackURI},
Public: true,
}
if err := s.storage.CreateClient(client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
//Grab the issuer that we'll reuse for the different endpoints to hit
issuer, err := url.Parse(s.issuerURL.String())
if err != nil {
t.Errorf("Could not parse issuer URL %v", err)
}
//Send a new Device Request
codeURL, _ := url.Parse(issuer.String())
codeURL.Path = path.Join(codeURL.Path, "device/code")
data := url.Values{}
data.Set("client_id", clientID)
data.Add("scope", strings.Join(requestedScopes, " "))
resp, err := http.PostForm(codeURL.String(), data)
if err != nil {
t.Errorf("Could not request device code: %v", err)
}
defer resp.Body.Close()
responseBody, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("Could read device code response %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
}
//Parse the code response
var deviceCode deviceCodeResponse
if err := json.Unmarshal(responseBody, &deviceCode); err != nil {
t.Errorf("Unexpected Device Code Response Format %v", string(responseBody))
}
//Mock the user hitting the verification URI and posting the form
verifyURL, _ := url.Parse(issuer.String())
verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code")
urlData := url.Values{}
urlData.Set("user_code", deviceCode.UserCode)
resp, err = http.PostForm(verifyURL.String(), urlData)
if err != nil {
t.Errorf("Error Posting Form: %v", err)
}
defer resp.Body.Close()
responseBody, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("Could read verification response %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
}
//Hit the Token Endpoint, and try and get an access token
tokenURL, _ := url.Parse(issuer.String())
tokenURL.Path = path.Join(tokenURL.Path, "/device/token")
v := url.Values{}
v.Add("grant_type", grantTypeDeviceCode)
v.Add("device_code", deviceCode.DeviceCode)
resp, err = http.PostForm(tokenURL.String(), v)
if err != nil {
t.Errorf("Could not request device token: %v", err)
}
defer resp.Body.Close()
responseBody, err = ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("Could read device token response %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
}
//Parse the response
var tokenRes accessTokenReponse
if err := json.Unmarshal(responseBody, &tokenRes); err != nil {
t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody))
}
token := &oauth2.Token{
AccessToken: tokenRes.AccessToken,
TokenType: tokenRes.TokenType,
RefreshToken: tokenRes.RefreshToken,
}
raw := make(map[string]interface{})
json.Unmarshal(responseBody, &raw) // no error checks for optional fields
token = token.WithExtra(raw)
if secs := tokenRes.ExpiresIn; secs > 0 {
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
}
//Run token tests to validate info is correct
// Create the OAuth2 config.
oauth2Config := &oauth2.Config{
ClientID: client.ID,
ClientSecret: client.Secret,
Endpoint: p.Endpoint(),
Scopes: requestedScopes,
RedirectURL: deviceCallbackURI,
}
if len(tc.scopes) != 0 {
oauth2Config.Scopes = tc.scopes
}
err = tc.handleToken(ctx, p, oauth2Config, token, conn)
if err != nil {
t.Errorf("%s: %v", tc.name, err)
}
}()
}
}

View File

@ -20,6 +20,8 @@ const (
tmplPassword = "password.html" tmplPassword = "password.html"
tmplOOB = "oob.html" tmplOOB = "oob.html"
tmplError = "error.html" tmplError = "error.html"
tmplDevice = "device.html"
tmplDeviceSuccess = "device_success.html"
) )
var requiredTmpls = []string{ var requiredTmpls = []string{
@ -28,6 +30,8 @@ var requiredTmpls = []string{
tmplPassword, tmplPassword,
tmplOOB, tmplOOB,
tmplError, tmplError,
tmplDevice,
tmplDeviceSuccess,
} }
type templates struct { type templates struct {
@ -36,6 +40,8 @@ type templates struct {
passwordTmpl *template.Template passwordTmpl *template.Template
oobTmpl *template.Template oobTmpl *template.Template
errorTmpl *template.Template errorTmpl *template.Template
deviceTmpl *template.Template
deviceSuccessTmpl *template.Template
} }
type webConfig struct { type webConfig struct {
@ -157,6 +163,8 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
passwordTmpl: tmpls.Lookup(tmplPassword), passwordTmpl: tmpls.Lookup(tmplPassword),
oobTmpl: tmpls.Lookup(tmplOOB), oobTmpl: tmpls.Lookup(tmplOOB),
errorTmpl: tmpls.Lookup(tmplError), errorTmpl: tmpls.Lookup(tmplError),
deviceTmpl: tmpls.Lookup(tmplDevice),
deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess),
}, nil }, nil
} }
@ -242,6 +250,27 @@ func (n byName) Len() int { return len(n) }
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name } func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] } func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error {
if lastWasInvalid {
w.WriteHeader(http.StatusBadRequest)
}
data := struct {
PostURL string
UserCode string
Invalid bool
ReqPath string
}{postURL, userCode, lastWasInvalid, r.URL.Path}
return renderTemplate(w, t.deviceTmpl, data)
}
func (t *templates) deviceSuccess(r *http.Request, w http.ResponseWriter, clientName string) error {
data := struct {
ClientName string
ReqPath string
}{clientName, r.URL.Path}
return renderTemplate(w, t.deviceSuccessTmpl, data)
}
func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error { func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
sort.Sort(byName(connectors)) sort.Sort(byName(connectors))
data := struct { data := struct {

View File

@ -49,6 +49,8 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
{"ConnectorCRUD", testConnectorCRUD}, {"ConnectorCRUD", testConnectorCRUD},
{"GarbageCollection", testGC}, {"GarbageCollection", testGC},
{"TimezoneSupport", testTimezones}, {"TimezoneSupport", testTimezones},
{"DeviceRequestCRUD", testDeviceRequestCRUD},
{"DeviceTokenCRUD", testDeviceTokenCRUD},
}) })
} }
@ -834,6 +836,87 @@ func testGC(t *testing.T, s storage.Storage) {
} else if err != storage.ErrNotFound { } else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err) t.Errorf("expected storage.ErrNotFound, got %v", err)
} }
userCode, err := storage.NewUserCode()
if err != nil {
t.Errorf("Unexpected Error: %v", err)
}
d := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
ClientSecret: "secret1",
Scopes: []string{"openid", "email"},
Expiry: expiry,
}
if err := s.CreateDeviceRequest(d); err != nil {
t.Fatalf("failed creating device request: %v", err)
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.DeviceRequests != 0 {
t.Errorf("expected no device garbage collection results, got %#v", result)
}
}
if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
t.Errorf("expected to be able to get auth request after GC: %v", err)
}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceRequests != 1 {
t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests)
}
if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
t.Errorf("expected device request to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
dt := storage.DeviceToken{
DeviceCode: storage.NewID(),
Status: "pending",
Token: "foo",
Expiry: expiry,
LastRequestTime: time.Now(),
PollIntervalSeconds: 0,
}
if err := s.CreateDeviceToken(dt); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else {
if result.DeviceTokens != 0 {
t.Errorf("expected no device token garbage collection results, got %#v", result)
}
}
if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil {
t.Errorf("expected to be able to get device token after GC: %v", err)
}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceTokens != 1 {
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
}
if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil {
t.Errorf("expected device token to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
} }
// testTimezones tests that backends either fully support timezones or // testTimezones tests that backends either fully support timezones or
@ -881,3 +964,72 @@ func testTimezones(t *testing.T, s storage.Storage) {
t.Fatalf("expected expiry %v got %v", wantTime, gotTime) t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
} }
} }
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
userCode, err := storage.NewUserCode()
if err != nil {
panic(err)
}
d1 := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
ClientSecret: "secret1",
Scopes: []string{"openid", "email"},
Expiry: neverExpire,
}
if err := s.CreateDeviceRequest(d1); err != nil {
t.Fatalf("failed creating device request: %v", err)
}
// Attempt to create same DeviceRequest twice.
err = s.CreateDeviceRequest(d1)
mustBeErrAlreadyExists(t, "device request", err)
//No manual deletes for device requests, will be handled by garbage collection routines
//see testGC
}
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
//Create a Token
d1 := storage.DeviceToken{
DeviceCode: storage.NewID(),
Status: "pending",
Token: storage.NewID(),
Expiry: neverExpire,
LastRequestTime: time.Now(),
PollIntervalSeconds: 0,
}
if err := s.CreateDeviceToken(d1); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
// Attempt to create same Device Token twice.
err := s.CreateDeviceToken(d1)
mustBeErrAlreadyExists(t, "device token", err)
//Update the device token, simulate a redemption
if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) {
old.Token = "token data"
old.Status = "complete"
return old, nil
}); err != nil {
t.Fatalf("failed to update device token: %v", err)
}
//Retrieve the device token
got, err := s.GetDeviceToken(d1.DeviceCode)
if err != nil {
t.Fatalf("failed to get device token: %v", err)
}
//Validate expected result set
if got.Status != "complete" {
t.Fatalf("update failed, wanted token status=%v got %v", "complete", got.Status)
}
if got.Token != "token data" {
t.Fatalf("update failed, wanted token %v got %v", "token data", got.Token)
}
}

View File

@ -22,6 +22,8 @@ const (
offlineSessionPrefix = "offline_session/" offlineSessionPrefix = "offline_session/"
connectorPrefix = "connector/" connectorPrefix = "connector/"
keysName = "openid-connect-keys" keysName = "openid-connect-keys"
deviceRequestPrefix = "device_req/"
deviceTokenPrefix = "device_token/"
// defaultStorageTimeout will be applied to all storage's operations. // defaultStorageTimeout will be applied to all storage's operations.
defaultStorageTimeout = 5 * time.Second defaultStorageTimeout = 5 * time.Second
@ -72,6 +74,36 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
result.AuthCodes++ result.AuthCodes++
} }
} }
deviceRequests, err := c.listDeviceRequests(ctx)
if err != nil {
return result, err
}
for _, deviceRequest := range deviceRequests {
if now.After(deviceRequest.Expiry) {
if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil {
c.logger.Errorf("failed to delete device request %v", err)
delErr = fmt.Errorf("failed to delete device request: %v", err)
}
result.DeviceRequests++
}
}
deviceTokens, err := c.listDeviceTokens(ctx)
if err != nil {
return result, err
}
for _, deviceToken := range deviceTokens {
if now.After(deviceToken.Expiry) {
if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil {
c.logger.Errorf("failed to delete device token %v", err)
delErr = fmt.Errorf("failed to delete device token: %v", err)
}
result.DeviceTokens++
}
}
return result, delErr return result, delErr
} }
@ -531,3 +563,77 @@ func keyEmail(prefix, email string) string { return prefix + strings.ToLower(ema
func keySession(prefix, userID, connID string) string { func keySession(prefix, userID, connID string) string {
return prefix + strings.ToLower(userID+"|"+connID) return prefix + strings.ToLower(userID+"|"+connID)
} }
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
}
func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r)
return r, err
}
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
if err != nil {
return requests, err
}
for _, v := range res.Kvs {
var r DeviceRequest
if err = json.Unmarshal(v.Value, &r); err != nil {
return requests, err
}
requests = append(requests, r)
}
return requests, nil
}
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t))
}
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t)
return t, err
}
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
if err != nil {
return deviceTokens, err
}
for _, v := range res.Kvs {
var dt DeviceToken
if err = json.Unmarshal(v.Value, &dt); err != nil {
return deviceTokens, err
}
deviceTokens = append(deviceTokens, dt)
}
return deviceTokens, nil
}
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) {
var current DeviceToken
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(toStorageDeviceToken(current))
if err != nil {
return nil, err
}
return json.Marshal(fromStorageDeviceToken(updated))
})
}

View File

@ -44,6 +44,8 @@ func cleanDB(c *conn) error {
passwordPrefix, passwordPrefix,
offlineSessionPrefix, offlineSessionPrefix,
connectorPrefix, connectorPrefix,
deviceRequestPrefix,
deviceTokenPrefix,
} { } {
_, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix()) _, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix())
if err != nil { if err != nil {

View File

@ -216,3 +216,56 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
} }
return s return s
} }
// DeviceRequest is a mirrored struct from storage with JSON struct tags
type DeviceRequest struct {
UserCode string `json:"user_code"`
DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
Scopes []string `json:"scopes"`
Expiry time.Time `json:"expiry"`
}
func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest {
return DeviceRequest{
UserCode: d.UserCode,
DeviceCode: d.DeviceCode,
ClientID: d.ClientID,
ClientSecret: d.ClientSecret,
Scopes: d.Scopes,
Expiry: d.Expiry,
}
}
// DeviceToken is a mirrored struct from storage with JSON struct tags
type DeviceToken struct {
DeviceCode string `json:"device_code"`
Status string `json:"status"`
Token string `json:"token"`
Expiry time.Time `json:"expiry"`
LastRequestTime time.Time `json:"last_request"`
PollIntervalSeconds int `json:"poll_interval"`
}
func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
return DeviceToken{
DeviceCode: t.DeviceCode,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
}
}
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
return storage.DeviceToken{
DeviceCode: t.DeviceCode,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
}
}

View File

@ -21,6 +21,8 @@ const (
kindPassword = "Password" kindPassword = "Password"
kindOfflineSessions = "OfflineSessions" kindOfflineSessions = "OfflineSessions"
kindConnector = "Connector" kindConnector = "Connector"
kindDeviceRequest = "DeviceRequest"
kindDeviceToken = "DeviceToken"
) )
const ( const (
@ -32,6 +34,8 @@ const (
resourcePassword = "passwords" resourcePassword = "passwords"
resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize. resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize.
resourceConnector = "connectors" resourceConnector = "connectors"
resourceDeviceRequest = "devicerequests"
resourceDeviceToken = "devicetokens"
) )
// Config values for the Kubernetes storage type. // Config values for the Kubernetes storage type.
@ -593,5 +597,84 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e
result.AuthCodes++ result.AuthCodes++
} }
} }
var deviceRequests DeviceRequestList
if err := cli.list(resourceDeviceRequest, &deviceRequests); err != nil {
return result, fmt.Errorf("failed to list device requests: %v", err)
}
for _, deviceRequest := range deviceRequests.DeviceRequests {
if now.After(deviceRequest.Expiry) {
if err := cli.delete(resourceDeviceRequest, deviceRequest.ObjectMeta.Name); err != nil {
cli.logger.Errorf("failed to delete device request: %v", err)
delErr = fmt.Errorf("failed to delete device request: %v", err)
}
result.DeviceRequests++
}
}
var deviceTokens DeviceTokenList
if err := cli.list(resourceDeviceToken, &deviceTokens); err != nil {
return result, fmt.Errorf("failed to list device tokens: %v", err)
}
for _, deviceToken := range deviceTokens.DeviceTokens {
if now.After(deviceToken.Expiry) {
if err := cli.delete(resourceDeviceToken, deviceToken.ObjectMeta.Name); err != nil {
cli.logger.Errorf("failed to delete device token: %v", err)
delErr = fmt.Errorf("failed to delete device token: %v", err)
}
result.DeviceTokens++
}
}
if delErr != nil {
return result, delErr
}
return result, delErr return result, delErr
} }
func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
}
func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
var req DeviceRequest
if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil {
return storage.DeviceRequest{}, err
}
return toStorageDeviceRequest(req), nil
}
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
}
func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
var token DeviceToken
if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil {
return storage.DeviceToken{}, err
}
return toStorageDeviceToken(token), nil
}
func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) {
err = cli.get(resourceDeviceToken, deviceCode, &t)
return
}
func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
r, err := cli.getDeviceToken(deviceCode)
if err != nil {
return err
}
updated, err := updater(toStorageDeviceToken(r))
if err != nil {
return err
}
updated.DeviceCode = deviceCode
newToken := cli.fromStorageDeviceToken(updated)
newToken.ObjectMeta = r.ObjectMeta
return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken)
}

View File

@ -85,6 +85,8 @@ func (s *StorageTestSuite) TestStorage() {
for _, resource := range []string{ for _, resource := range []string{
resourceAuthCode, resourceAuthCode,
resourceAuthRequest, resourceAuthRequest,
resourceDeviceRequest,
resourceDeviceToken,
resourceClient, resourceClient,
resourceRefreshToken, resourceRefreshToken,
resourceKeys, resourceKeys,

View File

@ -143,6 +143,36 @@ var customResourceDefinitions = []k8sapi.CustomResourceDefinition{
}, },
}, },
}, },
{
ObjectMeta: k8sapi.ObjectMeta{
Name: "devicerequests.dex.coreos.com",
},
TypeMeta: crdMeta,
Spec: k8sapi.CustomResourceDefinitionSpec{
Group: apiGroup,
Version: "v1",
Names: k8sapi.CustomResourceDefinitionNames{
Plural: "devicerequests",
Singular: "devicerequest",
Kind: "DeviceRequest",
},
},
},
{
ObjectMeta: k8sapi.ObjectMeta{
Name: "devicetokens.dex.coreos.com",
},
TypeMeta: crdMeta,
Spec: k8sapi.CustomResourceDefinitionSpec{
Group: apiGroup,
Version: "v1",
Names: k8sapi.CustomResourceDefinitionNames{
Plural: "devicetokens",
Singular: "devicetoken",
Kind: "DeviceToken",
},
},
},
} }
// There will only ever be a single keys resource. Maintain this by setting a // There will only ever be a single keys resource. Maintain this by setting a
@ -635,3 +665,103 @@ type ConnectorList struct {
k8sapi.ListMeta `json:"metadata,omitempty"` k8sapi.ListMeta `json:"metadata,omitempty"`
Connectors []Connector `json:"items"` Connectors []Connector `json:"items"`
} }
// DeviceRequest is a mirrored struct from storage with JSON struct tags and
// Kubernetes type metadata.
type DeviceRequest struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
DeviceCode string `json:"device_code,omitempty"`
ClientID string `json:"client_id,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
Scopes []string `json:"scopes,omitempty"`
Expiry time.Time `json:"expiry"`
}
// AuthRequestList is a list of AuthRequests.
type DeviceRequestList struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ListMeta `json:"metadata,omitempty"`
DeviceRequests []DeviceRequest `json:"items"`
}
func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceRequest {
req := DeviceRequest{
TypeMeta: k8sapi.TypeMeta{
Kind: kindDeviceRequest,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: strings.ToLower(a.UserCode),
Namespace: cli.namespace,
},
DeviceCode: a.DeviceCode,
ClientID: a.ClientID,
ClientSecret: a.ClientSecret,
Scopes: a.Scopes,
Expiry: a.Expiry,
}
return req
}
func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest {
return storage.DeviceRequest{
UserCode: strings.ToUpper(req.ObjectMeta.Name),
DeviceCode: req.DeviceCode,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
Scopes: req.Scopes,
Expiry: req.Expiry,
}
}
// DeviceToken is a mirrored struct from storage with JSON struct tags and
// Kubernetes type metadata.
type DeviceToken struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
Status string `json:"status,omitempty"`
Token string `json:"token,omitempty"`
Expiry time.Time `json:"expiry"`
LastRequestTime time.Time `json:"last_request"`
PollIntervalSeconds int `json:"poll_interval"`
}
// DeviceTokenList is a list of DeviceTokens.
type DeviceTokenList struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ListMeta `json:"metadata,omitempty"`
DeviceTokens []DeviceToken `json:"items"`
}
func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
req := DeviceToken{
TypeMeta: k8sapi.TypeMeta{
Kind: kindDeviceToken,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: t.DeviceCode,
Namespace: cli.namespace,
},
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
}
return req
}
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
return storage.DeviceToken{
DeviceCode: t.ObjectMeta.Name,
Status: t.Status,
Token: t.Token,
Expiry: t.Expiry,
LastRequestTime: t.LastRequestTime,
PollIntervalSeconds: t.PollIntervalSeconds,
}
}

View File

@ -20,6 +20,8 @@ func New(logger log.Logger) storage.Storage {
passwords: make(map[string]storage.Password), passwords: make(map[string]storage.Password),
offlineSessions: make(map[offlineSessionID]storage.OfflineSessions), offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
connectors: make(map[string]storage.Connector), connectors: make(map[string]storage.Connector),
deviceRequests: make(map[string]storage.DeviceRequest),
deviceTokens: make(map[string]storage.DeviceToken),
logger: logger, logger: logger,
} }
} }
@ -46,6 +48,8 @@ type memStorage struct {
passwords map[string]storage.Password passwords map[string]storage.Password
offlineSessions map[offlineSessionID]storage.OfflineSessions offlineSessions map[offlineSessionID]storage.OfflineSessions
connectors map[string]storage.Connector connectors map[string]storage.Connector
deviceRequests map[string]storage.DeviceRequest
deviceTokens map[string]storage.DeviceToken
keys storage.Keys keys storage.Keys
@ -79,6 +83,18 @@ func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err
result.AuthRequests++ result.AuthRequests++
} }
} }
for id, a := range s.deviceRequests {
if now.After(a.Expiry) {
delete(s.deviceRequests, id)
result.DeviceRequests++
}
}
for id, a := range s.deviceTokens {
if now.After(a.Expiry) {
delete(s.deviceTokens, id)
result.DeviceTokens++
}
}
}) })
return result, nil return result, nil
} }
@ -465,3 +481,61 @@ func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector
}) })
return return
} }
func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) {
s.tx(func() {
if _, ok := s.deviceRequests[d.UserCode]; ok {
err = storage.ErrAlreadyExists
} else {
s.deviceRequests[d.UserCode] = d
}
})
return
}
func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) {
s.tx(func() {
var ok bool
if req, ok = s.deviceRequests[userCode]; !ok {
err = storage.ErrNotFound
return
}
})
return
}
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
s.tx(func() {
if _, ok := s.deviceTokens[t.DeviceCode]; ok {
err = storage.ErrAlreadyExists
} else {
s.deviceTokens[t.DeviceCode] = t
}
})
return
}
func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
s.tx(func() {
var ok bool
if t, ok = s.deviceTokens[deviceCode]; !ok {
err = storage.ErrNotFound
return
}
})
return
}
func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) {
s.tx(func() {
r, ok := s.deviceTokens[deviceCode]
if !ok {
err = storage.ErrNotFound
return
}
if r, err = updater(r); err == nil {
s.deviceTokens[deviceCode] = r
}
})
return
}

View File

@ -100,6 +100,23 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
if n, err := r.RowsAffected(); err == nil { if n, err := r.RowsAffected(); err == nil {
result.AuthCodes = n result.AuthCodes = n
} }
r, err = c.Exec(`delete from device_request where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc device_request: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.DeviceRequests = n
}
r, err = c.Exec(`delete from device_token where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc device_token: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.DeviceTokens = n
}
return return
} }
@ -867,3 +884,113 @@ func (c *conn) delete(table, field, id string) error {
} }
return nil return nil
} }
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
_, err := c.Exec(`
insert into device_request (
user_code, device_code, client_id, client_secret, scopes, expiry
)
values (
$1, $2, $3, $4, $5, $6
);`,
d.UserCode, d.DeviceCode, d.ClientID, d.ClientSecret, encoder(d.Scopes), d.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert device request: %v", err)
}
return nil
}
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
_, err := c.Exec(`
insert into device_token (
device_code, status, token, expiry, last_request, poll_interval
)
values (
$1, $2, $3, $4, $5, $6
);`,
t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert device token: %v", err)
}
return nil
}
func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
return getDeviceRequest(c, userCode)
}
func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
err = q.QueryRow(`
select
device_code, client_id, client_secret, scopes, expiry
from device_request where user_code = $1;
`, userCode).Scan(
&d.DeviceCode, &d.ClientID, &d.ClientSecret, decoder(&d.Scopes), &d.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
return d, storage.ErrNotFound
}
return d, fmt.Errorf("select device token: %v", err)
}
d.UserCode = userCode
return d, nil
}
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
return getDeviceToken(c, deviceCode)
}
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
err = q.QueryRow(`
select
status, token, expiry, last_request, poll_interval
from device_token where device_code = $1;
`, deviceCode).Scan(
&a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds,
)
if err != nil {
if err == sql.ErrNoRows {
return a, storage.ErrNotFound
}
return a, fmt.Errorf("select device token: %v", err)
}
a.DeviceCode = deviceCode
return a, nil
}
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
return c.ExecTx(func(tx *trans) error {
r, err := getDeviceToken(tx, deviceCode)
if err != nil {
return err
}
if r, err = updater(r); err != nil {
return err
}
_, err = tx.Exec(`
update device_token
set
status = $1,
token = $2,
last_request = $3,
poll_interval = $4
where
device_code = $5
`,
r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode,
)
if err != nil {
return fmt.Errorf("update device token: %v", err)
}
return nil
})
}

View File

@ -229,4 +229,25 @@ var migrations = []migration{
}, },
flavor: &flavorMySQL, flavor: &flavorMySQL,
}, },
{
stmts: []string{`
create table device_request (
user_code text not null primary key,
device_code text not null,
client_id text not null,
client_secret text ,
scopes bytea not null, -- JSON array of strings
expiry timestamptz not null
);`,
`
create table device_token (
device_code text not null primary key,
status text not null,
token bytea,
expiry timestamptz not null,
last_request timestamptz not null,
poll_interval integer not null
);`,
},
},
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/base32" "encoding/base32"
"errors" "errors"
"io" "io"
"math/big"
"strings" "strings"
"time" "time"
@ -24,9 +25,21 @@ var (
// TODO(ericchiang): refactor ID creation onto the storage. // TODO(ericchiang): refactor ID creation onto the storage.
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567") var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
//Valid characters for user codes
const validUserCharacters = "BCDFGHJKLMNPQRSTVWXZ"
// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
func NewDeviceCode() string {
return newSecureID(32)
}
// NewID returns a random string which can be used as an ID for objects. // NewID returns a random string which can be used as an ID for objects.
func NewID() string { func NewID() string {
buff := make([]byte, 16) // 128 bit random ID. return newSecureID(16)
}
func newSecureID(len int) string {
buff := make([]byte, len) // random ID.
if _, err := io.ReadFull(rand.Reader, buff); err != nil { if _, err := io.ReadFull(rand.Reader, buff); err != nil {
panic(err) panic(err)
} }
@ -38,6 +51,8 @@ func NewID() string {
type GCResult struct { type GCResult struct {
AuthRequests int64 AuthRequests int64
AuthCodes int64 AuthCodes int64
DeviceRequests int64
DeviceTokens int64
} }
// Storage is the storage interface used by the server. Implementations are // Storage is the storage interface used by the server. Implementations are
@ -54,6 +69,8 @@ type Storage interface {
CreatePassword(p Password) error CreatePassword(p Password) error
CreateOfflineSessions(s OfflineSessions) error CreateOfflineSessions(s OfflineSessions) error
CreateConnector(c Connector) error CreateConnector(c Connector) error
CreateDeviceRequest(d DeviceRequest) error
CreateDeviceToken(d DeviceToken) error
// TODO(ericchiang): return (T, bool, error) so we can indicate not found // TODO(ericchiang): return (T, bool, error) so we can indicate not found
// requests that way instead of using ErrNotFound. // requests that way instead of using ErrNotFound.
@ -65,6 +82,8 @@ type Storage interface {
GetPassword(email string) (Password, error) GetPassword(email string) (Password, error)
GetOfflineSessions(userID string, connID string) (OfflineSessions, error) GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
GetConnector(id string) (Connector, error) GetConnector(id string) (Connector, error)
GetDeviceRequest(userCode string) (DeviceRequest, error)
GetDeviceToken(deviceCode string) (DeviceToken, error)
ListClients() ([]Client, error) ListClients() ([]Client, error)
ListRefreshTokens() ([]RefreshToken, error) ListRefreshTokens() ([]RefreshToken, error)
@ -101,8 +120,10 @@ type Storage interface {
UpdatePassword(email string, updater func(p Password) (Password, error)) error UpdatePassword(email string, updater func(p Password) (Password, error)) error
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error
// GarbageCollect deletes all expired AuthCodes and AuthRequests. // GarbageCollect deletes all expired AuthCodes,
// AuthRequests, DeviceRequests, and DeviceTokens.
GarbageCollect(now time.Time) (GCResult, error) GarbageCollect(now time.Time) (GCResult, error)
} }
@ -342,3 +363,49 @@ type Keys struct {
// For caching purposes, implementations MUST NOT update keys before this time. // For caching purposes, implementations MUST NOT update keys before this time.
NextRotation time.Time NextRotation time.Time
} }
// NewUserCode returns a randomized 8 character user code for the device flow.
// No vowels are included to prevent accidental generation of words
func NewUserCode() (string, error) {
code, err := randomString(8)
if err != nil {
return "", err
}
return code[:4] + "-" + code[4:], nil
}
func randomString(n int) (string, error) {
v := big.NewInt(int64(len(validUserCharacters)))
bytes := make([]byte, n)
for i := 0; i < n; i++ {
c, _ := rand.Int(rand.Reader, v)
bytes[i] = validUserCharacters[c.Int64()]
}
return string(bytes), nil
}
//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user
//authenticates using their user code or the expiry time passes.
type DeviceRequest struct {
//The code the user will enter in a browser
UserCode string
//The unique device code for device authentication
DeviceCode string
//The client ID the code is for
ClientID string
//The Client Secret
ClientSecret string
//The scopes the device requests
Scopes []string
//The expire time
Expiry time.Time
}
type DeviceToken struct {
DeviceCode string
Status string
Token string
Expiry time.Time
LastRequestTime time.Time
PollIntervalSeconds int
}

23
web/templates/device.html Normal file
View File

@ -0,0 +1,23 @@
{{ template "header.html" . }}
<div class="theme-panel">
<h2 class="theme-heading">Enter User Code</h2>
<form method="post" action="{{ .PostURL }}" method="get">
<div class="theme-form-row">
{{ if( .UserCode )}}
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" autocomplete="off" value="{{.UserCode}}" {{ if .Invalid }} autofocus {{ end }}/>
{{ else }}
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" placeholder="XXXX-XXXX" autocomplete="off" {{ if .Invalid }} autofocus {{ end }}/>
{{ end }}
</div>
{{ if .Invalid }}
<div id="login-error" class="dex-error-box">
Invalid or Expired User Code
</div>
{{ end }}
<button tabindex="3" id="submit-login" type="submit" class="dex-btn theme-btn--primary">Submit</button>
</form>
</div>
{{ template "footer.html" . }}

View File

@ -0,0 +1,8 @@
{{ template "header.html" . }}
<div class="theme-panel">
<h2 class="theme-heading">Login Successful for {{ .ClientName }}</h2>
<p>Return to your device to continue</p>
</div>
{{ template "footer.html" . }}