diff --git a/server/handlers.go b/server/handlers.go index 573542c0..a931d23a 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -2,7 +2,6 @@ package server import ( "context" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -15,6 +14,7 @@ import ( "sync" "time" + oidc "github.com/coreos/go-oidc" "github.com/gorilla/mux" jose "gopkg.in/square/go-jose.v2" @@ -23,10 +23,6 @@ import ( "github.com/dexidp/dex/storage" ) -var ( - errTokenExpired = errors.New("token has expired") -) - // newHealthChecker returns the healthz handler. The handler runs until the // provided context is canceled. func (s *Server) newHealthChecker(ctx context.Context) http.Handler { @@ -1055,84 +1051,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie } func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { - authorization := r.Header.Get("Authorization") - parts := strings.Fields(authorization) + const prefix = "Bearer " - if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") { - msg := "invalid authorization header" - w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg)) - s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest) + auth := r.Header.Get("authorization") + if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) { + w.Header().Set("WWW-Authenticate", "Bearer") + s.tokenErrHelper(w, errAccessDenied, "Invalid bearer token.", http.StatusUnauthorized) + return + } + rawIDToken := auth[len(prefix):] + + verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true}) + idToken, err := verifier.Verify(r.Context(), rawIDToken) + if err != nil { + s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden) return } - token := parts[1] - - verified, err := s.verify(token) - if err != nil { - if err == errTokenExpired { - s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized) - return - } - s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest) + var claims json.RawMessage + if err := idToken.Claims(&claims); err != nil { + s.tokenErrHelper(w, errServerError, err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") - w.Write(verified) -} - -func (s *Server) verify(token string) ([]byte, error) { - keys, err := s.storage.GetKeys() - if err != nil { - return nil, fmt.Errorf("failed to get keys: %v", err) - } - - if keys.SigningKey == nil { - return nil, fmt.Errorf("no private keys found") - } - - object, err := jose.ParseSigned(token) - if err != nil { - return nil, fmt.Errorf("unable to parse signed message") - } - - parts := strings.Split(token, ".") - if len(parts) != 3 { - return nil, fmt.Errorf("compact JWS format must have three parts") - } - - payload, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return nil, err - } - - // TODO: check other claims - var tokenInfo struct { - Expiry int64 `json:"exp"` - } - - if err := json.Unmarshal(payload, &tokenInfo); err != nil { - return nil, err - } - - if tokenInfo.Expiry < s.now().Unix() { - return nil, errTokenExpired - } - - var allKeys []*jose.JSONWebKey - - allKeys = append(allKeys, keys.SigningKeyPub) - for _, key := range keys.VerificationKeys { - allKeys = append(allKeys, key.PublicKey) - } - - for _, pubKey := range allKeys { - verified, err := object.Verify(pubKey) - if err == nil { - return verified, nil - } - } - return nil, errors.New("unable to verify jwt") + w.Write(claims) } func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) { diff --git a/server/oauth2.go b/server/oauth2.go index 26d152f4..9effc174 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -1,6 +1,7 @@ package server import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rsa" @@ -566,3 +567,41 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool { host, _, err := net.SplitHostPort(u.Host) return err == nil && host == "localhost" } + +// storageKeySet implements the oidc.KeySet interface backed by Dex storage +type storageKeySet struct { + storage.Storage +} + +func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { + jws, err := jose.ParseSigned(jwt) + if err != nil { + return nil, err + } + + keyID := "" + for _, sig := range jws.Signatures { + keyID = sig.Header.KeyID + break + } + + skeys, err := s.Storage.GetKeys() + if err != nil { + return nil, err + } + + keys := []*jose.JSONWebKey{skeys.SigningKeyPub} + for _, vk := range skeys.VerificationKeys { + keys = append(keys, vk.PublicKey) + } + + for _, key := range keys { + if keyID == "" || key.KeyID == keyID { + if payload, err := jws.Verify(key); err == nil { + return payload, nil + } + } + } + + return nil, errors.New("failed to verify id token signature") +} diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 8cad77a8..bb8d2723 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -2,6 +2,8 @@ package server import ( "context" + "crypto/rand" + "crypto/rsa" "net/http" "net/http/httptest" "net/url" @@ -11,6 +13,7 @@ import ( jose "gopkg.in/square/go-jose.v2" "github.com/dexidp/dex/storage" + "github.com/dexidp/dex/storage/memory" ) func TestParseAuthorizationRequest(t *testing.T) { @@ -259,3 +262,87 @@ func TestValidRedirectURI(t *testing.T) { } } } + +func TestStorageKeySet(t *testing.T) { + s := memory.New(logger) + if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) { + keys.SigningKey = &jose.JSONWebKey{ + Key: testKey, + KeyID: "testkey", + Algorithm: "RS256", + Use: "sig", + } + keys.SigningKeyPub = &jose.JSONWebKey{ + Key: testKey.Public(), + KeyID: "testkey", + Algorithm: "RS256", + Use: "sig", + } + return keys, nil + }); err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + tokenGenerator func() (jwt string, err error) + wantErr bool + }{ + { + name: "valid token", + tokenGenerator: func() (string, error) { + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: testKey}, nil) + if err != nil { + return "", err + } + + jws, err := signer.Sign([]byte("payload")) + if err != nil { + return "", err + } + + return jws.CompactSerialize() + }, + wantErr: false, + }, + { + name: "token signed by different key", + tokenGenerator: func() (string, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", err + } + + signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, nil) + if err != nil { + return "", err + } + + jws, err := signer.Sign([]byte("payload")) + if err != nil { + return "", err + } + + return jws.CompactSerialize() + }, + wantErr: true, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + jwt, err := tc.tokenGenerator() + if err != nil { + t.Fatal(err) + } + + keySet := &storageKeySet{s} + + _, err = keySet.VerifySignature(context.Background(), jwt) + if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) { + t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err) + } + }) + } +} diff --git a/server/server_test.go b/server/server_test.go index 87c80fbb..12f29340 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -200,6 +200,16 @@ func TestOAuth2CodeFlow(t *testing.T) { return nil }, }, + { + name: "fetch userinfo", + handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { + _, err := p.UserInfo(ctx, config.TokenSource(ctx, token)) + if err != nil { + return fmt.Errorf("failed to fetch userinfo: %v", err) + } + return nil + }, + }, { name: "verify id token and oauth2 token expiry", handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {