diff --git a/Documentation/connectors/oidc.md b/Documentation/connectors/oidc.md index 0f110332..b7ce5666 100644 --- a/Documentation/connectors/oidc.md +++ b/Documentation/connectors/oidc.md @@ -66,6 +66,12 @@ connectors: # all the claims requested. # https://openid.net/specs/openid-connect-core-1_0.html#UserInfo # getUserInfo: true + + # The set claim is used as user id. + # Default: sub + # Claims list at https://openid.net/specs/openid-connect-core-1_0.html#Claims + # + # userIdKey: nickname ``` [oidc-doc]: openid-connect.md diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 24d8103e..628bd3e0 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -44,6 +44,9 @@ type Config struct { // the token. This is especially useful where upstreams return "thin" // id tokens GetUserInfo bool `json:"getUserInfo"` + + // Configurable key which contains the user id claim + UserIDKey string `json:"userIDKey"` } // Domains that don't support basic auth. golang.org/x/oauth2 has an internal @@ -127,6 +130,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e hostedDomains: c.HostedDomains, insecureSkipEmailVerified: c.InsecureSkipEmailVerified, getUserInfo: c.GetUserInfo, + userIDKey: c.UserIDKey, }, nil } @@ -146,6 +150,7 @@ type oidcConnector struct { hostedDomains []string insecureSkipEmailVerified bool getUserInfo bool + userIDKey string } func (c *oidcConnector) Close() error { @@ -199,33 +204,41 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) } - var claims struct { - Username string `json:"name"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - HostedDomain string `json:"hd"` - } + var claims map[string]interface{} if err := idToken.Claims(&claims); err != nil { return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) } + name, found := claims["name"].(string) + if !found { + return identity, errors.New("missing \"name\" claim") + } + email, found := claims["email"].(string) + if !found { + return identity, errors.New("missing \"email\" claim") + } + emailVerified, found := claims["email_verified"].(bool) + if !found { + return identity, errors.New("missing \"email_verified\" claim") + } + hostedDomain, _ := claims["hd"].(string) + if len(c.hostedDomains) > 0 { found := false for _, domain := range c.hostedDomains { - if claims.HostedDomain == domain { + if hostedDomain == domain { found = true break } } if !found { - return identity, fmt.Errorf("oidc: unexpected hd claim %v", claims.HostedDomain) + return identity, fmt.Errorf("oidc: unexpected hd claim %v", hostedDomain) } } if c.insecureSkipEmailVerified { - claims.EmailVerified = true - + emailVerified = true } if c.getUserInfo { @@ -240,10 +253,19 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide identity = connector.Identity{ UserID: idToken.Subject, - Username: claims.Username, - Email: claims.Email, - EmailVerified: claims.EmailVerified, + Username: name, + Email: email, + EmailVerified: emailVerified, } + + if c.userIDKey != "" { + userID, found := claims[c.userIDKey].(string) + if !found { + return identity, fmt.Errorf("oidc: not found %v claim", c.userIDKey) + } + identity.UserID = userID + } + return identity, nil } diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 76a2b5e2..e9d0889a 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -1,7 +1,24 @@ package oidc import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "reflect" + "strings" "testing" + "time" + + "github.com/dexidp/dex/connector" + "github.com/sirupsen/logrus" + "gopkg.in/square/go-jose.v2" ) func TestKnownBrokenAuthHeaderProvider(t *testing.T) { @@ -23,3 +40,192 @@ func TestKnownBrokenAuthHeaderProvider(t *testing.T) { } } } + +func TestHandleCallback(t *testing.T) { + t.Helper() + + tests := []struct { + name string + userIDKey string + expectUserID string + }{ + {"simpleCase", "", "sub"}, + {"withUserIDKey", "name", "name"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + testServer, err := setupServer() + if err != nil { + t.Fatal("failed to setup test server", err) + } + defer testServer.Close() + serverURL := testServer.URL + config := Config{ + Issuer: serverURL, + ClientID: "clientID", + ClientSecret: "clientSecret", + Scopes: []string{"groups"}, + RedirectURI: fmt.Sprintf("%s/callback", serverURL), + UserIDKey: tc.userIDKey, + } + + conn, err := newConnector(config) + if err != nil { + t.Fatal("failed to create new connector", err) + } + + req, err := newRequestWithAuthCode(testServer.URL, "someCode") + if err != nil { + t.Fatal("failed to create request", err) + } + + identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) + if err != nil { + t.Fatal("handle callback failed", err) + } + + expectEquals(t, identity.UserID, tc.expectUserID) + expectEquals(t, identity.Username, "name") + expectEquals(t, identity.Email, "email") + expectEquals(t, identity.EmailVerified, true) + }) + } +} + +func setupServer() (*httptest.Server, error) { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + return nil, fmt.Errorf("failed to generate rsa key: %v", err) + } + + jwk := jose.JSONWebKey{ + Key: key, + KeyID: "keyId", + Algorithm: "RSA", + } + + mux := http.NewServeMux() + + mux.HandleFunc("/keys", func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(&map[string]interface{}{ + "keys": []map[string]interface{}{{ + "alg": jwk.Algorithm, + "kty": jwk.Algorithm, + "kid": jwk.KeyID, + "n": n(&key.PublicKey), + "e": e(&key.PublicKey), + }}, + }) + }) + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + url := fmt.Sprintf("http://%s", r.Host) + + token, err := newToken(&jwk, map[string]interface{}{ + "iss": url, + "aud": "clientID", + "exp": time.Now().Add(time.Hour).Unix(), + "sub": "sub", + "name": "name", + "email": "email", + "email_verified": true, + }) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + } + + w.Header().Add("Content-Type", "application/json") + json.NewEncoder(w).Encode(&map[string]string{ + "access_token": token, + "id_token": token, + "token_type": "Bearer", + }) + }) + + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + url := fmt.Sprintf("http://%s", r.Host) + + json.NewEncoder(w).Encode(&map[string]string{ + "issuer": url, + "token_endpoint": fmt.Sprintf("%s/token", url), + "authorization_endpoint": fmt.Sprintf("%s/authorize", url), + "userinfo_endpoint": fmt.Sprintf("%s/userinfo", url), + "jwks_uri": fmt.Sprintf("%s/keys", url), + }) + }) + + return httptest.NewServer(mux), nil +} + +func newToken(key *jose.JSONWebKey, claims map[string]interface{}) (string, error) { + signingKey := jose.SigningKey{ + Key: key, + Algorithm: jose.RS256, + } + + signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) + if err != nil { + return "", fmt.Errorf("failed to create new signer: %v", err) + } + + payload, err := json.Marshal(claims) + if err != nil { + return "", fmt.Errorf("failed to marshal claims: %v", err) + } + + signature, err := signer.Sign(payload) + if err != nil { + return "", fmt.Errorf("failed to sign: %v", err) + } + return signature.CompactSerialize() +} + +func newConnector(config Config) (*oidcConnector, error) { + logger := logrus.New() + conn, err := config.Open("id", logger) + if err != nil { + return nil, fmt.Errorf("unable to open: %v", err) + } + + oidcConn, ok := conn.(*oidcConnector) + if !ok { + return nil, errors.New("failed to convert to oidcConnector") + } + + return oidcConn, nil +} + +func newRequestWithAuthCode(serverURL string, code string) (*http.Request, error) { + req, err := http.NewRequest("GET", serverURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %v", err) + } + + values := req.URL.Query() + values.Add("code", code) + req.URL.RawQuery = values.Encode() + + return req, nil +} + +func n(pub *rsa.PublicKey) string { + return encode(pub.N.Bytes()) +} + +func e(pub *rsa.PublicKey) string { + data := make([]byte, 8) + binary.BigEndian.PutUint64(data, uint64(pub.E)) + return encode(bytes.TrimLeft(data, "\x00")) +} + +func encode(payload []byte) string { + result := base64.URLEncoding.EncodeToString(payload) + return strings.TrimRight(result, "=") +} + +func expectEquals(t *testing.T, a interface{}, b interface{}) { + if !reflect.DeepEqual(a, b) { + t.Errorf("Expected %+v to equal %+v", a, b) + } +}