add PKCE support to device code flow (#2575)

Signed-off-by: Bob Callaway <bobcallaway@users.noreply.github.com>
This commit is contained in:
Bob Callaway
2022-07-27 09:02:18 -07:00
committed by GitHub
parent 454122ca22
commit 83e2df821e
20 changed files with 790 additions and 32 deletions

View File

@@ -73,6 +73,17 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
clientID := r.Form.Get("client_id")
clientSecret := r.Form.Get("client_secret")
scopes := strings.Fields(r.Form.Get("scope"))
codeChallenge := r.Form.Get("code_challenge")
codeChallengeMethod := r.Form.Get("code_challenge_method")
if codeChallengeMethod == "" {
codeChallengeMethod = codeChallengeMethodPlain
}
if codeChallengeMethod != codeChallengeMethodS256 && codeChallengeMethod != codeChallengeMethodPlain {
description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod)
s.tokenErrHelper(w, errInvalidRequest, description, http.StatusBadRequest)
return
}
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
@@ -108,6 +119,10 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
Expiry: expireTime,
LastRequestTime: s.now(),
PollIntervalSeconds: 0,
PKCE: storage.PKCE{
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
},
}
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
@@ -236,6 +251,30 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
}
case deviceTokenComplete:
codeChallengeFromStorage := deviceToken.PKCE.CodeChallenge
providedCodeVerifier := r.Form.Get("code_verifier")
switch {
case providedCodeVerifier != "" && codeChallengeFromStorage != "":
calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, deviceToken.PKCE.CodeChallengeMethod)
if err != nil {
s.logger.Error(err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
if codeChallengeFromStorage != calculatedCodeChallenge {
s.tokenErrHelper(w, errInvalidGrant, "Invalid code_verifier.", http.StatusBadRequest)
return
}
case providedCodeVerifier != "":
// Received no code_challenge on /auth, but a code_verifier on /token
s.tokenErrHelper(w, errInvalidRequest, "No PKCE flow started. Cannot check code_verifier.", http.StatusBadRequest)
return
case codeChallengeFromStorage != "":
// Received PKCE request on /auth, but no code_verifier on /token
s.tokenErrHelper(w, errInvalidGrant, "Expecting parameter code_verifier in PKCE flow.", http.StatusBadRequest)
return
}
w.Write([]byte(deviceToken.Token))
}
}

View File

@@ -49,6 +49,7 @@ func TestHandleDeviceCode(t *testing.T) {
tests := []struct {
testName string
clientID string
codeChallengeMethod string
requestType string
scopes []string
expectedResponseCode int
@@ -71,6 +72,24 @@ func TestHandleDeviceCode(t *testing.T) {
expectedResponseCode: http.StatusBadRequest,
expectedContentType: "application/json",
},
{
testName: "New Code with valid PKCE",
clientID: "test",
requestType: "POST",
scopes: []string{"openid", "profile", "email"},
codeChallengeMethod: "S256",
expectedResponseCode: http.StatusOK,
expectedContentType: "application/json",
},
{
testName: "Invalid code challenge method",
clientID: "test",
requestType: "POST",
codeChallengeMethod: "invalid",
scopes: []string{"openid", "profile", "email"},
expectedResponseCode: http.StatusBadRequest,
expectedContentType: "application/json",
},
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
@@ -92,6 +111,7 @@ func TestHandleDeviceCode(t *testing.T) {
data := url.Values{}
data.Set("client_id", tc.clientID)
data.Set("code_challenge_method", tc.codeChallengeMethod)
for _, scope := range tc.scopes {
data.Add("scope", scope)
}
@@ -401,6 +421,13 @@ func TestDeviceTokenResponse(t *testing.T) {
now := func() time.Time { return t0 }
// Base PKCE values
// base64-urlencoded, sha256 digest of code_verifier
codeChallenge := "L7ZqsT_zNwvrH8E7J0CqPHx1wgBaFiaE-fAZcKUUAbc"
codeChallengeMethod := "S256"
// "random" string between 43 & 128 ASCII characters
codeVerifier := "66114650f56cc45dee7ee03c49f048ddf9aa53cbf5b09985832fa4f790ff2604"
baseDeviceRequest := storage.DeviceRequest{
UserCode: "ABCD-WXYZ",
DeviceCode: "foo",
@@ -415,6 +442,7 @@ func TestDeviceTokenResponse(t *testing.T) {
testDeviceToken storage.DeviceToken
testGrantType string
testDeviceCode string
testCodeVerifier string
expectedServerResponse string
expectedResponseCode int
}{
@@ -524,6 +552,101 @@ func TestDeviceTokenResponse(t *testing.T) {
expectedServerResponse: "{\"access_token\": \"foobar\"}",
expectedResponseCode: http.StatusOK,
},
{
testName: "Successful Exchange with PKCE",
testDeviceToken: storage.DeviceToken{
DeviceCode: "foo",
Status: deviceTokenComplete,
Token: "{\"access_token\": \"foobar\"}",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
PKCE: storage.PKCE{
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
},
},
testDeviceCode: "foo",
testCodeVerifier: codeVerifier,
testDeviceRequest: baseDeviceRequest,
expectedServerResponse: "{\"access_token\": \"foobar\"}",
expectedResponseCode: http.StatusOK,
},
{
testName: "Test Exchange started with PKCE but without verifier provided",
testDeviceToken: storage.DeviceToken{
DeviceCode: "foo",
Status: deviceTokenComplete,
Token: "{\"access_token\": \"foobar\"}",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
PKCE: storage.PKCE{
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
},
},
testDeviceCode: "foo",
testDeviceRequest: baseDeviceRequest,
expectedServerResponse: errInvalidGrant,
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Test Exchange not started with PKCE but verifier provided",
testDeviceToken: storage.DeviceToken{
DeviceCode: "foo",
Status: deviceTokenComplete,
Token: "{\"access_token\": \"foobar\"}",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
},
testDeviceCode: "foo",
testCodeVerifier: codeVerifier,
testDeviceRequest: baseDeviceRequest,
expectedServerResponse: errInvalidRequest,
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Test with PKCE but incorrect verifier provided",
testDeviceToken: storage.DeviceToken{
DeviceCode: "foo",
Status: deviceTokenComplete,
Token: "{\"access_token\": \"foobar\"}",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
PKCE: storage.PKCE{
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
},
},
testDeviceCode: "foo",
testCodeVerifier: "invalid",
testDeviceRequest: baseDeviceRequest,
expectedServerResponse: errInvalidGrant,
expectedResponseCode: http.StatusBadRequest,
},
{
testName: "Test with PKCE but incorrect challenge provided",
testDeviceToken: storage.DeviceToken{
DeviceCode: "foo",
Status: deviceTokenComplete,
Token: "{\"access_token\": \"foobar\"}",
Expiry: now().Add(5 * time.Minute),
LastRequestTime: time.Time{},
PollIntervalSeconds: 0,
PKCE: storage.PKCE{
CodeChallenge: "invalid",
CodeChallengeMethod: codeChallengeMethod,
},
},
testDeviceCode: "foo",
testCodeVerifier: codeVerifier,
testDeviceRequest: baseDeviceRequest,
expectedServerResponse: errInvalidGrant,
expectedResponseCode: http.StatusBadRequest,
},
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
@@ -558,6 +681,9 @@ func TestDeviceTokenResponse(t *testing.T) {
}
data.Set("grant_type", grantType)
data.Set("device_code", tc.testDeviceCode)
if tc.testCodeVerifier != "" {
data.Set("code_verifier", tc.testCodeVerifier)
}
req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value")