Use oidc.Verifier to verify tokens
This commit is contained in:
		@@ -2,7 +2,6 @@ package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -15,6 +14,7 @@ import (
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	oidc "github.com/coreos/go-oidc"
 | 
			
		||||
	"github.com/gorilla/mux"
 | 
			
		||||
	jose "gopkg.in/square/go-jose.v2"
 | 
			
		||||
 | 
			
		||||
@@ -23,10 +23,6 @@ import (
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	errTokenExpired = errors.New("token has expired")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// newHealthChecker returns the healthz handler. The handler runs until the
 | 
			
		||||
// provided context is canceled.
 | 
			
		||||
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
 | 
			
		||||
@@ -1055,84 +1051,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	authorization := r.Header.Get("Authorization")
 | 
			
		||||
	parts := strings.Fields(authorization)
 | 
			
		||||
	const prefix = "Bearer "
 | 
			
		||||
 | 
			
		||||
	if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
 | 
			
		||||
		msg := "invalid authorization header"
 | 
			
		||||
		w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg))
 | 
			
		||||
		s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
 | 
			
		||||
	auth := r.Header.Get("authorization")
 | 
			
		||||
	if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) {
 | 
			
		||||
		w.Header().Set("WWW-Authenticate", "Bearer")
 | 
			
		||||
		s.tokenErrHelper(w, errAccessDenied, "Invalid bearer token.", http.StatusUnauthorized)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	rawIDToken := auth[len(prefix):]
 | 
			
		||||
 | 
			
		||||
	verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
 | 
			
		||||
	idToken, err := verifier.Verify(r.Context(), rawIDToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token := parts[1]
 | 
			
		||||
 | 
			
		||||
	verified, err := s.verify(token)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == errTokenExpired {
 | 
			
		||||
			s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest)
 | 
			
		||||
	var claims json.RawMessage
 | 
			
		||||
	if err := idToken.Claims(&claims); err != nil {
 | 
			
		||||
		s.tokenErrHelper(w, errServerError, err.Error(), http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	w.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	w.Write(verified)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) verify(token string) ([]byte, error) {
 | 
			
		||||
	keys, err := s.storage.GetKeys()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to get keys: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if keys.SigningKey == nil {
 | 
			
		||||
		return nil, fmt.Errorf("no private keys found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	object, err := jose.ParseSigned(token)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("unable to parse signed message")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	parts := strings.Split(token, ".")
 | 
			
		||||
	if len(parts) != 3 {
 | 
			
		||||
		return nil, fmt.Errorf("compact JWS format must have three parts")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	payload, err := base64.RawURLEncoding.DecodeString(parts[1])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: check other claims
 | 
			
		||||
	var tokenInfo struct {
 | 
			
		||||
		Expiry int64 `json:"exp"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := json.Unmarshal(payload, &tokenInfo); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if tokenInfo.Expiry < s.now().Unix() {
 | 
			
		||||
		return nil, errTokenExpired
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var allKeys []*jose.JSONWebKey
 | 
			
		||||
 | 
			
		||||
	allKeys = append(allKeys, keys.SigningKeyPub)
 | 
			
		||||
	for _, key := range keys.VerificationKeys {
 | 
			
		||||
		allKeys = append(allKeys, key.PublicKey)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, pubKey := range allKeys {
 | 
			
		||||
		verified, err := object.Verify(pubKey)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return verified, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil, errors.New("unable to verify jwt")
 | 
			
		||||
	w.Write(claims)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/ecdsa"
 | 
			
		||||
	"crypto/elliptic"
 | 
			
		||||
	"crypto/rsa"
 | 
			
		||||
@@ -566,3 +567,41 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
 | 
			
		||||
	host, _, err := net.SplitHostPort(u.Host)
 | 
			
		||||
	return err == nil && host == "localhost"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
 | 
			
		||||
type storageKeySet struct {
 | 
			
		||||
	storage.Storage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
 | 
			
		||||
	jws, err := jose.ParseSigned(jwt)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	keyID := ""
 | 
			
		||||
	for _, sig := range jws.Signatures {
 | 
			
		||||
		keyID = sig.Header.KeyID
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	skeys, err := s.Storage.GetKeys()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	keys := []*jose.JSONWebKey{skeys.SigningKeyPub}
 | 
			
		||||
	for _, vk := range skeys.VerificationKeys {
 | 
			
		||||
		keys = append(keys, vk.PublicKey)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, key := range keys {
 | 
			
		||||
		if keyID == "" || key.KeyID == keyID {
 | 
			
		||||
			if payload, err := jws.Verify(key); err == nil {
 | 
			
		||||
				return payload, nil
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, errors.New("failed to verify id token signature")
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,6 +2,8 @@ package server
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/rand"
 | 
			
		||||
	"crypto/rsa"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/http/httptest"
 | 
			
		||||
	"net/url"
 | 
			
		||||
@@ -11,6 +13,7 @@ import (
 | 
			
		||||
	jose "gopkg.in/square/go-jose.v2"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/memory"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestParseAuthorizationRequest(t *testing.T) {
 | 
			
		||||
@@ -259,3 +262,87 @@ func TestValidRedirectURI(t *testing.T) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestStorageKeySet(t *testing.T) {
 | 
			
		||||
	s := memory.New(logger)
 | 
			
		||||
	if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
 | 
			
		||||
		keys.SigningKey = &jose.JSONWebKey{
 | 
			
		||||
			Key:       testKey,
 | 
			
		||||
			KeyID:     "testkey",
 | 
			
		||||
			Algorithm: "RS256",
 | 
			
		||||
			Use:       "sig",
 | 
			
		||||
		}
 | 
			
		||||
		keys.SigningKeyPub = &jose.JSONWebKey{
 | 
			
		||||
			Key:       testKey.Public(),
 | 
			
		||||
			KeyID:     "testkey",
 | 
			
		||||
			Algorithm: "RS256",
 | 
			
		||||
			Use:       "sig",
 | 
			
		||||
		}
 | 
			
		||||
		return keys, nil
 | 
			
		||||
	}); err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name           string
 | 
			
		||||
		tokenGenerator func() (jwt string, err error)
 | 
			
		||||
		wantErr        bool
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid token",
 | 
			
		||||
			tokenGenerator: func() (string, error) {
 | 
			
		||||
				signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: testKey}, nil)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return "", err
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				jws, err := signer.Sign([]byte("payload"))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return "", err
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				return jws.CompactSerialize()
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: false,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "token signed by different key",
 | 
			
		||||
			tokenGenerator: func() (string, error) {
 | 
			
		||||
				key, err := rsa.GenerateKey(rand.Reader, 2048)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return "", err
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, nil)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return "", err
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				jws, err := signer.Sign([]byte("payload"))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return "", err
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				return jws.CompactSerialize()
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tc := range tests {
 | 
			
		||||
		tc := tc
 | 
			
		||||
		t.Run(tc.name, func(t *testing.T) {
 | 
			
		||||
			jwt, err := tc.tokenGenerator()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatal(err)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			keySet := &storageKeySet{s}
 | 
			
		||||
 | 
			
		||||
			_, err = keySet.VerifySignature(context.Background(), jwt)
 | 
			
		||||
			if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) {
 | 
			
		||||
				t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err)
 | 
			
		||||
			}
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -200,6 +200,16 @@ func TestOAuth2CodeFlow(t *testing.T) {
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "fetch userinfo",
 | 
			
		||||
			handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
 | 
			
		||||
				_, err := p.UserInfo(ctx, config.TokenSource(ctx, token))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return fmt.Errorf("failed to fetch userinfo: %v", err)
 | 
			
		||||
				}
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "verify id token and oauth2 token expiry",
 | 
			
		||||
			handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user