Merge pull request #595 from ericchiang/dev-example-app-fix-refreshing-with-google
dev branch: check if a provider supports a refresh token scope
This commit is contained in:
		| @@ -30,6 +30,10 @@ type app struct { | |||||||
| 	verifier *oidc.IDTokenVerifier | 	verifier *oidc.IDTokenVerifier | ||||||
| 	provider *oidc.Provider | 	provider *oidc.Provider | ||||||
|  |  | ||||||
|  | 	// Does the provider use "offline_access" scope to request a refresh token | ||||||
|  | 	// or does it use "access_type=offline" (e.g. Google)? | ||||||
|  | 	offlineAsScope bool | ||||||
|  |  | ||||||
| 	ctx    context.Context | 	ctx    context.Context | ||||||
| 	cancel context.CancelFunc | 	cancel context.CancelFunc | ||||||
| } | } | ||||||
| @@ -102,6 +106,34 @@ func cmd() *cobra.Command { | |||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err) | 				return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
|  | 			var s struct { | ||||||
|  | 				// What scopes does a provider support? | ||||||
|  | 				// | ||||||
|  | 				// See: https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata | ||||||
|  | 				ScopesSupported []string `json:"scopes_supported"` | ||||||
|  | 			} | ||||||
|  | 			if err := provider.Claims(&s); err != nil { | ||||||
|  | 				return fmt.Errorf("Failed to parse provider scopes_supported: %v", err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if len(s.ScopesSupported) == 0 { | ||||||
|  | 				// scopes_supported is a "RECOMMENDED" discovery claim, not a required | ||||||
|  | 				// one. If missing, assume that the provider follows the spec and has | ||||||
|  | 				// an "offline_access" scope. | ||||||
|  | 				a.offlineAsScope = true | ||||||
|  | 			} else { | ||||||
|  | 				// See if scopes_supported has the "offline_access" scope. | ||||||
|  | 				a.offlineAsScope = func() bool { | ||||||
|  | 					for _, scope := range s.ScopesSupported { | ||||||
|  | 						if scope == oidc.ScopeOfflineAccess { | ||||||
|  | 							return true | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 					return false | ||||||
|  | 				}() | ||||||
|  | 			} | ||||||
|  |  | ||||||
| 			a.provider = provider | 			a.provider = provider | ||||||
| 			a.verifier = provider.NewVerifier(a.ctx, oidc.VerifyAudience(a.clientID)) | 			a.verifier = provider.NewVerifier(a.ctx, oidc.VerifyAudience(a.clientID)) | ||||||
|  |  | ||||||
| @@ -166,10 +198,15 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { | |||||||
| 		scopes = append(scopes, "audience:server:client_id:"+client) | 		scopes = append(scopes, "audience:server:client_id:"+client) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// TODO(ericchiang): Determine if provider does not support "offline_access" or has | 	authCodeURL := "" | ||||||
| 	// some other mechanism for requesting refresh tokens. | 	scopes = append(scopes, "openid", "profile", "email") | ||||||
| 	scopes = append(scopes, "openid", "profile", "email", "offline_access") | 	if a.offlineAsScope { | ||||||
| 	http.Redirect(w, r, a.oauth2Config(scopes).AuthCodeURL(""), http.StatusSeeOther) | 		scopes = append(scopes, "offline_access") | ||||||
|  | 		authCodeURL = a.oauth2Config(scopes).AuthCodeURL("") | ||||||
|  | 	} else { | ||||||
|  | 		authCodeURL = a.oauth2Config(scopes).AuthCodeURL("", oauth2.AccessTypeOffline) | ||||||
|  | 	} | ||||||
|  | 	http.Redirect(w, r, authCodeURL, http.StatusSeeOther) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { | func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { | ||||||
| @@ -195,7 +232,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { | |||||||
| 		} | 		} | ||||||
| 		token, err = oauth2Config.TokenSource(a.ctx, t).Token() | 		token, err = oauth2Config.TokenSource(a.ctx, t).Token() | ||||||
| 	default: | 	default: | ||||||
| 		http.Error(w, "no code in request", http.StatusBadRequest) | 		http.Error(w, fmt.Sprintf("no code in request: %q", r.Form), http.StatusBadRequest) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user