Merge pull request #1706 from justin-slowik/device_flow
Implementing the OAuth2 Device Authorization Grant
This commit is contained in:
commit
336c73c0a2
@ -279,6 +279,9 @@ type Expiry struct {
|
|||||||
|
|
||||||
// AuthRequests defines the duration of time for which the AuthRequests will be valid.
|
// AuthRequests defines the duration of time for which the AuthRequests will be valid.
|
||||||
AuthRequests string `json:"authRequests"`
|
AuthRequests string `json:"authRequests"`
|
||||||
|
|
||||||
|
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
|
||||||
|
DeviceRequests string `json:"deviceRequests"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logger holds configuration required to customize logging for dex.
|
// Logger holds configuration required to customize logging for dex.
|
||||||
|
@ -119,6 +119,7 @@ expiry:
|
|||||||
signingKeys: "7h"
|
signingKeys: "7h"
|
||||||
idTokens: "25h"
|
idTokens: "25h"
|
||||||
authRequests: "25h"
|
authRequests: "25h"
|
||||||
|
deviceRequests: "10m"
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
level: "debug"
|
level: "debug"
|
||||||
@ -200,6 +201,7 @@ logger:
|
|||||||
SigningKeys: "7h",
|
SigningKeys: "7h",
|
||||||
IDTokens: "25h",
|
IDTokens: "25h",
|
||||||
AuthRequests: "25h",
|
AuthRequests: "25h",
|
||||||
|
DeviceRequests: "10m",
|
||||||
},
|
},
|
||||||
Logger: Logger{
|
Logger: Logger{
|
||||||
Level: "debug",
|
Level: "debug",
|
||||||
|
@ -269,7 +269,14 @@ func serve(cmd *cobra.Command, args []string) error {
|
|||||||
logger.Infof("config auth requests valid for: %v", authRequests)
|
logger.Infof("config auth requests valid for: %v", authRequests)
|
||||||
serverConfig.AuthRequestsValidFor = authRequests
|
serverConfig.AuthRequestsValidFor = authRequests
|
||||||
}
|
}
|
||||||
|
if c.Expiry.DeviceRequests != "" {
|
||||||
|
deviceRequests, err := time.ParseDuration(c.Expiry.DeviceRequests)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err)
|
||||||
|
}
|
||||||
|
logger.Infof("config device requests valid for: %v", deviceRequests)
|
||||||
|
serverConfig.DeviceRequestsValidFor = deviceRequests
|
||||||
|
}
|
||||||
serv, err := server.NewServer(context.Background(), serverConfig)
|
serv, err := server.NewServer(context.Background(), serverConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize server: %v", err)
|
return fmt.Errorf("failed to initialize server: %v", err)
|
||||||
|
@ -64,6 +64,7 @@ telemetry:
|
|||||||
|
|
||||||
# Uncomment this block to enable configuration for the expiration time durations.
|
# Uncomment this block to enable configuration for the expiration time durations.
|
||||||
# expiry:
|
# expiry:
|
||||||
|
# deviceRequests: "5m"
|
||||||
# signingKeys: "6h"
|
# signingKeys: "6h"
|
||||||
# idTokens: "24h"
|
# idTokens: "24h"
|
||||||
|
|
||||||
@ -95,7 +96,11 @@ staticClients:
|
|||||||
- 'http://127.0.0.1:5555/callback'
|
- 'http://127.0.0.1:5555/callback'
|
||||||
name: 'Example App'
|
name: 'Example App'
|
||||||
secret: ZXhhbXBsZS1hcHAtc2VjcmV0
|
secret: ZXhhbXBsZS1hcHAtc2VjcmV0
|
||||||
|
# - id: example-device-client
|
||||||
|
# redirectURIs:
|
||||||
|
# - /device/callback
|
||||||
|
# name: 'Static Client for Device Flow'
|
||||||
|
# public: true
|
||||||
connectors:
|
connectors:
|
||||||
- type: mockCallback
|
- type: mockCallback
|
||||||
id: mock
|
id: mock
|
||||||
|
12
scripts/manifests/crds/devicerequests.yaml
Normal file
12
scripts/manifests/crds/devicerequests.yaml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
apiVersion: apiextensions.k8s.io/v1beta1
|
||||||
|
kind: CustomResourceDefinition
|
||||||
|
metadata:
|
||||||
|
name: devicerequests.dex.coreos.com
|
||||||
|
spec:
|
||||||
|
group: dex.coreos.com
|
||||||
|
names:
|
||||||
|
kind: DeviceRequest
|
||||||
|
listKind: DeviceRequestList
|
||||||
|
plural: devicerequests
|
||||||
|
singular: devicerequest
|
||||||
|
version: v1
|
12
scripts/manifests/crds/devicetokens.yaml
Normal file
12
scripts/manifests/crds/devicetokens.yaml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
apiVersion: apiextensions.k8s.io/v1beta1
|
||||||
|
kind: CustomResourceDefinition
|
||||||
|
metadata:
|
||||||
|
name: devicetokens.dex.coreos.com
|
||||||
|
spec:
|
||||||
|
group: dex.coreos.com
|
||||||
|
names:
|
||||||
|
kind: DeviceToken
|
||||||
|
listKind: DeviceTokenList
|
||||||
|
plural: devicetokens
|
||||||
|
singular: devicetoken
|
||||||
|
version: v1
|
390
server/deviceflowhandlers.go
Normal file
390
server/deviceflowhandlers.go
Normal file
@ -0,0 +1,390 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
type deviceCodeResponse struct {
|
||||||
|
//The unique device code for device authentication
|
||||||
|
DeviceCode string `json:"device_code"`
|
||||||
|
//The code the user will exchange via a browser and log in
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
//The url to verify the user code.
|
||||||
|
VerificationURI string `json:"verification_uri"`
|
||||||
|
//The verification uri with the user code appended for pre-filling form
|
||||||
|
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||||
|
//The lifetime of the device code
|
||||||
|
ExpireTime int `json:"expires_in"`
|
||||||
|
//How often the device is allowed to poll to verify that the user login occurred
|
||||||
|
PollInterval int `json:"interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) getDeviceVerificationURI() string {
|
||||||
|
return path.Join(s.issuerURL.Path, "/device/auth/verify_code")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
// Grab the parameter(s) from the query.
|
||||||
|
// If "user_code" is set, pre-populate the user code text field.
|
||||||
|
// If "invalid" is set, set the invalidAttempt boolean, which will display a message to the user that they
|
||||||
|
// attempted to redeem an invalid or expired user code.
|
||||||
|
userCode := r.URL.Query().Get("user_code")
|
||||||
|
invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid"))
|
||||||
|
if err != nil {
|
||||||
|
invalidAttempt = false
|
||||||
|
}
|
||||||
|
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil {
|
||||||
|
s.logger.Errorf("Server template error: %v", err)
|
||||||
|
s.renderError(r, w, http.StatusNotFound, "Page not found")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||||
|
pollIntervalSeconds := 5
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Get the client id and scopes from the post
|
||||||
|
clientID := r.Form.Get("client_id")
|
||||||
|
clientSecret := r.Form.Get("client_secret")
|
||||||
|
scopes := strings.Fields(r.Form.Get("scope"))
|
||||||
|
|
||||||
|
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
|
||||||
|
|
||||||
|
//Make device code
|
||||||
|
deviceCode := storage.NewDeviceCode()
|
||||||
|
|
||||||
|
//make user code
|
||||||
|
userCode, err := storage.NewUserCode()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Error generating user code: %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Generate the expire time
|
||||||
|
expireTime := time.Now().Add(s.deviceRequestsValidFor)
|
||||||
|
|
||||||
|
//Store the Device Request
|
||||||
|
deviceReq := storage.DeviceRequest{
|
||||||
|
UserCode: userCode,
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
Scopes: scopes,
|
||||||
|
Expiry: expireTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
||||||
|
s.logger.Errorf("Failed to store device request; %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Store the device token
|
||||||
|
deviceToken := storage.DeviceToken{
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
Status: deviceTokenPending,
|
||||||
|
Expiry: expireTime,
|
||||||
|
LastRequestTime: s.now(),
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
||||||
|
s.logger.Errorf("Failed to store device token %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(s.issuerURL.String())
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Could not parse issuer URL %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u.Path = path.Join(u.Path, "device")
|
||||||
|
vURI := u.String()
|
||||||
|
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("user_code", userCode)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
vURIComplete := u.String()
|
||||||
|
|
||||||
|
code := deviceCodeResponse{
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
UserCode: userCode,
|
||||||
|
VerificationURI: vURI,
|
||||||
|
VerificationURIComplete: vURIComplete,
|
||||||
|
ExpireTime: int(s.deviceRequestsValidFor.Seconds()),
|
||||||
|
PollInterval: pollIntervalSeconds,
|
||||||
|
}
|
||||||
|
|
||||||
|
enc := json.NewEncoder(w)
|
||||||
|
enc.SetEscapeHTML(false)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
enc.Encode(code)
|
||||||
|
|
||||||
|
default:
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warnf("Could not parse Device Token Request body: %v", err)
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceCode := r.Form.Get("device_code")
|
||||||
|
if deviceCode == "" {
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
grantType := r.PostFormValue("grant_type")
|
||||||
|
if grantType != grantTypeDeviceCode {
|
||||||
|
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := s.now()
|
||||||
|
|
||||||
|
//Grab the device token, check validity
|
||||||
|
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get device code: %v", err)
|
||||||
|
}
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
} else if now.After(deviceToken.Expiry) {
|
||||||
|
s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Rate Limiting check
|
||||||
|
slowDown := false
|
||||||
|
pollInterval := deviceToken.PollIntervalSeconds
|
||||||
|
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
|
||||||
|
if now.Before(minRequestTime) {
|
||||||
|
slowDown = true
|
||||||
|
//Continually increase the poll interval until the user waits the proper time
|
||||||
|
pollInterval += 5
|
||||||
|
} else {
|
||||||
|
pollInterval = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
switch deviceToken.Status {
|
||||||
|
case deviceTokenPending:
|
||||||
|
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||||
|
old.PollIntervalSeconds = pollInterval
|
||||||
|
old.LastRequestTime = now
|
||||||
|
return old, nil
|
||||||
|
}
|
||||||
|
// Update device token last request time in storage
|
||||||
|
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
|
||||||
|
s.logger.Errorf("failed to update device token: %v", err)
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if slowDown {
|
||||||
|
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
|
||||||
|
} else {
|
||||||
|
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
case deviceTokenComplete:
|
||||||
|
w.Write([]byte(deviceToken.Token))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
userCode := r.FormValue("state")
|
||||||
|
code := r.FormValue("code")
|
||||||
|
|
||||||
|
if userCode == "" || code == "" {
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Request was missing parameters")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorization redirect callback from OAuth2 auth flow.
|
||||||
|
if errMsg := r.FormValue("error"); errMsg != "" {
|
||||||
|
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authCode, err := s.storage.GetAuthCode(code)
|
||||||
|
if err != nil || s.now().After(authCode.Expiry) {
|
||||||
|
errCode := http.StatusBadRequest
|
||||||
|
if err != nil && err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get auth code: %v", err)
|
||||||
|
errCode = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
s.renderError(r, w, errCode, "Invalid or expired auth code.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Grab the device request from storage
|
||||||
|
deviceReq, err := s.storage.GetDeviceRequest(userCode)
|
||||||
|
if err != nil || s.now().After(deviceReq.Expiry) {
|
||||||
|
errCode := http.StatusBadRequest
|
||||||
|
if err != nil && err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get device code: %v", err)
|
||||||
|
errCode = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
s.renderError(r, w, errCode, "Invalid or expired user code.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
client, err := s.storage.GetClient(deviceReq.ClientID)
|
||||||
|
if err != nil {
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get client: %v", err)
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
} else {
|
||||||
|
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if client.Secret != deviceReq.ClientSecret {
|
||||||
|
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := s.exchangeAuthCode(w, authCode, client)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Grab the device token from storage
|
||||||
|
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
|
||||||
|
if err != nil || s.now().After(old.Expiry) {
|
||||||
|
errCode := http.StatusBadRequest
|
||||||
|
if err != nil && err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get device token: %v", err)
|
||||||
|
errCode = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
s.renderError(r, w, errCode, "Invalid or expired device code.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||||
|
if old.Status == deviceTokenComplete {
|
||||||
|
return old, errors.New("device token already complete")
|
||||||
|
}
|
||||||
|
respStr, err := json.MarshalIndent(resp, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to marshal device token response: %v", err)
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||||
|
return old, err
|
||||||
|
}
|
||||||
|
|
||||||
|
old.Token = string(respStr)
|
||||||
|
old.Status = deviceTokenComplete
|
||||||
|
return old, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update refresh token in the storage, store the token and mark as complete
|
||||||
|
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
|
||||||
|
s.logger.Errorf("failed to update device token: %v", err)
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.templates.deviceSuccess(r, w, client.Name); err != nil {
|
||||||
|
s.logger.Errorf("Server template error: %v", err)
|
||||||
|
s.renderError(r, w, http.StatusNotFound, "Page not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPost:
|
||||||
|
err := r.ParseForm()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warnf("Could not parse user code verification request body : %v", err)
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userCode := r.Form.Get("user_code")
|
||||||
|
if userCode == "" {
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "No user code received")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userCode = strings.ToUpper(userCode)
|
||||||
|
|
||||||
|
//Find the user code in the available requests
|
||||||
|
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
|
||||||
|
if err != nil || s.now().After(deviceRequest.Expiry) {
|
||||||
|
if err != nil && err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get device request: %v", err)
|
||||||
|
}
|
||||||
|
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil {
|
||||||
|
s.logger.Errorf("Server template error: %v", err)
|
||||||
|
s.renderError(r, w, http.StatusNotFound, "Page not found")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
//Redirect to Dex Auth Endpoint
|
||||||
|
authURL := path.Join(s.issuerURL.Path, "/auth")
|
||||||
|
u, err := url.Parse(authURL)
|
||||||
|
if err != nil {
|
||||||
|
s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("client_id", deviceRequest.ClientID)
|
||||||
|
q.Set("client_secret", deviceRequest.ClientSecret)
|
||||||
|
q.Set("state", deviceRequest.UserCode)
|
||||||
|
q.Set("response_type", "code")
|
||||||
|
q.Set("redirect_uri", "/device/callback")
|
||||||
|
q.Set("scope", strings.Join(deviceRequest.Scopes, " "))
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
http.Redirect(w, r, u.String(), http.StatusFound)
|
||||||
|
|
||||||
|
default:
|
||||||
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||||
|
}
|
||||||
|
}
|
678
server/deviceflowhandlers_test.go
Normal file
678
server/deviceflowhandlers_test.go
Normal file
@ -0,0 +1,678 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/dexidp/dex/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeviceVerificationURI(t *testing.T) {
|
||||||
|
t0 := time.Now()
|
||||||
|
|
||||||
|
now := func() time.Time { return t0 }
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
// Setup a dex server.
|
||||||
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||||
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
|
c.Now = now
|
||||||
|
})
|
||||||
|
defer httpServer.Close()
|
||||||
|
|
||||||
|
u, err := url.Parse(s.issuerURL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not parse issuer URL %v", err)
|
||||||
|
}
|
||||||
|
u.Path = path.Join(u.Path, "/device/auth/verify_code")
|
||||||
|
|
||||||
|
uri := s.getDeviceVerificationURI()
|
||||||
|
if uri != u.Path {
|
||||||
|
t.Errorf("Invalid verification URI. Expected %v got %v", u.Path, uri)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleDeviceCode(t *testing.T) {
|
||||||
|
t0 := time.Now()
|
||||||
|
|
||||||
|
now := func() time.Time { return t0 }
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
testName string
|
||||||
|
clientID string
|
||||||
|
requestType string
|
||||||
|
scopes []string
|
||||||
|
expectedResponseCode int
|
||||||
|
expectedServerResponse string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
testName: "New Code",
|
||||||
|
clientID: "test",
|
||||||
|
requestType: "POST",
|
||||||
|
scopes: []string{"openid", "profile", "email"},
|
||||||
|
expectedResponseCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Invalid request Type (GET)",
|
||||||
|
clientID: "test",
|
||||||
|
requestType: "GET",
|
||||||
|
scopes: []string{"openid", "profile", "email"},
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.testName, func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Setup a dex server.
|
||||||
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||||
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
|
c.Now = now
|
||||||
|
})
|
||||||
|
defer httpServer.Close()
|
||||||
|
|
||||||
|
u, err := url.Parse(s.issuerURL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not parse issuer URL %v", err)
|
||||||
|
}
|
||||||
|
u.Path = path.Join(u.Path, "device/code")
|
||||||
|
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", tc.clientID)
|
||||||
|
for _, scope := range tc.scopes {
|
||||||
|
data.Add("scope", scope)
|
||||||
|
}
|
||||||
|
req, _ := http.NewRequest(tc.requestType, u.String(), bytes.NewBufferString(data.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != tc.expectedResponseCode {
|
||||||
|
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(rr.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could read token response %v", err)
|
||||||
|
}
|
||||||
|
if tc.expectedResponseCode == http.StatusOK {
|
||||||
|
var resp deviceCodeResponse
|
||||||
|
if err := json.Unmarshal(body, &resp); err != nil {
|
||||||
|
t.Errorf("Unexpected Device Code Response Format %v", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceCallback(t *testing.T) {
|
||||||
|
t0 := time.Now()
|
||||||
|
|
||||||
|
now := func() time.Time { return t0 }
|
||||||
|
|
||||||
|
type formValues struct {
|
||||||
|
state string
|
||||||
|
code string
|
||||||
|
error string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base "Control" test values
|
||||||
|
baseFormValues := formValues{
|
||||||
|
state: "XXXX-XXXX",
|
||||||
|
code: "somecode",
|
||||||
|
}
|
||||||
|
baseAuthCode := storage.AuthCode{
|
||||||
|
ID: "somecode",
|
||||||
|
ClientID: "testclient",
|
||||||
|
RedirectURI: deviceCallbackURI,
|
||||||
|
Nonce: "",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
ConnectorID: "mock",
|
||||||
|
ConnectorData: nil,
|
||||||
|
Claims: storage.Claims{},
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
}
|
||||||
|
baseDeviceRequest := storage.DeviceRequest{
|
||||||
|
UserCode: "XXXX-XXXX",
|
||||||
|
DeviceCode: "devicecode",
|
||||||
|
ClientID: "testclient",
|
||||||
|
ClientSecret: "",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
}
|
||||||
|
baseDeviceToken := storage.DeviceToken{
|
||||||
|
DeviceCode: "devicecode",
|
||||||
|
Status: deviceTokenPending,
|
||||||
|
Token: "",
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
LastRequestTime: time.Time{},
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
testName string
|
||||||
|
expectedResponseCode int
|
||||||
|
values formValues
|
||||||
|
testAuthCode storage.AuthCode
|
||||||
|
testDeviceRequest storage.DeviceRequest
|
||||||
|
testDeviceToken storage.DeviceToken
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
testName: "Missing State",
|
||||||
|
values: formValues{
|
||||||
|
state: "",
|
||||||
|
code: "somecode",
|
||||||
|
error: "",
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Missing Code",
|
||||||
|
values: formValues{
|
||||||
|
state: "XXXX-XXXX",
|
||||||
|
code: "",
|
||||||
|
error: "",
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Error During Authorization",
|
||||||
|
values: formValues{
|
||||||
|
state: "XXXX-XXXX",
|
||||||
|
code: "somecode",
|
||||||
|
error: "Error Condition",
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Expired Auth Code",
|
||||||
|
values: baseFormValues,
|
||||||
|
testAuthCode: storage.AuthCode{
|
||||||
|
ID: "somecode",
|
||||||
|
ClientID: "testclient",
|
||||||
|
RedirectURI: deviceCallbackURI,
|
||||||
|
Nonce: "",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
ConnectorID: "pic",
|
||||||
|
ConnectorData: nil,
|
||||||
|
Claims: storage.Claims{},
|
||||||
|
Expiry: now().Add(-5 * time.Minute),
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Invalid Auth Code",
|
||||||
|
values: baseFormValues,
|
||||||
|
testAuthCode: storage.AuthCode{
|
||||||
|
ID: "somecode",
|
||||||
|
ClientID: "testclient",
|
||||||
|
RedirectURI: deviceCallbackURI,
|
||||||
|
Nonce: "",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
ConnectorID: "pic",
|
||||||
|
ConnectorData: nil,
|
||||||
|
Claims: storage.Claims{},
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Expired Device Request",
|
||||||
|
values: baseFormValues,
|
||||||
|
testAuthCode: baseAuthCode,
|
||||||
|
testDeviceRequest: storage.DeviceRequest{
|
||||||
|
UserCode: "XXXX-XXXX",
|
||||||
|
DeviceCode: "devicecode",
|
||||||
|
ClientID: "testclient",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
Expiry: now().Add(-5 * time.Minute),
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Non-Existent User Code",
|
||||||
|
values: baseFormValues,
|
||||||
|
testAuthCode: baseAuthCode,
|
||||||
|
testDeviceRequest: storage.DeviceRequest{
|
||||||
|
UserCode: "ZZZZ-ZZZZ",
|
||||||
|
DeviceCode: "devicecode",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Bad Device Request Client",
|
||||||
|
values: baseFormValues,
|
||||||
|
testAuthCode: baseAuthCode,
|
||||||
|
testDeviceRequest: storage.DeviceRequest{
|
||||||
|
UserCode: "XXXX-XXXX",
|
||||||
|
DeviceCode: "devicecode",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Bad Device Request Secret",
|
||||||
|
values: baseFormValues,
|
||||||
|
testAuthCode: baseAuthCode,
|
||||||
|
testDeviceRequest: storage.DeviceRequest{
|
||||||
|
UserCode: "XXXX-XXXX",
|
||||||
|
DeviceCode: "devicecode",
|
||||||
|
ClientSecret: "foobar",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Expired Device Token",
|
||||||
|
values: baseFormValues,
|
||||||
|
testAuthCode: baseAuthCode,
|
||||||
|
testDeviceRequest: baseDeviceRequest,
|
||||||
|
testDeviceToken: storage.DeviceToken{
|
||||||
|
DeviceCode: "devicecode",
|
||||||
|
Status: deviceTokenPending,
|
||||||
|
Token: "",
|
||||||
|
Expiry: now().Add(-5 * time.Minute),
|
||||||
|
LastRequestTime: time.Time{},
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Device Code Already Redeemed",
|
||||||
|
values: baseFormValues,
|
||||||
|
testAuthCode: baseAuthCode,
|
||||||
|
testDeviceRequest: baseDeviceRequest,
|
||||||
|
testDeviceToken: storage.DeviceToken{
|
||||||
|
DeviceCode: "devicecode",
|
||||||
|
Status: deviceTokenComplete,
|
||||||
|
Token: "",
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
LastRequestTime: time.Time{},
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
},
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Successful Exchange",
|
||||||
|
values: baseFormValues,
|
||||||
|
testAuthCode: baseAuthCode,
|
||||||
|
testDeviceRequest: baseDeviceRequest,
|
||||||
|
testDeviceToken: baseDeviceToken,
|
||||||
|
expectedResponseCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.testName, func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Setup a dex server.
|
||||||
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||||
|
//c.Issuer = c.Issuer + "/non-root-path"
|
||||||
|
c.Now = now
|
||||||
|
})
|
||||||
|
defer httpServer.Close()
|
||||||
|
|
||||||
|
if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil {
|
||||||
|
t.Fatalf("failed to create auth code: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
|
||||||
|
t.Fatalf("failed to create device request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
|
||||||
|
t.Fatalf("failed to create device token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := storage.Client{
|
||||||
|
ID: "testclient",
|
||||||
|
Secret: "",
|
||||||
|
RedirectURIs: []string{deviceCallbackURI},
|
||||||
|
}
|
||||||
|
if err := s.storage.CreateClient(client); err != nil {
|
||||||
|
t.Fatalf("failed to create client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(s.issuerURL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not parse issuer URL %v", err)
|
||||||
|
}
|
||||||
|
u.Path = path.Join(u.Path, "device/callback")
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("state", tc.values.state)
|
||||||
|
q.Set("code", tc.values.code)
|
||||||
|
q.Set("error", tc.values.error)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
req, _ := http.NewRequest("GET", u.String(), nil)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != tc.expectedResponseCode {
|
||||||
|
t.Errorf("%s: Unexpected Response Type. Expected %v got %v", tc.testName, tc.expectedResponseCode, rr.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceTokenResponse(t *testing.T) {
|
||||||
|
t0 := time.Now()
|
||||||
|
|
||||||
|
now := func() time.Time { return t0 }
|
||||||
|
|
||||||
|
baseDeviceRequest := storage.DeviceRequest{
|
||||||
|
UserCode: "ABCD-WXYZ",
|
||||||
|
DeviceCode: "foo",
|
||||||
|
ClientID: "testclient",
|
||||||
|
Scopes: []string{"openid", "profile", "offline_access"},
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
testName string
|
||||||
|
testDeviceRequest storage.DeviceRequest
|
||||||
|
testDeviceToken storage.DeviceToken
|
||||||
|
testGrantType string
|
||||||
|
testDeviceCode string
|
||||||
|
expectedServerResponse string
|
||||||
|
expectedResponseCode int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
testName: "Valid but pending token",
|
||||||
|
testDeviceRequest: baseDeviceRequest,
|
||||||
|
testDeviceToken: storage.DeviceToken{
|
||||||
|
DeviceCode: "f00bar",
|
||||||
|
Status: deviceTokenPending,
|
||||||
|
Token: "",
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
LastRequestTime: time.Time{},
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
},
|
||||||
|
testDeviceCode: "f00bar",
|
||||||
|
expectedServerResponse: deviceTokenPending,
|
||||||
|
expectedResponseCode: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Invalid Grant Type",
|
||||||
|
testDeviceRequest: baseDeviceRequest,
|
||||||
|
testDeviceToken: storage.DeviceToken{
|
||||||
|
DeviceCode: "f00bar",
|
||||||
|
Status: deviceTokenPending,
|
||||||
|
Token: "",
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
LastRequestTime: time.Time{},
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
},
|
||||||
|
testDeviceCode: "f00bar",
|
||||||
|
testGrantType: grantTypeAuthorizationCode,
|
||||||
|
expectedServerResponse: errInvalidGrant,
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Test Slow Down State",
|
||||||
|
testDeviceRequest: baseDeviceRequest,
|
||||||
|
testDeviceToken: storage.DeviceToken{
|
||||||
|
DeviceCode: "f00bar",
|
||||||
|
Status: deviceTokenPending,
|
||||||
|
Token: "",
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
LastRequestTime: now(),
|
||||||
|
PollIntervalSeconds: 10,
|
||||||
|
},
|
||||||
|
testDeviceCode: "f00bar",
|
||||||
|
expectedServerResponse: deviceTokenSlowDown,
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Test Expired Device Token",
|
||||||
|
testDeviceRequest: baseDeviceRequest,
|
||||||
|
testDeviceToken: storage.DeviceToken{
|
||||||
|
DeviceCode: "f00bar",
|
||||||
|
Status: deviceTokenPending,
|
||||||
|
Token: "",
|
||||||
|
Expiry: now().Add(-5 * time.Minute),
|
||||||
|
LastRequestTime: time.Time{},
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
},
|
||||||
|
testDeviceCode: "f00bar",
|
||||||
|
expectedServerResponse: deviceTokenExpired,
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Test Non-existent Device Code",
|
||||||
|
testDeviceRequest: baseDeviceRequest,
|
||||||
|
testDeviceToken: storage.DeviceToken{
|
||||||
|
DeviceCode: "foo",
|
||||||
|
Status: deviceTokenPending,
|
||||||
|
Token: "",
|
||||||
|
Expiry: now().Add(-5 * time.Minute),
|
||||||
|
LastRequestTime: time.Time{},
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
},
|
||||||
|
testDeviceCode: "bar",
|
||||||
|
expectedServerResponse: errInvalidRequest,
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Empty Device Code in Request",
|
||||||
|
testDeviceRequest: baseDeviceRequest,
|
||||||
|
testDeviceToken: storage.DeviceToken{
|
||||||
|
DeviceCode: "bar",
|
||||||
|
Status: deviceTokenPending,
|
||||||
|
Token: "",
|
||||||
|
Expiry: now().Add(-5 * time.Minute),
|
||||||
|
LastRequestTime: time.Time{},
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
},
|
||||||
|
testDeviceCode: "",
|
||||||
|
expectedServerResponse: errInvalidRequest,
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Claim validated token from Device Code",
|
||||||
|
testDeviceRequest: baseDeviceRequest,
|
||||||
|
testDeviceToken: storage.DeviceToken{
|
||||||
|
DeviceCode: "foo",
|
||||||
|
Status: deviceTokenComplete,
|
||||||
|
Token: "{\"access_token\": \"foobar\"}",
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
LastRequestTime: time.Time{},
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
},
|
||||||
|
testDeviceCode: "foo",
|
||||||
|
expectedServerResponse: "{\"access_token\": \"foobar\"}",
|
||||||
|
expectedResponseCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.testName, func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Setup a dex server.
|
||||||
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||||
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
|
c.Now = now
|
||||||
|
})
|
||||||
|
defer httpServer.Close()
|
||||||
|
|
||||||
|
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
|
||||||
|
t.Fatalf("Failed to store device token %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil {
|
||||||
|
t.Fatalf("Failed to store device token %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(s.issuerURL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not parse issuer URL %v", err)
|
||||||
|
}
|
||||||
|
u.Path = path.Join(u.Path, "device/token")
|
||||||
|
|
||||||
|
data := url.Values{}
|
||||||
|
grantType := grantTypeDeviceCode
|
||||||
|
if tc.testGrantType != "" {
|
||||||
|
grantType = tc.testGrantType
|
||||||
|
}
|
||||||
|
data.Set("grant_type", grantType)
|
||||||
|
data.Set("device_code", tc.testDeviceCode)
|
||||||
|
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != tc.expectedResponseCode {
|
||||||
|
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(rr.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could read token response %v", err)
|
||||||
|
}
|
||||||
|
if tc.expectedResponseCode == http.StatusBadRequest || tc.expectedResponseCode == http.StatusUnauthorized {
|
||||||
|
expectJsonErrorResponse(tc.testName, body, tc.expectedServerResponse, t)
|
||||||
|
} else if string(body) != tc.expectedServerResponse {
|
||||||
|
t.Errorf("Unexpected Server Response. Expected %v got %v", tc.expectedServerResponse, string(body))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectJsonErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) {
|
||||||
|
jsonMap := make(map[string]interface{})
|
||||||
|
err := json.Unmarshal(body, &jsonMap)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error unmarshalling response: %v", err)
|
||||||
|
}
|
||||||
|
if jsonMap["error"] != expectedError {
|
||||||
|
t.Errorf("Test Case %s expected error %v, received %v", testCase, expectedError, jsonMap["error"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyCodeResponse(t *testing.T) {
|
||||||
|
t0 := time.Now()
|
||||||
|
|
||||||
|
now := func() time.Time { return t0 }
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
testName string
|
||||||
|
testDeviceRequest storage.DeviceRequest
|
||||||
|
userCode string
|
||||||
|
expectedResponseCode int
|
||||||
|
expectedRedirectPath string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
testName: "Unknown user code",
|
||||||
|
testDeviceRequest: storage.DeviceRequest{
|
||||||
|
UserCode: "ABCD-WXYZ",
|
||||||
|
DeviceCode: "f00bar",
|
||||||
|
ClientID: "testclient",
|
||||||
|
Scopes: []string{"openid", "profile", "offline_access"},
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
},
|
||||||
|
userCode: "CODE-TEST",
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
expectedRedirectPath: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Expired user code",
|
||||||
|
testDeviceRequest: storage.DeviceRequest{
|
||||||
|
UserCode: "ABCD-WXYZ",
|
||||||
|
DeviceCode: "f00bar",
|
||||||
|
ClientID: "testclient",
|
||||||
|
Scopes: []string{"openid", "profile", "offline_access"},
|
||||||
|
Expiry: now().Add(-5 * time.Minute),
|
||||||
|
},
|
||||||
|
userCode: "ABCD-WXYZ",
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
expectedRedirectPath: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "No user code",
|
||||||
|
testDeviceRequest: storage.DeviceRequest{
|
||||||
|
UserCode: "ABCD-WXYZ",
|
||||||
|
DeviceCode: "f00bar",
|
||||||
|
ClientID: "testclient",
|
||||||
|
Scopes: []string{"openid", "profile", "offline_access"},
|
||||||
|
Expiry: now().Add(-5 * time.Minute),
|
||||||
|
},
|
||||||
|
userCode: "",
|
||||||
|
expectedResponseCode: http.StatusBadRequest,
|
||||||
|
expectedRedirectPath: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
testName: "Valid user code, expect redirect to auth endpoint",
|
||||||
|
testDeviceRequest: storage.DeviceRequest{
|
||||||
|
UserCode: "ABCD-WXYZ",
|
||||||
|
DeviceCode: "f00bar",
|
||||||
|
ClientID: "testclient",
|
||||||
|
Scopes: []string{"openid", "profile", "offline_access"},
|
||||||
|
Expiry: now().Add(5 * time.Minute),
|
||||||
|
},
|
||||||
|
userCode: "ABCD-WXYZ",
|
||||||
|
expectedResponseCode: http.StatusFound,
|
||||||
|
expectedRedirectPath: "/auth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.testName, func(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Setup a dex server.
|
||||||
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||||
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
|
c.Now = now
|
||||||
|
})
|
||||||
|
defer httpServer.Close()
|
||||||
|
|
||||||
|
if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil {
|
||||||
|
t.Fatalf("Failed to store device token %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(s.issuerURL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Could not parse issuer URL %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u.Path = path.Join(u.Path, "device/auth/verify_code")
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("user_code", tc.userCode)
|
||||||
|
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||||
|
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
s.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != tc.expectedResponseCode {
|
||||||
|
t.Errorf("Unexpected Response Type. Expected %v got %v", tc.expectedResponseCode, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err = url.Parse(s.issuerURL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not parse issuer URL %v", err)
|
||||||
|
}
|
||||||
|
u.Path = path.Join(u.Path, tc.expectedRedirectPath)
|
||||||
|
|
||||||
|
location := rr.Header().Get("Location")
|
||||||
|
if rr.Code == http.StatusFound && !strings.HasPrefix(location, u.Path) {
|
||||||
|
t.Errorf("Invalid Redirect. Expected %v got %v", u.Path, location)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -153,6 +153,8 @@ type discovery struct {
|
|||||||
Token string `json:"token_endpoint"`
|
Token string `json:"token_endpoint"`
|
||||||
Keys string `json:"jwks_uri"`
|
Keys string `json:"jwks_uri"`
|
||||||
UserInfo string `json:"userinfo_endpoint"`
|
UserInfo string `json:"userinfo_endpoint"`
|
||||||
|
DeviceEndpoint string `json:"device_authorization_endpoint"`
|
||||||
|
GrantTypes []string `json:"grant_types_supported"`
|
||||||
ResponseTypes []string `json:"response_types_supported"`
|
ResponseTypes []string `json:"response_types_supported"`
|
||||||
Subjects []string `json:"subject_types_supported"`
|
Subjects []string `json:"subject_types_supported"`
|
||||||
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
||||||
@ -168,7 +170,9 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
|||||||
Token: s.absURL("/token"),
|
Token: s.absURL("/token"),
|
||||||
Keys: s.absURL("/keys"),
|
Keys: s.absURL("/keys"),
|
||||||
UserInfo: s.absURL("/userinfo"),
|
UserInfo: s.absURL("/userinfo"),
|
||||||
|
DeviceEndpoint: s.absURL("/device/code"),
|
||||||
Subjects: []string{"public"},
|
Subjects: []string{"public"},
|
||||||
|
GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
|
||||||
IDTokenAlgs: []string{string(jose.RS256)},
|
IDTokenAlgs: []string{string(jose.RS256)},
|
||||||
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
||||||
AuthMethods: []string{"client_secret_basic"},
|
AuthMethods: []string{"client_secret_basic"},
|
||||||
@ -784,24 +788,33 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenResponse, err := s.exchangeAuthCode(w, authCode, client)
|
||||||
|
if err != nil {
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.writeAccessToken(w, tokenResponse)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenReponse, error) {
|
||||||
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
|
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to create new access token: %v", err)
|
s.logger.Errorf("failed to create new access token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
|
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to create ID token: %v", err)
|
s.logger.Errorf("failed to create ID token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.storage.DeleteAuthCode(code); err != nil {
|
if err := s.storage.DeleteAuthCode(authCode.ID); err != nil {
|
||||||
s.logger.Errorf("failed to delete auth code: %v", err)
|
s.logger.Errorf("failed to delete auth code: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
reqRefresh := func() bool {
|
reqRefresh := func() bool {
|
||||||
@ -848,13 +861,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||||||
if refreshToken, err = internal.Marshal(token); err != nil {
|
if refreshToken, err = internal.Marshal(token); err != nil {
|
||||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.storage.CreateRefresh(refresh); err != nil {
|
if err := s.storage.CreateRefresh(refresh); err != nil {
|
||||||
s.logger.Errorf("failed to create refresh token: %v", err)
|
s.logger.Errorf("failed to create refresh token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteToken determines if we need to delete the newly created refresh token
|
// deleteToken determines if we need to delete the newly created refresh token
|
||||||
@ -885,7 +898,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||||||
s.logger.Errorf("failed to get offline session: %v", err)
|
s.logger.Errorf("failed to get offline session: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
deleteToken = true
|
deleteToken = true
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
offlineSessions := storage.OfflineSessions{
|
offlineSessions := storage.OfflineSessions{
|
||||||
UserID: refresh.Claims.UserID,
|
UserID: refresh.Claims.UserID,
|
||||||
@ -900,7 +913,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||||||
s.logger.Errorf("failed to create offline session: %v", err)
|
s.logger.Errorf("failed to create offline session: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
deleteToken = true
|
deleteToken = true
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
|
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
|
||||||
@ -909,7 +922,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||||||
s.logger.Errorf("failed to delete refresh token: %v", err)
|
s.logger.Errorf("failed to delete refresh token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
deleteToken = true
|
deleteToken = true
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -921,11 +934,11 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||||||
s.logger.Errorf("failed to update offline session: %v", err)
|
s.logger.Errorf("failed to update offline session: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
deleteToken = true
|
deleteToken = true
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
|
return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
||||||
@ -1121,7 +1134,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
|
||||||
|
s.writeAccessToken(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
@ -1368,23 +1382,29 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
|
resp := s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry)
|
||||||
|
s.writeAccessToken(w, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
|
type accessTokenReponse struct {
|
||||||
resp := struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
AccessToken string `json:"access_token"`
|
||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
ExpiresIn int `json:"expires_in"`
|
ExpiresIn int `json:"expires_in"`
|
||||||
RefreshToken string `json:"refresh_token,omitempty"`
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
IDToken string `json:"id_token"`
|
IDToken string `json:"id_token"`
|
||||||
}{
|
}
|
||||||
|
|
||||||
|
func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenReponse {
|
||||||
|
return &accessTokenReponse{
|
||||||
accessToken,
|
accessToken,
|
||||||
"bearer",
|
"bearer",
|
||||||
int(expiry.Sub(s.now()).Seconds()),
|
int(expiry.Sub(s.now()).Seconds()),
|
||||||
refreshToken,
|
refreshToken,
|
||||||
idToken,
|
idToken,
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenReponse) {
|
||||||
data, err := json.Marshal(resp)
|
data, err := json.Marshal(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to marshal access token response: %v", err)
|
s.logger.Errorf("failed to marshal access token response: %v", err)
|
||||||
|
@ -114,6 +114,10 @@ const (
|
|||||||
scopeCrossClientPrefix = "audience:server:client_id:"
|
scopeCrossClientPrefix = "audience:server:client_id:"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
deviceCallbackURI = "/device/callback"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
redirectURIOOB = "urn:ietf:wg:oauth:2.0:oob"
|
redirectURIOOB = "urn:ietf:wg:oauth:2.0:oob"
|
||||||
)
|
)
|
||||||
@ -122,6 +126,7 @@ const (
|
|||||||
grantTypeAuthorizationCode = "authorization_code"
|
grantTypeAuthorizationCode = "authorization_code"
|
||||||
grantTypeRefreshToken = "refresh_token"
|
grantTypeRefreshToken = "refresh_token"
|
||||||
grantTypePassword = "password"
|
grantTypePassword = "password"
|
||||||
|
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -130,6 +135,13 @@ const (
|
|||||||
responseTypeIDToken = "id_token" // ID Token in url fragment
|
responseTypeIDToken = "id_token" // ID Token in url fragment
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
deviceTokenPending = "authorization_pending"
|
||||||
|
deviceTokenComplete = "complete"
|
||||||
|
deviceTokenSlowDown = "slow_down"
|
||||||
|
deviceTokenExpired = "expired_token"
|
||||||
|
)
|
||||||
|
|
||||||
func parseScopes(scopes []string) connector.Scopes {
|
func parseScopes(scopes []string) connector.Scopes {
|
||||||
var s connector.Scopes
|
var s connector.Scopes
|
||||||
for _, scope := range scopes {
|
for _, scope := range scopes {
|
||||||
@ -425,6 +437,9 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
|
|||||||
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
|
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
|
||||||
return nil, &authErr{"", "", errInvalidRequest, description}
|
return nil, &authErr{"", "", errInvalidRequest, description}
|
||||||
}
|
}
|
||||||
|
if redirectURI == deviceCallbackURI && client.Public {
|
||||||
|
redirectURI = s.issuerURL.Path + deviceCallbackURI
|
||||||
|
}
|
||||||
|
|
||||||
// From here on out, we want to redirect back to the client with an error.
|
// From here on out, we want to redirect back to the client with an error.
|
||||||
newErr := func(typ, format string, a ...interface{}) *authErr {
|
newErr := func(typ, format string, a ...interface{}) *authErr {
|
||||||
@ -566,7 +581,7 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if redirectURI == redirectURIOOB {
|
if redirectURI == redirectURIOOB || redirectURI == deviceCallbackURI {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,6 +78,7 @@ type Config struct {
|
|||||||
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
||||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||||
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
||||||
|
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
|
||||||
// If set, the server will use this connector to handle password grants
|
// If set, the server will use this connector to handle password grants
|
||||||
PasswordConnector string
|
PasswordConnector string
|
||||||
|
|
||||||
@ -158,6 +159,7 @@ type Server struct {
|
|||||||
|
|
||||||
idTokensValidFor time.Duration
|
idTokensValidFor time.Duration
|
||||||
authRequestsValidFor time.Duration
|
authRequestsValidFor time.Duration
|
||||||
|
deviceRequestsValidFor time.Duration
|
||||||
|
|
||||||
logger log.Logger
|
logger log.Logger
|
||||||
}
|
}
|
||||||
@ -219,6 +221,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||||||
supportedResponseTypes: supported,
|
supportedResponseTypes: supported,
|
||||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||||
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
||||||
|
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
|
||||||
skipApproval: c.SkipApprovalScreen,
|
skipApproval: c.SkipApprovalScreen,
|
||||||
alwaysShowLogin: c.AlwaysShowLoginScreen,
|
alwaysShowLogin: c.AlwaysShowLoginScreen,
|
||||||
now: now,
|
now: now,
|
||||||
@ -302,6 +305,11 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||||||
handleWithCORS("/userinfo", s.handleUserInfo)
|
handleWithCORS("/userinfo", s.handleUserInfo)
|
||||||
handleFunc("/auth", s.handleAuthorization)
|
handleFunc("/auth", s.handleAuthorization)
|
||||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||||
|
handleFunc("/device", s.handleDeviceExchange)
|
||||||
|
handleFunc("/device/auth/verify_code", s.verifyUserCode)
|
||||||
|
handleFunc("/device/code", s.handleDeviceCode)
|
||||||
|
handleFunc("/device/token", s.handleDeviceToken)
|
||||||
|
handleFunc(deviceCallbackURI, s.handleDeviceCallback)
|
||||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Strip the X-Remote-* headers to prevent security issues on
|
// Strip the X-Remote-* headers to prevent security issues on
|
||||||
// misconfigured authproxy connector setups.
|
// misconfigured authproxy connector setups.
|
||||||
@ -450,7 +458,8 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura
|
|||||||
if r, err := s.storage.GarbageCollect(now()); err != nil {
|
if r, err := s.storage.GarbageCollect(now()); err != nil {
|
||||||
s.logger.Errorf("garbage collection failed: %v", err)
|
s.logger.Errorf("garbage collection failed: %v", err)
|
||||||
} else if r.AuthRequests > 0 || r.AuthCodes > 0 {
|
} else if r.AuthRequests > 0 || r.AuthCodes > 0 {
|
||||||
s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes)
|
s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d, device requests =%d, device tokens=%d",
|
||||||
|
r.AuthRequests, r.AuthCodes, r.DeviceRequests, r.DeviceTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,11 +8,13 @@ import (
|
|||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
@ -203,41 +205,36 @@ func TestDiscovery(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestOAuth2CodeFlow runs integration tests against a test server. The tests stand up a server
|
type oauth2Tests struct {
|
||||||
// which requires no interaction to login, logs in through a test client, then passes the client
|
clientID string
|
||||||
// and returned token to the test.
|
tests []test
|
||||||
func TestOAuth2CodeFlow(t *testing.T) {
|
}
|
||||||
clientID := "testclient"
|
|
||||||
clientSecret := "testclientsecret"
|
type test struct {
|
||||||
|
name string
|
||||||
|
// If specified these set of scopes will be used during the test case.
|
||||||
|
scopes []string
|
||||||
|
// handleToken provides the OAuth2 token response for the integration test.
|
||||||
|
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token, *mock.Callback) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) oauth2Tests {
|
||||||
requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
|
requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
|
||||||
|
|
||||||
t0 := time.Now()
|
|
||||||
|
|
||||||
// Always have the time function used by the server return the same time so
|
|
||||||
// we can predict expected values of "expires_in" fields exactly.
|
|
||||||
now := func() time.Time { return t0 }
|
|
||||||
|
|
||||||
// Used later when configuring test servers to set how long id_tokens will be valid for.
|
// Used later when configuring test servers to set how long id_tokens will be valid for.
|
||||||
//
|
//
|
||||||
// The actual value of 30s is completely arbitrary. We just need to set a value
|
// The actual value of 30s is completely arbitrary. We just need to set a value
|
||||||
// so tests can compute the expected "expires_in" field.
|
// so tests can compute the expected "expires_in" field.
|
||||||
idTokensValidFor := time.Second * 30
|
idTokensValidFor := time.Second * 30
|
||||||
|
|
||||||
// Connector used by the tests.
|
|
||||||
var conn *mock.Callback
|
|
||||||
|
|
||||||
oidcConfig := &oidc.Config{SkipClientIDCheck: true}
|
oidcConfig := &oidc.Config{SkipClientIDCheck: true}
|
||||||
|
|
||||||
tests := []struct {
|
return oauth2Tests{
|
||||||
name string
|
clientID: clientID,
|
||||||
// If specified these set of scopes will be used during the test case.
|
tests: []test{
|
||||||
scopes []string
|
|
||||||
// handleToken provides the OAuth2 token response for the integration test.
|
|
||||||
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error
|
|
||||||
}{
|
|
||||||
{
|
{
|
||||||
name: "verify ID Token",
|
name: "verify ID Token",
|
||||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||||
idToken, ok := token.Extra("id_token").(string)
|
idToken, ok := token.Extra("id_token").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("no id token found")
|
return fmt.Errorf("no id token found")
|
||||||
@ -250,7 +247,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "fetch userinfo",
|
name: "fetch userinfo",
|
||||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||||
ui, err := p.UserInfo(ctx, config.TokenSource(ctx, token))
|
ui, err := p.UserInfo(ctx, config.TokenSource(ctx, token))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to fetch userinfo: %v", err)
|
return fmt.Errorf("failed to fetch userinfo: %v", err)
|
||||||
@ -263,7 +260,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "verify id token and oauth2 token expiry",
|
name: "verify id token and oauth2 token expiry",
|
||||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||||
expectedExpiry := now().Add(idTokensValidFor)
|
expectedExpiry := now().Add(idTokensValidFor)
|
||||||
|
|
||||||
timeEq := func(t1, t2 time.Time, within time.Duration) bool {
|
timeEq := func(t1, t2 time.Time, within time.Duration) bool {
|
||||||
@ -290,7 +287,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "verify at_hash",
|
name: "verify at_hash",
|
||||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||||
rawIDToken, ok := token.Extra("id_token").(string)
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("no id token found")
|
return fmt.Errorf("no id token found")
|
||||||
@ -322,7 +319,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "refresh token",
|
name: "refresh token",
|
||||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||||
// have to use time.Now because the OAuth2 package uses it.
|
// have to use time.Now because the OAuth2 package uses it.
|
||||||
token.Expiry = time.Now().Add(time.Second * -10)
|
token.Expiry = time.Now().Add(time.Second * -10)
|
||||||
if token.Valid() {
|
if token.Valid() {
|
||||||
@ -345,7 +342,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "refresh with explicit scopes",
|
name: "refresh with explicit scopes",
|
||||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||||
v := url.Values{}
|
v := url.Values{}
|
||||||
v.Add("client_id", clientID)
|
v.Add("client_id", clientID)
|
||||||
v.Add("client_secret", clientSecret)
|
v.Add("client_secret", clientSecret)
|
||||||
@ -369,7 +366,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "refresh with extra spaces",
|
name: "refresh with extra spaces",
|
||||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||||
v := url.Values{}
|
v := url.Values{}
|
||||||
v.Add("client_id", clientID)
|
v.Add("client_id", clientID)
|
||||||
v.Add("client_secret", clientSecret)
|
v.Add("client_secret", clientSecret)
|
||||||
@ -398,7 +395,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "refresh with unauthorized scopes",
|
name: "refresh with unauthorized scopes",
|
||||||
scopes: []string{"openid", "email"},
|
scopes: []string{"openid", "email"},
|
||||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||||
v := url.Values{}
|
v := url.Values{}
|
||||||
v.Add("client_id", clientID)
|
v.Add("client_id", clientID)
|
||||||
v.Add("client_secret", clientSecret)
|
v.Add("client_secret", clientSecret)
|
||||||
@ -425,7 +422,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||||||
// This test ensures that the connector.RefreshConnector interface is being
|
// This test ensures that the connector.RefreshConnector interface is being
|
||||||
// used when clients request a refresh token.
|
// used when clients request a refresh token.
|
||||||
name: "refresh with identity changes",
|
name: "refresh with identity changes",
|
||||||
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
|
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
|
||||||
// have to use time.Now because the OAuth2 package uses it.
|
// have to use time.Now because the OAuth2 package uses it.
|
||||||
token.Expiry = time.Now().Add(time.Second * -10)
|
token.Expiry = time.Now().Add(time.Second * -10)
|
||||||
if token.Valid() {
|
if token.Valid() {
|
||||||
@ -472,9 +469,35 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
// TestOAuth2CodeFlow runs integration tests against a test server. The tests stand up a server
|
||||||
|
// which requires no interaction to login, logs in through a test client, then passes the client
|
||||||
|
// and returned token to the test.
|
||||||
|
func TestOAuth2CodeFlow(t *testing.T) {
|
||||||
|
clientID := "testclient"
|
||||||
|
clientSecret := "testclientsecret"
|
||||||
|
requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
|
||||||
|
|
||||||
|
t0 := time.Now()
|
||||||
|
|
||||||
|
// Always have the time function used by the server return the same time so
|
||||||
|
// we can predict expected values of "expires_in" fields exactly.
|
||||||
|
now := func() time.Time { return t0 }
|
||||||
|
|
||||||
|
// Used later when configuring test servers to set how long id_tokens will be valid for.
|
||||||
|
//
|
||||||
|
// The actual value of 30s is completely arbitrary. We just need to set a value
|
||||||
|
// so tests can compute the expected "expires_in" field.
|
||||||
|
idTokensValidFor := time.Second * 30
|
||||||
|
|
||||||
|
// Connector used by the tests.
|
||||||
|
var conn *mock.Callback
|
||||||
|
|
||||||
|
tests := makeOAuth2Tests(clientID, clientSecret, now)
|
||||||
|
for _, tc := range tests.tests {
|
||||||
func() {
|
func() {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@ -540,7 +563,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
|
|||||||
t.Errorf("failed to exchange code for token: %v", err)
|
t.Errorf("failed to exchange code for token: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = tc.handleToken(ctx, p, oauth2Config, token)
|
err = tc.handleToken(ctx, p, oauth2Config, token, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%s: %v", tc.name, err)
|
t.Errorf("%s: %v", tc.name, err)
|
||||||
}
|
}
|
||||||
@ -1253,3 +1276,157 @@ func TestRefreshTokenFlow(t *testing.T) {
|
|||||||
t.Errorf("Token refreshed with invalid refresh token, error expected.")
|
t.Errorf("Token refreshed with invalid refresh token, error expected.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestOAuth2DeviceFlow runs device flow integration tests against a test server
|
||||||
|
func TestOAuth2DeviceFlow(t *testing.T) {
|
||||||
|
clientID := "testclient"
|
||||||
|
clientSecret := ""
|
||||||
|
requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
|
||||||
|
|
||||||
|
t0 := time.Now()
|
||||||
|
|
||||||
|
// Always have the time function used by the server return the same time so
|
||||||
|
// we can predict expected values of "expires_in" fields exactly.
|
||||||
|
now := func() time.Time { return t0 }
|
||||||
|
|
||||||
|
// Connector used by the tests.
|
||||||
|
var conn *mock.Callback
|
||||||
|
idTokensValidFor := time.Second * 30
|
||||||
|
|
||||||
|
for _, tc := range makeOAuth2Tests(clientID, clientSecret, now).tests {
|
||||||
|
func() {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Setup a dex server.
|
||||||
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||||
|
c.Issuer = c.Issuer + "/non-root-path"
|
||||||
|
c.Now = now
|
||||||
|
c.IDTokensValidFor = idTokensValidFor
|
||||||
|
})
|
||||||
|
defer httpServer.Close()
|
||||||
|
|
||||||
|
mockConn := s.connectors["mock"]
|
||||||
|
conn = mockConn.Connector.(*mock.Callback)
|
||||||
|
|
||||||
|
p, err := oidc.NewProvider(ctx, httpServer.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get provider: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Add the Clients to the test server
|
||||||
|
client := storage.Client{
|
||||||
|
ID: clientID,
|
||||||
|
RedirectURIs: []string{deviceCallbackURI},
|
||||||
|
Public: true,
|
||||||
|
}
|
||||||
|
if err := s.storage.CreateClient(client); err != nil {
|
||||||
|
t.Fatalf("failed to create client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Grab the issuer that we'll reuse for the different endpoints to hit
|
||||||
|
issuer, err := url.Parse(s.issuerURL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not parse issuer URL %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Send a new Device Request
|
||||||
|
codeURL, _ := url.Parse(issuer.String())
|
||||||
|
codeURL.Path = path.Join(codeURL.Path, "device/code")
|
||||||
|
|
||||||
|
data := url.Values{}
|
||||||
|
data.Set("client_id", clientID)
|
||||||
|
data.Add("scope", strings.Join(requestedScopes, " "))
|
||||||
|
resp, err := http.PostForm(codeURL.String(), data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not request device code: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could read device code response %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
//Parse the code response
|
||||||
|
var deviceCode deviceCodeResponse
|
||||||
|
if err := json.Unmarshal(responseBody, &deviceCode); err != nil {
|
||||||
|
t.Errorf("Unexpected Device Code Response Format %v", string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
//Mock the user hitting the verification URI and posting the form
|
||||||
|
verifyURL, _ := url.Parse(issuer.String())
|
||||||
|
verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code")
|
||||||
|
urlData := url.Values{}
|
||||||
|
urlData.Set("user_code", deviceCode.UserCode)
|
||||||
|
resp, err = http.PostForm(verifyURL.String(), urlData)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error Posting Form: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err = ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could read verification response %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("%v - Unexpected Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
//Hit the Token Endpoint, and try and get an access token
|
||||||
|
tokenURL, _ := url.Parse(issuer.String())
|
||||||
|
tokenURL.Path = path.Join(tokenURL.Path, "/device/token")
|
||||||
|
v := url.Values{}
|
||||||
|
v.Add("grant_type", grantTypeDeviceCode)
|
||||||
|
v.Add("device_code", deviceCode.DeviceCode)
|
||||||
|
resp, err = http.PostForm(tokenURL.String(), v)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could not request device token: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
responseBody, err = ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Could read device token response %v", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("%v - Unexpected Token Response Type. Expected 200 got %v. Response: %v", tc.name, resp.StatusCode, string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
//Parse the response
|
||||||
|
var tokenRes accessTokenReponse
|
||||||
|
if err := json.Unmarshal(responseBody, &tokenRes); err != nil {
|
||||||
|
t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
token := &oauth2.Token{
|
||||||
|
AccessToken: tokenRes.AccessToken,
|
||||||
|
TokenType: tokenRes.TokenType,
|
||||||
|
RefreshToken: tokenRes.RefreshToken,
|
||||||
|
}
|
||||||
|
raw := make(map[string]interface{})
|
||||||
|
json.Unmarshal(responseBody, &raw) // no error checks for optional fields
|
||||||
|
token = token.WithExtra(raw)
|
||||||
|
if secs := tokenRes.ExpiresIn; secs > 0 {
|
||||||
|
token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Run token tests to validate info is correct
|
||||||
|
// Create the OAuth2 config.
|
||||||
|
oauth2Config := &oauth2.Config{
|
||||||
|
ClientID: client.ID,
|
||||||
|
ClientSecret: client.Secret,
|
||||||
|
Endpoint: p.Endpoint(),
|
||||||
|
Scopes: requestedScopes,
|
||||||
|
RedirectURL: deviceCallbackURI,
|
||||||
|
}
|
||||||
|
if len(tc.scopes) != 0 {
|
||||||
|
oauth2Config.Scopes = tc.scopes
|
||||||
|
}
|
||||||
|
err = tc.handleToken(ctx, p, oauth2Config, token, conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%s: %v", tc.name, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -20,6 +20,8 @@ const (
|
|||||||
tmplPassword = "password.html"
|
tmplPassword = "password.html"
|
||||||
tmplOOB = "oob.html"
|
tmplOOB = "oob.html"
|
||||||
tmplError = "error.html"
|
tmplError = "error.html"
|
||||||
|
tmplDevice = "device.html"
|
||||||
|
tmplDeviceSuccess = "device_success.html"
|
||||||
)
|
)
|
||||||
|
|
||||||
var requiredTmpls = []string{
|
var requiredTmpls = []string{
|
||||||
@ -28,6 +30,8 @@ var requiredTmpls = []string{
|
|||||||
tmplPassword,
|
tmplPassword,
|
||||||
tmplOOB,
|
tmplOOB,
|
||||||
tmplError,
|
tmplError,
|
||||||
|
tmplDevice,
|
||||||
|
tmplDeviceSuccess,
|
||||||
}
|
}
|
||||||
|
|
||||||
type templates struct {
|
type templates struct {
|
||||||
@ -36,6 +40,8 @@ type templates struct {
|
|||||||
passwordTmpl *template.Template
|
passwordTmpl *template.Template
|
||||||
oobTmpl *template.Template
|
oobTmpl *template.Template
|
||||||
errorTmpl *template.Template
|
errorTmpl *template.Template
|
||||||
|
deviceTmpl *template.Template
|
||||||
|
deviceSuccessTmpl *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
type webConfig struct {
|
type webConfig struct {
|
||||||
@ -157,6 +163,8 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
|
|||||||
passwordTmpl: tmpls.Lookup(tmplPassword),
|
passwordTmpl: tmpls.Lookup(tmplPassword),
|
||||||
oobTmpl: tmpls.Lookup(tmplOOB),
|
oobTmpl: tmpls.Lookup(tmplOOB),
|
||||||
errorTmpl: tmpls.Lookup(tmplError),
|
errorTmpl: tmpls.Lookup(tmplError),
|
||||||
|
deviceTmpl: tmpls.Lookup(tmplDevice),
|
||||||
|
deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -242,6 +250,27 @@ func (n byName) Len() int { return len(n) }
|
|||||||
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
|
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
|
||||||
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
|
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
|
||||||
|
|
||||||
|
func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error {
|
||||||
|
if lastWasInvalid {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
data := struct {
|
||||||
|
PostURL string
|
||||||
|
UserCode string
|
||||||
|
Invalid bool
|
||||||
|
ReqPath string
|
||||||
|
}{postURL, userCode, lastWasInvalid, r.URL.Path}
|
||||||
|
return renderTemplate(w, t.deviceTmpl, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *templates) deviceSuccess(r *http.Request, w http.ResponseWriter, clientName string) error {
|
||||||
|
data := struct {
|
||||||
|
ClientName string
|
||||||
|
ReqPath string
|
||||||
|
}{clientName, r.URL.Path}
|
||||||
|
return renderTemplate(w, t.deviceSuccessTmpl, data)
|
||||||
|
}
|
||||||
|
|
||||||
func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
|
func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
|
||||||
sort.Sort(byName(connectors))
|
sort.Sort(byName(connectors))
|
||||||
data := struct {
|
data := struct {
|
||||||
|
@ -49,6 +49,8 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
|
|||||||
{"ConnectorCRUD", testConnectorCRUD},
|
{"ConnectorCRUD", testConnectorCRUD},
|
||||||
{"GarbageCollection", testGC},
|
{"GarbageCollection", testGC},
|
||||||
{"TimezoneSupport", testTimezones},
|
{"TimezoneSupport", testTimezones},
|
||||||
|
{"DeviceRequestCRUD", testDeviceRequestCRUD},
|
||||||
|
{"DeviceTokenCRUD", testDeviceTokenCRUD},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -834,6 +836,87 @@ func testGC(t *testing.T, s storage.Storage) {
|
|||||||
} else if err != storage.ErrNotFound {
|
} else if err != storage.ErrNotFound {
|
||||||
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
userCode, err := storage.NewUserCode()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected Error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
d := storage.DeviceRequest{
|
||||||
|
UserCode: userCode,
|
||||||
|
DeviceCode: storage.NewID(),
|
||||||
|
ClientID: "client1",
|
||||||
|
ClientSecret: "secret1",
|
||||||
|
Scopes: []string{"openid", "email"},
|
||||||
|
Expiry: expiry,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateDeviceRequest(d); err != nil {
|
||||||
|
t.Fatalf("failed creating device request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tz := range []*time.Location{time.UTC, est, pst} {
|
||||||
|
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
} else {
|
||||||
|
if result.DeviceRequests != 0 {
|
||||||
|
t.Errorf("expected no device garbage collection results, got %#v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
|
||||||
|
t.Errorf("expected to be able to get auth request after GC: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||||
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
} else if r.DeviceRequests != 1 {
|
||||||
|
t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
|
||||||
|
t.Errorf("expected device request to be GC'd")
|
||||||
|
} else if err != storage.ErrNotFound {
|
||||||
|
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dt := storage.DeviceToken{
|
||||||
|
DeviceCode: storage.NewID(),
|
||||||
|
Status: "pending",
|
||||||
|
Token: "foo",
|
||||||
|
Expiry: expiry,
|
||||||
|
LastRequestTime: time.Now(),
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateDeviceToken(dt); err != nil {
|
||||||
|
t.Fatalf("failed creating device token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tz := range []*time.Location{time.UTC, est, pst} {
|
||||||
|
result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
} else {
|
||||||
|
if result.DeviceTokens != 0 {
|
||||||
|
t.Errorf("expected no device token garbage collection results, got %#v", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := s.GetDeviceToken(dt.DeviceCode); err != nil {
|
||||||
|
t.Errorf("expected to be able to get device token after GC: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
|
||||||
|
t.Errorf("garbage collection failed: %v", err)
|
||||||
|
} else if r.DeviceTokens != 1 {
|
||||||
|
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.GetDeviceToken(dt.DeviceCode); err == nil {
|
||||||
|
t.Errorf("expected device token to be GC'd")
|
||||||
|
} else if err != storage.ErrNotFound {
|
||||||
|
t.Errorf("expected storage.ErrNotFound, got %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// testTimezones tests that backends either fully support timezones or
|
// testTimezones tests that backends either fully support timezones or
|
||||||
@ -881,3 +964,72 @@ func testTimezones(t *testing.T, s storage.Storage) {
|
|||||||
t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
|
t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
||||||
|
userCode, err := storage.NewUserCode()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
d1 := storage.DeviceRequest{
|
||||||
|
UserCode: userCode,
|
||||||
|
DeviceCode: storage.NewID(),
|
||||||
|
ClientID: "client1",
|
||||||
|
ClientSecret: "secret1",
|
||||||
|
Scopes: []string{"openid", "email"},
|
||||||
|
Expiry: neverExpire,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateDeviceRequest(d1); err != nil {
|
||||||
|
t.Fatalf("failed creating device request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to create same DeviceRequest twice.
|
||||||
|
err = s.CreateDeviceRequest(d1)
|
||||||
|
mustBeErrAlreadyExists(t, "device request", err)
|
||||||
|
|
||||||
|
//No manual deletes for device requests, will be handled by garbage collection routines
|
||||||
|
//see testGC
|
||||||
|
}
|
||||||
|
|
||||||
|
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
|
||||||
|
//Create a Token
|
||||||
|
d1 := storage.DeviceToken{
|
||||||
|
DeviceCode: storage.NewID(),
|
||||||
|
Status: "pending",
|
||||||
|
Token: storage.NewID(),
|
||||||
|
Expiry: neverExpire,
|
||||||
|
LastRequestTime: time.Now(),
|
||||||
|
PollIntervalSeconds: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateDeviceToken(d1); err != nil {
|
||||||
|
t.Fatalf("failed creating device token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to create same Device Token twice.
|
||||||
|
err := s.CreateDeviceToken(d1)
|
||||||
|
mustBeErrAlreadyExists(t, "device token", err)
|
||||||
|
|
||||||
|
//Update the device token, simulate a redemption
|
||||||
|
if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||||
|
old.Token = "token data"
|
||||||
|
old.Status = "complete"
|
||||||
|
return old, nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("failed to update device token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Retrieve the device token
|
||||||
|
got, err := s.GetDeviceToken(d1.DeviceCode)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get device token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Validate expected result set
|
||||||
|
if got.Status != "complete" {
|
||||||
|
t.Fatalf("update failed, wanted token status=%v got %v", "complete", got.Status)
|
||||||
|
}
|
||||||
|
if got.Token != "token data" {
|
||||||
|
t.Fatalf("update failed, wanted token %v got %v", "token data", got.Token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -22,6 +22,8 @@ const (
|
|||||||
offlineSessionPrefix = "offline_session/"
|
offlineSessionPrefix = "offline_session/"
|
||||||
connectorPrefix = "connector/"
|
connectorPrefix = "connector/"
|
||||||
keysName = "openid-connect-keys"
|
keysName = "openid-connect-keys"
|
||||||
|
deviceRequestPrefix = "device_req/"
|
||||||
|
deviceTokenPrefix = "device_token/"
|
||||||
|
|
||||||
// defaultStorageTimeout will be applied to all storage's operations.
|
// defaultStorageTimeout will be applied to all storage's operations.
|
||||||
defaultStorageTimeout = 5 * time.Second
|
defaultStorageTimeout = 5 * time.Second
|
||||||
@ -72,6 +74,36 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
|
|||||||
result.AuthCodes++
|
result.AuthCodes++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceRequests, err := c.listDeviceRequests(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, deviceRequest := range deviceRequests {
|
||||||
|
if now.After(deviceRequest.Expiry) {
|
||||||
|
if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil {
|
||||||
|
c.logger.Errorf("failed to delete device request %v", err)
|
||||||
|
delErr = fmt.Errorf("failed to delete device request: %v", err)
|
||||||
|
}
|
||||||
|
result.DeviceRequests++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
deviceTokens, err := c.listDeviceTokens(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return result, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, deviceToken := range deviceTokens {
|
||||||
|
if now.After(deviceToken.Expiry) {
|
||||||
|
if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil {
|
||||||
|
c.logger.Errorf("failed to delete device token %v", err)
|
||||||
|
delErr = fmt.Errorf("failed to delete device token: %v", err)
|
||||||
|
}
|
||||||
|
result.DeviceTokens++
|
||||||
|
}
|
||||||
|
}
|
||||||
return result, delErr
|
return result, delErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -531,3 +563,77 @@ func keyEmail(prefix, email string) string { return prefix + strings.ToLower(ema
|
|||||||
func keySession(prefix, userID, connID string) string {
|
func keySession(prefix, userID, connID string) string {
|
||||||
return prefix + strings.ToLower(userID+"|"+connID)
|
return prefix + strings.ToLower(userID+"|"+connID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||||
|
defer cancel()
|
||||||
|
err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r)
|
||||||
|
return r, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
|
||||||
|
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
|
||||||
|
if err != nil {
|
||||||
|
return requests, err
|
||||||
|
}
|
||||||
|
for _, v := range res.Kvs {
|
||||||
|
var r DeviceRequest
|
||||||
|
if err = json.Unmarshal(v.Value, &r); err != nil {
|
||||||
|
return requests, err
|
||||||
|
}
|
||||||
|
requests = append(requests, r)
|
||||||
|
}
|
||||||
|
return requests, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||||
|
defer cancel()
|
||||||
|
err = c.getKey(ctx, keyID(deviceTokenPrefix, deviceCode), &t)
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken, err error) {
|
||||||
|
res, err := c.db.Get(ctx, deviceTokenPrefix, clientv3.WithPrefix())
|
||||||
|
if err != nil {
|
||||||
|
return deviceTokens, err
|
||||||
|
}
|
||||||
|
for _, v := range res.Kvs {
|
||||||
|
var dt DeviceToken
|
||||||
|
if err = json.Unmarshal(v.Value, &dt); err != nil {
|
||||||
|
return deviceTokens, err
|
||||||
|
}
|
||||||
|
deviceTokens = append(deviceTokens, dt)
|
||||||
|
}
|
||||||
|
return deviceTokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) {
|
||||||
|
var current DeviceToken
|
||||||
|
if len(currentValue) > 0 {
|
||||||
|
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
updated, err := updater(toStorageDeviceToken(current))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return json.Marshal(fromStorageDeviceToken(updated))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -44,6 +44,8 @@ func cleanDB(c *conn) error {
|
|||||||
passwordPrefix,
|
passwordPrefix,
|
||||||
offlineSessionPrefix,
|
offlineSessionPrefix,
|
||||||
connectorPrefix,
|
connectorPrefix,
|
||||||
|
deviceRequestPrefix,
|
||||||
|
deviceTokenPrefix,
|
||||||
} {
|
} {
|
||||||
_, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix())
|
_, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -216,3 +216,56 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
|||||||
}
|
}
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeviceRequest is a mirrored struct from storage with JSON struct tags
|
||||||
|
type DeviceRequest struct {
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
DeviceCode string `json:"device_code"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
ClientSecret string `json:"client_secret"`
|
||||||
|
Scopes []string `json:"scopes"`
|
||||||
|
Expiry time.Time `json:"expiry"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest {
|
||||||
|
return DeviceRequest{
|
||||||
|
UserCode: d.UserCode,
|
||||||
|
DeviceCode: d.DeviceCode,
|
||||||
|
ClientID: d.ClientID,
|
||||||
|
ClientSecret: d.ClientSecret,
|
||||||
|
Scopes: d.Scopes,
|
||||||
|
Expiry: d.Expiry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceToken is a mirrored struct from storage with JSON struct tags
|
||||||
|
type DeviceToken struct {
|
||||||
|
DeviceCode string `json:"device_code"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
Expiry time.Time `json:"expiry"`
|
||||||
|
LastRequestTime time.Time `json:"last_request"`
|
||||||
|
PollIntervalSeconds int `json:"poll_interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||||
|
return DeviceToken{
|
||||||
|
DeviceCode: t.DeviceCode,
|
||||||
|
Status: t.Status,
|
||||||
|
Token: t.Token,
|
||||||
|
Expiry: t.Expiry,
|
||||||
|
LastRequestTime: t.LastRequestTime,
|
||||||
|
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
|
||||||
|
return storage.DeviceToken{
|
||||||
|
DeviceCode: t.DeviceCode,
|
||||||
|
Status: t.Status,
|
||||||
|
Token: t.Token,
|
||||||
|
Expiry: t.Expiry,
|
||||||
|
LastRequestTime: t.LastRequestTime,
|
||||||
|
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -21,6 +21,8 @@ const (
|
|||||||
kindPassword = "Password"
|
kindPassword = "Password"
|
||||||
kindOfflineSessions = "OfflineSessions"
|
kindOfflineSessions = "OfflineSessions"
|
||||||
kindConnector = "Connector"
|
kindConnector = "Connector"
|
||||||
|
kindDeviceRequest = "DeviceRequest"
|
||||||
|
kindDeviceToken = "DeviceToken"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -32,6 +34,8 @@ const (
|
|||||||
resourcePassword = "passwords"
|
resourcePassword = "passwords"
|
||||||
resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize.
|
resourceOfflineSessions = "offlinesessionses" // Again attempts to pluralize.
|
||||||
resourceConnector = "connectors"
|
resourceConnector = "connectors"
|
||||||
|
resourceDeviceRequest = "devicerequests"
|
||||||
|
resourceDeviceToken = "devicetokens"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config values for the Kubernetes storage type.
|
// Config values for the Kubernetes storage type.
|
||||||
@ -593,5 +597,84 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e
|
|||||||
result.AuthCodes++
|
result.AuthCodes++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var deviceRequests DeviceRequestList
|
||||||
|
if err := cli.list(resourceDeviceRequest, &deviceRequests); err != nil {
|
||||||
|
return result, fmt.Errorf("failed to list device requests: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, deviceRequest := range deviceRequests.DeviceRequests {
|
||||||
|
if now.After(deviceRequest.Expiry) {
|
||||||
|
if err := cli.delete(resourceDeviceRequest, deviceRequest.ObjectMeta.Name); err != nil {
|
||||||
|
cli.logger.Errorf("failed to delete device request: %v", err)
|
||||||
|
delErr = fmt.Errorf("failed to delete device request: %v", err)
|
||||||
|
}
|
||||||
|
result.DeviceRequests++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var deviceTokens DeviceTokenList
|
||||||
|
if err := cli.list(resourceDeviceToken, &deviceTokens); err != nil {
|
||||||
|
return result, fmt.Errorf("failed to list device tokens: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, deviceToken := range deviceTokens.DeviceTokens {
|
||||||
|
if now.After(deviceToken.Expiry) {
|
||||||
|
if err := cli.delete(resourceDeviceToken, deviceToken.ObjectMeta.Name); err != nil {
|
||||||
|
cli.logger.Errorf("failed to delete device token: %v", err)
|
||||||
|
delErr = fmt.Errorf("failed to delete device token: %v", err)
|
||||||
|
}
|
||||||
|
result.DeviceTokens++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if delErr != nil {
|
||||||
return result, delErr
|
return result, delErr
|
||||||
}
|
}
|
||||||
|
return result, delErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||||
|
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||||
|
var req DeviceRequest
|
||||||
|
if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil {
|
||||||
|
return storage.DeviceRequest{}, err
|
||||||
|
}
|
||||||
|
return toStorageDeviceRequest(req), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
|
||||||
|
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||||
|
var token DeviceToken
|
||||||
|
if err := cli.get(resourceDeviceToken, deviceCode, &token); err != nil {
|
||||||
|
return storage.DeviceToken{}, err
|
||||||
|
}
|
||||||
|
return toStorageDeviceToken(token), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) {
|
||||||
|
err = cli.get(resourceDeviceToken, deviceCode, &t)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||||
|
r, err := cli.getDeviceToken(deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updated, err := updater(toStorageDeviceToken(r))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
updated.DeviceCode = deviceCode
|
||||||
|
|
||||||
|
newToken := cli.fromStorageDeviceToken(updated)
|
||||||
|
newToken.ObjectMeta = r.ObjectMeta
|
||||||
|
return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken)
|
||||||
|
}
|
||||||
|
@ -85,6 +85,8 @@ func (s *StorageTestSuite) TestStorage() {
|
|||||||
for _, resource := range []string{
|
for _, resource := range []string{
|
||||||
resourceAuthCode,
|
resourceAuthCode,
|
||||||
resourceAuthRequest,
|
resourceAuthRequest,
|
||||||
|
resourceDeviceRequest,
|
||||||
|
resourceDeviceToken,
|
||||||
resourceClient,
|
resourceClient,
|
||||||
resourceRefreshToken,
|
resourceRefreshToken,
|
||||||
resourceKeys,
|
resourceKeys,
|
||||||
|
@ -143,6 +143,36 @@ var customResourceDefinitions = []k8sapi.CustomResourceDefinition{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ObjectMeta: k8sapi.ObjectMeta{
|
||||||
|
Name: "devicerequests.dex.coreos.com",
|
||||||
|
},
|
||||||
|
TypeMeta: crdMeta,
|
||||||
|
Spec: k8sapi.CustomResourceDefinitionSpec{
|
||||||
|
Group: apiGroup,
|
||||||
|
Version: "v1",
|
||||||
|
Names: k8sapi.CustomResourceDefinitionNames{
|
||||||
|
Plural: "devicerequests",
|
||||||
|
Singular: "devicerequest",
|
||||||
|
Kind: "DeviceRequest",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ObjectMeta: k8sapi.ObjectMeta{
|
||||||
|
Name: "devicetokens.dex.coreos.com",
|
||||||
|
},
|
||||||
|
TypeMeta: crdMeta,
|
||||||
|
Spec: k8sapi.CustomResourceDefinitionSpec{
|
||||||
|
Group: apiGroup,
|
||||||
|
Version: "v1",
|
||||||
|
Names: k8sapi.CustomResourceDefinitionNames{
|
||||||
|
Plural: "devicetokens",
|
||||||
|
Singular: "devicetoken",
|
||||||
|
Kind: "DeviceToken",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// There will only ever be a single keys resource. Maintain this by setting a
|
// There will only ever be a single keys resource. Maintain this by setting a
|
||||||
@ -635,3 +665,103 @@ type ConnectorList struct {
|
|||||||
k8sapi.ListMeta `json:"metadata,omitempty"`
|
k8sapi.ListMeta `json:"metadata,omitempty"`
|
||||||
Connectors []Connector `json:"items"`
|
Connectors []Connector `json:"items"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeviceRequest is a mirrored struct from storage with JSON struct tags and
|
||||||
|
// Kubernetes type metadata.
|
||||||
|
type DeviceRequest struct {
|
||||||
|
k8sapi.TypeMeta `json:",inline"`
|
||||||
|
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
||||||
|
|
||||||
|
DeviceCode string `json:"device_code,omitempty"`
|
||||||
|
ClientID string `json:"client_id,omitempty"`
|
||||||
|
ClientSecret string `json:"client_secret,omitempty"`
|
||||||
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
|
Expiry time.Time `json:"expiry"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthRequestList is a list of AuthRequests.
|
||||||
|
type DeviceRequestList struct {
|
||||||
|
k8sapi.TypeMeta `json:",inline"`
|
||||||
|
k8sapi.ListMeta `json:"metadata,omitempty"`
|
||||||
|
DeviceRequests []DeviceRequest `json:"items"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceRequest {
|
||||||
|
req := DeviceRequest{
|
||||||
|
TypeMeta: k8sapi.TypeMeta{
|
||||||
|
Kind: kindDeviceRequest,
|
||||||
|
APIVersion: cli.apiVersion,
|
||||||
|
},
|
||||||
|
ObjectMeta: k8sapi.ObjectMeta{
|
||||||
|
Name: strings.ToLower(a.UserCode),
|
||||||
|
Namespace: cli.namespace,
|
||||||
|
},
|
||||||
|
DeviceCode: a.DeviceCode,
|
||||||
|
ClientID: a.ClientID,
|
||||||
|
ClientSecret: a.ClientSecret,
|
||||||
|
Scopes: a.Scopes,
|
||||||
|
Expiry: a.Expiry,
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest {
|
||||||
|
return storage.DeviceRequest{
|
||||||
|
UserCode: strings.ToUpper(req.ObjectMeta.Name),
|
||||||
|
DeviceCode: req.DeviceCode,
|
||||||
|
ClientID: req.ClientID,
|
||||||
|
ClientSecret: req.ClientSecret,
|
||||||
|
Scopes: req.Scopes,
|
||||||
|
Expiry: req.Expiry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceToken is a mirrored struct from storage with JSON struct tags and
|
||||||
|
// Kubernetes type metadata.
|
||||||
|
type DeviceToken struct {
|
||||||
|
k8sapi.TypeMeta `json:",inline"`
|
||||||
|
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
||||||
|
|
||||||
|
Status string `json:"status,omitempty"`
|
||||||
|
Token string `json:"token,omitempty"`
|
||||||
|
Expiry time.Time `json:"expiry"`
|
||||||
|
LastRequestTime time.Time `json:"last_request"`
|
||||||
|
PollIntervalSeconds int `json:"poll_interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceTokenList is a list of DeviceTokens.
|
||||||
|
type DeviceTokenList struct {
|
||||||
|
k8sapi.TypeMeta `json:",inline"`
|
||||||
|
k8sapi.ListMeta `json:"metadata,omitempty"`
|
||||||
|
DeviceTokens []DeviceToken `json:"items"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||||
|
req := DeviceToken{
|
||||||
|
TypeMeta: k8sapi.TypeMeta{
|
||||||
|
Kind: kindDeviceToken,
|
||||||
|
APIVersion: cli.apiVersion,
|
||||||
|
},
|
||||||
|
ObjectMeta: k8sapi.ObjectMeta{
|
||||||
|
Name: t.DeviceCode,
|
||||||
|
Namespace: cli.namespace,
|
||||||
|
},
|
||||||
|
Status: t.Status,
|
||||||
|
Token: t.Token,
|
||||||
|
Expiry: t.Expiry,
|
||||||
|
LastRequestTime: t.LastRequestTime,
|
||||||
|
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
|
||||||
|
return storage.DeviceToken{
|
||||||
|
DeviceCode: t.ObjectMeta.Name,
|
||||||
|
Status: t.Status,
|
||||||
|
Token: t.Token,
|
||||||
|
Expiry: t.Expiry,
|
||||||
|
LastRequestTime: t.LastRequestTime,
|
||||||
|
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -20,6 +20,8 @@ func New(logger log.Logger) storage.Storage {
|
|||||||
passwords: make(map[string]storage.Password),
|
passwords: make(map[string]storage.Password),
|
||||||
offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
|
offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
|
||||||
connectors: make(map[string]storage.Connector),
|
connectors: make(map[string]storage.Connector),
|
||||||
|
deviceRequests: make(map[string]storage.DeviceRequest),
|
||||||
|
deviceTokens: make(map[string]storage.DeviceToken),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -46,6 +48,8 @@ type memStorage struct {
|
|||||||
passwords map[string]storage.Password
|
passwords map[string]storage.Password
|
||||||
offlineSessions map[offlineSessionID]storage.OfflineSessions
|
offlineSessions map[offlineSessionID]storage.OfflineSessions
|
||||||
connectors map[string]storage.Connector
|
connectors map[string]storage.Connector
|
||||||
|
deviceRequests map[string]storage.DeviceRequest
|
||||||
|
deviceTokens map[string]storage.DeviceToken
|
||||||
|
|
||||||
keys storage.Keys
|
keys storage.Keys
|
||||||
|
|
||||||
@ -79,6 +83,18 @@ func (s *memStorage) GarbageCollect(now time.Time) (result storage.GCResult, err
|
|||||||
result.AuthRequests++
|
result.AuthRequests++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for id, a := range s.deviceRequests {
|
||||||
|
if now.After(a.Expiry) {
|
||||||
|
delete(s.deviceRequests, id)
|
||||||
|
result.DeviceRequests++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for id, a := range s.deviceTokens {
|
||||||
|
if now.After(a.Expiry) {
|
||||||
|
delete(s.deviceTokens, id)
|
||||||
|
result.DeviceTokens++
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
@ -465,3 +481,61 @@ func (s *memStorage) UpdateConnector(id string, updater func(c storage.Connector
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
if _, ok := s.deviceRequests[d.UserCode]; ok {
|
||||||
|
err = storage.ErrAlreadyExists
|
||||||
|
} else {
|
||||||
|
s.deviceRequests[d.UserCode] = d
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
var ok bool
|
||||||
|
if req, ok = s.deviceRequests[userCode]; !ok {
|
||||||
|
err = storage.ErrNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
if _, ok := s.deviceTokens[t.DeviceCode]; ok {
|
||||||
|
err = storage.ErrAlreadyExists
|
||||||
|
} else {
|
||||||
|
s.deviceTokens[t.DeviceCode] = t
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
var ok bool
|
||||||
|
if t, ok = s.deviceTokens[deviceCode]; !ok {
|
||||||
|
err = storage.ErrNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
r, ok := s.deviceTokens[deviceCode]
|
||||||
|
if !ok {
|
||||||
|
err = storage.ErrNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r, err = updater(r); err == nil {
|
||||||
|
s.deviceTokens[deviceCode] = r
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -100,6 +100,23 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
|
|||||||
if n, err := r.RowsAffected(); err == nil {
|
if n, err := r.RowsAffected(); err == nil {
|
||||||
result.AuthCodes = n
|
result.AuthCodes = n
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r, err = c.Exec(`delete from device_request where expiry < $1`, now)
|
||||||
|
if err != nil {
|
||||||
|
return result, fmt.Errorf("gc device_request: %v", err)
|
||||||
|
}
|
||||||
|
if n, err := r.RowsAffected(); err == nil {
|
||||||
|
result.DeviceRequests = n
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err = c.Exec(`delete from device_token where expiry < $1`, now)
|
||||||
|
if err != nil {
|
||||||
|
return result, fmt.Errorf("gc device_token: %v", err)
|
||||||
|
}
|
||||||
|
if n, err := r.RowsAffected(); err == nil {
|
||||||
|
result.DeviceTokens = n
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -867,3 +884,113 @@ func (c *conn) delete(table, field, id string) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||||
|
_, err := c.Exec(`
|
||||||
|
insert into device_request (
|
||||||
|
user_code, device_code, client_id, client_secret, scopes, expiry
|
||||||
|
)
|
||||||
|
values (
|
||||||
|
$1, $2, $3, $4, $5, $6
|
||||||
|
);`,
|
||||||
|
d.UserCode, d.DeviceCode, d.ClientID, d.ClientSecret, encoder(d.Scopes), d.Expiry,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if c.alreadyExistsCheck(err) {
|
||||||
|
return storage.ErrAlreadyExists
|
||||||
|
}
|
||||||
|
return fmt.Errorf("insert device request: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
||||||
|
_, err := c.Exec(`
|
||||||
|
insert into device_token (
|
||||||
|
device_code, status, token, expiry, last_request, poll_interval
|
||||||
|
)
|
||||||
|
values (
|
||||||
|
$1, $2, $3, $4, $5, $6
|
||||||
|
);`,
|
||||||
|
t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if c.alreadyExistsCheck(err) {
|
||||||
|
return storage.ErrAlreadyExists
|
||||||
|
}
|
||||||
|
return fmt.Errorf("insert device token: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||||
|
return getDeviceRequest(c, userCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
|
||||||
|
err = q.QueryRow(`
|
||||||
|
select
|
||||||
|
device_code, client_id, client_secret, scopes, expiry
|
||||||
|
from device_request where user_code = $1;
|
||||||
|
`, userCode).Scan(
|
||||||
|
&d.DeviceCode, &d.ClientID, &d.ClientSecret, decoder(&d.Scopes), &d.Expiry,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return d, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
return d, fmt.Errorf("select device token: %v", err)
|
||||||
|
}
|
||||||
|
d.UserCode = userCode
|
||||||
|
return d, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||||
|
return getDeviceToken(c, deviceCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
|
||||||
|
err = q.QueryRow(`
|
||||||
|
select
|
||||||
|
status, token, expiry, last_request, poll_interval
|
||||||
|
from device_token where device_code = $1;
|
||||||
|
`, deviceCode).Scan(
|
||||||
|
&a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return a, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
return a, fmt.Errorf("select device token: %v", err)
|
||||||
|
}
|
||||||
|
a.DeviceCode = deviceCode
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||||
|
return c.ExecTx(func(tx *trans) error {
|
||||||
|
r, err := getDeviceToken(tx, deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if r, err = updater(r); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec(`
|
||||||
|
update device_token
|
||||||
|
set
|
||||||
|
status = $1,
|
||||||
|
token = $2,
|
||||||
|
last_request = $3,
|
||||||
|
poll_interval = $4
|
||||||
|
where
|
||||||
|
device_code = $5
|
||||||
|
`,
|
||||||
|
r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update device token: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -229,4 +229,25 @@ var migrations = []migration{
|
|||||||
},
|
},
|
||||||
flavor: &flavorMySQL,
|
flavor: &flavorMySQL,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
stmts: []string{`
|
||||||
|
create table device_request (
|
||||||
|
user_code text not null primary key,
|
||||||
|
device_code text not null,
|
||||||
|
client_id text not null,
|
||||||
|
client_secret text ,
|
||||||
|
scopes bytea not null, -- JSON array of strings
|
||||||
|
expiry timestamptz not null
|
||||||
|
);`,
|
||||||
|
`
|
||||||
|
create table device_token (
|
||||||
|
device_code text not null primary key,
|
||||||
|
status text not null,
|
||||||
|
token bytea,
|
||||||
|
expiry timestamptz not null,
|
||||||
|
last_request timestamptz not null,
|
||||||
|
poll_interval integer not null
|
||||||
|
);`,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -24,9 +25,21 @@ var (
|
|||||||
// TODO(ericchiang): refactor ID creation onto the storage.
|
// TODO(ericchiang): refactor ID creation onto the storage.
|
||||||
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
|
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
|
||||||
|
|
||||||
|
//Valid characters for user codes
|
||||||
|
const validUserCharacters = "BCDFGHJKLMNPQRSTVWXZ"
|
||||||
|
|
||||||
|
// NewDeviceCode returns a 32 char alphanumeric cryptographically secure string
|
||||||
|
func NewDeviceCode() string {
|
||||||
|
return newSecureID(32)
|
||||||
|
}
|
||||||
|
|
||||||
// NewID returns a random string which can be used as an ID for objects.
|
// NewID returns a random string which can be used as an ID for objects.
|
||||||
func NewID() string {
|
func NewID() string {
|
||||||
buff := make([]byte, 16) // 128 bit random ID.
|
return newSecureID(16)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSecureID(len int) string {
|
||||||
|
buff := make([]byte, len) // random ID.
|
||||||
if _, err := io.ReadFull(rand.Reader, buff); err != nil {
|
if _, err := io.ReadFull(rand.Reader, buff); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@ -38,6 +51,8 @@ func NewID() string {
|
|||||||
type GCResult struct {
|
type GCResult struct {
|
||||||
AuthRequests int64
|
AuthRequests int64
|
||||||
AuthCodes int64
|
AuthCodes int64
|
||||||
|
DeviceRequests int64
|
||||||
|
DeviceTokens int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// Storage is the storage interface used by the server. Implementations are
|
// Storage is the storage interface used by the server. Implementations are
|
||||||
@ -54,6 +69,8 @@ type Storage interface {
|
|||||||
CreatePassword(p Password) error
|
CreatePassword(p Password) error
|
||||||
CreateOfflineSessions(s OfflineSessions) error
|
CreateOfflineSessions(s OfflineSessions) error
|
||||||
CreateConnector(c Connector) error
|
CreateConnector(c Connector) error
|
||||||
|
CreateDeviceRequest(d DeviceRequest) error
|
||||||
|
CreateDeviceToken(d DeviceToken) error
|
||||||
|
|
||||||
// TODO(ericchiang): return (T, bool, error) so we can indicate not found
|
// TODO(ericchiang): return (T, bool, error) so we can indicate not found
|
||||||
// requests that way instead of using ErrNotFound.
|
// requests that way instead of using ErrNotFound.
|
||||||
@ -65,6 +82,8 @@ type Storage interface {
|
|||||||
GetPassword(email string) (Password, error)
|
GetPassword(email string) (Password, error)
|
||||||
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
||||||
GetConnector(id string) (Connector, error)
|
GetConnector(id string) (Connector, error)
|
||||||
|
GetDeviceRequest(userCode string) (DeviceRequest, error)
|
||||||
|
GetDeviceToken(deviceCode string) (DeviceToken, error)
|
||||||
|
|
||||||
ListClients() ([]Client, error)
|
ListClients() ([]Client, error)
|
||||||
ListRefreshTokens() ([]RefreshToken, error)
|
ListRefreshTokens() ([]RefreshToken, error)
|
||||||
@ -101,8 +120,10 @@ type Storage interface {
|
|||||||
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
||||||
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
|
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
|
||||||
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
|
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
|
||||||
|
UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error
|
||||||
|
|
||||||
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
|
// GarbageCollect deletes all expired AuthCodes,
|
||||||
|
// AuthRequests, DeviceRequests, and DeviceTokens.
|
||||||
GarbageCollect(now time.Time) (GCResult, error)
|
GarbageCollect(now time.Time) (GCResult, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,3 +363,49 @@ type Keys struct {
|
|||||||
// For caching purposes, implementations MUST NOT update keys before this time.
|
// For caching purposes, implementations MUST NOT update keys before this time.
|
||||||
NextRotation time.Time
|
NextRotation time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewUserCode returns a randomized 8 character user code for the device flow.
|
||||||
|
// No vowels are included to prevent accidental generation of words
|
||||||
|
func NewUserCode() (string, error) {
|
||||||
|
code, err := randomString(8)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return code[:4] + "-" + code[4:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomString(n int) (string, error) {
|
||||||
|
v := big.NewInt(int64(len(validUserCharacters)))
|
||||||
|
bytes := make([]byte, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
c, _ := rand.Int(rand.Reader, v)
|
||||||
|
bytes[i] = validUserCharacters[c.Int64()]
|
||||||
|
}
|
||||||
|
return string(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//DeviceRequest represents an OIDC device authorization request. It holds the state of a device request until the user
|
||||||
|
//authenticates using their user code or the expiry time passes.
|
||||||
|
type DeviceRequest struct {
|
||||||
|
//The code the user will enter in a browser
|
||||||
|
UserCode string
|
||||||
|
//The unique device code for device authentication
|
||||||
|
DeviceCode string
|
||||||
|
//The client ID the code is for
|
||||||
|
ClientID string
|
||||||
|
//The Client Secret
|
||||||
|
ClientSecret string
|
||||||
|
//The scopes the device requests
|
||||||
|
Scopes []string
|
||||||
|
//The expire time
|
||||||
|
Expiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type DeviceToken struct {
|
||||||
|
DeviceCode string
|
||||||
|
Status string
|
||||||
|
Token string
|
||||||
|
Expiry time.Time
|
||||||
|
LastRequestTime time.Time
|
||||||
|
PollIntervalSeconds int
|
||||||
|
}
|
||||||
|
23
web/templates/device.html
Normal file
23
web/templates/device.html
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
{{ template "header.html" . }}
|
||||||
|
|
||||||
|
<div class="theme-panel">
|
||||||
|
<h2 class="theme-heading">Enter User Code</h2>
|
||||||
|
<form method="post" action="{{ .PostURL }}" method="get">
|
||||||
|
<div class="theme-form-row">
|
||||||
|
{{ if( .UserCode )}}
|
||||||
|
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" autocomplete="off" value="{{.UserCode}}" {{ if .Invalid }} autofocus {{ end }}/>
|
||||||
|
{{ else }}
|
||||||
|
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" placeholder="XXXX-XXXX" autocomplete="off" {{ if .Invalid }} autofocus {{ end }}/>
|
||||||
|
{{ end }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{ if .Invalid }}
|
||||||
|
<div id="login-error" class="dex-error-box">
|
||||||
|
Invalid or Expired User Code
|
||||||
|
</div>
|
||||||
|
{{ end }}
|
||||||
|
<button tabindex="3" id="submit-login" type="submit" class="dex-btn theme-btn--primary">Submit</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{ template "footer.html" . }}
|
8
web/templates/device_success.html
Normal file
8
web/templates/device_success.html
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
{{ template "header.html" . }}
|
||||||
|
|
||||||
|
<div class="theme-panel">
|
||||||
|
<h2 class="theme-heading">Login Successful for {{ .ClientName }}</h2>
|
||||||
|
<p>Return to your device to continue</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{ template "footer.html" . }}
|
Reference in New Issue
Block a user