diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 1a9462da..cdb3ff55 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -205,11 +205,29 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide return identity, fmt.Errorf("oidc: failed to get token: %v", err) } + return c.createIdentity(r.Context(), identity, token) +} + +// Refresh is implemented for backwards compatibility, even though it's a no-op. +func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) { + t := &oauth2.Token{ + RefreshToken: string(identity.ConnectorData), + Expiry: time.Now().Add(-time.Hour), + } + token, err := c.oauth2Config.TokenSource(ctx, t).Token() + if err != nil { + return identity, fmt.Errorf("oidc: failed to get token: %v", err) + } + + return c.createIdentity(ctx, identity, token) +} + +func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok { return identity, errors.New("oidc: no id_token in token response") } - idToken, err := c.verifier.Verify(r.Context(), rawIDToken) + idToken, err := c.verifier.Verify(ctx, rawIDToken) if err != nil { return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) } @@ -221,7 +239,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide // We immediately want to run getUserInfo if configured before we validate the claims if c.getUserInfo { - userInfo, err := c.provider.UserInfo(r.Context(), oauth2.StaticTokenSource(token)) + userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token)) if err != nil { return identity, fmt.Errorf("oidc: error loading userinfo: %v", err) } @@ -284,57 +302,3 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide return identity, nil } - -// Refresh is implemented for backwards compatibility, even though it's a no-op. -func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) { - t := &oauth2.Token{ - RefreshToken: string(identity.ConnectorData), - Expiry: time.Now().Add(-time.Hour), - } - token, err := c.oauth2Config.TokenSource(ctx, t).Token() - if err != nil { - return identity, fmt.Errorf("oidc: failed to get token: %v", err) - } - - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - return identity, errors.New("oidc: no id_token in token response") - } - idToken, err := c.verifier.Verify(ctx, rawIDToken) - if err != nil { - return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) - } - - var claims struct { - Username string `json:"name"` - Email string `json:"email"` - EmailVerified bool `json:"email_verified"` - HostedDomain string `json:"hd"` - } - if err := idToken.Claims(&claims); err != nil { - return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) - } - - if len(c.hostedDomains) > 0 { - found := false - for _, domain := range c.hostedDomains { - if claims.HostedDomain == domain { - found = true - break - } - } - - if !found { - return identity, fmt.Errorf("oidc: unexpected hd claim %v", claims.HostedDomain) - } - } - - identity = connector.Identity{ - UserID: idToken.Subject, - Username: claims.Username, - Email: claims.Email, - EmailVerified: claims.EmailVerified, - ConnectorData: []byte(token.RefreshToken), - } - return identity, nil -}