Use oidc.Verifier to verify tokens

This commit is contained in:
Andy Lindeman 2019-06-20 13:15:59 -04:00
parent 157c359f3e
commit 46f5726d11
4 changed files with 154 additions and 75 deletions

View File

@ -2,7 +2,6 @@ package server
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@ -15,6 +14,7 @@ import (
"sync"
"time"
oidc "github.com/coreos/go-oidc"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"
@ -23,10 +23,6 @@ import (
"github.com/dexidp/dex/storage"
)
var (
errTokenExpired = errors.New("token has expired")
)
// newHealthChecker returns the healthz handler. The handler runs until the
// provided context is canceled.
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
@ -1055,84 +1051,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
}
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
authorization := r.Header.Get("Authorization")
parts := strings.Fields(authorization)
const prefix = "Bearer "
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
msg := "invalid authorization header"
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg))
s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
auth := r.Header.Get("authorization")
if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) {
w.Header().Set("WWW-Authenticate", "Bearer")
s.tokenErrHelper(w, errAccessDenied, "Invalid bearer token.", http.StatusUnauthorized)
return
}
rawIDToken := auth[len(prefix):]
verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
idToken, err := verifier.Verify(r.Context(), rawIDToken)
if err != nil {
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
return
}
token := parts[1]
verified, err := s.verify(token)
if err != nil {
if err == errTokenExpired {
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized)
return
}
s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest)
var claims json.RawMessage
if err := idToken.Claims(&claims); err != nil {
s.tokenErrHelper(w, errServerError, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(verified)
}
func (s *Server) verify(token string) ([]byte, error) {
keys, err := s.storage.GetKeys()
if err != nil {
return nil, fmt.Errorf("failed to get keys: %v", err)
}
if keys.SigningKey == nil {
return nil, fmt.Errorf("no private keys found")
}
object, err := jose.ParseSigned(token)
if err != nil {
return nil, fmt.Errorf("unable to parse signed message")
}
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("compact JWS format must have three parts")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
// TODO: check other claims
var tokenInfo struct {
Expiry int64 `json:"exp"`
}
if err := json.Unmarshal(payload, &tokenInfo); err != nil {
return nil, err
}
if tokenInfo.Expiry < s.now().Unix() {
return nil, errTokenExpired
}
var allKeys []*jose.JSONWebKey
allKeys = append(allKeys, keys.SigningKeyPub)
for _, key := range keys.VerificationKeys {
allKeys = append(allKeys, key.PublicKey)
}
for _, pubKey := range allKeys {
verified, err := object.Verify(pubKey)
if err == nil {
return verified, nil
}
}
return nil, errors.New("unable to verify jwt")
w.Write(claims)
}
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {

View File

@ -1,6 +1,7 @@
package server
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
@ -566,3 +567,41 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
host, _, err := net.SplitHostPort(u.Host)
return err == nil && host == "localhost"
}
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
type storageKeySet struct {
storage.Storage
}
func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, err
}
keyID := ""
for _, sig := range jws.Signatures {
keyID = sig.Header.KeyID
break
}
skeys, err := s.Storage.GetKeys()
if err != nil {
return nil, err
}
keys := []*jose.JSONWebKey{skeys.SigningKeyPub}
for _, vk := range skeys.VerificationKeys {
keys = append(keys, vk.PublicKey)
}
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, err := jws.Verify(key); err == nil {
return payload, nil
}
}
}
return nil, errors.New("failed to verify id token signature")
}

View File

@ -2,6 +2,8 @@ package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"net/http"
"net/http/httptest"
"net/url"
@ -11,6 +13,7 @@ import (
jose "gopkg.in/square/go-jose.v2"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
)
func TestParseAuthorizationRequest(t *testing.T) {
@ -259,3 +262,87 @@ func TestValidRedirectURI(t *testing.T) {
}
}
}
func TestStorageKeySet(t *testing.T) {
s := memory.New(logger)
if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
keys.SigningKey = &jose.JSONWebKey{
Key: testKey,
KeyID: "testkey",
Algorithm: "RS256",
Use: "sig",
}
keys.SigningKeyPub = &jose.JSONWebKey{
Key: testKey.Public(),
KeyID: "testkey",
Algorithm: "RS256",
Use: "sig",
}
return keys, nil
}); err != nil {
t.Fatal(err)
}
tests := []struct {
name string
tokenGenerator func() (jwt string, err error)
wantErr bool
}{
{
name: "valid token",
tokenGenerator: func() (string, error) {
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: testKey}, nil)
if err != nil {
return "", err
}
jws, err := signer.Sign([]byte("payload"))
if err != nil {
return "", err
}
return jws.CompactSerialize()
},
wantErr: false,
},
{
name: "token signed by different key",
tokenGenerator: func() (string, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", err
}
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, nil)
if err != nil {
return "", err
}
jws, err := signer.Sign([]byte("payload"))
if err != nil {
return "", err
}
return jws.CompactSerialize()
},
wantErr: true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
jwt, err := tc.tokenGenerator()
if err != nil {
t.Fatal(err)
}
keySet := &storageKeySet{s}
_, err = keySet.VerifySignature(context.Background(), jwt)
if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) {
t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err)
}
})
}
}

View File

@ -200,6 +200,16 @@ func TestOAuth2CodeFlow(t *testing.T) {
return nil
},
},
{
name: "fetch userinfo",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {
_, err := p.UserInfo(ctx, config.TokenSource(ctx, token))
if err != nil {
return fmt.Errorf("failed to fetch userinfo: %v", err)
}
return nil
},
},
{
name: "verify id token and oauth2 token expiry",
handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error {