Merge pull request #875 from ericchiang/fix-example-app-custom-ca
cmd/example-app: fix custom CA behavior
This commit is contained in:
		@@ -37,8 +37,7 @@ type app struct {
 | 
				
			|||||||
	// or does it use "access_type=offline" (e.g. Google)?
 | 
						// or does it use "access_type=offline" (e.g. Google)?
 | 
				
			||||||
	offlineAsScope bool
 | 
						offlineAsScope bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ctx    context.Context
 | 
						client *http.Client
 | 
				
			||||||
	cancel context.CancelFunc
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// return an HTTP client which trusts the provided root CAs.
 | 
					// return an HTTP client which trusts the provided root CAs.
 | 
				
			||||||
@@ -118,31 +117,31 @@ func cmd() *cobra.Command {
 | 
				
			|||||||
				return fmt.Errorf("parse listen address: %v", err)
 | 
									return fmt.Errorf("parse listen address: %v", err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			a.ctx, a.cancel = context.WithCancel(context.Background())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			if rootCAs != "" {
 | 
								if rootCAs != "" {
 | 
				
			||||||
				client, err := httpClientForRootCAs(rootCAs)
 | 
									client, err := httpClientForRootCAs(rootCAs)
 | 
				
			||||||
				if err != nil {
 | 
									if err != nil {
 | 
				
			||||||
					return err
 | 
										return err
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
 | 
									a.client = client
 | 
				
			||||||
				// This sets the OAuth2 client and oidc client.
 | 
					 | 
				
			||||||
				a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, client)
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if debug {
 | 
								if debug {
 | 
				
			||||||
				client, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client)
 | 
									if a.client == nil {
 | 
				
			||||||
				if ok {
 | 
										a.client = &http.Client{
 | 
				
			||||||
					client.Transport = debugTransport{client.Transport}
 | 
					 | 
				
			||||||
				} else {
 | 
					 | 
				
			||||||
					a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, &http.Client{
 | 
					 | 
				
			||||||
						Transport: debugTransport{http.DefaultTransport},
 | 
											Transport: debugTransport{http.DefaultTransport},
 | 
				
			||||||
					})
 | 
										}
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										a.client.Transport = debugTransport{a.client.Transport}
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if a.client == nil {
 | 
				
			||||||
 | 
									a.client = http.DefaultClient
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// TODO(ericchiang): Retry with backoff
 | 
								// TODO(ericchiang): Retry with backoff
 | 
				
			||||||
			provider, err := oidc.NewProvider(a.ctx, issuerURL)
 | 
								ctx := oidc.ClientContext(context.Background(), a.client)
 | 
				
			||||||
 | 
								provider, err := oidc.NewProvider(ctx, issuerURL)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err)
 | 
									return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -258,6 +257,8 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
		err   error
 | 
							err   error
 | 
				
			||||||
		token *oauth2.Token
 | 
							token *oauth2.Token
 | 
				
			||||||
	)
 | 
						)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ctx := oidc.ClientContext(r.Context(), a.client)
 | 
				
			||||||
	oauth2Config := a.oauth2Config(nil)
 | 
						oauth2Config := a.oauth2Config(nil)
 | 
				
			||||||
	switch r.Method {
 | 
						switch r.Method {
 | 
				
			||||||
	case "GET":
 | 
						case "GET":
 | 
				
			||||||
@@ -275,7 +276,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
			http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
 | 
								http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest)
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		token, err = oauth2Config.Exchange(a.ctx, code)
 | 
							token, err = oauth2Config.Exchange(ctx, code)
 | 
				
			||||||
	case "POST":
 | 
						case "POST":
 | 
				
			||||||
		// Form request from frontend to refresh a token.
 | 
							// Form request from frontend to refresh a token.
 | 
				
			||||||
		refresh := r.FormValue("refresh_token")
 | 
							refresh := r.FormValue("refresh_token")
 | 
				
			||||||
@@ -287,7 +288,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
			RefreshToken: refresh,
 | 
								RefreshToken: refresh,
 | 
				
			||||||
			Expiry:       time.Now().Add(-time.Hour),
 | 
								Expiry:       time.Now().Add(-time.Hour),
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		token, err = oauth2Config.TokenSource(r.Context(), t).Token()
 | 
							token, err = oauth2Config.TokenSource(ctx, t).Token()
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
		http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
 | 
							http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user