9c699b1028
Extracted test cases from OAuth2Code flow tests to reuse in device flow deviceHandler unit tests to test specific device endpoints Include client secret as an optional parameter for standards compliance Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
387 lines
12 KiB
Go
387 lines
12 KiB
Go
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/dexidp/dex/storage"
|
|
)
|
|
|
|
type deviceCodeResponse struct {
|
|
//The unique device code for device authentication
|
|
DeviceCode string `json:"device_code"`
|
|
//The code the user will exchange via a browser and log in
|
|
UserCode string `json:"user_code"`
|
|
//The url to verify the user code.
|
|
VerificationURI string `json:"verification_uri"`
|
|
//The verification uri with the user code appended for pre-filling form
|
|
VerificationURIComplete string `json:"verification_uri_complete"`
|
|
//The lifetime of the device code
|
|
ExpireTime int `json:"expires_in"`
|
|
//How often the device is allowed to poll to verify that the user login occurred
|
|
PollInterval int `json:"interval"`
|
|
}
|
|
|
|
func (s *Server) getDeviceVerificationURI() string {
|
|
return path.Join(s.issuerURL.Path, "/device/auth/verify_code")
|
|
}
|
|
|
|
func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodGet:
|
|
userCode := r.URL.Query().Get("user_code")
|
|
invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid"))
|
|
if err != nil {
|
|
invalidAttempt = false
|
|
}
|
|
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil {
|
|
s.logger.Errorf("Server template error: %v", err)
|
|
s.renderError(r, w, http.StatusNotFound, "Page not found")
|
|
}
|
|
default:
|
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
|
pollIntervalSeconds := 5
|
|
|
|
switch r.Method {
|
|
case http.MethodPost:
|
|
err := r.ParseForm()
|
|
if err != nil {
|
|
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
//Get the client id and scopes from the post
|
|
clientID := r.Form.Get("client_id")
|
|
clientSecret := r.Form.Get("client_secret")
|
|
scopes := strings.Fields(r.Form.Get("scope"))
|
|
|
|
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
|
|
|
|
//Make device code
|
|
deviceCode := storage.NewDeviceCode()
|
|
|
|
//make user code
|
|
userCode, err := storage.NewUserCode()
|
|
if err != nil {
|
|
s.logger.Errorf("Error generating user code: %v", err)
|
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
|
}
|
|
|
|
//Generate the expire time
|
|
expireTime := time.Now().Add(s.deviceRequestsValidFor)
|
|
|
|
//Store the Device Request
|
|
deviceReq := storage.DeviceRequest{
|
|
UserCode: userCode,
|
|
DeviceCode: deviceCode,
|
|
ClientID: clientID,
|
|
ClientSecret: clientSecret,
|
|
Scopes: scopes,
|
|
Expiry: expireTime,
|
|
}
|
|
|
|
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
|
s.logger.Errorf("Failed to store device request; %v", err)
|
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
//Store the device token
|
|
deviceToken := storage.DeviceToken{
|
|
DeviceCode: deviceCode,
|
|
Status: deviceTokenPending,
|
|
Expiry: expireTime,
|
|
LastRequestTime: s.now(),
|
|
PollIntervalSeconds: 0,
|
|
}
|
|
|
|
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
|
s.logger.Errorf("Failed to store device token %v", err)
|
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
u, err := url.Parse(s.issuerURL.String())
|
|
if err != nil {
|
|
s.logger.Errorf("Could not parse issuer URL %v", err)
|
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
u.Path = path.Join(u.Path, "device")
|
|
vURI := u.String()
|
|
|
|
q := u.Query()
|
|
q.Set("user_code", userCode)
|
|
u.RawQuery = q.Encode()
|
|
vURIComplete := u.String()
|
|
|
|
code := deviceCodeResponse{
|
|
DeviceCode: deviceCode,
|
|
UserCode: userCode,
|
|
VerificationURI: vURI,
|
|
VerificationURIComplete: vURIComplete,
|
|
ExpireTime: int(s.deviceRequestsValidFor.Seconds()),
|
|
PollInterval: pollIntervalSeconds,
|
|
}
|
|
|
|
enc := json.NewEncoder(w)
|
|
enc.SetEscapeHTML(false)
|
|
enc.SetIndent("", " ")
|
|
enc.Encode(code)
|
|
|
|
default:
|
|
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
switch r.Method {
|
|
case http.MethodPost:
|
|
err := r.ParseForm()
|
|
if err != nil {
|
|
s.logger.Warnf("Could not parse Device Token Request body: %v", err)
|
|
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
deviceCode := r.Form.Get("device_code")
|
|
if deviceCode == "" {
|
|
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
grantType := r.PostFormValue("grant_type")
|
|
if grantType != grantTypeDeviceCode {
|
|
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
now := s.now()
|
|
|
|
//Grab the device token, check validity
|
|
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
|
if err != nil {
|
|
if err != storage.ErrNotFound {
|
|
s.logger.Errorf("failed to get device code: %v", err)
|
|
}
|
|
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
|
|
return
|
|
} else if now.After(deviceToken.Expiry) {
|
|
s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
//Rate Limiting check
|
|
slowDown := false
|
|
pollInterval := deviceToken.PollIntervalSeconds
|
|
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
|
|
if now.Before(minRequestTime) {
|
|
slowDown = true
|
|
//Continually increase the poll interval until the user waits the proper time
|
|
pollInterval += 5
|
|
} else {
|
|
pollInterval = 5
|
|
}
|
|
|
|
switch deviceToken.Status {
|
|
case deviceTokenPending:
|
|
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
|
old.PollIntervalSeconds = pollInterval
|
|
old.LastRequestTime = now
|
|
return old, nil
|
|
}
|
|
// Update device token last request time in storage
|
|
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
|
|
s.logger.Errorf("failed to update device token: %v", err)
|
|
s.renderError(r, w, http.StatusInternalServerError, "")
|
|
return
|
|
}
|
|
if slowDown {
|
|
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
|
|
} else {
|
|
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
|
}
|
|
case deviceTokenComplete:
|
|
w.Write([]byte(deviceToken.Token))
|
|
}
|
|
default:
|
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodGet:
|
|
userCode := r.FormValue("state")
|
|
code := r.FormValue("code")
|
|
|
|
if userCode == "" || code == "" {
|
|
s.renderError(r, w, http.StatusBadRequest, "Request was missing parameters")
|
|
return
|
|
}
|
|
|
|
// Authorization redirect callback from OAuth2 auth flow.
|
|
if errMsg := r.FormValue("error"); errMsg != "" {
|
|
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
authCode, err := s.storage.GetAuthCode(code)
|
|
if err != nil || s.now().After(authCode.Expiry) {
|
|
errCode := http.StatusBadRequest
|
|
if err != nil && err != storage.ErrNotFound {
|
|
s.logger.Errorf("failed to get auth code: %v", err)
|
|
errCode = http.StatusInternalServerError
|
|
}
|
|
s.renderError(r, w, errCode, "Invalid or expired auth code.")
|
|
return
|
|
}
|
|
|
|
//Grab the device request from storage
|
|
deviceReq, err := s.storage.GetDeviceRequest(userCode)
|
|
if err != nil || s.now().After(deviceReq.Expiry) {
|
|
errCode := http.StatusBadRequest
|
|
if err != nil && err != storage.ErrNotFound {
|
|
s.logger.Errorf("failed to get device code: %v", err)
|
|
errCode = http.StatusInternalServerError
|
|
}
|
|
s.renderError(r, w, errCode, "Invalid or expired user code.")
|
|
return
|
|
}
|
|
|
|
client, err := s.storage.GetClient(deviceReq.ClientID)
|
|
if err != nil {
|
|
if err != storage.ErrNotFound {
|
|
s.logger.Errorf("failed to get client: %v", err)
|
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
|
} else {
|
|
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
|
|
}
|
|
return
|
|
}
|
|
if client.Secret != deviceReq.ClientSecret {
|
|
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
resp, err := s.exchangeAuthCode(w, authCode, client)
|
|
if err != nil {
|
|
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
|
|
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
|
|
return
|
|
}
|
|
|
|
//Grab the device token from storage
|
|
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
|
|
if err != nil || s.now().After(old.Expiry) {
|
|
errCode := http.StatusBadRequest
|
|
if err != nil && err != storage.ErrNotFound {
|
|
s.logger.Errorf("failed to get device token: %v", err)
|
|
errCode = http.StatusInternalServerError
|
|
}
|
|
s.renderError(r, w, errCode, "Invalid or expired device code.")
|
|
return
|
|
}
|
|
|
|
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
|
if old.Status == deviceTokenComplete {
|
|
return old, errors.New("device token already complete")
|
|
}
|
|
respStr, err := json.MarshalIndent(resp, "", " ")
|
|
if err != nil {
|
|
s.logger.Errorf("failed to marshal device token response: %v", err)
|
|
s.renderError(r, w, http.StatusInternalServerError, "")
|
|
return old, err
|
|
}
|
|
|
|
old.Token = string(respStr)
|
|
old.Status = deviceTokenComplete
|
|
return old, nil
|
|
}
|
|
|
|
// Update refresh token in the storage, store the token and mark as complete
|
|
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
|
|
s.logger.Errorf("failed to update device token: %v", err)
|
|
s.renderError(r, w, http.StatusBadRequest, "")
|
|
return
|
|
}
|
|
|
|
if err := s.templates.deviceSuccess(r, w, client.Name); err != nil {
|
|
s.logger.Errorf("Server template error: %v", err)
|
|
s.renderError(r, w, http.StatusNotFound, "Page not found")
|
|
}
|
|
|
|
default:
|
|
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodPost:
|
|
err := r.ParseForm()
|
|
if err != nil {
|
|
s.logger.Warnf("Could not parse user code verification request body : %v", err)
|
|
s.renderError(r, w, http.StatusBadRequest, "")
|
|
return
|
|
}
|
|
|
|
userCode := r.Form.Get("user_code")
|
|
if userCode == "" {
|
|
s.renderError(r, w, http.StatusBadRequest, "No user code received")
|
|
return
|
|
}
|
|
|
|
userCode = strings.ToUpper(userCode)
|
|
|
|
//Find the user code in the available requests
|
|
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
|
|
if err != nil || s.now().After(deviceRequest.Expiry) {
|
|
if err != nil && err != storage.ErrNotFound {
|
|
s.logger.Errorf("failed to get device request: %v", err)
|
|
}
|
|
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil {
|
|
s.logger.Errorf("Server template error: %v", err)
|
|
s.renderError(r, w, http.StatusNotFound, "Page not found")
|
|
}
|
|
return
|
|
}
|
|
|
|
//Redirect to Dex Auth Endpoint
|
|
authURL := path.Join(s.issuerURL.Path, "/auth")
|
|
u, err := url.Parse(authURL)
|
|
if err != nil {
|
|
s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.")
|
|
return
|
|
}
|
|
q := u.Query()
|
|
q.Set("client_id", deviceRequest.ClientID)
|
|
q.Set("client_secret", deviceRequest.ClientSecret)
|
|
q.Set("state", deviceRequest.UserCode)
|
|
q.Set("response_type", "code")
|
|
q.Set("redirect_uri", path.Join(s.issuerURL.Path, "/device/callback"))
|
|
q.Set("scope", strings.Join(deviceRequest.Scopes, " "))
|
|
u.RawQuery = q.Encode()
|
|
|
|
http.Redirect(w, r, u.String(), http.StatusFound)
|
|
|
|
default:
|
|
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
|
}
|
|
}
|