connector/oidc: expose oauth2.RegisterBrokenAuthHeaderProvider
This commit is contained in:
		| @@ -6,6 +6,9 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | 	"strings" | ||||||
|  | 	"sync" | ||||||
|  |  | ||||||
| 	"github.com/Sirupsen/logrus" | 	"github.com/Sirupsen/logrus" | ||||||
| 	"github.com/coreos/go-oidc" | 	"github.com/coreos/go-oidc" | ||||||
| @@ -21,7 +24,50 @@ type Config struct { | |||||||
| 	ClientSecret string `json:"clientSecret"` | 	ClientSecret string `json:"clientSecret"` | ||||||
| 	RedirectURI  string `json:"redirectURI"` | 	RedirectURI  string `json:"redirectURI"` | ||||||
|  |  | ||||||
|  | 	// Causes client_secret to be passed as POST parameters instead of basic | ||||||
|  | 	// auth. This is specifically "NOT RECOMMENDED" by the OAuth2 RFC, but some | ||||||
|  | 	// providers require it. | ||||||
|  | 	// | ||||||
|  | 	// https://tools.ietf.org/html/rfc6749#section-2.3.1 | ||||||
|  | 	BasicAuthUnsupported *bool `json:"basicAuthUnsupported"` | ||||||
|  |  | ||||||
| 	Scopes []string `json:"scopes"` // defaults to "profile" and "email" | 	Scopes []string `json:"scopes"` // defaults to "profile" and "email" | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Domains that don't support basic auth. golang.org/x/oauth2 has an internal | ||||||
|  | // list, but it only matches specific URLs, not top level domains. | ||||||
|  | var brokenAuthHeaderDomains = []string{ | ||||||
|  | 	// See: https://github.com/coreos/dex/issues/859 | ||||||
|  | 	"okta.com", | ||||||
|  | 	"oktapreview.com", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Detect auth header provider issues for known providers. This lets users | ||||||
|  | // avoid having to explicitly set "basicAuthUnsupported" in their config. | ||||||
|  | // | ||||||
|  | // Setting the config field always overrides values returned by this function. | ||||||
|  | func knownBrokenAuthHeaderProvider(issuerURL string) bool { | ||||||
|  | 	if u, err := url.Parse(issuerURL); err == nil { | ||||||
|  | 		for _, host := range brokenAuthHeaderDomains { | ||||||
|  | 			if u.Host == host || strings.HasSuffix(u.Host, "."+host) { | ||||||
|  | 				return true | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // golang.org/x/oauth2 doesn't do internal locking. Need to do it in this | ||||||
|  | // package ourselves and hope that other packages aren't calling it at the | ||||||
|  | // same time. | ||||||
|  | var registerMu = new(sync.Mutex) | ||||||
|  |  | ||||||
|  | func registerBrokenAuthHeaderProvider(url string) { | ||||||
|  | 	registerMu.Lock() | ||||||
|  | 	defer registerMu.Unlock() | ||||||
|  |  | ||||||
|  | 	oauth2.RegisterBrokenAuthHeaderProvider(url) | ||||||
| } | } | ||||||
|  |  | ||||||
| // Open returns a connector which can be used to login users through an upstream | // Open returns a connector which can be used to login users through an upstream | ||||||
| @@ -35,6 +81,15 @@ func (c *Config) Open(logger logrus.FieldLogger) (conn connector.Connector, err | |||||||
| 		return nil, fmt.Errorf("failed to get provider: %v", err) | 		return nil, fmt.Errorf("failed to get provider: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if c.BasicAuthUnsupported != nil { | ||||||
|  | 		// Setting "basicAuthUnsupported" always overrides our detection. | ||||||
|  | 		if *c.BasicAuthUnsupported { | ||||||
|  | 			registerBrokenAuthHeaderProvider(provider.Endpoint().TokenURL) | ||||||
|  | 		} | ||||||
|  | 	} else if knownBrokenAuthHeaderProvider(c.Issuer) { | ||||||
|  | 		registerBrokenAuthHeaderProvider(provider.Endpoint().TokenURL) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	scopes := []string{oidc.ScopeOpenID} | 	scopes := []string{oidc.ScopeOpenID} | ||||||
| 	if len(c.Scopes) > 0 { | 	if len(c.Scopes) > 0 { | ||||||
| 		scopes = append(scopes, c.Scopes...) | 		scopes = append(scopes, c.Scopes...) | ||||||
|   | |||||||
							
								
								
									
										23
									
								
								connector/oidc/oidc_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								connector/oidc/oidc_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | |||||||
|  | package oidc | ||||||
|  |  | ||||||
|  | import "testing" | ||||||
|  |  | ||||||
|  | func TestKnownBrokenAuthHeaderProvider(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		issuerURL string | ||||||
|  | 		expect    bool | ||||||
|  | 	}{ | ||||||
|  | 		{"https://dev.oktapreview.com", true}, | ||||||
|  | 		{"https://dev.okta.com", true}, | ||||||
|  | 		{"https://okta.com", true}, | ||||||
|  | 		{"https://dev.oktaaccounts.com", false}, | ||||||
|  | 		{"https://accounts.google.com", false}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		got := knownBrokenAuthHeaderProvider(tc.issuerURL) | ||||||
|  | 		if got != tc.expect { | ||||||
|  | 			t.Errorf("knownBrokenAuthHeaderProvider(%q), want=%t, got=%t", tc.issuerURL, tc.expect, got) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user