add PKCE support to device code flow (#2575)
Signed-off-by: Bob Callaway <bobcallaway@users.noreply.github.com>
This commit is contained in:
@@ -73,6 +73,17 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||
clientID := r.Form.Get("client_id")
|
||||
clientSecret := r.Form.Get("client_secret")
|
||||
scopes := strings.Fields(r.Form.Get("scope"))
|
||||
codeChallenge := r.Form.Get("code_challenge")
|
||||
codeChallengeMethod := r.Form.Get("code_challenge_method")
|
||||
|
||||
if codeChallengeMethod == "" {
|
||||
codeChallengeMethod = codeChallengeMethodPlain
|
||||
}
|
||||
if codeChallengeMethod != codeChallengeMethodS256 && codeChallengeMethod != codeChallengeMethodPlain {
|
||||
description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod)
|
||||
s.tokenErrHelper(w, errInvalidRequest, description, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
|
||||
|
||||
@@ -108,6 +119,10 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
|
||||
Expiry: expireTime,
|
||||
LastRequestTime: s.now(),
|
||||
PollIntervalSeconds: 0,
|
||||
PKCE: storage.PKCE{
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
},
|
||||
}
|
||||
|
||||
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
|
||||
@@ -236,6 +251,30 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
|
||||
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
|
||||
}
|
||||
case deviceTokenComplete:
|
||||
codeChallengeFromStorage := deviceToken.PKCE.CodeChallenge
|
||||
providedCodeVerifier := r.Form.Get("code_verifier")
|
||||
|
||||
switch {
|
||||
case providedCodeVerifier != "" && codeChallengeFromStorage != "":
|
||||
calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, deviceToken.PKCE.CodeChallengeMethod)
|
||||
if err != nil {
|
||||
s.logger.Error(err)
|
||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if codeChallengeFromStorage != calculatedCodeChallenge {
|
||||
s.tokenErrHelper(w, errInvalidGrant, "Invalid code_verifier.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
case providedCodeVerifier != "":
|
||||
// Received no code_challenge on /auth, but a code_verifier on /token
|
||||
s.tokenErrHelper(w, errInvalidRequest, "No PKCE flow started. Cannot check code_verifier.", http.StatusBadRequest)
|
||||
return
|
||||
case codeChallengeFromStorage != "":
|
||||
// Received PKCE request on /auth, but no code_verifier on /token
|
||||
s.tokenErrHelper(w, errInvalidGrant, "Expecting parameter code_verifier in PKCE flow.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Write([]byte(deviceToken.Token))
|
||||
}
|
||||
}
|
||||
|
@@ -49,6 +49,7 @@ func TestHandleDeviceCode(t *testing.T) {
|
||||
tests := []struct {
|
||||
testName string
|
||||
clientID string
|
||||
codeChallengeMethod string
|
||||
requestType string
|
||||
scopes []string
|
||||
expectedResponseCode int
|
||||
@@ -71,6 +72,24 @@ func TestHandleDeviceCode(t *testing.T) {
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
expectedContentType: "application/json",
|
||||
},
|
||||
{
|
||||
testName: "New Code with valid PKCE",
|
||||
clientID: "test",
|
||||
requestType: "POST",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
codeChallengeMethod: "S256",
|
||||
expectedResponseCode: http.StatusOK,
|
||||
expectedContentType: "application/json",
|
||||
},
|
||||
{
|
||||
testName: "Invalid code challenge method",
|
||||
clientID: "test",
|
||||
requestType: "POST",
|
||||
codeChallengeMethod: "invalid",
|
||||
scopes: []string{"openid", "profile", "email"},
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
expectedContentType: "application/json",
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.testName, func(t *testing.T) {
|
||||
@@ -92,6 +111,7 @@ func TestHandleDeviceCode(t *testing.T) {
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("client_id", tc.clientID)
|
||||
data.Set("code_challenge_method", tc.codeChallengeMethod)
|
||||
for _, scope := range tc.scopes {
|
||||
data.Add("scope", scope)
|
||||
}
|
||||
@@ -401,6 +421,13 @@ func TestDeviceTokenResponse(t *testing.T) {
|
||||
|
||||
now := func() time.Time { return t0 }
|
||||
|
||||
// Base PKCE values
|
||||
// base64-urlencoded, sha256 digest of code_verifier
|
||||
codeChallenge := "L7ZqsT_zNwvrH8E7J0CqPHx1wgBaFiaE-fAZcKUUAbc"
|
||||
codeChallengeMethod := "S256"
|
||||
// "random" string between 43 & 128 ASCII characters
|
||||
codeVerifier := "66114650f56cc45dee7ee03c49f048ddf9aa53cbf5b09985832fa4f790ff2604"
|
||||
|
||||
baseDeviceRequest := storage.DeviceRequest{
|
||||
UserCode: "ABCD-WXYZ",
|
||||
DeviceCode: "foo",
|
||||
@@ -415,6 +442,7 @@ func TestDeviceTokenResponse(t *testing.T) {
|
||||
testDeviceToken storage.DeviceToken
|
||||
testGrantType string
|
||||
testDeviceCode string
|
||||
testCodeVerifier string
|
||||
expectedServerResponse string
|
||||
expectedResponseCode int
|
||||
}{
|
||||
@@ -524,6 +552,101 @@ func TestDeviceTokenResponse(t *testing.T) {
|
||||
expectedServerResponse: "{\"access_token\": \"foobar\"}",
|
||||
expectedResponseCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
testName: "Successful Exchange with PKCE",
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "foo",
|
||||
Status: deviceTokenComplete,
|
||||
Token: "{\"access_token\": \"foobar\"}",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
PKCE: storage.PKCE{
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
},
|
||||
},
|
||||
testDeviceCode: "foo",
|
||||
testCodeVerifier: codeVerifier,
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
expectedServerResponse: "{\"access_token\": \"foobar\"}",
|
||||
expectedResponseCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
testName: "Test Exchange started with PKCE but without verifier provided",
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "foo",
|
||||
Status: deviceTokenComplete,
|
||||
Token: "{\"access_token\": \"foobar\"}",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
PKCE: storage.PKCE{
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
},
|
||||
},
|
||||
testDeviceCode: "foo",
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
expectedServerResponse: errInvalidGrant,
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Test Exchange not started with PKCE but verifier provided",
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "foo",
|
||||
Status: deviceTokenComplete,
|
||||
Token: "{\"access_token\": \"foobar\"}",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
},
|
||||
testDeviceCode: "foo",
|
||||
testCodeVerifier: codeVerifier,
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
expectedServerResponse: errInvalidRequest,
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Test with PKCE but incorrect verifier provided",
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "foo",
|
||||
Status: deviceTokenComplete,
|
||||
Token: "{\"access_token\": \"foobar\"}",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
PKCE: storage.PKCE{
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
},
|
||||
},
|
||||
testDeviceCode: "foo",
|
||||
testCodeVerifier: "invalid",
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
expectedServerResponse: errInvalidGrant,
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
testName: "Test with PKCE but incorrect challenge provided",
|
||||
testDeviceToken: storage.DeviceToken{
|
||||
DeviceCode: "foo",
|
||||
Status: deviceTokenComplete,
|
||||
Token: "{\"access_token\": \"foobar\"}",
|
||||
Expiry: now().Add(5 * time.Minute),
|
||||
LastRequestTime: time.Time{},
|
||||
PollIntervalSeconds: 0,
|
||||
PKCE: storage.PKCE{
|
||||
CodeChallenge: "invalid",
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
},
|
||||
},
|
||||
testDeviceCode: "foo",
|
||||
testCodeVerifier: codeVerifier,
|
||||
testDeviceRequest: baseDeviceRequest,
|
||||
expectedServerResponse: errInvalidGrant,
|
||||
expectedResponseCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.testName, func(t *testing.T) {
|
||||
@@ -558,6 +681,9 @@ func TestDeviceTokenResponse(t *testing.T) {
|
||||
}
|
||||
data.Set("grant_type", grantType)
|
||||
data.Set("device_code", tc.testDeviceCode)
|
||||
if tc.testCodeVerifier != "" {
|
||||
data.Set("code_verifier", tc.testCodeVerifier)
|
||||
}
|
||||
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")
|
||||
|
||||
|
Reference in New Issue
Block a user