Device flow token code exchange (#2)
* Added /device/token handler with associated business logic and storage tests. Perform user code exchange, flag the device code as complete. Moved device handler code into its own file for cleanliness. Cleanup * Removed PKCE code * Rate limiting for /device/token endpoint based on ietf standards * Configurable Device expiry Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
This commit is contained in:
committed by
justin-slowik
parent
0d1a0e4129
commit
9bbdc721d5
359
server/deviceHandlers.go
Normal file
359
server/deviceHandlers.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
type deviceCodeResponse struct {
|
||||
//The unique device code for device authentication
|
||||
DeviceCode string `json:"device_code"`
|
||||
//The code the user will exchange via a browser and log in
|
||||
UserCode string `json:"user_code"`
|
||||
//The url to verify the user code.
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
//The verification uri with the user code appended for pre-filling form
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
//The lifetime of the device code
|
||||
ExpireTime int `json:"expires_in"`
|
||||
//How often the device is allowed to poll to verify that the user login occurred
|
||||
PollInterval int `json:"interval"`
|
||||
}
|
||||
|
||||
func (s *Server) getDeviceAuthURI() string {
|
||||
return path.Join(s.issuerURL.Path, "/device/auth/verify_code")
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
userCode := r.URL.Query().Get("user_code")
|
||||
invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid"))
|
||||
if err != nil {
|
||||
invalidAttempt = false
|
||||
}
|
||||
if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, invalidAttempt); err != nil {
|
||||
s.logger.Errorf("Server template error: %v", err)
|
||||
}
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||
pollIntervalSeconds := 5
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
//Get the client id and scopes from the post
|
||||
clientID := r.Form.Get("client_id")
|
||||
scopes := r.Form["scope"]
|
||||
|
||||
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
|
||||
|
||||
//Make device code
|
||||
deviceCode := storage.NewDeviceCode()
|
||||
|
||||
//make user code
|
||||
userCode, err := storage.NewUserCode()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Error generating user code: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
//Generate the expire time
|
||||
expireTime := time.Now().Add(s.deviceRequestsValidFor)
|
||||
|
||||
//Store the Device Request
|
||||
deviceReq := storage.DeviceRequest{
|
||||
UserCode: userCode,
|
||||
DeviceCode: deviceCode,
|
||||
ClientID: clientID,
|
||||
Scopes: scopes,
|
||||
Expiry: expireTime,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
||||
s.logger.Errorf("Failed to store device request; %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
//Store the device token
|
||||
deviceToken := storage.DeviceToken{
|
||||
DeviceCode: deviceCode,
|
||||
Status: deviceTokenPending,
|
||||
Expiry: expireTime,
|
||||
LastRequestTime: time.Now(),
|
||||
PollIntervalSeconds: 5,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
||||
s.logger.Errorf("Failed to store device token %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not parse issuer URL %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
u.Path = path.Join(u.Path, "device")
|
||||
vURI := u.String()
|
||||
|
||||
q := u.Query()
|
||||
q.Set("user_code", userCode)
|
||||
u.RawQuery = q.Encode()
|
||||
vURIComplete := u.String()
|
||||
|
||||
code := deviceCodeResponse{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: vURI,
|
||||
VerificationURIComplete: vURIComplete,
|
||||
ExpireTime: int(s.deviceRequestsValidFor.Seconds()),
|
||||
PollInterval: pollIntervalSeconds,
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
enc.Encode(code)
|
||||
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.logger.Warnf("Could not parse Device Token Request body: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
deviceCode := r.Form.Get("device_code")
|
||||
if deviceCode == "" {
|
||||
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.PostFormValue("grant_type")
|
||||
if grantType != grantTypeDeviceCode {
|
||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
now := s.now()
|
||||
|
||||
//Grab the device token
|
||||
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||
if err != nil || now.After(deviceToken.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
//Rate Limiting check
|
||||
pollInterval := deviceToken.PollIntervalSeconds
|
||||
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
|
||||
if now.Before(minRequestTime) {
|
||||
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
|
||||
//Continually increase the poll interval until the user waits the proper time
|
||||
pollInterval += 5
|
||||
} else {
|
||||
pollInterval = 5
|
||||
}
|
||||
|
||||
switch deviceToken.Status {
|
||||
case deviceTokenPending:
|
||||
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||
old.PollIntervalSeconds = pollInterval
|
||||
old.LastRequestTime = now
|
||||
return old, nil
|
||||
}
|
||||
// Update device token last request time in storage
|
||||
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
|
||||
s.logger.Errorf("failed to update device token: %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||
case deviceTokenComplete:
|
||||
w.Write([]byte(deviceToken.Token))
|
||||
}
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
userCode := r.FormValue("state")
|
||||
code := r.FormValue("code")
|
||||
|
||||
if userCode == "" || code == "" {
|
||||
s.renderError(r, w, http.StatusBadRequest, "Request was missing parameters")
|
||||
return
|
||||
}
|
||||
|
||||
// Authorization redirect callback from OAuth2 auth flow.
|
||||
if errMsg := r.FormValue("error"); errMsg != "" {
|
||||
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
authCode, err := s.storage.GetAuthCode(code)
|
||||
if err != nil || s.now().After(authCode.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get auth code: %v", err)
|
||||
}
|
||||
s.renderError(r, w, http.StatusBadRequest, "Invalid or expired auth code.")
|
||||
return
|
||||
}
|
||||
|
||||
//Grab the device request from storage
|
||||
deviceReq, err := s.storage.GetDeviceRequest(userCode)
|
||||
if err != nil || s.now().After(deviceReq.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
}
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.")
|
||||
return
|
||||
}
|
||||
|
||||
reqClient, err := s.storage.GetClient(deviceReq.ClientID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to get reqClient %q: %v", deviceReq.ClientID, err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve device client.")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.exchangeAuthCode(w, authCode, reqClient)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
|
||||
return
|
||||
}
|
||||
|
||||
//Grab the device request from storage
|
||||
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
|
||||
if err != nil || s.now().After(old.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device token: %v", err)
|
||||
}
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.")
|
||||
return
|
||||
}
|
||||
|
||||
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||
if old.Status == deviceTokenComplete {
|
||||
return old, errors.New("device token already complete")
|
||||
}
|
||||
respStr, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to marshal device token response: %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return old, err
|
||||
}
|
||||
|
||||
old.Token = string(respStr)
|
||||
old.Status = deviceTokenComplete
|
||||
return old, nil
|
||||
}
|
||||
|
||||
// Update refresh token in the storage, store the token and mark as complete
|
||||
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
|
||||
s.logger.Errorf("failed to update device token: %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.templates.deviceSuccess(r, w, reqClient.Name); err != nil {
|
||||
s.logger.Errorf("Server template error: %v", err)
|
||||
}
|
||||
|
||||
default:
|
||||
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
message := "Could not parse user code verification Request body"
|
||||
s.logger.Warnf("%s : %v", message, err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userCode := r.Form.Get("user_code")
|
||||
if userCode == "" {
|
||||
s.renderError(r, w, http.StatusBadRequest, "No user code received")
|
||||
return
|
||||
}
|
||||
|
||||
userCode = strings.ToUpper(userCode)
|
||||
|
||||
//Find the user code in the available requests
|
||||
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
|
||||
if err != nil || s.now().After(deviceRequest.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device request: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
}
|
||||
if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, true); err != nil {
|
||||
s.logger.Errorf("Server template error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//Redirect to Dex Auth Endpoint
|
||||
authURL := path.Join(s.issuerURL.Path, "/auth")
|
||||
u, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.")
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("client_id", deviceRequest.ClientID)
|
||||
q.Set("state", deviceRequest.UserCode)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("redirect_uri", path.Join(s.issuerURL.Path, "/device/callback"))
|
||||
q.Set("scope", strings.Join(deviceRequest.Scopes, " "))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusFound)
|
||||
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
@@ -147,30 +147,34 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
type discovery struct {
|
||||
Issuer string `json:"issuer"`
|
||||
Auth string `json:"authorization_endpoint"`
|
||||
Token string `json:"token_endpoint"`
|
||||
Keys string `json:"jwks_uri"`
|
||||
UserInfo string `json:"userinfo_endpoint"`
|
||||
ResponseTypes []string `json:"response_types_supported"`
|
||||
Subjects []string `json:"subject_types_supported"`
|
||||
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
||||
Scopes []string `json:"scopes_supported"`
|
||||
AuthMethods []string `json:"token_endpoint_auth_methods_supported"`
|
||||
Claims []string `json:"claims_supported"`
|
||||
Issuer string `json:"issuer"`
|
||||
Auth string `json:"authorization_endpoint"`
|
||||
Token string `json:"token_endpoint"`
|
||||
Keys string `json:"jwks_uri"`
|
||||
UserInfo string `json:"userinfo_endpoint"`
|
||||
DeviceEndpoint string `json:"device_authorization_endpoint"`
|
||||
GrantTypes []string `json:"grant_types_supported"'`
|
||||
ResponseTypes []string `json:"response_types_supported"`
|
||||
Subjects []string `json:"subject_types_supported"`
|
||||
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
||||
Scopes []string `json:"scopes_supported"`
|
||||
AuthMethods []string `json:"token_endpoint_auth_methods_supported"`
|
||||
Claims []string `json:"claims_supported"`
|
||||
}
|
||||
|
||||
func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
||||
d := discovery{
|
||||
Issuer: s.issuerURL.String(),
|
||||
Auth: s.absURL("/auth"),
|
||||
Token: s.absURL("/token"),
|
||||
Keys: s.absURL("/keys"),
|
||||
UserInfo: s.absURL("/userinfo"),
|
||||
Subjects: []string{"public"},
|
||||
IDTokenAlgs: []string{string(jose.RS256)},
|
||||
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
||||
AuthMethods: []string{"client_secret_basic"},
|
||||
Issuer: s.issuerURL.String(),
|
||||
Auth: s.absURL("/auth"),
|
||||
Token: s.absURL("/token"),
|
||||
Keys: s.absURL("/keys"),
|
||||
UserInfo: s.absURL("/userinfo"),
|
||||
DeviceEndpoint: s.absURL("/device/code"),
|
||||
Subjects: []string{"public"},
|
||||
GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
|
||||
IDTokenAlgs: []string{string(jose.RS256)},
|
||||
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
||||
AuthMethods: []string{"client_secret_basic"},
|
||||
Claims: []string{
|
||||
"aud", "email", "email_verified", "exp",
|
||||
"iat", "iss", "locale", "name", "sub",
|
||||
@@ -783,24 +787,33 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := s.exchangeAuthCode(w, authCode, client)
|
||||
if err != nil {
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.writeAccessToken(w, tokenResponse)
|
||||
}
|
||||
|
||||
func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenReponse, error) {
|
||||
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to create new access token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to create ID token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.storage.DeleteAuthCode(code); err != nil {
|
||||
if err := s.storage.DeleteAuthCode(authCode.ID); err != nil {
|
||||
s.logger.Errorf("failed to delete auth code: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqRefresh := func() bool {
|
||||
@@ -847,13 +860,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
if refreshToken, err = internal.Marshal(token); err != nil {
|
||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.storage.CreateRefresh(refresh); err != nil {
|
||||
s.logger.Errorf("failed to create refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// deleteToken determines if we need to delete the newly created refresh token
|
||||
@@ -884,7 +897,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
s.logger.Errorf("failed to get offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
offlineSessions := storage.OfflineSessions{
|
||||
UserID: refresh.Claims.UserID,
|
||||
@@ -899,7 +912,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
s.logger.Errorf("failed to create offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
|
||||
@@ -908,7 +921,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
s.logger.Errorf("failed to delete refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -920,11 +933,11 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
s.logger.Errorf("failed to update offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
|
||||
return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
|
||||
}
|
||||
|
||||
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
||||
@@ -1120,7 +1133,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
||||
return
|
||||
}
|
||||
|
||||
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
||||
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
|
||||
s.writeAccessToken(w, resp)
|
||||
}
|
||||
|
||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -1378,12 +1392,26 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, r
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
IDToken string `json:"id_token"`
|
||||
}{
|
||||
|
||||
type accessTokenReponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenReponse {
|
||||
return &accessTokenReponse{
|
||||
accessToken,
|
||||
"bearer",
|
||||
int(expiry.Sub(s.now()).Seconds()),
|
||||
refreshToken,
|
||||
idToken,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenReponse) {
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to marshal access token response: %v", err)
|
||||
@@ -1414,145 +1442,3 @@ func usernamePrompt(conn connector.PasswordConnector) string {
|
||||
}
|
||||
return "Username"
|
||||
}
|
||||
|
||||
type deviceCodeResponse struct {
|
||||
//The unique device code for device authentication
|
||||
DeviceCode string `json:"device_code"`
|
||||
//The code the user will exchange via a browser and log in
|
||||
UserCode string `json:"user_code"`
|
||||
//The url to verify the user code.
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
//The lifetime of the device code
|
||||
ExpireTime int `json:"expires_in"`
|
||||
//How often the device is allowed to poll to verify that the user login occurred
|
||||
PollInterval int `json:"interval"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||
//TODO replace with configurable values
|
||||
expireIntervalSeconds := 300
|
||||
requestsPerMinute := 5
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
//Get the client id and scopes from the post
|
||||
clientID := r.Form.Get("client_id")
|
||||
scopes := r.Form["scope"]
|
||||
|
||||
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
|
||||
|
||||
//Make device code
|
||||
deviceCode := storage.NewDeviceCode()
|
||||
|
||||
//make user code
|
||||
userCode, err := storage.NewUserCode()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Error generating user code: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
//make a pkce verification code
|
||||
pkceCode := storage.NewID()
|
||||
|
||||
//Generate the expire time
|
||||
expireTime := time.Now().Add(time.Second * time.Duration(expireIntervalSeconds))
|
||||
|
||||
//Store the Device Request
|
||||
deviceReq := storage.DeviceRequest{
|
||||
UserCode: userCode,
|
||||
DeviceCode: deviceCode,
|
||||
ClientID: clientID,
|
||||
Scopes: scopes,
|
||||
PkceVerifier: pkceCode,
|
||||
Expiry: expireTime,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
||||
s.logger.Errorf("Failed to store device request; %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
//Store the device token
|
||||
deviceToken := storage.DeviceToken{
|
||||
DeviceCode: deviceCode,
|
||||
Status: deviceTokenPending,
|
||||
Expiry: expireTime,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
||||
s.logger.Errorf("Failed to store device token %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
code := deviceCodeResponse{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: path.Join(s.issuerURL.String(), "/device"),
|
||||
ExpireTime: expireIntervalSeconds,
|
||||
PollInterval: requestsPerMinute,
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
enc.Encode(code)
|
||||
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
message := "Could not parse Device Token Request body"
|
||||
s.logger.Warnf("%s : %v", message, err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
deviceCode := r.Form.Get("device_code")
|
||||
if deviceCode == "" {
|
||||
message := "No device code received"
|
||||
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.PostFormValue("grant_type")
|
||||
if grantType != grantTypeDeviceCode {
|
||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
//Grab the device token from the db
|
||||
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||
if err != nil || s.now().After(deviceToken.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch deviceToken.Status {
|
||||
case deviceTokenPending:
|
||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||
case deviceTokenComplete:
|
||||
w.Write([]byte(deviceToken.Token))
|
||||
}
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
||||
|
@@ -122,7 +122,7 @@ const (
|
||||
grantTypeAuthorizationCode = "authorization_code"
|
||||
grantTypeRefreshToken = "refresh_token"
|
||||
grantTypePassword = "password"
|
||||
grantTypeDeviceCode = "device_code"
|
||||
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -134,6 +134,8 @@ const (
|
||||
const (
|
||||
deviceTokenPending = "authorization_pending"
|
||||
deviceTokenComplete = "complete"
|
||||
deviceTokenSlowDown = "slow_down"
|
||||
deviceTokenExpired = "expired_token"
|
||||
)
|
||||
|
||||
func parseScopes(scopes []string) connector.Scopes {
|
||||
|
@@ -75,12 +75,13 @@ type Config struct {
|
||||
// If enabled, the connectors selection page will always be shown even if there's only one
|
||||
AlwaysShowLoginScreen bool
|
||||
|
||||
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
||||
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
||||
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
|
||||
// If set, the server will use this connector to handle password grants
|
||||
PasswordConnector string
|
||||
|
||||
|
||||
GCFrequency time.Duration // Defaults to 5 minutes
|
||||
|
||||
// If specified, the server will use this function for determining time.
|
||||
@@ -156,8 +157,9 @@ type Server struct {
|
||||
|
||||
now func() time.Time
|
||||
|
||||
idTokensValidFor time.Duration
|
||||
authRequestsValidFor time.Duration
|
||||
idTokensValidFor time.Duration
|
||||
authRequestsValidFor time.Duration
|
||||
deviceRequestsValidFor time.Duration
|
||||
|
||||
logger log.Logger
|
||||
}
|
||||
@@ -219,6 +221,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||
supportedResponseTypes: supported,
|
||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
||||
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
|
||||
skipApproval: c.SkipApprovalScreen,
|
||||
alwaysShowLogin: c.AlwaysShowLoginScreen,
|
||||
now: now,
|
||||
@@ -302,8 +305,11 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||
handleWithCORS("/userinfo", s.handleUserInfo)
|
||||
handleFunc("/auth", s.handleAuthorization)
|
||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||
handleFunc("/device", s.handleDeviceExchange)
|
||||
handleFunc("/device/auth/verify_code", s.verifyUserCode)
|
||||
handleFunc("/device/code", s.handleDeviceCode)
|
||||
handleFunc("/device/token", s.handleDeviceToken)
|
||||
handleFunc("/device/callback", s.handleDeviceCallback)
|
||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||
// Strip the X-Remote-* headers to prevent security issues on
|
||||
// misconfigured authproxy connector setups.
|
||||
|
@@ -15,11 +15,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
tmplApproval = "approval.html"
|
||||
tmplLogin = "login.html"
|
||||
tmplPassword = "password.html"
|
||||
tmplOOB = "oob.html"
|
||||
tmplError = "error.html"
|
||||
tmplApproval = "approval.html"
|
||||
tmplLogin = "login.html"
|
||||
tmplPassword = "password.html"
|
||||
tmplOOB = "oob.html"
|
||||
tmplError = "error.html"
|
||||
tmplDevice = "device.html"
|
||||
tmplDeviceSuccess = "device_success.html"
|
||||
)
|
||||
|
||||
var requiredTmpls = []string{
|
||||
@@ -28,14 +30,17 @@ var requiredTmpls = []string{
|
||||
tmplPassword,
|
||||
tmplOOB,
|
||||
tmplError,
|
||||
tmplDevice,
|
||||
}
|
||||
|
||||
type templates struct {
|
||||
loginTmpl *template.Template
|
||||
approvalTmpl *template.Template
|
||||
passwordTmpl *template.Template
|
||||
oobTmpl *template.Template
|
||||
errorTmpl *template.Template
|
||||
loginTmpl *template.Template
|
||||
approvalTmpl *template.Template
|
||||
passwordTmpl *template.Template
|
||||
oobTmpl *template.Template
|
||||
errorTmpl *template.Template
|
||||
deviceTmpl *template.Template
|
||||
deviceSuccessTmpl *template.Template
|
||||
}
|
||||
|
||||
type webConfig struct {
|
||||
@@ -152,11 +157,13 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
|
||||
return nil, fmt.Errorf("missing template(s): %s", missingTmpls)
|
||||
}
|
||||
return &templates{
|
||||
loginTmpl: tmpls.Lookup(tmplLogin),
|
||||
approvalTmpl: tmpls.Lookup(tmplApproval),
|
||||
passwordTmpl: tmpls.Lookup(tmplPassword),
|
||||
oobTmpl: tmpls.Lookup(tmplOOB),
|
||||
errorTmpl: tmpls.Lookup(tmplError),
|
||||
loginTmpl: tmpls.Lookup(tmplLogin),
|
||||
approvalTmpl: tmpls.Lookup(tmplApproval),
|
||||
passwordTmpl: tmpls.Lookup(tmplPassword),
|
||||
oobTmpl: tmpls.Lookup(tmplOOB),
|
||||
errorTmpl: tmpls.Lookup(tmplError),
|
||||
deviceTmpl: tmpls.Lookup(tmplDevice),
|
||||
deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -242,6 +249,24 @@ func (n byName) Len() int { return len(n) }
|
||||
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
|
||||
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
|
||||
|
||||
func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error {
|
||||
data := struct {
|
||||
PostURL string
|
||||
UserCode string
|
||||
Invalid bool
|
||||
ReqPath string
|
||||
}{postURL, userCode, lastWasInvalid, r.URL.Path}
|
||||
return renderTemplate(w, t.deviceTmpl, data)
|
||||
}
|
||||
|
||||
func (t *templates) deviceSuccess(r *http.Request, w http.ResponseWriter, clientName string) error {
|
||||
data := struct {
|
||||
ClientName string
|
||||
ReqPath string
|
||||
}{clientName, r.URL.Path}
|
||||
return renderTemplate(w, t.deviceSuccessTmpl, data)
|
||||
}
|
||||
|
||||
func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
|
||||
sort.Sort(byName(connectors))
|
||||
data := struct {
|
||||
|
Reference in New Issue
Block a user