Add support for IDPs that do not send ID tokens in the reply when using a refresh grant. Add tests for the aforementioned functionality.
Signed-off-by: Anthony Brandelli <abrandel@cisco.com>
This commit is contained in:
		| @@ -226,6 +226,13 @@ func (e *oauth2Error) Error() string { | |||||||
| 	return e.error + ": " + e.errorDescription | 	return e.error + ": " + e.errorDescription | ||||||
| } | } | ||||||
|  |  | ||||||
|  | type caller uint | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	createCaller caller = iota | ||||||
|  | 	refreshCaller | ||||||
|  | ) | ||||||
|  |  | ||||||
| func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { | func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { | ||||||
| 	q := r.URL.Query() | 	q := r.URL.Query() | ||||||
| 	if errType := q.Get("error"); errType != "" { | 	if errType := q.Get("error"); errType != "" { | ||||||
| @@ -235,8 +242,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return identity, fmt.Errorf("oidc: failed to get token: %v", err) | 		return identity, fmt.Errorf("oidc: failed to get token: %v", err) | ||||||
| 	} | 	} | ||||||
|  | 	return c.createIdentity(r.Context(), identity, token, createCaller) | ||||||
| 	return c.createIdentity(r.Context(), identity, token) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // Refresh is used to refresh a session with the refresh token provided by the IdP | // Refresh is used to refresh a session with the refresh token provided by the IdP | ||||||
| @@ -255,23 +261,25 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err) | 		return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err) | ||||||
| 	} | 	} | ||||||
|  | 	return c.createIdentity(ctx, identity, token, refreshCaller) | ||||||
| 	return c.createIdentity(ctx, identity, token) |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) { | func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, caller caller) (connector.Identity, error) { | ||||||
| 	rawIDToken, ok := token.Extra("id_token").(string) |  | ||||||
| 	if !ok { |  | ||||||
| 		return identity, errors.New("oidc: no id_token in token response") |  | ||||||
| 	} |  | ||||||
| 	idToken, err := c.verifier.Verify(ctx, rawIDToken) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	var claims map[string]interface{} | 	var claims map[string]interface{} | ||||||
| 	if err := idToken.Claims(&claims); err != nil { |  | ||||||
| 		return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) | 	rawIDToken, ok := token.Extra("id_token").(string) | ||||||
|  | 	if ok { | ||||||
|  | 		idToken, err := c.verifier.Verify(ctx, rawIDToken) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if err := idToken.Claims(&claims); err != nil { | ||||||
|  | 			return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) | ||||||
|  | 		} | ||||||
|  | 	} else if caller != refreshCaller { | ||||||
|  | 		// ID tokens aren't mandatory in the reply when using a refresh_token grant | ||||||
|  | 		return identity, errors.New("oidc: no id_token in token response") | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// We immediately want to run getUserInfo if configured before we validate the claims | 	// We immediately want to run getUserInfo if configured before we validate the claims | ||||||
| @@ -285,6 +293,12 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I | |||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	const subjectClaimKey = "sub" | ||||||
|  | 	subject, found := claims[subjectClaimKey].(string) | ||||||
|  | 	if !found { | ||||||
|  | 	    return identity, fmt.Errorf("missing \"%s\" claim", subjectClaimKey) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	userNameKey := "name" | 	userNameKey := "name" | ||||||
| 	if c.userNameKey != "" { | 	if c.userNameKey != "" { | ||||||
| 		userNameKey = c.userNameKey | 		userNameKey = c.userNameKey | ||||||
| @@ -358,7 +372,7 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	identity = connector.Identity{ | 	identity = connector.Identity{ | ||||||
| 		UserID:            idToken.Subject, | 		UserID:            subject, | ||||||
| 		Username:          name, | 		Username:          name, | ||||||
| 		PreferredUsername: preferredUsername, | 		PreferredUsername: preferredUsername, | ||||||
| 		Email:             email, | 		Email:             email, | ||||||
|   | |||||||
| @@ -275,7 +275,8 @@ func TestHandleCallback(t *testing.T) { | |||||||
|  |  | ||||||
| 	for _, tc := range tests { | 	for _, tc := range tests { | ||||||
| 		t.Run(tc.name, func(t *testing.T) { | 		t.Run(tc.name, func(t *testing.T) { | ||||||
| 			testServer, err := setupServer(tc.token) | 			idTokenDesired := true | ||||||
|  | 			testServer, err := setupServer(tc.token, idTokenDesired) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				t.Fatal("failed to setup test server", err) | 				t.Fatal("failed to setup test server", err) | ||||||
| 			} | 			} | ||||||
| @@ -331,7 +332,87 @@ func TestHandleCallback(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func setupServer(tok map[string]interface{}) (*httptest.Server, error) { | func TestRefresh(t *testing.T) { | ||||||
|  | 	t.Helper() | ||||||
|  |  | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name           string | ||||||
|  | 		expectUserID   string | ||||||
|  | 		expectUserName string | ||||||
|  | 		idTokenDesired bool | ||||||
|  | 		token          map[string]interface{} | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:           "IDTokenOnRefresh", | ||||||
|  | 			expectUserID:   "subvalue", | ||||||
|  | 			expectUserName: "namevalue", | ||||||
|  | 			idTokenDesired: true, | ||||||
|  | 			token: map[string]interface{}{ | ||||||
|  | 				"sub":  "subvalue", | ||||||
|  | 				"name": "namevalue", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:           "NoIDTokenOnRefresh", | ||||||
|  | 			expectUserID:   "subvalue", | ||||||
|  | 			expectUserName: "namevalue", | ||||||
|  | 			idTokenDesired: false, | ||||||
|  | 			token: map[string]interface{}{ | ||||||
|  | 				"sub":  "subvalue", | ||||||
|  | 				"name": "namevalue", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		t.Run(tc.name, func(t *testing.T) { | ||||||
|  | 			testServer, err := setupServer(tc.token, tc.idTokenDesired) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal("failed to setup test server", err) | ||||||
|  | 			} | ||||||
|  | 			defer testServer.Close() | ||||||
|  |  | ||||||
|  | 			scopes := []string{"openid", "offline_access"} | ||||||
|  | 			serverURL := testServer.URL | ||||||
|  | 			config := Config{ | ||||||
|  | 				Issuer:       serverURL, | ||||||
|  | 				ClientID:     "clientID", | ||||||
|  | 				ClientSecret: "clientSecret", | ||||||
|  | 				Scopes:       scopes, | ||||||
|  | 				RedirectURI:  fmt.Sprintf("%s/callback", serverURL), | ||||||
|  | 				GetUserInfo:  true, | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			conn, err := newConnector(config) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal("failed to create new connector", err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			req, err := newRequestWithAuthCode(testServer.URL, "someCode") | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal("failed to create request", err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			refreshTokenStr := "{\"RefreshToken\":\"asdf\"}" | ||||||
|  | 			refreshToken := []byte(refreshTokenStr) | ||||||
|  |  | ||||||
|  | 			identity := connector.Identity{ | ||||||
|  | 				UserID:        tc.expectUserID, | ||||||
|  | 				Username:      tc.expectUserName, | ||||||
|  | 				ConnectorData: refreshToken, | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			refreshIdentity, err := conn.Refresh(req.Context(), connector.Scopes{OfflineAccess: true}, identity) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal("Refresh failed", err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			expectEquals(t, refreshIdentity.UserID, tc.expectUserID) | ||||||
|  | 			expectEquals(t, refreshIdentity.Username, tc.expectUserName) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) { | ||||||
| 	key, err := rsa.GenerateKey(rand.Reader, 1024) | 	key, err := rsa.GenerateKey(rand.Reader, 1024) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to generate rsa key: %v", err) | 		return nil, fmt.Errorf("failed to generate rsa key: %v", err) | ||||||
| @@ -368,11 +449,21 @@ func setupServer(tok map[string]interface{}) (*httptest.Server, error) { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		w.Header().Add("Content-Type", "application/json") | 		w.Header().Add("Content-Type", "application/json") | ||||||
| 		json.NewEncoder(w).Encode(&map[string]string{ | 		if idTokenDesired { | ||||||
| 			"access_token": token, | 			json.NewEncoder(w).Encode(&map[string]string{ | ||||||
| 			"id_token":     token, | 				"access_token": token, | ||||||
| 			"token_type":   "Bearer", | 				"id_token":     token, | ||||||
| 		}) | 				"token_type":   "Bearer"}) | ||||||
|  | 		} else { | ||||||
|  | 			json.NewEncoder(w).Encode(&map[string]string{ | ||||||
|  | 				"access_token": token, | ||||||
|  | 				"token_type":   "Bearer"}) | ||||||
|  | 		} | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 		w.Header().Add("Content-Type", "application/json") | ||||||
|  | 	        json.NewEncoder(w).Encode(tok) | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
| 	mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { | 	mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user