server/{handler,oauth2}: cleanup error returns

Now, we'll return a standard error, and have the caller act upon this
being an instance of authErr.

Also changes the storage.AuthRequest return to a pointer, and returns
nil in error cases.

Signed-off-by: Stephan Renatus <srenatus@chef.io>
This commit is contained in:
Stephan Renatus 2019-07-24 12:45:50 +02:00
parent d7c7d42466
commit 8561a66365
No known key found for this signature in database
GPG Key ID: 811376EBA81C2C59
2 changed files with 36 additions and 32 deletions

View File

@ -200,17 +200,21 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
authReq, err := s.parseAuthorizationRequest(r) authReq, err := s.parseAuthorizationRequest(r)
if err != nil { if err != nil {
s.logger.Errorf("Failed to parse authorization request: %v", err) s.logger.Errorf("Failed to parse authorization request: %v", err)
if handler, ok := err.Handle(); ok { status := http.StatusInternalServerError
// client_id and redirect_uri checked out and we can redirect back to
// the client with the error. // If this is an authErr, let's let it handle the error, or update the HTTP
handler.ServeHTTP(w, r) // status code
return if err, ok := err.(*authErr); ok {
if handler, ok := err.Handle(); ok {
// client_id and redirect_uri checked out and we can redirect back to
// the client with the error.
handler.ServeHTTP(w, r)
return
}
status = err.Status()
} }
// Otherwise render the error to the user. s.renderError(w, status, err.Error())
//
// TODO(ericchiang): Should we just always render the error?
s.renderError(w, err.Status(), err.Error())
return return
} }
@ -220,15 +224,15 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
// //
// See: https://github.com/dexidp/dex/issues/646 // See: https://github.com/dexidp/dex/issues/646
authReq.Expiry = s.now().Add(s.authRequestsValidFor) authReq.Expiry = s.now().Add(s.authRequestsValidFor)
if err := s.storage.CreateAuthRequest(authReq); err != nil { if err := s.storage.CreateAuthRequest(*authReq); err != nil {
s.logger.Errorf("Failed to create authorization request: %v", err) s.logger.Errorf("Failed to create authorization request: %v", err)
s.renderError(w, http.StatusInternalServerError, "Failed to connect to the database.") s.renderError(w, http.StatusInternalServerError, "Failed to connect to the database.")
return return
} }
connectors, e := s.storage.ListConnectors() connectors, err := s.storage.ListConnectors()
if e != nil { if err != nil {
s.logger.Errorf("Failed to get list of connectors: %v", e) s.logger.Errorf("Failed to get list of connectors: %v", err)
s.renderError(w, http.StatusInternalServerError, "Failed to retrieve connector list.") s.renderError(w, http.StatusInternalServerError, "Failed to retrieve connector list.")
return return
} }

View File

@ -379,14 +379,14 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
} }
// parse the initial request from the OAuth2 client. // parse the initial request from the OAuth2 client.
func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthRequest, oauth2Err *authErr) { func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthRequest, error) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
return req, &authErr{"", "", errInvalidRequest, "Failed to parse request body."} return nil, &authErr{"", "", errInvalidRequest, "Failed to parse request body."}
} }
q := r.Form q := r.Form
redirectURI, err := url.QueryUnescape(q.Get("redirect_uri")) redirectURI, err := url.QueryUnescape(q.Get("redirect_uri"))
if err != nil { if err != nil {
return req, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."} return nil, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."}
} }
clientID := q.Get("client_id") clientID := q.Get("client_id")
@ -401,25 +401,25 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
if err != nil { if err != nil {
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
description := fmt.Sprintf("Invalid client_id (%q).", clientID) description := fmt.Sprintf("Invalid client_id (%q).", clientID)
return req, &authErr{"", "", errUnauthorizedClient, description} return nil, &authErr{"", "", errUnauthorizedClient, description}
} }
s.logger.Errorf("Failed to get client: %v", err) s.logger.Errorf("Failed to get client: %v", err)
return req, &authErr{"", "", errServerError, ""} return nil, &authErr{"", "", errServerError, ""}
} }
if connectorID != "" { if connectorID != "" {
connectors, err := s.storage.ListConnectors() connectors, err := s.storage.ListConnectors()
if err != nil { if err != nil {
return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"} return nil, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
} }
if !validateConnectorID(connectors, connectorID) { if !validateConnectorID(connectors, connectorID) {
return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"} return nil, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
} }
} }
if !validateRedirectURI(client, redirectURI) { if !validateRedirectURI(client, redirectURI) {
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
return req, &authErr{"", "", errInvalidRequest, description} return nil, &authErr{"", "", errInvalidRequest, description}
} }
// From here on out, we want to redirect back to the client with an error. // From here on out, we want to redirect back to the client with an error.
@ -446,7 +446,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
isTrusted, err := s.validateCrossClientTrust(clientID, peerID) isTrusted, err := s.validateCrossClientTrust(clientID, peerID)
if err != nil { if err != nil {
return req, newErr(errServerError, "Internal server error.") return nil, newErr(errServerError, "Internal server error.")
} }
if !isTrusted { if !isTrusted {
invalidScopes = append(invalidScopes, scope) invalidScopes = append(invalidScopes, scope)
@ -454,13 +454,13 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
} }
} }
if !hasOpenIDScope { if !hasOpenIDScope {
return req, newErr("invalid_scope", `Missing required scope(s) ["openid"].`) return nil, newErr("invalid_scope", `Missing required scope(s) ["openid"].`)
} }
if len(unrecognized) > 0 { if len(unrecognized) > 0 {
return req, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized) return nil, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized)
} }
if len(invalidScopes) > 0 { if len(invalidScopes) > 0 {
return req, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes) return nil, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes)
} }
var rt struct { var rt struct {
@ -478,23 +478,23 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
case responseTypeToken: case responseTypeToken:
rt.token = true rt.token = true
default: default:
return req, newErr("invalid_request", "Invalid response type %q", responseType) return nil, newErr("invalid_request", "Invalid response type %q", responseType)
} }
if !s.supportedResponseTypes[responseType] { if !s.supportedResponseTypes[responseType] {
return req, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType) return nil, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType)
} }
} }
if len(responseTypes) == 0 { if len(responseTypes) == 0 {
return req, newErr("invalid_requests", "No response_type provided") return nil, newErr("invalid_requests", "No response_type provided")
} }
if rt.token && !rt.code && !rt.idToken { if rt.token && !rt.code && !rt.idToken {
// "token" can't be provided by its own. // "token" can't be provided by its own.
// //
// https://openid.net/specs/openid-connect-core-1_0.html#Authentication // https://openid.net/specs/openid-connect-core-1_0.html#Authentication
return req, newErr("invalid_request", "Response type 'token' must be provided with type 'id_token' and/or 'code'") return nil, newErr("invalid_request", "Response type 'token' must be provided with type 'id_token' and/or 'code'")
} }
if !rt.code { if !rt.code {
// Either "id_token code" or "id_token" has been provided which implies the // Either "id_token code" or "id_token" has been provided which implies the
@ -502,17 +502,17 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
// //
// https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest // https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest
if nonce == "" { if nonce == "" {
return req, newErr("invalid_request", "Response type 'token' requires a 'nonce' value.") return nil, newErr("invalid_request", "Response type 'token' requires a 'nonce' value.")
} }
} }
if rt.token { if rt.token {
if redirectURI == redirectURIOOB { if redirectURI == redirectURIOOB {
err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB) err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB)
return req, newErr("invalid_request", err) return nil, newErr("invalid_request", err)
} }
} }
return storage.AuthRequest{ return &storage.AuthRequest{
ID: storage.NewID(), ID: storage.NewID(),
ClientID: client.ID, ClientID: client.ID,
State: state, State: state,