cmd/example-app: fix custom CA behavior
This commit is contained in:
		| @@ -37,8 +37,7 @@ type app struct { | |||||||
| 	// or does it use "access_type=offline" (e.g. Google)? | 	// or does it use "access_type=offline" (e.g. Google)? | ||||||
| 	offlineAsScope bool | 	offlineAsScope bool | ||||||
|  |  | ||||||
| 	ctx    context.Context | 	client *http.Client | ||||||
| 	cancel context.CancelFunc |  | ||||||
| } | } | ||||||
|  |  | ||||||
| // return an HTTP client which trusts the provided root CAs. | // return an HTTP client which trusts the provided root CAs. | ||||||
| @@ -118,31 +117,31 @@ func cmd() *cobra.Command { | |||||||
| 				return fmt.Errorf("parse listen address: %v", err) | 				return fmt.Errorf("parse listen address: %v", err) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			a.ctx, a.cancel = context.WithCancel(context.Background()) |  | ||||||
|  |  | ||||||
| 			if rootCAs != "" { | 			if rootCAs != "" { | ||||||
| 				client, err := httpClientForRootCAs(rootCAs) | 				client, err := httpClientForRootCAs(rootCAs) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					return err | 					return err | ||||||
| 				} | 				} | ||||||
|  | 				a.client = client | ||||||
| 				// This sets the OAuth2 client and oidc client. |  | ||||||
| 				a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, client) |  | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			if debug { | 			if debug { | ||||||
| 				client, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client) | 				if a.client == nil { | ||||||
| 				if ok { | 					a.client = &http.Client{ | ||||||
| 					client.Transport = debugTransport{client.Transport} |  | ||||||
| 				} else { |  | ||||||
| 					a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, &http.Client{ |  | ||||||
| 						Transport: debugTransport{http.DefaultTransport}, | 						Transport: debugTransport{http.DefaultTransport}, | ||||||
| 					}) | 					} | ||||||
|  | 				} else { | ||||||
|  | 					a.client.Transport = debugTransport{a.client.Transport} | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			if a.client == nil { | ||||||
|  | 				a.client = http.DefaultClient | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			// TODO(ericchiang): Retry with backoff | 			// TODO(ericchiang): Retry with backoff | ||||||
| 			provider, err := oidc.NewProvider(a.ctx, issuerURL) | 			ctx := oidc.ClientContext(context.Background(), a.client) | ||||||
|  | 			provider, err := oidc.NewProvider(ctx, issuerURL) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err) | 				return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err) | ||||||
| 			} | 			} | ||||||
| @@ -258,6 +257,8 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { | |||||||
| 		err   error | 		err   error | ||||||
| 		token *oauth2.Token | 		token *oauth2.Token | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
|  | 	ctx := oidc.ClientContext(r.Context(), a.client) | ||||||
| 	oauth2Config := a.oauth2Config(nil) | 	oauth2Config := a.oauth2Config(nil) | ||||||
| 	switch r.Method { | 	switch r.Method { | ||||||
| 	case "GET": | 	case "GET": | ||||||
| @@ -275,7 +276,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { | |||||||
| 			http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) | 			http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) | ||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		token, err = oauth2Config.Exchange(a.ctx, code) | 		token, err = oauth2Config.Exchange(ctx, code) | ||||||
| 	case "POST": | 	case "POST": | ||||||
| 		// Form request from frontend to refresh a token. | 		// Form request from frontend to refresh a token. | ||||||
| 		refresh := r.FormValue("refresh_token") | 		refresh := r.FormValue("refresh_token") | ||||||
| @@ -287,7 +288,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { | |||||||
| 			RefreshToken: refresh, | 			RefreshToken: refresh, | ||||||
| 			Expiry:       time.Now().Add(-time.Hour), | 			Expiry:       time.Now().Add(-time.Hour), | ||||||
| 		} | 		} | ||||||
| 		token, err = oauth2Config.TokenSource(r.Context(), t).Token() | 		token, err = oauth2Config.TokenSource(ctx, t).Token() | ||||||
| 	default: | 	default: | ||||||
| 		http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) | 		http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) | ||||||
| 		return | 		return | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user