Add userinfo endpoint
Co-authored-by: Yuxing Li <360983+jackielii@users.noreply.github.com> Co-authored-by: Francisco Santiago <1737357+fjbsantiago@users.noreply.github.com>
This commit is contained in:
		
				
					committed by
					
						
						mdbraber
					
				
			
			
				
	
			
			
			
						parent
						
							49e59fb54f
						
					
				
				
					commit
					a8d059a237
				
			@@ -3,6 +3,7 @@ package server
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -22,6 +23,10 @@ import (
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	errTokenExpired = errors.New("token has expired")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// newHealthChecker returns the healthz handler. The handler runs until the
 | 
			
		||||
// provided context is canceled.
 | 
			
		||||
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
 | 
			
		||||
@@ -151,6 +156,7 @@ type discovery struct {
 | 
			
		||||
	Auth          string   `json:"authorization_endpoint"`
 | 
			
		||||
	Token         string   `json:"token_endpoint"`
 | 
			
		||||
	Keys          string   `json:"jwks_uri"`
 | 
			
		||||
	UserInfo      string   `json:"userinfo_endpoint"`
 | 
			
		||||
	ResponseTypes []string `json:"response_types_supported"`
 | 
			
		||||
	Subjects      []string `json:"subject_types_supported"`
 | 
			
		||||
	IDTokenAlgs   []string `json:"id_token_signing_alg_values_supported"`
 | 
			
		||||
@@ -165,6 +171,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
 | 
			
		||||
		Auth:        s.absURL("/auth"),
 | 
			
		||||
		Token:       s.absURL("/token"),
 | 
			
		||||
		Keys:        s.absURL("/keys"),
 | 
			
		||||
		Keys:        s.absURL("/userinfo"),
 | 
			
		||||
		Subjects:    []string{"public"},
 | 
			
		||||
		IDTokenAlgs: []string{string(jose.RS256)},
 | 
			
		||||
		Scopes:      []string{"openid", "email", "groups", "profile", "offline_access"},
 | 
			
		||||
@@ -559,7 +566,12 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
 | 
			
		||||
		idToken       string
 | 
			
		||||
		idTokenExpiry time.Time
 | 
			
		||||
 | 
			
		||||
		accessToken = storage.NewID()
 | 
			
		||||
i		accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			s.logger.Errorf("failed to create new access token: %v", err)
 | 
			
		||||
			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	for _, responseType := range authReq.ResponseTypes {
 | 
			
		||||
@@ -965,7 +977,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
 | 
			
		||||
		Groups:        ident.Groups,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	accessToken := storage.NewID()
 | 
			
		||||
	accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to create new access token: %v", err)
 | 
			
		||||
		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("failed to create ID token: %v", err)
 | 
			
		||||
@@ -1026,6 +1044,88 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
 | 
			
		||||
	s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	authorization := r.Header.Get("Authorization")
 | 
			
		||||
	parts := strings.Fields(authorization)
 | 
			
		||||
 | 
			
		||||
	if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
 | 
			
		||||
		msg := "invalid authorization header"
 | 
			
		||||
		w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg))
 | 
			
		||||
		s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token := parts[1]
 | 
			
		||||
 | 
			
		||||
	verified, err := s.verify(token)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == errTokenExpired {
 | 
			
		||||
			s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	w.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	w.Write(verified)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
func (s *Server) verify(token string) ([]byte, error) {
 | 
			
		||||
	keys, err := s.storage.GetKeys()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("failed to get keys: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if keys.SigningKey == nil {
 | 
			
		||||
		return nil, fmt.Errorf("no private keys found")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	object, err := jose.ParseSigned(token)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, fmt.Errorf("unable to parse signed message")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Parse the message to check expiry, as it jose doesn't distinguish expiry error from others
 | 
			
		||||
	parts := strings.Split(token, ".")
 | 
			
		||||
	if len(parts) != 3 {
 | 
			
		||||
		return nil, fmt.Errorf("square/go-jose: compact JWS format must have three parts")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	payload, err := base64.RawURLEncoding.DecodeString(parts[1])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: check other claims
 | 
			
		||||
	var tokenInfo struct {
 | 
			
		||||
		Expiry int64 `json:"exp"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := json.Unmarshal(payload, &tokenInfo); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if tokenInfo.Expiry < s.now().Unix() {
 | 
			
		||||
		return nil, errTokenExpired
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var allKeys []*jose.JSONWebKey
 | 
			
		||||
 | 
			
		||||
	allKeys = append(allKeys, keys.SigningKeyPub)
 | 
			
		||||
	for _, key := range keys.VerificationKeys {
 | 
			
		||||
		allKeys = append(allKeys, key.PublicKey)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, pubKey := range allKeys {
 | 
			
		||||
		verified, err := object.Verify(pubKey)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			return verified, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil, errors.New("unable to verify jwt")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
 | 
			
		||||
	// TODO(ericchiang): figure out an access token story and support the user info
 | 
			
		||||
	// endpoint. For now use a random value so no one depends on the access_token
 | 
			
		||||
 
 | 
			
		||||
@@ -265,6 +265,11 @@ type federatedIDClaims struct {
 | 
			
		||||
	UserID      string `json:"user_id,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, err error) {
 | 
			
		||||
	idToken, _, err := s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), connID)
 | 
			
		||||
	return idToken, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, connID string) (idToken string, expiry time.Time, err error) {
 | 
			
		||||
	keys, err := s.storage.GetKeys()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -270,6 +270,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
 | 
			
		||||
	// TODO(ericchiang): rate limit certain paths based on IP.
 | 
			
		||||
	handleWithCORS("/token", s.handleToken)
 | 
			
		||||
	handleWithCORS("/keys", s.handlePublicKeys)
 | 
			
		||||
	handleWithCORS("/userinfo", s.handleUserInfo)
 | 
			
		||||
	handleFunc("/auth", s.handleAuthorization)
 | 
			
		||||
	handleFunc("/auth/{connector}", s.handleConnectorLogin)
 | 
			
		||||
	r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user