server/{handler,oauth2}: cleanup error returns
Now, we'll return a standard error, and have the caller act upon this being an instance of authErr. Also changes the storage.AuthRequest return to a pointer, and returns nil in error cases. Signed-off-by: Stephan Renatus <srenatus@chef.io>
This commit is contained in:
		@@ -200,17 +200,21 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	authReq, err := s.parseAuthorizationRequest(r)
 | 
						authReq, err := s.parseAuthorizationRequest(r)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		s.logger.Errorf("Failed to parse authorization request: %v", err)
 | 
							s.logger.Errorf("Failed to parse authorization request: %v", err)
 | 
				
			||||||
		if handler, ok := err.Handle(); ok {
 | 
							status := http.StatusInternalServerError
 | 
				
			||||||
			// client_id and redirect_uri checked out and we can redirect back to
 | 
					
 | 
				
			||||||
			// the client with the error.
 | 
							// If this is an authErr, let's let it handle the error, or update the HTTP
 | 
				
			||||||
			handler.ServeHTTP(w, r)
 | 
							// status code
 | 
				
			||||||
			return
 | 
							if err, ok := err.(*authErr); ok {
 | 
				
			||||||
 | 
								if handler, ok := err.Handle(); ok {
 | 
				
			||||||
 | 
									// client_id and redirect_uri checked out and we can redirect back to
 | 
				
			||||||
 | 
									// the client with the error.
 | 
				
			||||||
 | 
									handler.ServeHTTP(w, r)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								status = err.Status()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// Otherwise render the error to the user.
 | 
							s.renderError(w, status, err.Error())
 | 
				
			||||||
		//
 | 
					 | 
				
			||||||
		// TODO(ericchiang): Should we just always render the error?
 | 
					 | 
				
			||||||
		s.renderError(w, err.Status(), err.Error())
 | 
					 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -220,15 +224,15 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	//
 | 
						//
 | 
				
			||||||
	// See: https://github.com/dexidp/dex/issues/646
 | 
						// See: https://github.com/dexidp/dex/issues/646
 | 
				
			||||||
	authReq.Expiry = s.now().Add(s.authRequestsValidFor)
 | 
						authReq.Expiry = s.now().Add(s.authRequestsValidFor)
 | 
				
			||||||
	if err := s.storage.CreateAuthRequest(authReq); err != nil {
 | 
						if err := s.storage.CreateAuthRequest(*authReq); err != nil {
 | 
				
			||||||
		s.logger.Errorf("Failed to create authorization request: %v", err)
 | 
							s.logger.Errorf("Failed to create authorization request: %v", err)
 | 
				
			||||||
		s.renderError(w, http.StatusInternalServerError, "Failed to connect to the database.")
 | 
							s.renderError(w, http.StatusInternalServerError, "Failed to connect to the database.")
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	connectors, e := s.storage.ListConnectors()
 | 
						connectors, err := s.storage.ListConnectors()
 | 
				
			||||||
	if e != nil {
 | 
						if err != nil {
 | 
				
			||||||
		s.logger.Errorf("Failed to get list of connectors: %v", e)
 | 
							s.logger.Errorf("Failed to get list of connectors: %v", err)
 | 
				
			||||||
		s.renderError(w, http.StatusInternalServerError, "Failed to retrieve connector list.")
 | 
							s.renderError(w, http.StatusInternalServerError, "Failed to retrieve connector list.")
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -379,14 +379,14 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// parse the initial request from the OAuth2 client.
 | 
					// parse the initial request from the OAuth2 client.
 | 
				
			||||||
func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthRequest, oauth2Err *authErr) {
 | 
					func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthRequest, error) {
 | 
				
			||||||
	if err := r.ParseForm(); err != nil {
 | 
						if err := r.ParseForm(); err != nil {
 | 
				
			||||||
		return req, &authErr{"", "", errInvalidRequest, "Failed to parse request body."}
 | 
							return nil, &authErr{"", "", errInvalidRequest, "Failed to parse request body."}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	q := r.Form
 | 
						q := r.Form
 | 
				
			||||||
	redirectURI, err := url.QueryUnescape(q.Get("redirect_uri"))
 | 
						redirectURI, err := url.QueryUnescape(q.Get("redirect_uri"))
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return req, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."}
 | 
							return nil, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	clientID := q.Get("client_id")
 | 
						clientID := q.Get("client_id")
 | 
				
			||||||
@@ -401,25 +401,25 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		if err == storage.ErrNotFound {
 | 
							if err == storage.ErrNotFound {
 | 
				
			||||||
			description := fmt.Sprintf("Invalid client_id (%q).", clientID)
 | 
								description := fmt.Sprintf("Invalid client_id (%q).", clientID)
 | 
				
			||||||
			return req, &authErr{"", "", errUnauthorizedClient, description}
 | 
								return nil, &authErr{"", "", errUnauthorizedClient, description}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		s.logger.Errorf("Failed to get client: %v", err)
 | 
							s.logger.Errorf("Failed to get client: %v", err)
 | 
				
			||||||
		return req, &authErr{"", "", errServerError, ""}
 | 
							return nil, &authErr{"", "", errServerError, ""}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if connectorID != "" {
 | 
						if connectorID != "" {
 | 
				
			||||||
		connectors, err := s.storage.ListConnectors()
 | 
							connectors, err := s.storage.ListConnectors()
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
 | 
								return nil, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if !validateConnectorID(connectors, connectorID) {
 | 
							if !validateConnectorID(connectors, connectorID) {
 | 
				
			||||||
			return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
 | 
								return nil, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !validateRedirectURI(client, redirectURI) {
 | 
						if !validateRedirectURI(client, redirectURI) {
 | 
				
			||||||
		description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
 | 
							description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
 | 
				
			||||||
		return req, &authErr{"", "", errInvalidRequest, description}
 | 
							return nil, &authErr{"", "", errInvalidRequest, description}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// From here on out, we want to redirect back to the client with an error.
 | 
						// From here on out, we want to redirect back to the client with an error.
 | 
				
			||||||
@@ -446,7 +446,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			isTrusted, err := s.validateCrossClientTrust(clientID, peerID)
 | 
								isTrusted, err := s.validateCrossClientTrust(clientID, peerID)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return req, newErr(errServerError, "Internal server error.")
 | 
									return nil, newErr(errServerError, "Internal server error.")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if !isTrusted {
 | 
								if !isTrusted {
 | 
				
			||||||
				invalidScopes = append(invalidScopes, scope)
 | 
									invalidScopes = append(invalidScopes, scope)
 | 
				
			||||||
@@ -454,13 +454,13 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if !hasOpenIDScope {
 | 
						if !hasOpenIDScope {
 | 
				
			||||||
		return req, newErr("invalid_scope", `Missing required scope(s) ["openid"].`)
 | 
							return nil, newErr("invalid_scope", `Missing required scope(s) ["openid"].`)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(unrecognized) > 0 {
 | 
						if len(unrecognized) > 0 {
 | 
				
			||||||
		return req, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized)
 | 
							return nil, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if len(invalidScopes) > 0 {
 | 
						if len(invalidScopes) > 0 {
 | 
				
			||||||
		return req, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes)
 | 
							return nil, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var rt struct {
 | 
						var rt struct {
 | 
				
			||||||
@@ -478,23 +478,23 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
				
			|||||||
		case responseTypeToken:
 | 
							case responseTypeToken:
 | 
				
			||||||
			rt.token = true
 | 
								rt.token = true
 | 
				
			||||||
		default:
 | 
							default:
 | 
				
			||||||
			return req, newErr("invalid_request", "Invalid response type %q", responseType)
 | 
								return nil, newErr("invalid_request", "Invalid response type %q", responseType)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if !s.supportedResponseTypes[responseType] {
 | 
							if !s.supportedResponseTypes[responseType] {
 | 
				
			||||||
			return req, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType)
 | 
								return nil, newErr(errUnsupportedResponseType, "Unsupported response type %q", responseType)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(responseTypes) == 0 {
 | 
						if len(responseTypes) == 0 {
 | 
				
			||||||
		return req, newErr("invalid_requests", "No response_type provided")
 | 
							return nil, newErr("invalid_requests", "No response_type provided")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if rt.token && !rt.code && !rt.idToken {
 | 
						if rt.token && !rt.code && !rt.idToken {
 | 
				
			||||||
		// "token" can't be provided by its own.
 | 
							// "token" can't be provided by its own.
 | 
				
			||||||
		//
 | 
							//
 | 
				
			||||||
		// https://openid.net/specs/openid-connect-core-1_0.html#Authentication
 | 
							// https://openid.net/specs/openid-connect-core-1_0.html#Authentication
 | 
				
			||||||
		return req, newErr("invalid_request", "Response type 'token' must be provided with type 'id_token' and/or 'code'")
 | 
							return nil, newErr("invalid_request", "Response type 'token' must be provided with type 'id_token' and/or 'code'")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if !rt.code {
 | 
						if !rt.code {
 | 
				
			||||||
		// Either "id_token code" or "id_token" has been provided which implies the
 | 
							// Either "id_token code" or "id_token" has been provided which implies the
 | 
				
			||||||
@@ -502,17 +502,17 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
				
			|||||||
		//
 | 
							//
 | 
				
			||||||
		// https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest
 | 
							// https://openid.net/specs/openid-connect-core-1_0.html#ImplicitAuthRequest
 | 
				
			||||||
		if nonce == "" {
 | 
							if nonce == "" {
 | 
				
			||||||
			return req, newErr("invalid_request", "Response type 'token' requires a 'nonce' value.")
 | 
								return nil, newErr("invalid_request", "Response type 'token' requires a 'nonce' value.")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if rt.token {
 | 
						if rt.token {
 | 
				
			||||||
		if redirectURI == redirectURIOOB {
 | 
							if redirectURI == redirectURIOOB {
 | 
				
			||||||
			err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB)
 | 
								err := fmt.Sprintf("Cannot use response type 'token' with redirect_uri '%s'.", redirectURIOOB)
 | 
				
			||||||
			return req, newErr("invalid_request", err)
 | 
								return nil, newErr("invalid_request", err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return storage.AuthRequest{
 | 
						return &storage.AuthRequest{
 | 
				
			||||||
		ID:                  storage.NewID(),
 | 
							ID:                  storage.NewID(),
 | 
				
			||||||
		ClientID:            client.ID,
 | 
							ClientID:            client.ID,
 | 
				
			||||||
		State:               state,
 | 
							State:               state,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user