PKCE implementation (#1784)

* Basic implementation of PKCE

Signed-off-by: Tadeusz Magura-Witkowski <tadeuszmw@gmail.com>

* @mfmarche on 24 Feb: when code_verifier is set, don't check client_secret

In PKCE flow, no client_secret is used, so the check for a valid client_secret
would always fail.

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* @deric on 16 Jun: return invalid_grant when wrong code_verifier

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Enforce PKCE flow on /token when PKCE flow was started on /auth
Also dissallow PKCE on /token, when PKCE flow was not started on /auth

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* fixed error messages when mixed PKCE/no PKCE flow.

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* server_test.go: Added PKCE error cases on /token endpoint

* Added test for invalid_grant, when wrong code_verifier is sent
* Added test for mixed PKCE / no PKCE auth flows.

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* cleanup: extracted method checkErrorResponse and type TestDefinition

* fixed connector being overwritten

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* /token endpoint: skip client_secret verification only for grand type authorization_code with PKCE extension

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Allow "Authorization" header in CORS handlers

* Adds "Authorization" to the default CORS headers{"Accept", "Accept-Language", "Content-Language", "Origin"}

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Add "code_challenge_methods_supported" to discovery endpoint

discovery endpoint /dex/.well-known/openid-configuration
now has the following entry:

"code_challenge_methods_supported": [
  "S256",
  "plain"
]

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Updated tests (mixed-up comments), added a PKCE test

* @asoorm added test that checks if downgrade to "plain" on /token endpoint

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* remove redefinition of providedCodeVerifier, fixed spelling (#6)

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>
Signed-off-by: Bernd Eckstein <HEllRZA@users.noreply.github.com>

* Rename struct CodeChallenge to PKCE

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* PKCE: Check clientSecret when available

In authorization_code flow with PKCE, allow empty client_secret on /auth and /token endpoints. But check the client_secret when it is given.

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Enable PKCE with public: true

dex configuration public on staticClients now enables the following behavior in PKCE:
- Public: false, PKCE will always check client_secret. This means PKCE in it's natural form is disabled.
- Public: true, PKCE is enabled. It will only check client_secret if the client has sent one. But it allows the code flow if the client didn't sent one.

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Redirect error on unsupported code_challenge_method

- Check for unsupported code_challenge_method after redirect uri is validated, and use newErr() to return the error.
- Add PKCE tests to oauth2_test.go

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Reverted go.mod and go.sum to the state of master

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Don't omit client secret check for PKCE

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Allow public clients (e.g. with PKCE) to have redirect URIs configured

Signed-off-by: Martin Heide <martin.heide@faro.com>

* Remove "Authorization" as Accepted Headers on CORS, small fixes

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Revert "Allow public clients (e.g. with PKCE) to have redirect URIs configured"

This reverts commit b6e297b78537dc44cd3e1374f0b4d34bf89404ac.

Signed-off-by: Martin Heide <martin.heide@faro.com>

* PKCE on client_secret client error message

* When connecting to the token endpoint with PKCE without client_secret, but the client is configured with a client_secret, generate a special error message.

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* Output info message when PKCE without client_secret used on confidential client

* removes the special error message

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

* General missing/invalid client_secret message on token endpoint

Signed-off-by: Bernd Eckstein <Bernd.Eckstein@faro.com>

Co-authored-by: Tadeusz Magura-Witkowski <tadeuszmw@gmail.com>
Co-authored-by: Martin Heide <martin.heide@faro.com>
Co-authored-by: M. Heide <66078329+heidemn-faro@users.noreply.github.com>
This commit is contained in:
Bernd Eckstein 2020-10-26 11:33:40 +01:00 committed by GitHub
parent 2a282860fa
commit b5519695a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 439 additions and 53 deletions

View File

@ -2,6 +2,8 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -23,6 +25,11 @@ import (
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
const (
CodeChallengeMethodPlain = "plain"
CodeChallengeMethodS256 = "S256"
)
// newHealthChecker returns the healthz handler. The handler runs until the // newHealthChecker returns the healthz handler. The handler runs until the
// provided context is canceled. // provided context is canceled.
func (s *Server) newHealthChecker(ctx context.Context) http.Handler { func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
@ -158,6 +165,7 @@ type discovery struct {
ResponseTypes []string `json:"response_types_supported"` ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"` Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
CodeChallengeAlgs []string `json:"code_challenge_methods_supported"`
Scopes []string `json:"scopes_supported"` Scopes []string `json:"scopes_supported"`
AuthMethods []string `json:"token_endpoint_auth_methods_supported"` AuthMethods []string `json:"token_endpoint_auth_methods_supported"`
Claims []string `json:"claims_supported"` Claims []string `json:"claims_supported"`
@ -174,6 +182,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
Subjects: []string{"public"}, Subjects: []string{"public"},
GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode}, GrantTypes: []string{grantTypeAuthorizationCode, grantTypeRefreshToken, grantTypeDeviceCode},
IDTokenAlgs: []string{string(jose.RS256)}, IDTokenAlgs: []string{string(jose.RS256)},
CodeChallengeAlgs: []string{CodeChallengeMethodS256, CodeChallengeMethodPlain},
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
AuthMethods: []string{"client_secret_basic"}, AuthMethods: []string{"client_secret_basic"},
Claims: []string{ Claims: []string{
@ -643,6 +652,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
Expiry: s.now().Add(time.Minute * 30), Expiry: s.now().Add(time.Minute * 30),
RedirectURI: authReq.RedirectURI, RedirectURI: authReq.RedirectURI,
ConnectorData: authReq.ConnectorData, ConnectorData: authReq.ConnectorData,
PKCE: authReq.PKCE,
} }
if err := s.storage.CreateAuthCode(code); err != nil { if err := s.storage.CreateAuthCode(code); err != nil {
s.logger.Errorf("Failed to create auth code: %v", err) s.logger.Errorf("Failed to create auth code: %v", err)
@ -756,6 +766,11 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
return return
} }
if client.Secret != clientSecret { if client.Secret != clientSecret {
if clientSecret == "" {
s.logger.Infof("missing client_secret on token request for client: %s", client.ID)
} else {
s.logger.Infof("invalid client_secret on token request for client: %s", client.ID)
}
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return return
} }
@ -773,6 +788,18 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
} }
} }
func (s *Server) calculateCodeChallenge(codeVerifier, codeChallengeMethod string) (string, error) {
switch codeChallengeMethod {
case CodeChallengeMethodPlain:
return codeVerifier, nil
case CodeChallengeMethodS256:
shaSum := sha256.Sum256([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(shaSum[:]), nil
default:
return "", fmt.Errorf("unknown challenge method (%v)", codeChallengeMethod)
}
}
// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3 // handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) { func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
code := r.PostFormValue("code") code := r.PostFormValue("code")
@ -789,6 +816,31 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return return
} }
// RFC 7636 (PKCE)
codeChallengeFromStorage := authCode.PKCE.CodeChallenge
providedCodeVerifier := r.PostFormValue("code_verifier")
if providedCodeVerifier != "" && codeChallengeFromStorage != "" {
calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, authCode.PKCE.CodeChallengeMethod)
if err != nil {
s.logger.Error(err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
if codeChallengeFromStorage != calculatedCodeChallenge {
s.tokenErrHelper(w, errInvalidGrant, "Invalid code_verifier.", http.StatusBadRequest)
return
}
} else if providedCodeVerifier != "" {
// Received no code_challenge on /auth, but a code_verifier on /token
s.tokenErrHelper(w, errInvalidRequest, "No PKCE flow started. Cannot check code_verifier.", http.StatusBadRequest)
return
} else if codeChallengeFromStorage != "" {
// Received PKCE request on /auth, but no code_verifier on /token
s.tokenErrHelper(w, errInvalidGrant, "Expecting parameter code_verifier in PKCE flow.", http.StatusBadRequest)
return
}
if authCode.RedirectURI != redirectURI { if authCode.RedirectURI != redirectURI {
s.tokenErrHelper(w, errInvalidRequest, "redirect_uri did not match URI from initial request.", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "redirect_uri did not match URI from initial request.", http.StatusBadRequest)
return return

View File

@ -413,6 +413,13 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
scopes := strings.Fields(q.Get("scope")) scopes := strings.Fields(q.Get("scope"))
responseTypes := strings.Fields(q.Get("response_type")) responseTypes := strings.Fields(q.Get("response_type"))
codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")
if codeChallengeMethod == "" {
codeChallengeMethod = CodeChallengeMethodPlain
}
client, err := s.storage.GetClient(clientID) client, err := s.storage.GetClient(clientID)
if err != nil { if err != nil {
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
@ -446,6 +453,11 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)} return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)}
} }
if codeChallengeMethod != CodeChallengeMethodS256 && codeChallengeMethod != CodeChallengeMethodPlain {
description := fmt.Sprintf("Unsupported PKCE challenge method (%q).", codeChallengeMethod)
return nil, newErr(errInvalidRequest, description)
}
var ( var (
unrecognized []string unrecognized []string
invalidScopes []string invalidScopes []string
@ -541,6 +553,10 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
RedirectURI: redirectURI, RedirectURI: redirectURI,
ResponseTypes: responseTypes, ResponseTypes: responseTypes,
ConnectorID: connectorID, ConnectorID: connectorID,
PKCE: storage.PKCE{
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
},
}, nil }, nil
} }

View File

@ -197,6 +197,78 @@ func TestParseAuthorizationRequest(t *testing.T) {
}, },
wantErr: true, wantErr: true,
}, },
{
name: "PKCE code_challenge_method plain",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code"},
queryParams: map[string]string{
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code",
"code_challenge": "123",
"code_challenge_method": "plain",
"scope": "openid email profile",
},
},
{
name: "PKCE code_challenge_method default plain",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code"},
queryParams: map[string]string{
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code",
"code_challenge": "123",
"scope": "openid email profile",
},
},
{
name: "PKCE code_challenge_method S256",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code"},
queryParams: map[string]string{
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code",
"code_challenge": "123",
"code_challenge_method": "S256",
"scope": "openid email profile",
},
},
{
name: "PKCE invalid code_challenge_method",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code"},
queryParams: map[string]string{
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code",
"code_challenge": "123",
"code_challenge_method": "invalid_method",
"scope": "openid email profile",
},
wantErr: true,
},
} }
for _, tc := range tests { for _, tc := range tests {

View File

@ -216,6 +216,29 @@ type test struct {
scopes []string scopes []string
// handleToken provides the OAuth2 token response for the integration test. // handleToken provides the OAuth2 token response for the integration test.
handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token, *mock.Callback) error handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token, *mock.Callback) error
// extra parameters to pass when requesting auth_code
authCodeOptions []oauth2.AuthCodeOption
// extra parameters to pass when retrieving id token
retrieveTokenOptions []oauth2.AuthCodeOption
// define an error response, when the test expects an error on the token endpoint
tokenError ErrorResponse
}
// Defines an expected error by HTTP Status Code and
// the OAuth2 error int the response json
type ErrorResponse struct {
Error string
StatusCode int
}
// https://tools.ietf.org/html/rfc6749#section-5.2
type OAuth2ErrorResponse struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
ErrorURI string `json:"error_uri"`
} }
func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) oauth2Tests { func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time) oauth2Tests {
@ -229,6 +252,17 @@ func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time)
oidcConfig := &oidc.Config{SkipClientIDCheck: true} oidcConfig := &oidc.Config{SkipClientIDCheck: true}
basicIDTokenVerify := func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token, conn *mock.Callback) error {
idToken, ok := token.Extra("id_token").(string)
if !ok {
return fmt.Errorf("no id token found")
}
if _, err := p.Verifier(oidcConfig).Verify(ctx, idToken); err != nil {
return fmt.Errorf("failed to verify id token: %v", err)
}
return nil
}
return oauth2Tests{ return oauth2Tests{
clientID: clientID, clientID: clientID,
tests: []test{ tests: []test{
@ -469,6 +503,110 @@ func makeOAuth2Tests(clientID string, clientSecret string, now func() time.Time)
return nil return nil
}, },
}, },
{
// This test ensures that PKCE work in "plain" mode (no code_challenge_method specified)
name: "PKCE with plain",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "challenge123"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge123"),
},
handleToken: basicIDTokenVerify,
},
{
// This test ensures that PKCE works in "S256" mode
name: "PKCE with S256",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge123"),
},
handleToken: basicIDTokenVerify,
},
{
// This test ensures that PKCE does fail with wrong code_verifier in "plain" mode
name: "PKCE with plain and wrong code_verifier",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "challenge123"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge124"),
},
handleToken: basicIDTokenVerify,
tokenError: ErrorResponse{
Error: errInvalidGrant,
StatusCode: http.StatusBadRequest,
},
},
{
// This test ensures that PKCE fail with wrong code_verifier in "S256" mode
name: "PKCE with S256 and wrong code_verifier",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge124"),
},
handleToken: basicIDTokenVerify,
tokenError: ErrorResponse{
Error: errInvalidGrant,
StatusCode: http.StatusBadRequest,
},
},
{
// Ensure that, when PKCE flow started on /auth
// we stay in PKCE flow on /token
name: "PKCE flow expected on /token",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
// No PKCE call on /token
},
handleToken: basicIDTokenVerify,
tokenError: ErrorResponse{
Error: errInvalidGrant,
StatusCode: http.StatusBadRequest,
},
},
{
// Ensure that when no PKCE flow was started on /auth
// we cannot switch to PKCE on /token
name: "No PKCE flow started on /auth",
authCodeOptions: []oauth2.AuthCodeOption{
// No PKCE call on /auth
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "challenge123"),
},
handleToken: basicIDTokenVerify,
tokenError: ErrorResponse{
Error: errInvalidRequest,
StatusCode: http.StatusBadRequest,
},
},
{
// Make sure that, when we start with "S256" on /auth, we cannot downgrade to "plain" on /token
name: "PKCE with S256 and try to downgrade to plain",
authCodeOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
},
retrieveTokenOptions: []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", "lyyl-X4a69qrqgEfUL8wodWic3Be9ZZ5eovBgIKKi-w"),
oauth2.SetAuthURLParam("code_challenge_method", "plain"),
},
handleToken: basicIDTokenVerify,
tokenError: ErrorResponse{
Error: errInvalidGrant,
StatusCode: http.StatusBadRequest,
},
},
}, },
} }
} }
@ -537,7 +675,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
oauth2Client := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { oauth2Client := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/callback" { if r.URL.Path != "/callback" {
// User is visiting app first time. Redirect to dex. // User is visiting app first time. Redirect to dex.
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther) http.Redirect(w, r, oauth2Config.AuthCodeURL(state, tc.authCodeOptions...), http.StatusSeeOther)
return return
} }
@ -558,7 +696,11 @@ func TestOAuth2CodeFlow(t *testing.T) {
// Grab code, exchange for token. // Grab code, exchange for token.
if code := q.Get("code"); code != "" { if code := q.Get("code"); code != "" {
gotCode = true gotCode = true
token, err := oauth2Config.Exchange(ctx, code) token, err := oauth2Config.Exchange(ctx, code, tc.retrieveTokenOptions...)
if tc.tokenError.StatusCode != 0 {
checkErrorResponse(err, t, tc)
return
}
if err != nil { if err != nil {
t.Errorf("failed to exchange code for token: %v", err) t.Errorf("failed to exchange code for token: %v", err)
return return
@ -1170,6 +1312,30 @@ func TestKeyCacher(t *testing.T) {
} }
} }
func checkErrorResponse(err error, t *testing.T, tc test) {
if err == nil {
t.Errorf("%s: DANGEROUS! got a token when we should not get one!", tc.name)
return
}
if rErr, ok := err.(*oauth2.RetrieveError); ok {
if rErr.Response.StatusCode != tc.tokenError.StatusCode {
t.Errorf("%s: got wrong StatusCode from server %d. expected %d",
tc.name, rErr.Response.StatusCode, tc.tokenError.StatusCode)
}
details := new(OAuth2ErrorResponse)
if err := json.Unmarshal(rErr.Body, details); err != nil {
t.Errorf("%s: could not parse return json: %s", tc.name, err)
return
}
if tc.tokenError.Error != "" && details.Error != tc.tokenError.Error {
t.Errorf("%s: got wrong Error in response: %s (%s). expected %s",
tc.name, details.Error, details.ErrorDescription, tc.tokenError.Error)
}
} else {
t.Errorf("%s: unexpected error type: %s. expected *oauth2.RetrieveError", tc.name, reflect.TypeOf(err))
}
}
type oauth2Client struct { type oauth2Client struct {
config *oauth2.Config config *oauth2.Config
token *oauth2.Token token *oauth2.Token

View File

@ -81,6 +81,11 @@ func mustBeErrAlreadyExists(t *testing.T, kind string, err error) {
} }
func testAuthRequestCRUD(t *testing.T, s storage.Storage) { func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
codeChallenge := storage.PKCE{
CodeChallenge: "code_challenge_test",
CodeChallengeMethod: "plain",
}
a1 := storage.AuthRequest{ a1 := storage.AuthRequest{
ID: storage.NewID(), ID: storage.NewID(),
ClientID: "client1", ClientID: "client1",
@ -101,6 +106,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
EmailVerified: true, EmailVerified: true,
Groups: []string{"a", "b"}, Groups: []string{"a", "b"},
}, },
PKCE: codeChallenge,
} }
identity := storage.Claims{Email: "foobar"} identity := storage.Claims{Email: "foobar"}
@ -155,6 +161,10 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
t.Fatalf("update failed, wanted identity=%#v got %#v", identity, got.Claims) t.Fatalf("update failed, wanted identity=%#v got %#v", identity, got.Claims)
} }
if !reflect.DeepEqual(got.PKCE, codeChallenge) {
t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE)
}
if err := s.DeleteAuthRequest(a1.ID); err != nil { if err := s.DeleteAuthRequest(a1.ID); err != nil {
t.Fatalf("failed to delete auth request: %v", err) t.Fatalf("failed to delete auth request: %v", err)
} }

View File

@ -21,6 +21,9 @@ type AuthCode struct {
Claims Claims `json:"claims,omitempty"` Claims Claims `json:"claims,omitempty"`
Expiry time.Time `json:"expiry"` Expiry time.Time `json:"expiry"`
CodeChallenge string `json:"code_challenge,omitempty"`
CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
} }
func fromStorageAuthCode(a storage.AuthCode) AuthCode { func fromStorageAuthCode(a storage.AuthCode) AuthCode {
@ -34,6 +37,8 @@ func fromStorageAuthCode(a storage.AuthCode) AuthCode {
Scopes: a.Scopes, Scopes: a.Scopes,
Claims: fromStorageClaims(a.Claims), Claims: fromStorageClaims(a.Claims),
Expiry: a.Expiry, Expiry: a.Expiry,
CodeChallenge: a.PKCE.CodeChallenge,
CodeChallengeMethod: a.PKCE.CodeChallengeMethod,
} }
} }
@ -58,6 +63,9 @@ type AuthRequest struct {
ConnectorID string `json:"connector_id"` ConnectorID string `json:"connector_id"`
ConnectorData []byte `json:"connector_data"` ConnectorData []byte `json:"connector_data"`
CodeChallenge string `json:"code_challenge,omitempty"`
CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
} }
func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
@ -75,6 +83,8 @@ func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
Claims: fromStorageClaims(a.Claims), Claims: fromStorageClaims(a.Claims),
ConnectorID: a.ConnectorID, ConnectorID: a.ConnectorID,
ConnectorData: a.ConnectorData, ConnectorData: a.ConnectorData,
CodeChallenge: a.PKCE.CodeChallenge,
CodeChallengeMethod: a.PKCE.CodeChallengeMethod,
} }
} }
@ -93,6 +103,10 @@ func toStorageAuthRequest(a AuthRequest) storage.AuthRequest {
ConnectorData: a.ConnectorData, ConnectorData: a.ConnectorData,
Expiry: a.Expiry, Expiry: a.Expiry,
Claims: toStorageClaims(a.Claims), Claims: toStorageClaims(a.Claims),
PKCE: storage.PKCE{
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
},
} }
} }

View File

@ -299,6 +299,9 @@ type AuthRequest struct {
ConnectorData []byte `json:"connectorData,omitempty"` ConnectorData []byte `json:"connectorData,omitempty"`
Expiry time.Time `json:"expiry"` Expiry time.Time `json:"expiry"`
CodeChallenge string `json:"code_challenge,omitempty"`
CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
} }
// AuthRequestList is a list of AuthRequests. // AuthRequestList is a list of AuthRequests.
@ -323,6 +326,10 @@ func toStorageAuthRequest(req AuthRequest) storage.AuthRequest {
ConnectorData: req.ConnectorData, ConnectorData: req.ConnectorData,
Expiry: req.Expiry, Expiry: req.Expiry,
Claims: toStorageClaims(req.Claims), Claims: toStorageClaims(req.Claims),
PKCE: storage.PKCE{
CodeChallenge: req.CodeChallenge,
CodeChallengeMethod: req.CodeChallengeMethod,
},
} }
return a return a
} }
@ -349,6 +356,8 @@ func (cli *client) fromStorageAuthRequest(a storage.AuthRequest) AuthRequest {
ConnectorData: a.ConnectorData, ConnectorData: a.ConnectorData,
Expiry: a.Expiry, Expiry: a.Expiry,
Claims: fromStorageClaims(a.Claims), Claims: fromStorageClaims(a.Claims),
CodeChallenge: a.PKCE.CodeChallenge,
CodeChallengeMethod: a.PKCE.CodeChallengeMethod,
} }
return req return req
} }
@ -422,6 +431,9 @@ type AuthCode struct {
ConnectorData []byte `json:"connectorData,omitempty"` ConnectorData []byte `json:"connectorData,omitempty"`
Expiry time.Time `json:"expiry"` Expiry time.Time `json:"expiry"`
CodeChallenge string `json:"code_challenge,omitempty"`
CodeChallengeMethod string `json:"code_challenge_method,omitempty"`
} }
// AuthCodeList is a list of AuthCodes. // AuthCodeList is a list of AuthCodes.
@ -449,6 +461,8 @@ func (cli *client) fromStorageAuthCode(a storage.AuthCode) AuthCode {
Scopes: a.Scopes, Scopes: a.Scopes,
Claims: fromStorageClaims(a.Claims), Claims: fromStorageClaims(a.Claims),
Expiry: a.Expiry, Expiry: a.Expiry,
CodeChallenge: a.PKCE.CodeChallenge,
CodeChallengeMethod: a.PKCE.CodeChallengeMethod,
} }
} }
@ -463,6 +477,10 @@ func toStorageAuthCode(a AuthCode) storage.AuthCode {
Scopes: a.Scopes, Scopes: a.Scopes,
Claims: toStorageClaims(a.Claims), Claims: toStorageClaims(a.Claims),
Expiry: a.Expiry, Expiry: a.Expiry,
PKCE: storage.PKCE{
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
},
} }
} }

View File

@ -130,10 +130,11 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, connector_id, connector_data,
expiry expiry,
code_challenge, code_challenge_method
) )
values ( values (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20
); );
`, `,
a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
@ -142,6 +143,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData, a.ConnectorID, a.ConnectorData,
a.Expiry, a.Expiry,
a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) { if c.alreadyExistsCheck(err) {
@ -172,8 +174,9 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
claims_email = $12, claims_email_verified = $13, claims_email = $12, claims_email_verified = $13,
claims_groups = $14, claims_groups = $14,
connector_id = $15, connector_data = $16, connector_id = $15, connector_data = $16,
expiry = $17 expiry = $17,
where id = $18; code_challenge = $18, code_challenge_method = $19
where id = $20;
`, `,
a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
a.ForceApprovalPrompt, a.LoggedIn, a.ForceApprovalPrompt, a.LoggedIn,
@ -181,7 +184,9 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
a.Claims.Email, a.Claims.EmailVerified, a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups), encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData, a.ConnectorID, a.ConnectorData,
a.Expiry, r.ID, a.Expiry,
a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod,
r.ID,
) )
if err != nil { if err != nil {
return fmt.Errorf("update auth request: %v", err) return fmt.Errorf("update auth request: %v", err)
@ -201,7 +206,8 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
force_approval_prompt, logged_in, force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, expiry connector_id, connector_data, expiry,
code_challenge, code_challenge_method
from auth_request where id = $1; from auth_request where id = $1;
`, id).Scan( `, id).Scan(
&a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, &a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State,
@ -210,6 +216,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
&a.Claims.Email, &a.Claims.EmailVerified, &a.Claims.Email, &a.Claims.EmailVerified,
decoder(&a.Claims.Groups), decoder(&a.Claims.Groups),
&a.ConnectorID, &a.ConnectorData, &a.Expiry, &a.ConnectorID, &a.ConnectorData, &a.Expiry,
&a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -227,13 +234,15 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error {
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, connector_id, connector_data,
expiry expiry,
code_challenge, code_challenge_method
) )
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14); values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16);
`, `,
a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID, a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID,
a.Claims.Username, a.Claims.PreferredUsername, a.Claims.Email, a.Claims.EmailVerified, a.Claims.Username, a.Claims.PreferredUsername, a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry, encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry,
a.PKCE.CodeChallenge, a.PKCE.CodeChallengeMethod,
) )
if err != nil { if err != nil {
@ -252,12 +261,14 @@ func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, connector_id, connector_data,
expiry expiry,
code_challenge, code_challenge_method
from auth_code where id = $1; from auth_code where id = $1;
`, id).Scan( `, id).Scan(
&a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID, &a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID,
&a.Claims.Username, &a.Claims.PreferredUsername, &a.Claims.Email, &a.Claims.EmailVerified, &a.Claims.Username, &a.Claims.PreferredUsername, &a.Claims.Email, &a.Claims.EmailVerified,
decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry, decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry,
&a.PKCE.CodeChallenge, &a.PKCE.CodeChallengeMethod,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {

View File

@ -250,4 +250,19 @@ var migrations = []migration{
);`, );`,
}, },
}, },
{
stmts: []string{`
alter table auth_request
add column code_challenge text not null default '';`,
`
alter table auth_request
add column code_challenge_method text not null default '';`,
`
alter table auth_code
add column code_challenge text not null default '';`,
`
alter table auth_code
add column code_challenge_method text not null default '';`,
},
},
} }

View File

@ -169,6 +169,12 @@ type Claims struct {
Groups []string Groups []string
} }
// Data needed for PKCE (RFC 7636)
type PKCE struct {
CodeChallenge string
CodeChallengeMethod string
}
// AuthRequest represents a OAuth2 client authorization request. It holds the state // AuthRequest represents a OAuth2 client authorization request. It holds the state
// of a single auth flow up to the point that the user authorizes the client. // of a single auth flow up to the point that the user authorizes the client.
type AuthRequest struct { type AuthRequest struct {
@ -206,6 +212,9 @@ type AuthRequest struct {
// Set when the user authenticates. // Set when the user authenticates.
ConnectorID string ConnectorID string
ConnectorData []byte ConnectorData []byte
// PKCE CodeChallenge and CodeChallengeMethod
PKCE PKCE
} }
// AuthCode represents a code which can be exchanged for an OAuth2 token response. // AuthCode represents a code which can be exchanged for an OAuth2 token response.
@ -241,6 +250,9 @@ type AuthCode struct {
Claims Claims Claims Claims
Expiry time.Time Expiry time.Time
// PKCE CodeChallenge and CodeChallengeMethod
PKCE PKCE
} }
// RefreshToken is an OAuth2 refresh token which allows a client to request new // RefreshToken is an OAuth2 refresh token which allows a client to request new