Add support for IDPs that do not send ID tokens in the reply when using a refresh grant. Add tests for the aforementioned functionality.
Signed-off-by: Anthony Brandelli <abrandel@cisco.com>
This commit is contained in:
		| @@ -226,6 +226,13 @@ func (e *oauth2Error) Error() string { | ||||
| 	return e.error + ": " + e.errorDescription | ||||
| } | ||||
|  | ||||
| type caller uint | ||||
|  | ||||
| const ( | ||||
| 	createCaller caller = iota | ||||
| 	refreshCaller | ||||
| ) | ||||
|  | ||||
| func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { | ||||
| 	q := r.URL.Query() | ||||
| 	if errType := q.Get("error"); errType != "" { | ||||
| @@ -235,8 +242,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide | ||||
| 	if err != nil { | ||||
| 		return identity, fmt.Errorf("oidc: failed to get token: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return c.createIdentity(r.Context(), identity, token) | ||||
| 	return c.createIdentity(r.Context(), identity, token, createCaller) | ||||
| } | ||||
|  | ||||
| // Refresh is used to refresh a session with the refresh token provided by the IdP | ||||
| @@ -255,24 +261,26 @@ func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identit | ||||
| 	if err != nil { | ||||
| 		return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return c.createIdentity(ctx, identity, token) | ||||
| 	return c.createIdentity(ctx, identity, token, refreshCaller) | ||||
| } | ||||
|  | ||||
| func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) { | ||||
| func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token, caller caller) (connector.Identity, error) { | ||||
| 	var claims map[string]interface{} | ||||
|  | ||||
| 	rawIDToken, ok := token.Extra("id_token").(string) | ||||
| 	if !ok { | ||||
| 		return identity, errors.New("oidc: no id_token in token response") | ||||
| 	} | ||||
| 	if ok { | ||||
| 		idToken, err := c.verifier.Verify(ctx, rawIDToken) | ||||
| 		if err != nil { | ||||
| 			return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) | ||||
| 		} | ||||
|  | ||||
| 	var claims map[string]interface{} | ||||
| 		if err := idToken.Claims(&claims); err != nil { | ||||
| 			return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) | ||||
| 		} | ||||
| 	} else if caller != refreshCaller { | ||||
| 		// ID tokens aren't mandatory in the reply when using a refresh_token grant | ||||
| 		return identity, errors.New("oidc: no id_token in token response") | ||||
| 	} | ||||
|  | ||||
| 	// We immediately want to run getUserInfo if configured before we validate the claims | ||||
| 	if c.getUserInfo { | ||||
| @@ -285,6 +293,12 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	const subjectClaimKey = "sub" | ||||
| 	subject, found := claims[subjectClaimKey].(string) | ||||
| 	if !found { | ||||
| 	    return identity, fmt.Errorf("missing \"%s\" claim", subjectClaimKey) | ||||
| 	} | ||||
|  | ||||
| 	userNameKey := "name" | ||||
| 	if c.userNameKey != "" { | ||||
| 		userNameKey = c.userNameKey | ||||
| @@ -358,7 +372,7 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I | ||||
| 	} | ||||
|  | ||||
| 	identity = connector.Identity{ | ||||
| 		UserID:            idToken.Subject, | ||||
| 		UserID:            subject, | ||||
| 		Username:          name, | ||||
| 		PreferredUsername: preferredUsername, | ||||
| 		Email:             email, | ||||
|   | ||||
| @@ -275,7 +275,8 @@ func TestHandleCallback(t *testing.T) { | ||||
|  | ||||
| 	for _, tc := range tests { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			testServer, err := setupServer(tc.token) | ||||
| 			idTokenDesired := true | ||||
| 			testServer, err := setupServer(tc.token, idTokenDesired) | ||||
| 			if err != nil { | ||||
| 				t.Fatal("failed to setup test server", err) | ||||
| 			} | ||||
| @@ -331,7 +332,87 @@ func TestHandleCallback(t *testing.T) { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func setupServer(tok map[string]interface{}) (*httptest.Server, error) { | ||||
| func TestRefresh(t *testing.T) { | ||||
| 	t.Helper() | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		name           string | ||||
| 		expectUserID   string | ||||
| 		expectUserName string | ||||
| 		idTokenDesired bool | ||||
| 		token          map[string]interface{} | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:           "IDTokenOnRefresh", | ||||
| 			expectUserID:   "subvalue", | ||||
| 			expectUserName: "namevalue", | ||||
| 			idTokenDesired: true, | ||||
| 			token: map[string]interface{}{ | ||||
| 				"sub":  "subvalue", | ||||
| 				"name": "namevalue", | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:           "NoIDTokenOnRefresh", | ||||
| 			expectUserID:   "subvalue", | ||||
| 			expectUserName: "namevalue", | ||||
| 			idTokenDesired: false, | ||||
| 			token: map[string]interface{}{ | ||||
| 				"sub":  "subvalue", | ||||
| 				"name": "namevalue", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	for _, tc := range tests { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			testServer, err := setupServer(tc.token, tc.idTokenDesired) | ||||
| 			if err != nil { | ||||
| 				t.Fatal("failed to setup test server", err) | ||||
| 			} | ||||
| 			defer testServer.Close() | ||||
|  | ||||
| 			scopes := []string{"openid", "offline_access"} | ||||
| 			serverURL := testServer.URL | ||||
| 			config := Config{ | ||||
| 				Issuer:       serverURL, | ||||
| 				ClientID:     "clientID", | ||||
| 				ClientSecret: "clientSecret", | ||||
| 				Scopes:       scopes, | ||||
| 				RedirectURI:  fmt.Sprintf("%s/callback", serverURL), | ||||
| 				GetUserInfo:  true, | ||||
| 			} | ||||
|  | ||||
| 			conn, err := newConnector(config) | ||||
| 			if err != nil { | ||||
| 				t.Fatal("failed to create new connector", err) | ||||
| 			} | ||||
|  | ||||
| 			req, err := newRequestWithAuthCode(testServer.URL, "someCode") | ||||
| 			if err != nil { | ||||
| 				t.Fatal("failed to create request", err) | ||||
| 			} | ||||
|  | ||||
| 			refreshTokenStr := "{\"RefreshToken\":\"asdf\"}" | ||||
| 			refreshToken := []byte(refreshTokenStr) | ||||
|  | ||||
| 			identity := connector.Identity{ | ||||
| 				UserID:        tc.expectUserID, | ||||
| 				Username:      tc.expectUserName, | ||||
| 				ConnectorData: refreshToken, | ||||
| 			} | ||||
|  | ||||
| 			refreshIdentity, err := conn.Refresh(req.Context(), connector.Scopes{OfflineAccess: true}, identity) | ||||
| 			if err != nil { | ||||
| 				t.Fatal("Refresh failed", err) | ||||
| 			} | ||||
|  | ||||
| 			expectEquals(t, refreshIdentity.UserID, tc.expectUserID) | ||||
| 			expectEquals(t, refreshIdentity.Username, tc.expectUserName) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) { | ||||
| 	key, err := rsa.GenerateKey(rand.Reader, 1024) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to generate rsa key: %v", err) | ||||
| @@ -368,11 +449,21 @@ func setupServer(tok map[string]interface{}) (*httptest.Server, error) { | ||||
| 		} | ||||
|  | ||||
| 		w.Header().Add("Content-Type", "application/json") | ||||
| 		if idTokenDesired { | ||||
| 			json.NewEncoder(w).Encode(&map[string]string{ | ||||
| 				"access_token": token, | ||||
| 				"id_token":     token, | ||||
| 			"token_type":   "Bearer", | ||||
| 				"token_type":   "Bearer"}) | ||||
| 		} else { | ||||
| 			json.NewEncoder(w).Encode(&map[string]string{ | ||||
| 				"access_token": token, | ||||
| 				"token_type":   "Bearer"}) | ||||
| 		} | ||||
| 	}) | ||||
|  | ||||
| 	mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.Header().Add("Content-Type", "application/json") | ||||
| 	        json.NewEncoder(w).Encode(tok) | ||||
| 	}) | ||||
|  | ||||
| 	mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user