299 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			299 lines
		
	
	
		
			8.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package oauth
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"crypto/tls"
 | |
| 	"crypto/x509"
 | |
| 	"encoding/base64"
 | |
| 	"encoding/json"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"os"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"golang.org/x/oauth2"
 | |
| 
 | |
| 	"github.com/dexidp/dex/connector"
 | |
| 	"github.com/dexidp/dex/pkg/log"
 | |
| )
 | |
| 
 | |
| type oauthConnector struct {
 | |
| 	clientID             string
 | |
| 	clientSecret         string
 | |
| 	redirectURI          string
 | |
| 	tokenURL             string
 | |
| 	authorizationURL     string
 | |
| 	userInfoURL          string
 | |
| 	scopes               []string
 | |
| 	userIDKey            string
 | |
| 	userNameKey          string
 | |
| 	preferredUsernameKey string
 | |
| 	emailKey             string
 | |
| 	emailVerifiedKey     string
 | |
| 	groupsKey            string
 | |
| 	httpClient           *http.Client
 | |
| 	logger               log.Logger
 | |
| }
 | |
| 
 | |
| type connectorData struct {
 | |
| 	AccessToken string
 | |
| }
 | |
| 
 | |
| type Config struct {
 | |
| 	ClientID           string   `json:"clientID"`
 | |
| 	ClientSecret       string   `json:"clientSecret"`
 | |
| 	RedirectURI        string   `json:"redirectURI"`
 | |
| 	TokenURL           string   `json:"tokenURL"`
 | |
| 	AuthorizationURL   string   `json:"authorizationURL"`
 | |
| 	UserInfoURL        string   `json:"userInfoURL"`
 | |
| 	Scopes             []string `json:"scopes"`
 | |
| 	RootCAs            []string `json:"rootCAs"`
 | |
| 	InsecureSkipVerify bool     `json:"insecureSkipVerify"`
 | |
| 	UserIDKey          string   `json:"userIDKey"` // defaults to "id"
 | |
| 	ClaimMapping       struct {
 | |
| 		UserNameKey          string `json:"userNameKey"`          // defaults to "user_name"
 | |
| 		PreferredUsernameKey string `json:"preferredUsernameKey"` // defaults to "preferred_username"
 | |
| 		GroupsKey            string `json:"groupsKey"`            // defaults to "groups"
 | |
| 		EmailKey             string `json:"emailKey"`             // defaults to "email"
 | |
| 		EmailVerifiedKey     string `json:"emailVerifiedKey"`     // defaults to "email_verified"
 | |
| 	} `json:"claimMapping"`
 | |
| }
 | |
| 
 | |
| func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) {
 | |
| 	var err error
 | |
| 
 | |
| 	userIDKey := c.UserIDKey
 | |
| 	if userIDKey == "" {
 | |
| 		userIDKey = "id"
 | |
| 	}
 | |
| 
 | |
| 	userNameKey := c.ClaimMapping.UserNameKey
 | |
| 	if userNameKey == "" {
 | |
| 		userNameKey = "user_name"
 | |
| 	}
 | |
| 
 | |
| 	preferredUsernameKey := c.ClaimMapping.PreferredUsernameKey
 | |
| 	if preferredUsernameKey == "" {
 | |
| 		preferredUsernameKey = "preferred_username"
 | |
| 	}
 | |
| 
 | |
| 	groupsKey := c.ClaimMapping.GroupsKey
 | |
| 	if groupsKey == "" {
 | |
| 		groupsKey = "groups"
 | |
| 	}
 | |
| 
 | |
| 	emailKey := c.ClaimMapping.EmailKey
 | |
| 	if emailKey == "" {
 | |
| 		emailKey = "email"
 | |
| 	}
 | |
| 
 | |
| 	emailVerifiedKey := c.ClaimMapping.EmailVerifiedKey
 | |
| 	if emailVerifiedKey == "" {
 | |
| 		emailVerifiedKey = "email_verified"
 | |
| 	}
 | |
| 
 | |
| 	oauthConn := &oauthConnector{
 | |
| 		clientID:             c.ClientID,
 | |
| 		clientSecret:         c.ClientSecret,
 | |
| 		tokenURL:             c.TokenURL,
 | |
| 		authorizationURL:     c.AuthorizationURL,
 | |
| 		userInfoURL:          c.UserInfoURL,
 | |
| 		scopes:               c.Scopes,
 | |
| 		redirectURI:          c.RedirectURI,
 | |
| 		logger:               logger,
 | |
| 		userIDKey:            userIDKey,
 | |
| 		userNameKey:          userNameKey,
 | |
| 		preferredUsernameKey: preferredUsernameKey,
 | |
| 		groupsKey:            groupsKey,
 | |
| 		emailKey:             emailKey,
 | |
| 		emailVerifiedKey:     emailVerifiedKey,
 | |
| 	}
 | |
| 
 | |
| 	oauthConn.httpClient, err = newHTTPClient(c.RootCAs, c.InsecureSkipVerify)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	return oauthConn, err
 | |
| }
 | |
| 
 | |
| func newHTTPClient(rootCAs []string, insecureSkipVerify bool) (*http.Client, error) {
 | |
| 	pool, err := x509.SystemCertPool()
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	tlsConfig := tls.Config{RootCAs: pool, InsecureSkipVerify: insecureSkipVerify}
 | |
| 	for _, rootCA := range rootCAs {
 | |
| 		rootCABytes, err := os.ReadFile(rootCA)
 | |
| 		if err != nil {
 | |
| 			return nil, fmt.Errorf("failed to read root-ca: %v", err)
 | |
| 		}
 | |
| 		if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) {
 | |
| 			return nil, fmt.Errorf("no certs found in root CA file %q", rootCA)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return &http.Client{
 | |
| 		Transport: &http.Transport{
 | |
| 			TLSClientConfig: &tlsConfig,
 | |
| 			Proxy:           http.ProxyFromEnvironment,
 | |
| 			DialContext: (&net.Dialer{
 | |
| 				Timeout:   30 * time.Second,
 | |
| 				KeepAlive: 30 * time.Second,
 | |
| 				DualStack: true,
 | |
| 			}).DialContext,
 | |
| 			MaxIdleConns:          100,
 | |
| 			IdleConnTimeout:       90 * time.Second,
 | |
| 			TLSHandshakeTimeout:   10 * time.Second,
 | |
| 			ExpectContinueTimeout: 1 * time.Second,
 | |
| 		},
 | |
| 	}, nil
 | |
| }
 | |
| 
 | |
| func (c *oauthConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
 | |
| 	if c.redirectURI != callbackURL {
 | |
| 		return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
 | |
| 	}
 | |
| 
 | |
| 	oauth2Config := &oauth2.Config{
 | |
| 		ClientID:     c.clientID,
 | |
| 		ClientSecret: c.clientSecret,
 | |
| 		Endpoint:     oauth2.Endpoint{TokenURL: c.tokenURL, AuthURL: c.authorizationURL},
 | |
| 		RedirectURL:  c.redirectURI,
 | |
| 		Scopes:       c.scopes,
 | |
| 	}
 | |
| 
 | |
| 	return oauth2Config.AuthCodeURL(state), nil
 | |
| }
 | |
| 
 | |
| func (c *oauthConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
 | |
| 	q := r.URL.Query()
 | |
| 	if errType := q.Get("error"); errType != "" {
 | |
| 		return identity, errors.New(q.Get("error_description"))
 | |
| 	}
 | |
| 
 | |
| 	oauth2Config := &oauth2.Config{
 | |
| 		ClientID:     c.clientID,
 | |
| 		ClientSecret: c.clientSecret,
 | |
| 		Endpoint:     oauth2.Endpoint{TokenURL: c.tokenURL, AuthURL: c.authorizationURL},
 | |
| 		RedirectURL:  c.redirectURI,
 | |
| 		Scopes:       c.scopes,
 | |
| 	}
 | |
| 
 | |
| 	ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient)
 | |
| 
 | |
| 	token, err := oauth2Config.Exchange(ctx, q.Get("code"))
 | |
| 	if err != nil {
 | |
| 		return identity, fmt.Errorf("OAuth connector: failed to get token: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
 | |
| 
 | |
| 	userInfoResp, err := client.Get(c.userInfoURL)
 | |
| 	if err != nil {
 | |
| 		return identity, fmt.Errorf("OAuth Connector: failed to execute request to userinfo: %v", err)
 | |
| 	}
 | |
| 	defer userInfoResp.Body.Close()
 | |
| 
 | |
| 	if userInfoResp.StatusCode != http.StatusOK {
 | |
| 		return identity, fmt.Errorf("OAuth Connector: failed to execute request to userinfo: status %d", userInfoResp.StatusCode)
 | |
| 	}
 | |
| 
 | |
| 	var userInfoResult map[string]interface{}
 | |
| 	err = json.NewDecoder(userInfoResp.Body).Decode(&userInfoResult)
 | |
| 	if err != nil {
 | |
| 		return identity, fmt.Errorf("OAuth Connector: failed to parse userinfo: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	userID, found := userInfoResult[c.userIDKey]
 | |
| 	if !found {
 | |
| 		return identity, fmt.Errorf("OAuth Connector: not found %v claim", c.userIDKey)
 | |
| 	}
 | |
| 
 | |
| 	switch userID.(type) {
 | |
| 	case float64, int64, string:
 | |
| 		identity.UserID = fmt.Sprintf("%v", userID)
 | |
| 	default:
 | |
| 		return identity, fmt.Errorf("OAuth Connector: %v claim should be string or number, got %T", c.userIDKey, userID)
 | |
| 	}
 | |
| 
 | |
| 	identity.Username, _ = userInfoResult[c.userNameKey].(string)
 | |
| 	identity.PreferredUsername, _ = userInfoResult[c.preferredUsernameKey].(string)
 | |
| 	identity.Email, _ = userInfoResult[c.emailKey].(string)
 | |
| 	identity.EmailVerified, _ = userInfoResult[c.emailVerifiedKey].(bool)
 | |
| 
 | |
| 	if s.Groups {
 | |
| 		groups := map[string]struct{}{}
 | |
| 
 | |
| 		c.addGroupsFromMap(groups, userInfoResult)
 | |
| 		c.addGroupsFromToken(groups, token.AccessToken)
 | |
| 
 | |
| 		for groupName := range groups {
 | |
| 			identity.Groups = append(identity.Groups, groupName)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if s.OfflineAccess {
 | |
| 		data := connectorData{AccessToken: token.AccessToken}
 | |
| 		connData, err := json.Marshal(data)
 | |
| 		if err != nil {
 | |
| 			return identity, fmt.Errorf("OAuth Connector: failed to parse connector data for offline access: %v", err)
 | |
| 		}
 | |
| 		identity.ConnectorData = connData
 | |
| 	}
 | |
| 
 | |
| 	return identity, nil
 | |
| }
 | |
| 
 | |
| func (c *oauthConnector) addGroupsFromMap(groups map[string]struct{}, result map[string]interface{}) error {
 | |
| 	groupsClaim, ok := result[c.groupsKey].([]interface{})
 | |
| 	if !ok {
 | |
| 		return errors.New("cannot convert to slice")
 | |
| 	}
 | |
| 
 | |
| 	for _, group := range groupsClaim {
 | |
| 		if groupString, ok := group.(string); ok {
 | |
| 			groups[groupString] = struct{}{}
 | |
| 		}
 | |
| 		if groupMap, ok := group.(map[string]interface{}); ok {
 | |
| 			if groupName, ok := groupMap["name"].(string); ok {
 | |
| 				groups[groupName] = struct{}{}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *oauthConnector) addGroupsFromToken(groups map[string]struct{}, token string) error {
 | |
| 	parts := strings.Split(token, ".")
 | |
| 	if len(parts) < 2 {
 | |
| 		return errors.New("invalid token")
 | |
| 	}
 | |
| 
 | |
| 	decoded, err := decode(parts[1])
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	var claimsMap map[string]interface{}
 | |
| 	err = json.Unmarshal(decoded, &claimsMap)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return c.addGroupsFromMap(groups, claimsMap)
 | |
| }
 | |
| 
 | |
| func decode(seg string) ([]byte, error) {
 | |
| 	if l := len(seg) % 4; l > 0 {
 | |
| 		seg += strings.Repeat("=", 4-l)
 | |
| 	}
 | |
| 
 | |
| 	return base64.URLEncoding.DecodeString(seg)
 | |
| }
 |