server: add tests for refreshing with explicit scopes
This commit is contained in:
		| @@ -538,20 +538,25 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie | |||||||
| 	scopes := refresh.Scopes | 	scopes := refresh.Scopes | ||||||
| 	if scope != "" { | 	if scope != "" { | ||||||
| 		requestedScopes := strings.Split(scope, " ") | 		requestedScopes := strings.Split(scope, " ") | ||||||
| 		contains := func() bool { | 		var unauthorizedScopes []string | ||||||
| 		Loop: |  | ||||||
| 		for _, s := range requestedScopes { | 		for _, s := range requestedScopes { | ||||||
|  | 			contains := func() bool { | ||||||
| 				for _, scope := range refresh.Scopes { | 				for _, scope := range refresh.Scopes { | ||||||
| 					if s == scope { | 					if s == scope { | ||||||
| 						continue Loop | 						return true | ||||||
| 					} | 					} | ||||||
| 				} | 				} | ||||||
| 				return false | 				return false | ||||||
| 			} |  | ||||||
| 			return true |  | ||||||
| 			}() | 			}() | ||||||
| 			if !contains { | 			if !contains { | ||||||
| 			tokenErr(w, errInvalidRequest, "Requested scopes did not contain authorized scopes.", http.StatusBadRequest) | 				unauthorizedScopes = append(unauthorizedScopes, s) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if len(unauthorizedScopes) > 0 { | ||||||
|  | 			msg := fmt.Sprintf("Requested scopes contain unauthorized scope(s): %q.", unauthorizedScopes) | ||||||
|  | 			tokenErr(w, errInvalidRequest, msg, http.StatusBadRequest) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		scopes = requestedScopes | 		scopes = requestedScopes | ||||||
|   | |||||||
| @@ -52,6 +52,7 @@ func tokenErr(w http.ResponseWriter, typ, description string, statusCode int) { | |||||||
| 	} | 	} | ||||||
| 	w.Header().Set("Content-Type", "application/json") | 	w.Header().Set("Content-Type", "application/json") | ||||||
| 	w.Header().Set("Content-Length", strconv.Itoa(len(body))) | 	w.Header().Set("Content-Length", strconv.Itoa(len(body))) | ||||||
|  | 	w.WriteHeader(statusCode) | ||||||
| 	w.Write(body) | 	w.Write(body) | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -131,6 +131,99 @@ func TestDiscovery(t *testing.T) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestOAuth2CodeFlow(t *testing.T) { | func TestOAuth2CodeFlow(t *testing.T) { | ||||||
|  | 	clientID := "testclient" | ||||||
|  | 	clientSecret := "testclientsecret" | ||||||
|  | 	requestedScopes := []string{oidc.ScopeOpenID, "email", "offline_access"} | ||||||
|  |  | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name        string | ||||||
|  | 		handleToken func(context.Context, *oidc.Provider, *oauth2.Config, *oauth2.Token) error | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "verify ID Token", | ||||||
|  | 			handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { | ||||||
|  | 				idToken, ok := token.Extra("id_token").(string) | ||||||
|  | 				if !ok { | ||||||
|  | 					return fmt.Errorf("no id token found") | ||||||
|  | 				} | ||||||
|  | 				if _, err := p.NewVerifier(ctx).Verify(idToken); err != nil { | ||||||
|  | 					return fmt.Errorf("failed to verify id token: %v", err) | ||||||
|  | 				} | ||||||
|  | 				return nil | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "refresh token", | ||||||
|  | 			handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { | ||||||
|  | 				// have to use time.Now because the OAuth2 package uses it. | ||||||
|  | 				token.Expiry = time.Now().Add(time.Second * -10) | ||||||
|  | 				if token.Valid() { | ||||||
|  | 					return errors.New("token shouldn't be valid") | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				newToken, err := config.TokenSource(ctx, token).Token() | ||||||
|  | 				if err != nil { | ||||||
|  | 					return fmt.Errorf("failed to refresh token: %v", err) | ||||||
|  | 				} | ||||||
|  | 				if token.RefreshToken == newToken.RefreshToken { | ||||||
|  | 					return fmt.Errorf("old refresh token was the same as the new token %q", token.RefreshToken) | ||||||
|  | 				} | ||||||
|  | 				return nil | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "refresh with explicit scopes", | ||||||
|  | 			handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { | ||||||
|  | 				v := url.Values{} | ||||||
|  | 				v.Add("client_id", clientID) | ||||||
|  | 				v.Add("client_secret", clientSecret) | ||||||
|  | 				v.Add("grant_type", "refresh_token") | ||||||
|  | 				v.Add("refresh_token", token.RefreshToken) | ||||||
|  | 				v.Add("scope", strings.Join(requestedScopes, " ")) | ||||||
|  | 				resp, err := http.PostForm(p.TokenURL, v) | ||||||
|  | 				if err != nil { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 				defer resp.Body.Close() | ||||||
|  | 				if resp.StatusCode != http.StatusOK { | ||||||
|  | 					dump, err := httputil.DumpResponse(resp, true) | ||||||
|  | 					if err != nil { | ||||||
|  | 						panic(err) | ||||||
|  | 					} | ||||||
|  | 					return fmt.Errorf("unexpected response: %s", dump) | ||||||
|  | 				} | ||||||
|  | 				return nil | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "refresh with unauthorized scopes", | ||||||
|  | 			handleToken: func(ctx context.Context, p *oidc.Provider, config *oauth2.Config, token *oauth2.Token) error { | ||||||
|  | 				v := url.Values{} | ||||||
|  | 				v.Add("client_id", clientID) | ||||||
|  | 				v.Add("client_secret", clientSecret) | ||||||
|  | 				v.Add("grant_type", "refresh_token") | ||||||
|  | 				v.Add("refresh_token", token.RefreshToken) | ||||||
|  | 				// Request a scope that wasn't requestd initially. | ||||||
|  | 				v.Add("scope", strings.Join(append(requestedScopes, "profile"), " ")) | ||||||
|  | 				resp, err := http.PostForm(p.TokenURL, v) | ||||||
|  | 				if err != nil { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 				defer resp.Body.Close() | ||||||
|  | 				if resp.StatusCode == http.StatusOK { | ||||||
|  | 					dump, err := httputil.DumpResponse(resp, true) | ||||||
|  | 					if err != nil { | ||||||
|  | 						panic(err) | ||||||
|  | 					} | ||||||
|  | 					return fmt.Errorf("unexpected response: %s", dump) | ||||||
|  | 				} | ||||||
|  | 				return nil | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		func() { | ||||||
| 			ctx, cancel := context.WithCancel(context.Background()) | 			ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 			defer cancel() | 			defer cancel() | ||||||
|  |  | ||||||
| @@ -176,27 +269,12 @@ func TestOAuth2CodeFlow(t *testing.T) { | |||||||
| 							t.Errorf("failed to exchange code for token: %v", err) | 							t.Errorf("failed to exchange code for token: %v", err) | ||||||
| 							return | 							return | ||||||
| 						} | 						} | ||||||
| 				idToken, ok := token.Extra("id_token").(string) | 						err = tc.handleToken(ctx, p, oauth2Config, token) | ||||||
| 				if !ok { |  | ||||||
| 					t.Errorf("no id token found: %v", err) |  | ||||||
| 					return |  | ||||||
| 				} |  | ||||||
| 				// TODO(ericchiang): validate id token |  | ||||||
| 				_ = idToken |  | ||||||
|  |  | ||||||
| 				token.Expiry = time.Now().Add(time.Second * -10) |  | ||||||
| 				if token.Valid() { |  | ||||||
| 					t.Errorf("token shouldn't be valid") |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				newToken, err := oauth2Config.TokenSource(ctx, token).Token() |  | ||||||
| 						if err != nil { | 						if err != nil { | ||||||
| 					t.Errorf("failed to refresh token: %v", err) | 							t.Errorf("%s: %v", tc.name, err) | ||||||
|  | 						} | ||||||
| 						return | 						return | ||||||
| 				} |  | ||||||
| 				if token.RefreshToken == newToken.RefreshToken { |  | ||||||
| 					t.Errorf("old refresh token was the same as the new token %q", token.RefreshToken) |  | ||||||
| 				} |  | ||||||
| 					} | 					} | ||||||
| 					if gotState := q.Get("state"); gotState != state { | 					if gotState := q.Get("state"); gotState != state { | ||||||
| 						t.Errorf("state did not match, want=%q got=%q", state, gotState) | 						t.Errorf("state did not match, want=%q got=%q", state, gotState) | ||||||
| @@ -211,8 +289,8 @@ func TestOAuth2CodeFlow(t *testing.T) { | |||||||
|  |  | ||||||
| 			redirectURL := oauth2Server.URL + "/callback" | 			redirectURL := oauth2Server.URL + "/callback" | ||||||
| 			client := storage.Client{ | 			client := storage.Client{ | ||||||
| 		ID:           "testclient", | 				ID:           clientID, | ||||||
| 		Secret:       "testclientsecret", | 				Secret:       clientSecret, | ||||||
| 				RedirectURIs: []string{redirectURL}, | 				RedirectURIs: []string{redirectURL}, | ||||||
| 			} | 			} | ||||||
| 			if err := s.storage.CreateClient(client); err != nil { | 			if err := s.storage.CreateClient(client); err != nil { | ||||||
| @@ -223,7 +301,7 @@ func TestOAuth2CodeFlow(t *testing.T) { | |||||||
| 				ClientID:     client.ID, | 				ClientID:     client.ID, | ||||||
| 				ClientSecret: client.Secret, | 				ClientSecret: client.Secret, | ||||||
| 				Endpoint:     p.Endpoint(), | 				Endpoint:     p.Endpoint(), | ||||||
| 		Scopes:       []string{oidc.ScopeOpenID, "profile", "email", "offline_access"}, | 				Scopes:       requestedScopes, | ||||||
| 				RedirectURL:  redirectURL, | 				RedirectURL:  redirectURL, | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| @@ -237,6 +315,8 @@ func TestOAuth2CodeFlow(t *testing.T) { | |||||||
| 			if respDump, err = httputil.DumpResponse(resp, true); err != nil { | 			if respDump, err = httputil.DumpResponse(resp, true); err != nil { | ||||||
| 				t.Fatal(err) | 				t.Fatal(err) | ||||||
| 			} | 			} | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| type nonceSource struct { | type nonceSource struct { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user