Add userinfo endpoint
Co-authored-by: Yuxing Li <360983+jackielii@users.noreply.github.com> Co-authored-by: Francisco Santiago <1737357+fjbsantiago@users.noreply.github.com>
This commit is contained in:
parent
49e59fb54f
commit
a8d059a237
@ -3,6 +3,7 @@ package server
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -22,6 +23,10 @@ import (
|
|||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errTokenExpired = errors.New("token has expired")
|
||||||
|
)
|
||||||
|
|
||||||
// newHealthChecker returns the healthz handler. The handler runs until the
|
// newHealthChecker returns the healthz handler. The handler runs until the
|
||||||
// provided context is canceled.
|
// provided context is canceled.
|
||||||
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
|
func (s *Server) newHealthChecker(ctx context.Context) http.Handler {
|
||||||
@ -151,6 +156,7 @@ type discovery struct {
|
|||||||
Auth string `json:"authorization_endpoint"`
|
Auth string `json:"authorization_endpoint"`
|
||||||
Token string `json:"token_endpoint"`
|
Token string `json:"token_endpoint"`
|
||||||
Keys string `json:"jwks_uri"`
|
Keys string `json:"jwks_uri"`
|
||||||
|
UserInfo string `json:"userinfo_endpoint"`
|
||||||
ResponseTypes []string `json:"response_types_supported"`
|
ResponseTypes []string `json:"response_types_supported"`
|
||||||
Subjects []string `json:"subject_types_supported"`
|
Subjects []string `json:"subject_types_supported"`
|
||||||
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
||||||
@ -165,6 +171,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
|||||||
Auth: s.absURL("/auth"),
|
Auth: s.absURL("/auth"),
|
||||||
Token: s.absURL("/token"),
|
Token: s.absURL("/token"),
|
||||||
Keys: s.absURL("/keys"),
|
Keys: s.absURL("/keys"),
|
||||||
|
Keys: s.absURL("/userinfo"),
|
||||||
Subjects: []string{"public"},
|
Subjects: []string{"public"},
|
||||||
IDTokenAlgs: []string{string(jose.RS256)},
|
IDTokenAlgs: []string{string(jose.RS256)},
|
||||||
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
|
||||||
@ -559,7 +566,12 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
|
|||||||
idToken string
|
idToken string
|
||||||
idTokenExpiry time.Time
|
idTokenExpiry time.Time
|
||||||
|
|
||||||
accessToken = storage.NewID()
|
i accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to create new access token: %v", err)
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, responseType := range authReq.ResponseTypes {
|
for _, responseType := range authReq.ResponseTypes {
|
||||||
@ -965,7 +977,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||||||
Groups: ident.Groups,
|
Groups: ident.Groups,
|
||||||
}
|
}
|
||||||
|
|
||||||
accessToken := storage.NewID()
|
accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Errorf("failed to create new access token: %v", err)
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID)
|
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Errorf("failed to create ID token: %v", err)
|
s.logger.Errorf("failed to create ID token: %v", err)
|
||||||
@ -1026,6 +1044,88 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||||||
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authorization := r.Header.Get("Authorization")
|
||||||
|
parts := strings.Fields(authorization)
|
||||||
|
|
||||||
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
|
||||||
|
msg := "invalid authorization header"
|
||||||
|
w.Header().Set("WWW-Authenticate", fmt.Sprintf(`Bearer realm="dex", error="%s", error_description="%s"`, errInvalidRequest, msg))
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, msg, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := parts[1]
|
||||||
|
|
||||||
|
verified, err := s.verify(token)
|
||||||
|
if err != nil {
|
||||||
|
if err == errTokenExpired {
|
||||||
|
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.tokenErrHelper(w, errInvalidRequest, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Write(verified)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) verify(token string) ([]byte, error) {
|
||||||
|
keys, err := s.storage.GetKeys()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get keys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if keys.SigningKey == nil {
|
||||||
|
return nil, fmt.Errorf("no private keys found")
|
||||||
|
}
|
||||||
|
|
||||||
|
object, err := jose.ParseSigned(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse signed message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the message to check expiry, as it jose doesn't distinguish expiry error from others
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, fmt.Errorf("square/go-jose: compact JWS format must have three parts")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: check other claims
|
||||||
|
var tokenInfo struct {
|
||||||
|
Expiry int64 `json:"exp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(payload, &tokenInfo); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenInfo.Expiry < s.now().Unix() {
|
||||||
|
return nil, errTokenExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
var allKeys []*jose.JSONWebKey
|
||||||
|
|
||||||
|
allKeys = append(allKeys, keys.SigningKeyPub)
|
||||||
|
for _, key := range keys.VerificationKeys {
|
||||||
|
allKeys = append(allKeys, key.PublicKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pubKey := range allKeys {
|
||||||
|
verified, err := object.Verify(pubKey)
|
||||||
|
if err == nil {
|
||||||
|
return verified, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("unable to verify jwt")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
|
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
|
||||||
// TODO(ericchiang): figure out an access token story and support the user info
|
// TODO(ericchiang): figure out an access token story and support the user info
|
||||||
// endpoint. For now use a random value so no one depends on the access_token
|
// endpoint. For now use a random value so no one depends on the access_token
|
||||||
|
@ -265,6 +265,11 @@ type federatedIDClaims struct {
|
|||||||
UserID string `json:"user_id,omitempty"`
|
UserID string `json:"user_id,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, err error) {
|
||||||
|
idToken, _, err := s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), connID)
|
||||||
|
return idToken, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, connID string) (idToken string, expiry time.Time, err error) {
|
func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, connID string) (idToken string, expiry time.Time, err error) {
|
||||||
keys, err := s.storage.GetKeys()
|
keys, err := s.storage.GetKeys()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -270,6 +270,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||||||
// TODO(ericchiang): rate limit certain paths based on IP.
|
// TODO(ericchiang): rate limit certain paths based on IP.
|
||||||
handleWithCORS("/token", s.handleToken)
|
handleWithCORS("/token", s.handleToken)
|
||||||
handleWithCORS("/keys", s.handlePublicKeys)
|
handleWithCORS("/keys", s.handlePublicKeys)
|
||||||
|
handleWithCORS("/userinfo", s.handleUserInfo)
|
||||||
handleFunc("/auth", s.handleAuthorization)
|
handleFunc("/auth", s.handleAuthorization)
|
||||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||||
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
Reference in New Issue
Block a user