Device flow token code exchange (#2)
* Added /device/token handler with associated business logic and storage tests. Perform user code exchange, flag the device code as complete. Moved device handler code into its own file for cleanliness. Cleanup * Removed PKCE code * Rate limiting for /device/token endpoint based on ietf standards * Configurable Device expiry Signed-off-by: justin-slowik <justin.slowik@thermofisher.com>
This commit is contained in:
parent
0d1a0e4129
commit
9bbdc721d5
@ -283,6 +283,9 @@ type Expiry struct {
|
||||
|
||||
// AuthRequests defines the duration of time for which the AuthRequests will be valid.
|
||||
AuthRequests string `json:"authRequests"`
|
||||
|
||||
// DeviceRequests defines the duration of time for which the DeviceRequests will be valid.
|
||||
DeviceRequests string `json:"deviceRequests"`
|
||||
}
|
||||
|
||||
// Logger holds configuration required to customize logging for dex.
|
||||
|
@ -119,6 +119,7 @@ expiry:
|
||||
signingKeys: "7h"
|
||||
idTokens: "25h"
|
||||
authRequests: "25h"
|
||||
deviceRequests: "10m"
|
||||
|
||||
logger:
|
||||
level: "debug"
|
||||
@ -200,6 +201,7 @@ logger:
|
||||
SigningKeys: "7h",
|
||||
IDTokens: "25h",
|
||||
AuthRequests: "25h",
|
||||
DeviceRequests: "10m",
|
||||
},
|
||||
Logger: Logger{
|
||||
Level: "debug",
|
||||
|
@ -269,7 +269,14 @@ func serve(cmd *cobra.Command, args []string) error {
|
||||
logger.Infof("config auth requests valid for: %v", authRequests)
|
||||
serverConfig.AuthRequestsValidFor = authRequests
|
||||
}
|
||||
|
||||
if c.Expiry.DeviceRequests != "" {
|
||||
deviceRequests, err := time.ParseDuration(c.Expiry.DeviceRequests)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err)
|
||||
}
|
||||
logger.Infof("config device requests valid for: %v", deviceRequests)
|
||||
serverConfig.DeviceRequestsValidFor = deviceRequests
|
||||
}
|
||||
serv, err := server.NewServer(context.Background(), serverConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize server: %v", err)
|
||||
|
@ -64,6 +64,7 @@ telemetry:
|
||||
|
||||
# Uncomment this block to enable configuration for the expiration time durations.
|
||||
# expiry:
|
||||
# deviceRequests: "5m"
|
||||
# signingKeys: "6h"
|
||||
# idTokens: "24h"
|
||||
|
||||
|
359
server/deviceHandlers.go
Normal file
359
server/deviceHandlers.go
Normal file
@ -0,0 +1,359 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
type deviceCodeResponse struct {
|
||||
//The unique device code for device authentication
|
||||
DeviceCode string `json:"device_code"`
|
||||
//The code the user will exchange via a browser and log in
|
||||
UserCode string `json:"user_code"`
|
||||
//The url to verify the user code.
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
//The verification uri with the user code appended for pre-filling form
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
//The lifetime of the device code
|
||||
ExpireTime int `json:"expires_in"`
|
||||
//How often the device is allowed to poll to verify that the user login occurred
|
||||
PollInterval int `json:"interval"`
|
||||
}
|
||||
|
||||
func (s *Server) getDeviceAuthURI() string {
|
||||
return path.Join(s.issuerURL.Path, "/device/auth/verify_code")
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
userCode := r.URL.Query().Get("user_code")
|
||||
invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid"))
|
||||
if err != nil {
|
||||
invalidAttempt = false
|
||||
}
|
||||
if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, invalidAttempt); err != nil {
|
||||
s.logger.Errorf("Server template error: %v", err)
|
||||
}
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||
pollIntervalSeconds := 5
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
//Get the client id and scopes from the post
|
||||
clientID := r.Form.Get("client_id")
|
||||
scopes := r.Form["scope"]
|
||||
|
||||
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
|
||||
|
||||
//Make device code
|
||||
deviceCode := storage.NewDeviceCode()
|
||||
|
||||
//make user code
|
||||
userCode, err := storage.NewUserCode()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Error generating user code: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
//Generate the expire time
|
||||
expireTime := time.Now().Add(s.deviceRequestsValidFor)
|
||||
|
||||
//Store the Device Request
|
||||
deviceReq := storage.DeviceRequest{
|
||||
UserCode: userCode,
|
||||
DeviceCode: deviceCode,
|
||||
ClientID: clientID,
|
||||
Scopes: scopes,
|
||||
Expiry: expireTime,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
||||
s.logger.Errorf("Failed to store device request; %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
//Store the device token
|
||||
deviceToken := storage.DeviceToken{
|
||||
DeviceCode: deviceCode,
|
||||
Status: deviceTokenPending,
|
||||
Expiry: expireTime,
|
||||
LastRequestTime: time.Now(),
|
||||
PollIntervalSeconds: 5,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
||||
s.logger.Errorf("Failed to store device token %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := url.Parse(s.issuerURL.String())
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not parse issuer URL %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
u.Path = path.Join(u.Path, "device")
|
||||
vURI := u.String()
|
||||
|
||||
q := u.Query()
|
||||
q.Set("user_code", userCode)
|
||||
u.RawQuery = q.Encode()
|
||||
vURIComplete := u.String()
|
||||
|
||||
code := deviceCodeResponse{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: vURI,
|
||||
VerificationURIComplete: vURIComplete,
|
||||
ExpireTime: int(s.deviceRequestsValidFor.Seconds()),
|
||||
PollInterval: pollIntervalSeconds,
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
enc.Encode(code)
|
||||
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.logger.Warnf("Could not parse Device Token Request body: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
deviceCode := r.Form.Get("device_code")
|
||||
if deviceCode == "" {
|
||||
s.tokenErrHelper(w, errInvalidRequest, "No device code received", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.PostFormValue("grant_type")
|
||||
if grantType != grantTypeDeviceCode {
|
||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
now := s.now()
|
||||
|
||||
//Grab the device token
|
||||
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||
if err != nil || now.After(deviceToken.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
//Rate Limiting check
|
||||
pollInterval := deviceToken.PollIntervalSeconds
|
||||
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
|
||||
if now.Before(minRequestTime) {
|
||||
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
|
||||
//Continually increase the poll interval until the user waits the proper time
|
||||
pollInterval += 5
|
||||
} else {
|
||||
pollInterval = 5
|
||||
}
|
||||
|
||||
switch deviceToken.Status {
|
||||
case deviceTokenPending:
|
||||
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||
old.PollIntervalSeconds = pollInterval
|
||||
old.LastRequestTime = now
|
||||
return old, nil
|
||||
}
|
||||
// Update device token last request time in storage
|
||||
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
|
||||
s.logger.Errorf("failed to update device token: %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||
case deviceTokenComplete:
|
||||
w.Write([]byte(deviceToken.Token))
|
||||
}
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
userCode := r.FormValue("state")
|
||||
code := r.FormValue("code")
|
||||
|
||||
if userCode == "" || code == "" {
|
||||
s.renderError(r, w, http.StatusBadRequest, "Request was missing parameters")
|
||||
return
|
||||
}
|
||||
|
||||
// Authorization redirect callback from OAuth2 auth flow.
|
||||
if errMsg := r.FormValue("error"); errMsg != "" {
|
||||
http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
authCode, err := s.storage.GetAuthCode(code)
|
||||
if err != nil || s.now().After(authCode.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get auth code: %v", err)
|
||||
}
|
||||
s.renderError(r, w, http.StatusBadRequest, "Invalid or expired auth code.")
|
||||
return
|
||||
}
|
||||
|
||||
//Grab the device request from storage
|
||||
deviceReq, err := s.storage.GetDeviceRequest(userCode)
|
||||
if err != nil || s.now().After(deviceReq.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
}
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.")
|
||||
return
|
||||
}
|
||||
|
||||
reqClient, err := s.storage.GetClient(deviceReq.ClientID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Failed to get reqClient %q: %v", deviceReq.ClientID, err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve device client.")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.exchangeAuthCode(w, authCode, reqClient)
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
|
||||
return
|
||||
}
|
||||
|
||||
//Grab the device request from storage
|
||||
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
|
||||
if err != nil || s.now().After(old.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device token: %v", err)
|
||||
}
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.")
|
||||
return
|
||||
}
|
||||
|
||||
updater := func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||
if old.Status == deviceTokenComplete {
|
||||
return old, errors.New("device token already complete")
|
||||
}
|
||||
respStr, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to marshal device token response: %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return old, err
|
||||
}
|
||||
|
||||
old.Token = string(respStr)
|
||||
old.Status = deviceTokenComplete
|
||||
return old, nil
|
||||
}
|
||||
|
||||
// Update refresh token in the storage, store the token and mark as complete
|
||||
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
|
||||
s.logger.Errorf("failed to update device token: %v", err)
|
||||
s.renderError(r, w, http.StatusInternalServerError, "")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.templates.deviceSuccess(r, w, reqClient.Name); err != nil {
|
||||
s.logger.Errorf("Server template error: %v", err)
|
||||
}
|
||||
|
||||
default:
|
||||
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
message := "Could not parse user code verification Request body"
|
||||
s.logger.Warnf("%s : %v", message, err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
userCode := r.Form.Get("user_code")
|
||||
if userCode == "" {
|
||||
s.renderError(r, w, http.StatusBadRequest, "No user code received")
|
||||
return
|
||||
}
|
||||
|
||||
userCode = strings.ToUpper(userCode)
|
||||
|
||||
//Find the user code in the available requests
|
||||
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
|
||||
if err != nil || s.now().After(deviceRequest.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device request: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
}
|
||||
if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, true); err != nil {
|
||||
s.logger.Errorf("Server template error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
//Redirect to Dex Auth Endpoint
|
||||
authURL := path.Join(s.issuerURL.Path, "/auth")
|
||||
u, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
s.renderError(r, w, http.StatusInternalServerError, "Invalid auth URI.")
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("client_id", deviceRequest.ClientID)
|
||||
q.Set("state", deviceRequest.UserCode)
|
||||
q.Set("response_type", "code")
|
||||
q.Set("redirect_uri", path.Join(s.issuerURL.Path, "/device/callback"))
|
||||
q.Set("scope", strings.Join(deviceRequest.Scopes, " "))
|
||||
u.RawQuery = q.Encode()
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusFound)
|
||||
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
@ -152,6 +152,8 @@ type discovery struct {
|
||||
Token string `json:"token_endpoint"`
|
||||
Keys string `json:"jwks_uri"`
|
||||
UserInfo string `json:"userinfo_endpoint"`
|
||||
DeviceEndpoint string `json:"device_authorization_endpoint"`
|
||||
GrantTypes []string `json:"grant_types_supported"'`
|
||||
ResponseTypes []string `json:"response_types_supported"`
|
||||
Subjects []string `json:"subject_types_supported"`
|
||||
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
||||
@ -167,7 +169,9 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
||||
Token: s.absURL("/token"),
|
||||
Keys: s.absURL("/keys"),
|
||||
UserInfo: s.absURL("/userinfo"),
|
||||
DeviceEndpoint: s.absURL("/device/code"),
|
||||
Subjects: []string{"public"},
|
||||
GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
|
||||
IDTokenAlgs: []string{string(jose.RS256)},
|
||||
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
||||
AuthMethods: []string{"client_secret_basic"},
|
||||
@ -783,24 +787,33 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
return
|
||||
}
|
||||
|
||||
tokenResponse, err := s.exchangeAuthCode(w, authCode, client)
|
||||
if err != nil {
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.writeAccessToken(w, tokenResponse)
|
||||
}
|
||||
|
||||
func (s *Server) exchangeAuthCode(w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenReponse, error) {
|
||||
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to create new access token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to create ID token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.storage.DeleteAuthCode(code); err != nil {
|
||||
if err := s.storage.DeleteAuthCode(authCode.ID); err != nil {
|
||||
s.logger.Errorf("failed to delete auth code: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqRefresh := func() bool {
|
||||
@ -847,13 +860,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
if refreshToken, err = internal.Marshal(token); err != nil {
|
||||
s.logger.Errorf("failed to marshal refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.storage.CreateRefresh(refresh); err != nil {
|
||||
s.logger.Errorf("failed to create refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// deleteToken determines if we need to delete the newly created refresh token
|
||||
@ -884,7 +897,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
s.logger.Errorf("failed to get offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
offlineSessions := storage.OfflineSessions{
|
||||
UserID: refresh.Claims.UserID,
|
||||
@ -899,7 +912,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
s.logger.Errorf("failed to create offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
|
||||
@ -908,7 +921,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
s.logger.Errorf("failed to delete refresh token: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -920,11 +933,11 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
||||
s.logger.Errorf("failed to update offline session: %v", err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
deleteToken = true
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
|
||||
return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
|
||||
}
|
||||
|
||||
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
||||
@ -1120,7 +1133,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
||||
return
|
||||
}
|
||||
|
||||
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
||||
resp := s.toAccessTokenResponse(idToken, accessToken, rawNewToken, expiry)
|
||||
s.writeAccessToken(w, resp)
|
||||
}
|
||||
|
||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||
@ -1378,12 +1392,26 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, r
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
IDToken string `json:"id_token"`
|
||||
}{
|
||||
|
||||
type accessTokenReponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenReponse {
|
||||
return &accessTokenReponse{
|
||||
accessToken,
|
||||
"bearer",
|
||||
int(expiry.Sub(s.now()).Seconds()),
|
||||
refreshToken,
|
||||
idToken,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenReponse) {
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
s.logger.Errorf("failed to marshal access token response: %v", err)
|
||||
@ -1414,145 +1442,3 @@ func usernamePrompt(conn connector.PasswordConnector) string {
|
||||
}
|
||||
return "Username"
|
||||
}
|
||||
|
||||
type deviceCodeResponse struct {
|
||||
//The unique device code for device authentication
|
||||
DeviceCode string `json:"device_code"`
|
||||
//The code the user will exchange via a browser and log in
|
||||
UserCode string `json:"user_code"`
|
||||
//The url to verify the user code.
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
//The lifetime of the device code
|
||||
ExpireTime int `json:"expires_in"`
|
||||
//How often the device is allowed to poll to verify that the user login occurred
|
||||
PollInterval int `json:"interval"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||
//TODO replace with configurable values
|
||||
expireIntervalSeconds := 300
|
||||
requestsPerMinute := 5
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Could not parse Device Request body: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
//Get the client id and scopes from the post
|
||||
clientID := r.Form.Get("client_id")
|
||||
scopes := r.Form["scope"]
|
||||
|
||||
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
|
||||
|
||||
//Make device code
|
||||
deviceCode := storage.NewDeviceCode()
|
||||
|
||||
//make user code
|
||||
userCode, err := storage.NewUserCode()
|
||||
if err != nil {
|
||||
s.logger.Errorf("Error generating user code: %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
//make a pkce verification code
|
||||
pkceCode := storage.NewID()
|
||||
|
||||
//Generate the expire time
|
||||
expireTime := time.Now().Add(time.Second * time.Duration(expireIntervalSeconds))
|
||||
|
||||
//Store the Device Request
|
||||
deviceReq := storage.DeviceRequest{
|
||||
UserCode: userCode,
|
||||
DeviceCode: deviceCode,
|
||||
ClientID: clientID,
|
||||
Scopes: scopes,
|
||||
PkceVerifier: pkceCode,
|
||||
Expiry: expireTime,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
|
||||
s.logger.Errorf("Failed to store device request; %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
//Store the device token
|
||||
deviceToken := storage.DeviceToken{
|
||||
DeviceCode: deviceCode,
|
||||
Status: deviceTokenPending,
|
||||
Expiry: expireTime,
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
||||
s.logger.Errorf("Failed to store device token %v", err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
code := deviceCodeResponse{
|
||||
DeviceCode: deviceCode,
|
||||
UserCode: userCode,
|
||||
VerificationURI: path.Join(s.issuerURL.String(), "/device"),
|
||||
ExpireTime: expireIntervalSeconds,
|
||||
PollInterval: requestsPerMinute,
|
||||
}
|
||||
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
enc.Encode(code)
|
||||
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Invalid device code request type")
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.Method {
|
||||
case http.MethodPost:
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
message := "Could not parse Device Token Request body"
|
||||
s.logger.Warnf("%s : %v", message, err)
|
||||
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
deviceCode := r.Form.Get("device_code")
|
||||
if deviceCode == "" {
|
||||
message := "No device code received"
|
||||
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.PostFormValue("grant_type")
|
||||
if grantType != grantTypeDeviceCode {
|
||||
s.tokenErrHelper(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
//Grab the device token from the db
|
||||
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
|
||||
if err != nil || s.now().After(deviceToken.Expiry) {
|
||||
if err != storage.ErrNotFound {
|
||||
s.logger.Errorf("failed to get device code: %v", err)
|
||||
}
|
||||
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
switch deviceToken.Status {
|
||||
case deviceTokenPending:
|
||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||
case deviceTokenComplete:
|
||||
w.Write([]byte(deviceToken.Token))
|
||||
}
|
||||
default:
|
||||
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
|
||||
}
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ const (
|
||||
grantTypeAuthorizationCode = "authorization_code"
|
||||
grantTypeRefreshToken = "refresh_token"
|
||||
grantTypePassword = "password"
|
||||
grantTypeDeviceCode = "device_code"
|
||||
grantTypeDeviceCode = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -134,6 +134,8 @@ const (
|
||||
const (
|
||||
deviceTokenPending = "authorization_pending"
|
||||
deviceTokenComplete = "complete"
|
||||
deviceTokenSlowDown = "slow_down"
|
||||
deviceTokenExpired = "expired_token"
|
||||
)
|
||||
|
||||
func parseScopes(scopes []string) connector.Scopes {
|
||||
|
@ -78,6 +78,7 @@ type Config struct {
|
||||
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||
AuthRequestsValidFor time.Duration // Defaults to 24 hours
|
||||
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
|
||||
// If set, the server will use this connector to handle password grants
|
||||
PasswordConnector string
|
||||
|
||||
@ -158,6 +159,7 @@ type Server struct {
|
||||
|
||||
idTokensValidFor time.Duration
|
||||
authRequestsValidFor time.Duration
|
||||
deviceRequestsValidFor time.Duration
|
||||
|
||||
logger log.Logger
|
||||
}
|
||||
@ -219,6 +221,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||
supportedResponseTypes: supported,
|
||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||
authRequestsValidFor: value(c.AuthRequestsValidFor, 24*time.Hour),
|
||||
deviceRequestsValidFor: value(c.DeviceRequestsValidFor, 5*time.Minute),
|
||||
skipApproval: c.SkipApprovalScreen,
|
||||
alwaysShowLogin: c.AlwaysShowLoginScreen,
|
||||
now: now,
|
||||
@ -302,8 +305,11 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||
handleWithCORS("/userinfo", s.handleUserInfo)
|
||||
handleFunc("/auth", s.handleAuthorization)
|
||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||
handleFunc("/device", s.handleDeviceExchange)
|
||||
handleFunc("/device/auth/verify_code", s.verifyUserCode)
|
||||
handleFunc("/device/code", s.handleDeviceCode)
|
||||
handleFunc("/device/token", s.handleDeviceToken)
|
||||
handleFunc("/device/callback", s.handleDeviceCallback)
|
||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||
// Strip the X-Remote-* headers to prevent security issues on
|
||||
// misconfigured authproxy connector setups.
|
||||
|
@ -20,6 +20,8 @@ const (
|
||||
tmplPassword = "password.html"
|
||||
tmplOOB = "oob.html"
|
||||
tmplError = "error.html"
|
||||
tmplDevice = "device.html"
|
||||
tmplDeviceSuccess = "device_success.html"
|
||||
)
|
||||
|
||||
var requiredTmpls = []string{
|
||||
@ -28,6 +30,7 @@ var requiredTmpls = []string{
|
||||
tmplPassword,
|
||||
tmplOOB,
|
||||
tmplError,
|
||||
tmplDevice,
|
||||
}
|
||||
|
||||
type templates struct {
|
||||
@ -36,6 +39,8 @@ type templates struct {
|
||||
passwordTmpl *template.Template
|
||||
oobTmpl *template.Template
|
||||
errorTmpl *template.Template
|
||||
deviceTmpl *template.Template
|
||||
deviceSuccessTmpl *template.Template
|
||||
}
|
||||
|
||||
type webConfig struct {
|
||||
@ -157,6 +162,8 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
|
||||
passwordTmpl: tmpls.Lookup(tmplPassword),
|
||||
oobTmpl: tmpls.Lookup(tmplOOB),
|
||||
errorTmpl: tmpls.Lookup(tmplError),
|
||||
deviceTmpl: tmpls.Lookup(tmplDevice),
|
||||
deviceSuccessTmpl: tmpls.Lookup(tmplDeviceSuccess),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -242,6 +249,24 @@ func (n byName) Len() int { return len(n) }
|
||||
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
|
||||
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
|
||||
|
||||
func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error {
|
||||
data := struct {
|
||||
PostURL string
|
||||
UserCode string
|
||||
Invalid bool
|
||||
ReqPath string
|
||||
}{postURL, userCode, lastWasInvalid, r.URL.Path}
|
||||
return renderTemplate(w, t.deviceTmpl, data)
|
||||
}
|
||||
|
||||
func (t *templates) deviceSuccess(r *http.Request, w http.ResponseWriter, clientName string) error {
|
||||
data := struct {
|
||||
ClientName string
|
||||
ReqPath string
|
||||
}{clientName, r.URL.Path}
|
||||
return renderTemplate(w, t.deviceSuccessTmpl, data)
|
||||
}
|
||||
|
||||
func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
|
||||
sort.Sort(byName(connectors))
|
||||
data := struct {
|
||||
|
@ -847,7 +847,6 @@ func testGC(t *testing.T, s storage.Storage) {
|
||||
DeviceCode: storage.NewID(),
|
||||
ClientID: "client1",
|
||||
Scopes: []string{"openid", "email"},
|
||||
PkceVerifier: storage.NewID(),
|
||||
Expiry: expiry,
|
||||
}
|
||||
|
||||
@ -974,7 +973,6 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
||||
DeviceCode: storage.NewID(),
|
||||
ClientID: "client1",
|
||||
Scopes: []string{"openid", "email"},
|
||||
PkceVerifier: storage.NewID(),
|
||||
Expiry: neverExpire,
|
||||
}
|
||||
|
||||
@ -991,20 +989,44 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
|
||||
}
|
||||
|
||||
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
|
||||
//Create a Token
|
||||
d1 := storage.DeviceToken{
|
||||
DeviceCode: storage.NewID(),
|
||||
Status: "pending",
|
||||
Token: storage.NewID(),
|
||||
Expiry: neverExpire,
|
||||
LastRequestTime: time.Now(),
|
||||
PollIntervalSeconds: 0,
|
||||
}
|
||||
|
||||
if err := s.CreateDeviceToken(d1); err != nil {
|
||||
t.Fatalf("failed creating device token: %v", err)
|
||||
}
|
||||
|
||||
// Attempt to create same DeviceRequest twice.
|
||||
// Attempt to create same Device Token twice.
|
||||
err := s.CreateDeviceToken(d1)
|
||||
mustBeErrAlreadyExists(t, "device token", err)
|
||||
|
||||
//TODO Add update / delete tests as functionality is put into main code
|
||||
//Update the device token, simulate a redemption
|
||||
if err := s.UpdateDeviceToken(d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) {
|
||||
old.Token = "token data"
|
||||
old.Status = "complete"
|
||||
return old, nil
|
||||
}); err != nil {
|
||||
t.Fatalf("failed to update device token: %v", err)
|
||||
}
|
||||
|
||||
//Retrieve the device token
|
||||
got, err := s.GetDeviceToken(d1.DeviceCode)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get device token: %v", err)
|
||||
}
|
||||
|
||||
//Validate expected result set
|
||||
if got.Status != "complete" {
|
||||
t.Fatalf("update failed, wanted token status=%#v got %#v", "complete", got.Status)
|
||||
}
|
||||
if got.Token != "token data" {
|
||||
t.Fatalf("update failed, wanted token =%#v got %#v", "token data", got.Token)
|
||||
}
|
||||
}
|
||||
|
@ -570,6 +570,13 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
|
||||
}
|
||||
|
||||
func (c *conn) GetDeviceRequest(userCode string) (r storage.DeviceRequest, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
err = c.getKey(ctx, keyID(deviceRequestPrefix, userCode), &r)
|
||||
return r, err
|
||||
}
|
||||
|
||||
func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest, err error) {
|
||||
res, err := c.db.Get(ctx, deviceRequestPrefix, clientv3.WithPrefix())
|
||||
if err != nil {
|
||||
@ -612,3 +619,21 @@ func (c *conn) listDeviceTokens(ctx context.Context) (deviceTokens []DeviceToken
|
||||
}
|
||||
return deviceTokens, nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
|
||||
defer cancel()
|
||||
return c.txnUpdate(ctx, keyID(deviceTokenPrefix, deviceCode), func(currentValue []byte) ([]byte, error) {
|
||||
var current DeviceToken
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
updated, err := updater(toStorageDeviceToken(current))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return json.Marshal(fromStorageDeviceToken(updated))
|
||||
})
|
||||
}
|
||||
|
@ -223,7 +223,6 @@ type DeviceRequest struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
ClientID string `json:"client_id"`
|
||||
Scopes []string `json:"scopes"`
|
||||
PkceVerifier string `json:"pkce_verifier"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
|
||||
@ -233,7 +232,6 @@ func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest {
|
||||
DeviceCode: d.DeviceCode,
|
||||
ClientID: d.ClientID,
|
||||
Scopes: d.Scopes,
|
||||
PkceVerifier: d.PkceVerifier,
|
||||
Expiry: d.Expiry,
|
||||
}
|
||||
}
|
||||
@ -244,6 +242,8 @@ type DeviceToken struct {
|
||||
Status string `json:"status"`
|
||||
Token string `json:"token"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
LastRequestTime time.Time `json:"last_request"`
|
||||
PollIntervalSeconds int `json:"poll_interval"`
|
||||
}
|
||||
|
||||
func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||
@ -252,5 +252,18 @@ func fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||
Status: t.Status,
|
||||
Token: t.Token,
|
||||
Expiry: t.Expiry,
|
||||
LastRequestTime: t.LastRequestTime,
|
||||
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
|
||||
return storage.DeviceToken{
|
||||
DeviceCode: t.DeviceCode,
|
||||
Status: t.Status,
|
||||
Token: t.Token,
|
||||
Expiry: t.Expiry,
|
||||
LastRequestTime: t.LastRequestTime,
|
||||
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||
}
|
||||
}
|
||||
|
@ -638,6 +638,14 @@ func (cli *client) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||
return cli.post(resourceDeviceRequest, cli.fromStorageDeviceRequest(d))
|
||||
}
|
||||
|
||||
func (cli *client) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||
var req DeviceRequest
|
||||
if err := cli.get(resourceDeviceRequest, strings.ToLower(userCode), &req); err != nil {
|
||||
return storage.DeviceRequest{}, err
|
||||
}
|
||||
return toStorageDeviceRequest(req), nil
|
||||
}
|
||||
|
||||
func (cli *client) CreateDeviceToken(t storage.DeviceToken) error {
|
||||
return cli.post(resourceDeviceToken, cli.fromStorageDeviceToken(t))
|
||||
}
|
||||
@ -649,3 +657,24 @@ func (cli *client) GetDeviceToken(deviceCode string) (storage.DeviceToken, error
|
||||
}
|
||||
return toStorageDeviceToken(token), nil
|
||||
}
|
||||
|
||||
func (cli *client) getDeviceToken(deviceCode string) (t DeviceToken, err error) {
|
||||
err = cli.get(resourceDeviceToken, deviceCode, &t)
|
||||
return
|
||||
}
|
||||
|
||||
func (cli *client) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||
r, err := cli.getDeviceToken(deviceCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updated, err := updater(toStorageDeviceToken(r))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
updated.DeviceCode = deviceCode
|
||||
|
||||
newToken := cli.fromStorageDeviceToken(updated)
|
||||
newToken.ObjectMeta = r.ObjectMeta
|
||||
return cli.put(resourceDeviceToken, r.ObjectMeta.Name, newToken)
|
||||
}
|
||||
|
@ -675,7 +675,6 @@ type DeviceRequest struct {
|
||||
DeviceCode string `json:"device_code,omitempty"`
|
||||
CLientID string `json:"client_id,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
PkceVerifier string `json:"pkce_verifier,omitempty"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
|
||||
@ -699,12 +698,21 @@ func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceReque
|
||||
DeviceCode: a.DeviceCode,
|
||||
CLientID: a.ClientID,
|
||||
Scopes: a.Scopes,
|
||||
PkceVerifier: a.PkceVerifier,
|
||||
Expiry: a.Expiry,
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest {
|
||||
return storage.DeviceRequest{
|
||||
UserCode: strings.ToUpper(req.ObjectMeta.Name),
|
||||
DeviceCode: req.DeviceCode,
|
||||
ClientID: req.CLientID,
|
||||
Scopes: req.Scopes,
|
||||
Expiry: req.Expiry,
|
||||
}
|
||||
}
|
||||
|
||||
// DeviceToken is a mirrored struct from storage with JSON struct tags and
|
||||
// Kubernetes type metadata.
|
||||
type DeviceToken struct {
|
||||
@ -714,6 +722,8 @@ type DeviceToken struct {
|
||||
Status string `json:"status,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
LastRequestTime time.Time `json:"last_request"`
|
||||
PollIntervalSeconds int `json:"poll_interval"`
|
||||
}
|
||||
|
||||
// DeviceTokenList is a list of DeviceTokens.
|
||||
@ -736,6 +746,8 @@ func (cli *client) fromStorageDeviceToken(t storage.DeviceToken) DeviceToken {
|
||||
Status: t.Status,
|
||||
Token: t.Token,
|
||||
Expiry: t.Expiry,
|
||||
LastRequestTime: t.LastRequestTime,
|
||||
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||
}
|
||||
return req
|
||||
}
|
||||
@ -746,5 +758,7 @@ func toStorageDeviceToken(t DeviceToken) storage.DeviceToken {
|
||||
Status: t.Status,
|
||||
Token: t.Token,
|
||||
Expiry: t.Expiry,
|
||||
LastRequestTime: t.LastRequestTime,
|
||||
PollIntervalSeconds: t.PollIntervalSeconds,
|
||||
}
|
||||
}
|
||||
|
@ -493,6 +493,17 @@ func (s *memStorage) CreateDeviceRequest(d storage.DeviceRequest) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) GetDeviceRequest(userCode string) (req storage.DeviceRequest, err error) {
|
||||
s.tx(func() {
|
||||
var ok bool
|
||||
if req, ok = s.deviceRequests[userCode]; !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) CreateDeviceToken(t storage.DeviceToken) (err error) {
|
||||
s.tx(func() {
|
||||
if _, ok := s.deviceTokens[t.DeviceCode]; ok {
|
||||
@ -514,3 +525,17 @@ func (s *memStorage) GetDeviceToken(deviceCode string) (t storage.DeviceToken, e
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *memStorage) UpdateDeviceToken(deviceCode string, updater func(p storage.DeviceToken) (storage.DeviceToken, error)) (err error) {
|
||||
s.tx(func() {
|
||||
r, ok := s.deviceTokens[deviceCode]
|
||||
if !ok {
|
||||
err = storage.ErrNotFound
|
||||
return
|
||||
}
|
||||
if r, err = updater(r); err == nil {
|
||||
s.deviceTokens[deviceCode] = r
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@ -888,12 +888,12 @@ func (c *conn) delete(table, field, id string) error {
|
||||
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||
_, err := c.Exec(`
|
||||
insert into device_request (
|
||||
user_code, device_code, client_id, scopes, pkce_verifier, expiry
|
||||
user_code, device_code, client_id, scopes, expiry
|
||||
)
|
||||
values (
|
||||
$1, $2, $3, $4, $5, $6
|
||||
$1, $2, $3, $4, $5
|
||||
);`,
|
||||
d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.PkceVerifier, d.Expiry,
|
||||
d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
@ -907,12 +907,12 @@ func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
|
||||
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
||||
_, err := c.Exec(`
|
||||
insert into device_token (
|
||||
device_code, status, token, expiry
|
||||
device_code, status, token, expiry, last_request, poll_interval
|
||||
)
|
||||
values (
|
||||
$1, $2, $3, $4
|
||||
$1, $2, $3, $4, $5, $6
|
||||
);`,
|
||||
t.DeviceCode, t.Status, t.Token, t.Expiry,
|
||||
t.DeviceCode, t.Status, t.Token, t.Expiry, t.LastRequestTime, t.PollIntervalSeconds,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
@ -923,6 +923,28 @@ func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||
return getDeviceRequest(c, userCode)
|
||||
}
|
||||
|
||||
func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
|
||||
err = q.QueryRow(`
|
||||
select
|
||||
device_code, client_id, scopes, expiry
|
||||
from device_request where user_code = $1;
|
||||
`, userCode).Scan(
|
||||
&d.DeviceCode, &d.ClientID, decoder(&d.Scopes), &d.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return d, storage.ErrNotFound
|
||||
}
|
||||
return d, fmt.Errorf("select device token: %v", err)
|
||||
}
|
||||
d.UserCode = userCode
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||
return getDeviceToken(c, deviceCode)
|
||||
}
|
||||
@ -930,10 +952,10 @@ func (c *conn) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||
func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err error) {
|
||||
err = q.QueryRow(`
|
||||
select
|
||||
status, token, expiry
|
||||
status, token, expiry, last_request, poll_interval
|
||||
from device_token where device_code = $1;
|
||||
`, deviceCode).Scan(
|
||||
&a.Status, &a.Token, &a.Expiry,
|
||||
&a.Status, &a.Token, &a.Expiry, &a.LastRequestTime, &a.PollIntervalSeconds,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
@ -944,3 +966,31 @@ func getDeviceToken(q querier, deviceCode string) (a storage.DeviceToken, err er
|
||||
a.DeviceCode = deviceCode
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (c *conn) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||
return c.ExecTx(func(tx *trans) error {
|
||||
r, err := getDeviceToken(tx, deviceCode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r, err = updater(r); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(`
|
||||
update device_token
|
||||
set
|
||||
status = $1,
|
||||
token = $2,
|
||||
last_request = $3,
|
||||
poll_interval = $4
|
||||
where
|
||||
device_code = $5
|
||||
`,
|
||||
r.Status, r.Token, r.LastRequestTime, r.PollIntervalSeconds, r.DeviceCode,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update device token: %v", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
@ -236,7 +236,6 @@ var migrations = []migration{
|
||||
device_code text not null,
|
||||
client_id text not null,
|
||||
scopes bytea not null, -- JSON array of strings
|
||||
pkce_verifier text not null,
|
||||
expiry timestamptz not null
|
||||
);`,
|
||||
`
|
||||
@ -244,7 +243,9 @@ var migrations = []migration{
|
||||
device_code text not null primary key,
|
||||
status text not null,
|
||||
token text,
|
||||
expiry timestamptz not null
|
||||
expiry timestamptz not null,
|
||||
last_request timestamptz not null,
|
||||
poll_interval integer not null
|
||||
);`,
|
||||
},
|
||||
},
|
||||
|
@ -82,6 +82,7 @@ type Storage interface {
|
||||
GetPassword(email string) (Password, error)
|
||||
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
||||
GetConnector(id string) (Connector, error)
|
||||
GetDeviceRequest(userCode string) (DeviceRequest, error)
|
||||
GetDeviceToken(deviceCode string) (DeviceToken, error)
|
||||
|
||||
ListClients() ([]Client, error)
|
||||
@ -119,6 +120,7 @@ type Storage interface {
|
||||
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
||||
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
|
||||
UpdateConnector(id string, updater func(c Connector) (Connector, error)) error
|
||||
UpdateDeviceToken(deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error
|
||||
|
||||
// GarbageCollect deletes all expired AuthCodes,AuthRequests, DeviceRequests, and DeviceTokens.
|
||||
GarbageCollect(now time.Time) (GCResult, error)
|
||||
@ -392,8 +394,6 @@ type DeviceRequest struct {
|
||||
ClientID string
|
||||
//The scopes the device requests
|
||||
Scopes []string
|
||||
//PKCE Verification
|
||||
PkceVerifier string
|
||||
//The expire time
|
||||
Expiry time.Time
|
||||
}
|
||||
@ -403,4 +403,6 @@ type DeviceToken struct {
|
||||
Status string
|
||||
Token string
|
||||
Expiry time.Time
|
||||
LastRequestTime time.Time
|
||||
PollIntervalSeconds int
|
||||
}
|
||||
|
23
web/templates/device.html
Normal file
23
web/templates/device.html
Normal file
@ -0,0 +1,23 @@
|
||||
{{ template "header.html" . }}
|
||||
|
||||
<div class="theme-panel">
|
||||
<h2 class="theme-heading">Enter User Code</h2>
|
||||
<form method="post" action="{{ .PostURL }}" method="get">
|
||||
<div class="theme-form-row">
|
||||
{{ if( .UserCode )}}
|
||||
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" autocomplete="off" value="{{.UserCode}}" {{ if .Invalid }} autofocus {{ end }}/>
|
||||
{{ else }}
|
||||
<input tabindex="2" required id="user_code" name="user_code" type="text" class="theme-form-input" placeholder="XXXX-XXXX" autocomplete="off" {{ if .Invalid }} autofocus {{ end }}/>
|
||||
{{ end }}
|
||||
</div>
|
||||
|
||||
{{ if .Invalid }}
|
||||
<div id="login-error" class="dex-error-box">
|
||||
Invalid or Expired User Code
|
||||
</div>
|
||||
{{ end }}
|
||||
<button tabindex="3" id="submit-login" type="submit" class="dex-btn theme-btn--primary">Submit</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
{{ template "footer.html" . }}
|
8
web/templates/device_success.html
Normal file
8
web/templates/device_success.html
Normal file
@ -0,0 +1,8 @@
|
||||
{{ template "header.html" . }}
|
||||
|
||||
<div class="theme-panel">
|
||||
<h2 class="theme-heading">Login Successful for {{ .ClientName }}</h2>
|
||||
<p>Return to your device to continue</p>
|
||||
</div>
|
||||
|
||||
{{ template "footer.html" . }}
|
Reference in New Issue
Block a user