Merge pull request #1448 from cappyzawa/user-id-key
oidc: Make userID configurable
This commit is contained in:
		@@ -66,6 +66,12 @@ connectors:
 | 
				
			|||||||
    # all the claims requested.
 | 
					    # all the claims requested.
 | 
				
			||||||
    # https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
 | 
					    # https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
 | 
				
			||||||
    # getUserInfo: true
 | 
					    # getUserInfo: true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # The set claim is used as user id.
 | 
				
			||||||
 | 
					    # Default: sub
 | 
				
			||||||
 | 
					    # Claims list at https://openid.net/specs/openid-connect-core-1_0.html#Claims
 | 
				
			||||||
 | 
					    #
 | 
				
			||||||
 | 
					    # userIdKey: nickname
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[oidc-doc]: openid-connect.md
 | 
					[oidc-doc]: openid-connect.md
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -44,6 +44,9 @@ type Config struct {
 | 
				
			|||||||
	// the token. This is especially useful where upstreams return "thin"
 | 
						// the token. This is especially useful where upstreams return "thin"
 | 
				
			||||||
	// id tokens
 | 
						// id tokens
 | 
				
			||||||
	GetUserInfo bool `json:"getUserInfo"`
 | 
						GetUserInfo bool `json:"getUserInfo"`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Configurable key which contains the user id claim
 | 
				
			||||||
 | 
						UserIDKey string `json:"userIDKey"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Domains that don't support basic auth. golang.org/x/oauth2 has an internal
 | 
					// Domains that don't support basic auth. golang.org/x/oauth2 has an internal
 | 
				
			||||||
@@ -127,6 +130,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
 | 
				
			|||||||
		hostedDomains:             c.HostedDomains,
 | 
							hostedDomains:             c.HostedDomains,
 | 
				
			||||||
		insecureSkipEmailVerified: c.InsecureSkipEmailVerified,
 | 
							insecureSkipEmailVerified: c.InsecureSkipEmailVerified,
 | 
				
			||||||
		getUserInfo:               c.GetUserInfo,
 | 
							getUserInfo:               c.GetUserInfo,
 | 
				
			||||||
 | 
							userIDKey:                 c.UserIDKey,
 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -146,6 +150,7 @@ type oidcConnector struct {
 | 
				
			|||||||
	hostedDomains             []string
 | 
						hostedDomains             []string
 | 
				
			||||||
	insecureSkipEmailVerified bool
 | 
						insecureSkipEmailVerified bool
 | 
				
			||||||
	getUserInfo               bool
 | 
						getUserInfo               bool
 | 
				
			||||||
 | 
						userIDKey                 string
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (c *oidcConnector) Close() error {
 | 
					func (c *oidcConnector) Close() error {
 | 
				
			||||||
@@ -199,33 +204,41 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
 | 
				
			|||||||
		return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
 | 
							return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var claims struct {
 | 
						var claims map[string]interface{}
 | 
				
			||||||
		Username      string `json:"name"`
 | 
					 | 
				
			||||||
		Email         string `json:"email"`
 | 
					 | 
				
			||||||
		EmailVerified bool   `json:"email_verified"`
 | 
					 | 
				
			||||||
		HostedDomain  string `json:"hd"`
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if err := idToken.Claims(&claims); err != nil {
 | 
						if err := idToken.Claims(&claims); err != nil {
 | 
				
			||||||
		return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
 | 
							return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						name, found := claims["name"].(string)
 | 
				
			||||||
 | 
						if !found {
 | 
				
			||||||
 | 
							return identity, errors.New("missing \"name\" claim")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						email, found := claims["email"].(string)
 | 
				
			||||||
 | 
						if !found {
 | 
				
			||||||
 | 
							return identity, errors.New("missing \"email\" claim")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						emailVerified, found := claims["email_verified"].(bool)
 | 
				
			||||||
 | 
						if !found {
 | 
				
			||||||
 | 
							return identity, errors.New("missing \"email_verified\" claim")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						hostedDomain, _ := claims["hd"].(string)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(c.hostedDomains) > 0 {
 | 
						if len(c.hostedDomains) > 0 {
 | 
				
			||||||
		found := false
 | 
							found := false
 | 
				
			||||||
		for _, domain := range c.hostedDomains {
 | 
							for _, domain := range c.hostedDomains {
 | 
				
			||||||
			if claims.HostedDomain == domain {
 | 
								if hostedDomain == domain {
 | 
				
			||||||
				found = true
 | 
									found = true
 | 
				
			||||||
				break
 | 
									break
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if !found {
 | 
							if !found {
 | 
				
			||||||
			return identity, fmt.Errorf("oidc: unexpected hd claim %v", claims.HostedDomain)
 | 
								return identity, fmt.Errorf("oidc: unexpected hd claim %v", hostedDomain)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if c.insecureSkipEmailVerified {
 | 
						if c.insecureSkipEmailVerified {
 | 
				
			||||||
		claims.EmailVerified = true
 | 
							emailVerified = true
 | 
				
			||||||
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if c.getUserInfo {
 | 
						if c.getUserInfo {
 | 
				
			||||||
@@ -240,10 +253,19 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	identity = connector.Identity{
 | 
						identity = connector.Identity{
 | 
				
			||||||
		UserID:        idToken.Subject,
 | 
							UserID:        idToken.Subject,
 | 
				
			||||||
		Username:      claims.Username,
 | 
							Username:      name,
 | 
				
			||||||
		Email:         claims.Email,
 | 
							Email:         email,
 | 
				
			||||||
		EmailVerified: claims.EmailVerified,
 | 
							EmailVerified: emailVerified,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if c.userIDKey != "" {
 | 
				
			||||||
 | 
							userID, found := claims[c.userIDKey].(string)
 | 
				
			||||||
 | 
							if !found {
 | 
				
			||||||
 | 
								return identity, fmt.Errorf("oidc: not found %v claim", c.userIDKey)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							identity.UserID = userID
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return identity, nil
 | 
						return identity, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,24 @@
 | 
				
			|||||||
package oidc
 | 
					package oidc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"crypto/rand"
 | 
				
			||||||
 | 
						"crypto/rsa"
 | 
				
			||||||
 | 
						"encoding/base64"
 | 
				
			||||||
 | 
						"encoding/binary"
 | 
				
			||||||
 | 
						"encoding/json"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/http/httptest"
 | 
				
			||||||
 | 
						"reflect"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/dexidp/dex/connector"
 | 
				
			||||||
 | 
						"github.com/sirupsen/logrus"
 | 
				
			||||||
 | 
						"gopkg.in/square/go-jose.v2"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestKnownBrokenAuthHeaderProvider(t *testing.T) {
 | 
					func TestKnownBrokenAuthHeaderProvider(t *testing.T) {
 | 
				
			||||||
@@ -23,3 +40,192 @@ func TestKnownBrokenAuthHeaderProvider(t *testing.T) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestHandleCallback(t *testing.T) {
 | 
				
			||||||
 | 
						t.Helper()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name         string
 | 
				
			||||||
 | 
							userIDKey    string
 | 
				
			||||||
 | 
							expectUserID string
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{"simpleCase", "", "sub"},
 | 
				
			||||||
 | 
							{"withUserIDKey", "name", "name"},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, tc := range tests {
 | 
				
			||||||
 | 
							t.Run(tc.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								testServer, err := setupServer()
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Fatal("failed to setup test server", err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								defer testServer.Close()
 | 
				
			||||||
 | 
								serverURL := testServer.URL
 | 
				
			||||||
 | 
								config := Config{
 | 
				
			||||||
 | 
									Issuer:       serverURL,
 | 
				
			||||||
 | 
									ClientID:     "clientID",
 | 
				
			||||||
 | 
									ClientSecret: "clientSecret",
 | 
				
			||||||
 | 
									Scopes:       []string{"groups"},
 | 
				
			||||||
 | 
									RedirectURI:  fmt.Sprintf("%s/callback", serverURL),
 | 
				
			||||||
 | 
									UserIDKey:    tc.userIDKey,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								conn, err := newConnector(config)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Fatal("failed to create new connector", err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								req, err := newRequestWithAuthCode(testServer.URL, "someCode")
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Fatal("failed to create request", err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									t.Fatal("handle callback failed", err)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								expectEquals(t, identity.UserID, tc.expectUserID)
 | 
				
			||||||
 | 
								expectEquals(t, identity.Username, "name")
 | 
				
			||||||
 | 
								expectEquals(t, identity.Email, "email")
 | 
				
			||||||
 | 
								expectEquals(t, identity.EmailVerified, true)
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func setupServer() (*httptest.Server, error) {
 | 
				
			||||||
 | 
						key, err := rsa.GenerateKey(rand.Reader, 1024)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("failed to generate rsa key: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						jwk := jose.JSONWebKey{
 | 
				
			||||||
 | 
							Key:       key,
 | 
				
			||||||
 | 
							KeyID:     "keyId",
 | 
				
			||||||
 | 
							Algorithm: "RSA",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mux := http.NewServeMux()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mux.HandleFunc("/keys", func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							json.NewEncoder(w).Encode(&map[string]interface{}{
 | 
				
			||||||
 | 
								"keys": []map[string]interface{}{{
 | 
				
			||||||
 | 
									"alg": jwk.Algorithm,
 | 
				
			||||||
 | 
									"kty": jwk.Algorithm,
 | 
				
			||||||
 | 
									"kid": jwk.KeyID,
 | 
				
			||||||
 | 
									"n":   n(&key.PublicKey),
 | 
				
			||||||
 | 
									"e":   e(&key.PublicKey),
 | 
				
			||||||
 | 
								}},
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							url := fmt.Sprintf("http://%s", r.Host)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							token, err := newToken(&jwk, map[string]interface{}{
 | 
				
			||||||
 | 
								"iss":            url,
 | 
				
			||||||
 | 
								"aud":            "clientID",
 | 
				
			||||||
 | 
								"exp":            time.Now().Add(time.Hour).Unix(),
 | 
				
			||||||
 | 
								"sub":            "sub",
 | 
				
			||||||
 | 
								"name":           "name",
 | 
				
			||||||
 | 
								"email":          "email",
 | 
				
			||||||
 | 
								"email_verified": true,
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								w.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							w.Header().Add("Content-Type", "application/json")
 | 
				
			||||||
 | 
							json.NewEncoder(w).Encode(&map[string]string{
 | 
				
			||||||
 | 
								"access_token": token,
 | 
				
			||||||
 | 
								"id_token":     token,
 | 
				
			||||||
 | 
								"token_type":   "Bearer",
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							url := fmt.Sprintf("http://%s", r.Host)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							json.NewEncoder(w).Encode(&map[string]string{
 | 
				
			||||||
 | 
								"issuer":                 url,
 | 
				
			||||||
 | 
								"token_endpoint":         fmt.Sprintf("%s/token", url),
 | 
				
			||||||
 | 
								"authorization_endpoint": fmt.Sprintf("%s/authorize", url),
 | 
				
			||||||
 | 
								"userinfo_endpoint":      fmt.Sprintf("%s/userinfo", url),
 | 
				
			||||||
 | 
								"jwks_uri":               fmt.Sprintf("%s/keys", url),
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return httptest.NewServer(mux), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newToken(key *jose.JSONWebKey, claims map[string]interface{}) (string, error) {
 | 
				
			||||||
 | 
						signingKey := jose.SigningKey{
 | 
				
			||||||
 | 
							Key:       key,
 | 
				
			||||||
 | 
							Algorithm: jose.RS256,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{})
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("failed to create new signer: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						payload, err := json.Marshal(claims)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("failed to marshal claims: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						signature, err := signer.Sign(payload)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return "", fmt.Errorf("failed to sign: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return signature.CompactSerialize()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newConnector(config Config) (*oidcConnector, error) {
 | 
				
			||||||
 | 
						logger := logrus.New()
 | 
				
			||||||
 | 
						conn, err := config.Open("id", logger)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("unable to open: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						oidcConn, ok := conn.(*oidcConnector)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							return nil, errors.New("failed to convert to oidcConnector")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return oidcConn, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newRequestWithAuthCode(serverURL string, code string) (*http.Request, error) {
 | 
				
			||||||
 | 
						req, err := http.NewRequest("GET", serverURL, nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("failed to create request: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						values := req.URL.Query()
 | 
				
			||||||
 | 
						values.Add("code", code)
 | 
				
			||||||
 | 
						req.URL.RawQuery = values.Encode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return req, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func n(pub *rsa.PublicKey) string {
 | 
				
			||||||
 | 
						return encode(pub.N.Bytes())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func e(pub *rsa.PublicKey) string {
 | 
				
			||||||
 | 
						data := make([]byte, 8)
 | 
				
			||||||
 | 
						binary.BigEndian.PutUint64(data, uint64(pub.E))
 | 
				
			||||||
 | 
						return encode(bytes.TrimLeft(data, "\x00"))
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func encode(payload []byte) string {
 | 
				
			||||||
 | 
						result := base64.URLEncoding.EncodeToString(payload)
 | 
				
			||||||
 | 
						return strings.TrimRight(result, "=")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func expectEquals(t *testing.T, a interface{}, b interface{}) {
 | 
				
			||||||
 | 
						if !reflect.DeepEqual(a, b) {
 | 
				
			||||||
 | 
							t.Errorf("Expected %+v to equal %+v", a, b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user