Use oidc.Verifier to verify tokens
This commit is contained in:
		@@ -2,7 +2,6 @@ package server
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
	"encoding/base64"
 | 
					 | 
				
			||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
@@ -15,6 +14,7 @@ import (
 | 
				
			|||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						oidc "github.com/coreos/go-oidc"
 | 
				
			||||||
	"github.com/gorilla/mux"
 | 
						"github.com/gorilla/mux"
 | 
				
			||||||
	jose "gopkg.in/square/go-jose.v2"
 | 
						jose "gopkg.in/square/go-jose.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -23,10 +23,6 @@ import (
 | 
				
			|||||||
	"github.com/dexidp/dex/storage"
 | 
						"github.com/dexidp/dex/storage"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var (
 | 
					 | 
				
			||||||
	errTokenExpired = errors.New("token has expired")
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// newHealthChecker returns the healthz handler. The handler runs until the
 | 
					// newHealthChecker returns the healthz handler. The handler runs until the
 | 
				
			||||||
// provided context is canceled.
 | 
					// provided context is canceled.
 | 
				
			||||||
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
 | 
					func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
 | 
				
			||||||
@@ -1055,84 +1051,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
 | 
					func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
	authorization := r.Header.Get("Authorization")
 | 
						const prefix = "Bearer "
 | 
				
			||||||
	parts := strings.Fields(authorization)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
 | 
						auth := r.Header.Get("authorization")
 | 
				
			||||||
		msg := "invalid authorization header"
 | 
						if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) {
 | 
				
			||||||
		w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg))
 | 
							w.Header().Set("WWW-Authenticate", "Bearer")
 | 
				
			||||||
		s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
 | 
							s.tokenErrHelper(w, errAccessDenied, "Invalid bearer token.", http.StatusUnauthorized)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						rawIDToken := auth[len(prefix):]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	token := parts[1]
 | 
						verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
 | 
				
			||||||
 | 
						idToken, err := verifier.Verify(r.Context(), rawIDToken)
 | 
				
			||||||
	verified, err := s.verify(token)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		if err == errTokenExpired {
 | 
							s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
 | 
				
			||||||
			s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized)
 | 
					 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
		s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest)
 | 
					
 | 
				
			||||||
 | 
						var claims json.RawMessage
 | 
				
			||||||
 | 
						if err := idToken.Claims(&claims); err != nil {
 | 
				
			||||||
 | 
							s.tokenErrHelper(w, errServerError, err.Error(), http.StatusInternalServerError)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	w.Header().Set("Content-Type", "application/json")
 | 
						w.Header().Set("Content-Type", "application/json")
 | 
				
			||||||
	w.Write(verified)
 | 
						w.Write(claims)
 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (s *Server) verify(token string) ([]byte, error) {
 | 
					 | 
				
			||||||
	keys, err := s.storage.GetKeys()
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, fmt.Errorf("failed to get keys: %v", err)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if keys.SigningKey == nil {
 | 
					 | 
				
			||||||
		return nil, fmt.Errorf("no private keys found")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	object, err := jose.ParseSigned(token)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, fmt.Errorf("unable to parse signed message")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	parts := strings.Split(token, ".")
 | 
					 | 
				
			||||||
	if len(parts) != 3 {
 | 
					 | 
				
			||||||
		return nil, fmt.Errorf("compact JWS format must have three parts")
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	payload, err := base64.RawURLEncoding.DecodeString(parts[1])
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// TODO: check other claims
 | 
					 | 
				
			||||||
	var tokenInfo struct {
 | 
					 | 
				
			||||||
		Expiry int64 `json:"exp"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if err := json.Unmarshal(payload, &tokenInfo); err != nil {
 | 
					 | 
				
			||||||
		return nil, err
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if tokenInfo.Expiry < s.now().Unix() {
 | 
					 | 
				
			||||||
		return nil, errTokenExpired
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	var allKeys []*jose.JSONWebKey
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	allKeys = append(allKeys, keys.SigningKeyPub)
 | 
					 | 
				
			||||||
	for _, key := range keys.VerificationKeys {
 | 
					 | 
				
			||||||
		allKeys = append(allKeys, key.PublicKey)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	for _, pubKey := range allKeys {
 | 
					 | 
				
			||||||
		verified, err := object.Verify(pubKey)
 | 
					 | 
				
			||||||
		if err == nil {
 | 
					 | 
				
			||||||
			return verified, nil
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return nil, errors.New("unable to verify jwt")
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
 | 
					func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,7 @@
 | 
				
			|||||||
package server
 | 
					package server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"crypto/ecdsa"
 | 
						"crypto/ecdsa"
 | 
				
			||||||
	"crypto/elliptic"
 | 
						"crypto/elliptic"
 | 
				
			||||||
	"crypto/rsa"
 | 
						"crypto/rsa"
 | 
				
			||||||
@@ -566,3 +567,41 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
 | 
				
			|||||||
	host, _, err := net.SplitHostPort(u.Host)
 | 
						host, _, err := net.SplitHostPort(u.Host)
 | 
				
			||||||
	return err == nil && host == "localhost"
 | 
						return err == nil && host == "localhost"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// storageKeySet implements the oidc.KeySet interface backed by Dex storage
 | 
				
			||||||
 | 
					type storageKeySet struct {
 | 
				
			||||||
 | 
						storage.Storage
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
 | 
				
			||||||
 | 
						jws, err := jose.ParseSigned(jwt)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						keyID := ""
 | 
				
			||||||
 | 
						for _, sig := range jws.Signatures {
 | 
				
			||||||
 | 
							keyID = sig.Header.KeyID
 | 
				
			||||||
 | 
							break
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						skeys, err := s.Storage.GetKeys()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						keys := []*jose.JSONWebKey{skeys.SigningKeyPub}
 | 
				
			||||||
 | 
						for _, vk := range skeys.VerificationKeys {
 | 
				
			||||||
 | 
							keys = append(keys, vk.PublicKey)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, key := range keys {
 | 
				
			||||||
 | 
							if keyID == "" || key.KeyID == keyID {
 | 
				
			||||||
 | 
								if payload, err := jws.Verify(key); err == nil {
 | 
				
			||||||
 | 
									return payload, nil
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil, errors.New("failed to verify id token signature")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,6 +2,8 @@ package server
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"context"
 | 
						"context"
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
 | 
						"crypto/rsa"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httptest"
 | 
						"net/http/httptest"
 | 
				
			||||||
	"net/url"
 | 
						"net/url"
 | 
				
			||||||
@@ -11,6 +13,7 @@ import (
 | 
				
			|||||||
	jose "gopkg.in/square/go-jose.v2"
 | 
						jose "gopkg.in/square/go-jose.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/dexidp/dex/storage"
 | 
						"github.com/dexidp/dex/storage"
 | 
				
			||||||
 | 
						"github.com/dexidp/dex/storage/memory"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestParseAuthorizationRequest(t *testing.T) {
 | 
					func TestParseAuthorizationRequest(t *testing.T) {
 | 
				
			||||||
@@ -259,3 +262,87 @@ func TestValidRedirectURI(t *testing.T) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestStorageKeySet(t *testing.T) {
 | 
				
			||||||
 | 
						s := memory.New(logger)
 | 
				
			||||||
 | 
						if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
 | 
				
			||||||
 | 
							keys.SigningKey = &jose.JSONWebKey{
 | 
				
			||||||
 | 
								Key:       testKey,
 | 
				
			||||||
 | 
								KeyID:     "testkey",
 | 
				
			||||||
 | 
								Algorithm: "RS256",
 | 
				
			||||||
 | 
								Use:       "sig",
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							keys.SigningKeyPub = &jose.JSONWebKey{
 | 
				
			||||||
 | 
								Key:       testKey.Public(),
 | 
				
			||||||
 | 
								KeyID:     "testkey",
 | 
				
			||||||
 | 
								Algorithm: "RS256",
 | 
				
			||||||
 | 
								Use:       "sig",
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return keys, nil
 | 
				
			||||||
 | 
						}); err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name           string
 | 
				
			||||||
 | 
							tokenGenerator func() (jwt string, err error)
 | 
				
			||||||
 | 
							wantErr        bool
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "valid token",
 | 
				
			||||||
 | 
								tokenGenerator: func() (string, error) {
 | 
				
			||||||
 | 
									signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: testKey}, nil)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return "", err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									jws, err := signer.Sign([]byte("payload"))
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return "", err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									return jws.CompactSerialize()
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								wantErr: false,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "token signed by different key",
 | 
				
			||||||
 | 
								tokenGenerator: func() (string, error) {
 | 
				
			||||||
 | 
									key, err := rsa.GenerateKey(rand.Reader, 2048)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return "", err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, nil)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return "", err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									jws, err := signer.Sign([]byte("payload"))
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return "", err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									return jws.CompactSerialize()
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								wantErr: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, tc := range tests {
 | 
				
			||||||
 | 
							tc := tc
 | 
				
			||||||
 | 
							t.Run(tc.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								jwt, err := tc.tokenGenerator()
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Fatal(err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								keySet := &storageKeySet{s}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								_, err = keySet.VerifySignature(context.Background(), jwt)
 | 
				
			||||||
 | 
								if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) {
 | 
				
			||||||
 | 
									t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -200,6 +200,16 @@ func TestOAuth2CodeFlow(t *testing.T) {
 | 
				
			|||||||
				return nil
 | 
									return nil
 | 
				
			||||||
			},
 | 
								},
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "fetch userinfo",
 | 
				
			||||||
 | 
								handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
 | 
				
			||||||
 | 
									_, err := p.UserInfo(ctx, config.TokenSource(ctx, token))
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return fmt.Errorf("failed to fetch userinfo: %v", err)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									return nil
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name: "verify id token and oauth2 token expiry",
 | 
								name: "verify id token and oauth2 token expiry",
 | 
				
			||||||
			handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
 | 
								handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user