diff --git a/cmd/example-app/main.go b/cmd/example-app/main.go index 21b025d2..a21b2e86 100644 --- a/cmd/example-app/main.go +++ b/cmd/example-app/main.go @@ -237,6 +237,10 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { for _, client := range clients { scopes = append(scopes, "audience:server:client_id:"+client) } + connectorID := "" + if id := r.FormValue("connector_id"); id != "" { + connectorID = id + } authCodeURL := "" scopes = append(scopes, "openid", "profile", "email") @@ -248,6 +252,9 @@ func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { } else { authCodeURL = a.oauth2Config(scopes).AuthCodeURL(exampleAppState, oauth2.AccessTypeOffline) } + if connectorID != "" { + authCodeURL = authCodeURL + "&connector_id=" + connectorID + } http.Redirect(w, r, authCodeURL, http.StatusSeeOther) } @@ -307,7 +314,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { idToken, err := a.verifier.Verify(r.Context(), rawIDToken) if err != nil { - http.Error(w, fmt.Sprintf("Failed to verify ID token: %v", err), http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("failed to verify ID token: %v", err), http.StatusInternalServerError) return } @@ -318,10 +325,16 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { } var claims json.RawMessage - idToken.Claims(&claims) + if err := idToken.Claims(&claims); err != nil { + http.Error(w, fmt.Sprintf("error decoding ID token claims: %v", err), http.StatusInternalServerError) + return + } buff := new(bytes.Buffer) - json.Indent(buff, []byte(claims), "", " ") + if err := json.Indent(buff, []byte(claims), "", " "); err != nil { + http.Error(w, fmt.Sprintf("error indenting ID token claims: %v", err), http.StatusInternalServerError) + return + } - renderToken(w, a.redirectURI, rawIDToken, accessToken, token.RefreshToken, buff.Bytes()) + renderToken(w, a.redirectURI, rawIDToken, accessToken, token.RefreshToken, buff.String()) } diff --git a/cmd/example-app/templates.go b/cmd/example-app/templates.go index 3a4b5bb2..497eb8a7 100644 --- a/cmd/example-app/templates.go +++ b/cmd/example-app/templates.go @@ -15,8 +15,11 @@ var indexTmpl = template.Must(template.New("index.html").Parse(`
Extra scopes:
-- Request offline access: +
+ Connector ID: +
++ Request offline access:
@@ -63,13 +66,13 @@ pre { `)) -func renderToken(w http.ResponseWriter, redirectURL, idToken, accessToken, refreshToken string, claims []byte) { +func renderToken(w http.ResponseWriter, redirectURL, idToken, accessToken, refreshToken, claims string) { renderTemplate(w, tokenTmpl, tokenTmplData{ IDToken: idToken, AccessToken: accessToken, RefreshToken: refreshToken, RedirectURL: redirectURL, - Claims: string(claims), + Claims: claims, }) } diff --git a/server/handlers.go b/server/handlers.go index 70ef1321..39b98423 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -200,17 +200,21 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { authReq, err := s.parseAuthorizationRequest(r) if err != nil { s.logger.Errorf("Failed to parse authorization request: %v", err) - if handler, ok := err.Handle(); ok { - // client_id and redirect_uri checked out and we can redirect back to - // the client with the error. - handler.ServeHTTP(w, r) - return + status := http.StatusInternalServerError + + // If this is an authErr, let's let it handle the error, or update the HTTP + // status code + if err, ok := err.(*authErr); ok { + if handler, ok := err.Handle(); ok { + // client_id and redirect_uri checked out and we can redirect back to + // the client with the error. + handler.ServeHTTP(w, r) + return + } + status = err.Status() } - // Otherwise render the error to the user. - // - // TODO(ericchiang): Should we just always render the error? - s.renderError(w, err.Status(), err.Error()) + s.renderError(w, status, err.Error()) return } @@ -220,15 +224,15 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { // // See: https://github.com/dexidp/dex/issues/646 authReq.Expiry = s.now().Add(s.authRequestsValidFor) - if err := s.storage.CreateAuthRequest(authReq); err != nil { + if err := s.storage.CreateAuthRequest(*authReq); err != nil { s.logger.Errorf("Failed to create authorization request: %v", err) s.renderError(w, http.StatusInternalServerError, "Failed to connect to the database.") return } - connectors, e := s.storage.ListConnectors() - if e != nil { - s.logger.Errorf("Failed to get list of connectors: %v", e) + connectors, err := s.storage.ListConnectors() + if err != nil { + s.logger.Errorf("Failed to get list of connectors: %v", err) s.renderError(w, http.StatusInternalServerError, "Failed to retrieve connector list.") return } diff --git a/server/oauth2.go b/server/oauth2.go index 5b7a421c..79c4bf1a 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -379,14 +379,14 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str } // parse the initial request from the OAuth2 client. -func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthRequest, oauth2Err *authErr) { +func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthRequest, error) { if err := r.ParseForm(); err != nil { - return req, &authErr{"", "", errInvalidRequest, "Failed to parse request body."} + return nil, &authErr{"", "", errInvalidRequest, "Failed to parse request body."} } q := r.Form redirectURI, err := url.QueryUnescape(q.Get("redirect_uri")) if err != nil { - return req, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."} + return nil, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."} } clientID := q.Get("client_id") @@ -401,25 +401,25 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq if err != nil { if err == storage.ErrNotFound { description := fmt.Sprintf("Invalid client_id (%q).", clientID) - return req, &authErr{"", "", errUnauthorizedClient, description} + return nil, &authErr{"", "", errUnauthorizedClient, description} } s.logger.Errorf("Failed to get client: %v", err) - return req, &authErr{"", "", errServerError, ""} + return nil, &authErr{"", "", errServerError, ""} } if connectorID != "" { connectors, err := s.storage.ListConnectors() if err != nil { - return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"} + return nil, &authErr{"", "", errServerError, "Unable to retrieve connectors"} } if !validateConnectorID(connectors, connectorID) { - return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"} + return nil, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"} } } if !validateRedirectURI(client, redirectURI) { description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) - return req, &authErr{"", "", errInvalidRequest, description} + return nil, &authErr{"", "", errInvalidRequest, description} } // From here on out, we want to redirect back to the client with an error. @@ -446,7 +446,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq isTrusted, err := s.validateCrossClientTrust(clientID, peerID) if err != nil { - return req, newErr(errServerError, "Internal server error.") + return nil, newErr(errServerError, "Internal server error.") } if !isTrusted { invalidScopes = append(invalidScopes, scope) @@ -454,13 +454,13 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq } } if !hasOpenIDScope { - return req, newErr("invalid_scope", `Missing required scope(s) ["openid"].`) + return nil, newErr("invalid_scope", `Missing required scope(s) ["openid"].`) } if len(unrecognized) > 0 { - return req, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized) + return nil, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized) } if len(invalidScopes) > 0 { - return req, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes) + return nil, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes) } var rt struct { @@ -478,23 +478,23 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq case responseTypeToken: rt.token = true default: - return req, newErr("invalid_request", "Invalid response type %q", responseType) + return nil, newErr("invalid_request", "Invalid response type %q", responseType) } if !s.supportedResponseTypes[responseType] { - return req, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType) + return nil, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType) } } if len(responseTypes) == 0 { - return req, newErr("invalid_requests", "No response_type provided") + return nil, newErr("invalid_requests", "No response_type provided") } if rt.token && !rt.code && !rt.idToken { // "token" can't be provided by its own. // // https://openid.net/specs/openid-connect-core-1_0.html#Authentication - return req, newErr("invalid_request", "Response type 'token' must be provided with type 'id_token' and/or 'code'") + return nil, newErr("invalid_request", "Response type 'token' must be provided with type 'id_token' and/or 'code'") } if !rt.code { // Either "id_token code" or "id_token" has been provided which implies the @@ -502,17 +502,17 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq // // https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest if nonce == "" { - return req, newErr("invalid_request", "Response type 'token' requires a 'nonce' value.") + return nil, newErr("invalid_request", "Response type 'token' requires a 'nonce' value.") } } if rt.token { if redirectURI == redirectURIOOB { err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB) - return req, newErr("invalid_request", err) + return nil, newErr("invalid_request", err) } } - return storage.AuthRequest{ + return &storage.AuthRequest{ ID: storage.NewID(), ClientID: client.ID, State: state,