Add userinfo endpoint

Co-authored-by: Yuxing Li <360983+jackielii@users.noreply.github.com>
Co-authored-by: Francisco Santiago <1737357+fjbsantiago@users.noreply.github.com>
This commit is contained in:
Maarten den Braber 2019-05-27 09:17:39 +02:00 committed by mdbraber
parent 49e59fb54f
commit a8d059a237
3 changed files with 108 additions and 2 deletions

View File

@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -22,6 +23,10 @@ import (
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
var (
errTokenExpired = errors.New("token has expired")
)
// newHealthChecker returns the healthz handler. The handler runs until the // newHealthChecker returns the healthz handler. The handler runs until the
// provided context is canceled. // provided context is canceled.
func (s *Server) newHealthChecker(ctx context.Context) http.Handler { func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
@ -151,6 +156,7 @@ type discovery struct {
Auth string `json:"authorization_endpoint"` Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"` Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"` Keys string `json:"jwks_uri"`
UserInfo string `json:"userinfo_endpoint"`
ResponseTypes []string `json:"response_types_supported"` ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"` Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"` IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
@ -165,6 +171,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
Auth: s.absURL("/auth"), Auth: s.absURL("/auth"),
Token: s.absURL("/token"), Token: s.absURL("/token"),
Keys: s.absURL("/keys"), Keys: s.absURL("/keys"),
Keys: s.absURL("/userinfo"),
Subjects: []string{"public"}, Subjects: []string{"public"},
IDTokenAlgs: []string{string(jose.RS256)}, IDTokenAlgs: []string{string(jose.RS256)},
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"}, Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
@ -559,7 +566,12 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
idToken string idToken string
idTokenExpiry time.Time idTokenExpiry time.Time
accessToken = storage.NewID() i accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
) )
for _, responseType := range authReq.ResponseTypes { for _, responseType := range authReq.ResponseTypes {
@ -965,7 +977,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Groups: ident.Groups, Groups: ident.Groups,
} }
accessToken := storage.NewID() accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID) idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create ID token: %v", err) s.logger.Errorf("failed to create ID token: %v", err)
@ -1026,6 +1044,88 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry) s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
} }
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
authorization := r.Header.Get("Authorization")
parts := strings.Fields(authorization)
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
msg := "invalid authorization header"
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg))
s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
return
}
token := parts[1]
verified, err := s.verify(token)
if err != nil {
if err == errTokenExpired {
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized)
return
}
s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
w.Write(verified)
}
func (s *Server) verify(token string) ([]byte, error) {
keys, err := s.storage.GetKeys()
if err != nil {
return nil, fmt.Errorf("failed to get keys: %v", err)
}
if keys.SigningKey == nil {
return nil, fmt.Errorf("no private keys found")
}
object, err := jose.ParseSigned(token)
if err != nil {
return nil, fmt.Errorf("unable to parse signed message")
}
// Parse the message to check expiry, as it jose doesn't distinguish expiry error from others
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("square/go-jose: compact JWS format must have three parts")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
// TODO: check other claims
var tokenInfo struct {
Expiry int64 `json:"exp"`
}
if err := json.Unmarshal(payload, &tokenInfo); err != nil {
return nil, err
}
if tokenInfo.Expiry < s.now().Unix() {
return nil, errTokenExpired
}
var allKeys []*jose.JSONWebKey
allKeys = append(allKeys, keys.SigningKeyPub)
for _, key := range keys.VerificationKeys {
allKeys = append(allKeys, key.PublicKey)
}
for _, pubKey := range allKeys {
verified, err := object.Verify(pubKey)
if err == nil {
return verified, nil
}
}
return nil, errors.New("unable to verify jwt")
}
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) { func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
// TODO(ericchiang): figure out an access token story and support the user info // TODO(ericchiang): figure out an access token story and support the user info
// endpoint. For now use a random value so no one depends on the access_token // endpoint. For now use a random value so no one depends on the access_token

View File

@ -265,6 +265,11 @@ type federatedIDClaims struct {
UserID string `json:"user_id,omitempty"` UserID string `json:"user_id,omitempty"`
} }
func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, err error) {
idToken, _, err := s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), connID)
return idToken, err
}
func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, connID string) (idToken string, expiry time.Time, err error) { func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, connID string) (idToken string, expiry time.Time, err error) {
keys, err := s.storage.GetKeys() keys, err := s.storage.GetKeys()
if err != nil { if err != nil {

View File

@ -270,6 +270,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
// TODO(ericchiang): rate limit certain paths based on IP. // TODO(ericchiang): rate limit certain paths based on IP.
handleWithCORS("/token", s.handleToken) handleWithCORS("/token", s.handleToken)
handleWithCORS("/keys", s.handlePublicKeys) handleWithCORS("/keys", s.handlePublicKeys)
handleWithCORS("/userinfo", s.handleUserInfo)
handleFunc("/auth", s.handleAuthorization) handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/auth/{connector}", s.handleConnectorLogin)
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) { r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {