add documentation and tests

This commit is contained in:
Ben Navetta
2017-06-21 22:56:02 -07:00
parent 687bc9ca5c
commit cbb007663f
3 changed files with 139 additions and 18 deletions

View File

@@ -33,7 +33,9 @@ type Config struct {
Scopes []string `json:"scopes"` // defaults to "profile" and "email"
HostedDomain string `json:"hostedDomain"`
// Optional list of whitelisted domains when using Google
// If this field is nonempty, only users from a listed domain will be allowed to log in
HostedDomains []string `json:"hostedDomain"`
}
// Domains that don't support basic auth. golang.org/x/oauth2 has an internal
@@ -111,9 +113,9 @@ func (c *Config) Open(logger logrus.FieldLogger) (conn connector.Connector, err
verifier: provider.Verifier(
&oidc.Config{ClientID: clientID},
),
logger: logger,
cancel: cancel,
hostedDomain: c.HostedDomain,
logger: logger,
cancel: cancel,
hostedDomains: c.HostedDomains,
}, nil
}
@@ -123,13 +125,13 @@ var (
)
type oidcConnector struct {
redirectURI string
oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier
ctx context.Context
cancel context.CancelFunc
logger logrus.FieldLogger
hostedDomain string
redirectURI string
oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier
ctx context.Context
cancel context.CancelFunc
logger logrus.FieldLogger
hostedDomains []string
}
func (c *oidcConnector) Close() error {
@@ -142,11 +144,14 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
}
if c.hostedDomain != "" {
return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", c.hostedDomain)), nil
} else {
return c.oauth2Config.AuthCodeURL(state), nil
if len(c.hostedDomains) > 0 {
preferredDomain := c.hostedDomains[0]
if len(c.hostedDomains) > 1 {
preferredDomain = "*"
}
return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil
}
return c.oauth2Config.AuthCodeURL(state), nil
}
type oauth2Error struct {
@@ -190,8 +195,18 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}
if claims.HostedDomain != c.hostedDomain {
return identity, fmt.Errorf("oidc: unexpected hd claim %v", claims.HostedDomain)
if len(c.hostedDomains) > 0 {
found := false
for _, domain := range c.hostedDomains {
if claims.HostedDomain != domain {
found = true
break
}
}
if !found {
return identity, fmt.Errorf("oidc: unexpected hd claim %v", claims.HostedDomain)
}
}
identity = connector.Identity{