fix: defer creation of auth request.
Rather than creating the auth request when the user hits /auth, pass
the arguments through to /auth/{connector} and have the auth request
created there.  This prevents a database error when using the "Select
another login method" link, and also avoids a few other error cases.
Fixes #1849, #646.
Signed-off-by: Alastair Houghton <alastair@alastairs-place.net>
			
			
This commit is contained in:
		@@ -124,6 +124,66 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
 | 
			
		||||
 | 
			
		||||
// handleAuthorization handles the OAuth2 auth endpoint.
 | 
			
		||||
func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	// Extract the arguments
 | 
			
		||||
	if err := r.ParseForm(); err != nil {
 | 
			
		||||
		s.logger.Errorf("Failed to parse arguments: %v", err)
 | 
			
		||||
 | 
			
		||||
		s.renderError(r, w, http.StatusBadRequest, "Bad query/form arguments")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	connectorID := r.Form.Get("connector_id")
 | 
			
		||||
 | 
			
		||||
	connectors, err := s.storage.ListConnectors()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("Failed to get list of connectors: %v", err)
 | 
			
		||||
		s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Construct a URL with all of the arguments in its query
 | 
			
		||||
	connURL := url.URL{
 | 
			
		||||
		RawQuery: r.Form.Encode(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Redirect if a client chooses a specific connector_id
 | 
			
		||||
	if connectorID != "" {
 | 
			
		||||
		for _, c := range connectors {
 | 
			
		||||
			if c.ID == connectorID {
 | 
			
		||||
				connURL.Path = s.absPath("/auth", c.ID)
 | 
			
		||||
				http.Redirect(w, r, connURL.String(), http.StatusFound)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(connectors) == 1 && !s.alwaysShowLogin {
 | 
			
		||||
		for _, c := range connectors {
 | 
			
		||||
			connURL.Path = s.absPath("/auth", c.ID)
 | 
			
		||||
			http.Redirect(w, r, connURL.String(), http.StatusFound)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	connectorInfos := make([]connectorInfo, len(connectors))
 | 
			
		||||
	for index, conn := range connectors {
 | 
			
		||||
		connURL.Path = s.absPath("/auth", conn.ID)
 | 
			
		||||
		connectorInfos[index] = connectorInfo{
 | 
			
		||||
			ID:   conn.ID,
 | 
			
		||||
			Name: conn.Name,
 | 
			
		||||
			Type: conn.Type,
 | 
			
		||||
			URL:  connURL.String(),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := s.templates.login(r, w, connectorInfos); err != nil {
 | 
			
		||||
		s.logger.Errorf("Server template error: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	authReq, err := s.parseAuthorizationRequest(r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("Failed to parse authorization request: %v", err)
 | 
			
		||||
@@ -145,64 +205,6 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO(ericchiang): Create this authorization request later in the login flow
 | 
			
		||||
	// so users don't hit "not found" database errors if they wait at the login
 | 
			
		||||
	// screen too long.
 | 
			
		||||
	//
 | 
			
		||||
	// See: https://github.com/dexidp/dex/issues/646
 | 
			
		||||
	authReq.Expiry = s.now().Add(s.authRequestsValidFor)
 | 
			
		||||
	if err := s.storage.CreateAuthRequest(*authReq); err != nil {
 | 
			
		||||
		s.logger.Errorf("Failed to create authorization request: %v", err)
 | 
			
		||||
		s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	connectors, err := s.storage.ListConnectors()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("Failed to get list of connectors: %v", err)
 | 
			
		||||
		s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Redirect if a client chooses a specific connector_id
 | 
			
		||||
	if authReq.ConnectorID != "" {
 | 
			
		||||
		for _, c := range connectors {
 | 
			
		||||
			if c.ID == authReq.ConnectorID {
 | 
			
		||||
				http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(connectors) == 1 && !s.alwaysShowLogin {
 | 
			
		||||
		for _, c := range connectors {
 | 
			
		||||
			// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter
 | 
			
		||||
			// on create the auth request.
 | 
			
		||||
			http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	connectorInfos := make([]connectorInfo, len(connectors))
 | 
			
		||||
	for index, conn := range connectors {
 | 
			
		||||
		connectorInfos[index] = connectorInfo{
 | 
			
		||||
			ID:   conn.ID,
 | 
			
		||||
			Name: conn.Name,
 | 
			
		||||
			Type: conn.Type,
 | 
			
		||||
			// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter
 | 
			
		||||
			// on create the auth request.
 | 
			
		||||
			URL: s.absPath("/auth", conn.ID) + "?req=" + authReq.ID,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := s.templates.login(r, w, connectorInfos); err != nil {
 | 
			
		||||
		s.logger.Errorf("Server template error: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	connID := mux.Vars(r)["connector"]
 | 
			
		||||
	conn, err := s.getConnector(connID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -211,33 +213,22 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	authReqID := r.FormValue("req")
 | 
			
		||||
 | 
			
		||||
	authReq, err := s.storage.GetAuthRequest(authReqID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		s.logger.Errorf("Failed to get auth request: %v", err)
 | 
			
		||||
		if err == storage.ErrNotFound {
 | 
			
		||||
			s.renderError(r, w, http.StatusBadRequest, "Login session expired.")
 | 
			
		||||
		} else {
 | 
			
		||||
			s.renderError(r, w, http.StatusInternalServerError, "Database error.")
 | 
			
		||||
		}
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Set the connector being used for the login.
 | 
			
		||||
	if authReq.ConnectorID != connID {
 | 
			
		||||
		updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
 | 
			
		||||
			if a.ConnectorID != "" {
 | 
			
		||||
				return a, fmt.Errorf("connector is already set for this auth request")
 | 
			
		||||
			}
 | 
			
		||||
			a.ConnectorID = connID
 | 
			
		||||
			return a, nil
 | 
			
		||||
		}
 | 
			
		||||
		if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil {
 | 
			
		||||
			s.logger.Errorf("Failed to set connector ID on auth request: %v", err)
 | 
			
		||||
			s.renderError(r, w, http.StatusInternalServerError, "Database error.")
 | 
			
		||||
	if authReq.ConnectorID != "" && authReq.ConnectorID != connID {
 | 
			
		||||
		s.logger.Errorf("Mismatched connector ID in auth request: %s vs %s",
 | 
			
		||||
			authReq.ConnectorID, connID)
 | 
			
		||||
		s.renderError(r, w, http.StatusBadRequest, "Bad connector ID")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	authReq.ConnectorID = connID
 | 
			
		||||
 | 
			
		||||
	// Actually create the auth request
 | 
			
		||||
	authReq.Expiry = s.now().Add(s.authRequestsValidFor)
 | 
			
		||||
	if err := s.storage.CreateAuthRequest(*authReq); err != nil {
 | 
			
		||||
		s.logger.Errorf("Failed to create authorization request: %v", err)
 | 
			
		||||
		s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	scopes := parseScopes(authReq.Scopes)
 | 
			
		||||
@@ -250,7 +241,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			// Use the auth request ID as the "state" token.
 | 
			
		||||
			//
 | 
			
		||||
			// TODO(ericchiang): Is this appropriate or should we also be using a nonce?
 | 
			
		||||
			callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID)
 | 
			
		||||
			callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				s.logger.Errorf("Connector %q returned error when creating callback: %v", connID, err)
 | 
			
		||||
				s.renderError(r, w, http.StatusInternalServerError, "Login error.")
 | 
			
		||||
@@ -262,7 +253,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
				s.logger.Errorf("Server template error: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
		case connector.SAMLConnector:
 | 
			
		||||
			action, value, err := conn.POSTData(scopes, authReqID)
 | 
			
		||||
			action, value, err := conn.POSTData(scopes, authReq.ID)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				s.logger.Errorf("Creating SAML data: %v", err)
 | 
			
		||||
				s.renderError(r, w, http.StatusInternalServerError, "Connector Login Error")
 | 
			
		||||
@@ -285,7 +276,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
				    document.forms[0].submit();
 | 
			
		||||
				</script>
 | 
			
		||||
			  </body>
 | 
			
		||||
			  </html>`, action, value, authReqID)
 | 
			
		||||
			  </html>`, action, value, authReq.ID)
 | 
			
		||||
		default:
 | 
			
		||||
			s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
 | 
			
		||||
		}
 | 
			
		||||
@@ -311,7 +302,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
			}
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
 | 
			
		||||
		redirectURL, err := s.finalizeLogin(identity, *authReq, conn.Connector)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			s.logger.Errorf("Failed to finalize login: %v", err)
 | 
			
		||||
			s.renderError(r, w, http.StatusInternalServerError, "Login error.")
 | 
			
		||||
 
 | 
			
		||||
@@ -19,7 +19,6 @@ import (
 | 
			
		||||
	"golang.org/x/oauth2"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/memory"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func TestHandleHealth(t *testing.T) {
 | 
			
		||||
@@ -133,87 +132,6 @@ func TestHandleInvalidSAMLCallbacks(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestConnectorLoginDoesNotAllowToChangeConnectorForAuthRequest(t *testing.T) {
 | 
			
		||||
	memStorage := memory.New(logger)
 | 
			
		||||
 | 
			
		||||
	templates, err := loadTemplates(webConfig{webFS: os.DirFS("../web")}, "templates")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal("failed to load templates")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s := &Server{
 | 
			
		||||
		storage:                memStorage,
 | 
			
		||||
		logger:                 logger,
 | 
			
		||||
		templates:              templates,
 | 
			
		||||
		supportedResponseTypes: map[string]bool{"code": true},
 | 
			
		||||
		now:                    time.Now,
 | 
			
		||||
		connectors:             make(map[string]Connector),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	r := mux.NewRouter()
 | 
			
		||||
	r.HandleFunc("/auth/{connector}", s.handleConnectorLogin)
 | 
			
		||||
	s.mux = r
 | 
			
		||||
 | 
			
		||||
	clientID := "clientID"
 | 
			
		||||
	clientSecret := "secret"
 | 
			
		||||
	redirectURL := "localhost:5555" + "/callback"
 | 
			
		||||
	client := storage.Client{
 | 
			
		||||
		ID:           clientID,
 | 
			
		||||
		Secret:       clientSecret,
 | 
			
		||||
		RedirectURIs: []string{redirectURL},
 | 
			
		||||
	}
 | 
			
		||||
	if err := memStorage.CreateClient(client); err != nil {
 | 
			
		||||
		t.Fatal("failed to create client")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	createConnector := func(t *testing.T, id string) storage.Connector {
 | 
			
		||||
		connector := storage.Connector{
 | 
			
		||||
			ID:              id,
 | 
			
		||||
			Type:            "mockCallback",
 | 
			
		||||
			Name:            "Mock",
 | 
			
		||||
			ResourceVersion: "1",
 | 
			
		||||
		}
 | 
			
		||||
		if err := memStorage.CreateConnector(connector); err != nil {
 | 
			
		||||
			t.Fatalf("failed to create connector %v", id)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return connector
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	connector1 := createConnector(t, "mock1")
 | 
			
		||||
	connector2 := createConnector(t, "mock2")
 | 
			
		||||
 | 
			
		||||
	authReq := storage.AuthRequest{
 | 
			
		||||
		ID: storage.NewID(),
 | 
			
		||||
	}
 | 
			
		||||
	if err := memStorage.CreateAuthRequest(authReq); err != nil {
 | 
			
		||||
		t.Fatal("failed to create auth request")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	createConnectorLoginRequest := func(connID string) *http.Request {
 | 
			
		||||
		req := httptest.NewRequest("GET", "/auth/"+connID, nil)
 | 
			
		||||
		q := req.URL.Query()
 | 
			
		||||
		q.Add("req", authReq.ID)
 | 
			
		||||
		q.Add("redirect_uri", redirectURL)
 | 
			
		||||
		q.Add("scope", "openid")
 | 
			
		||||
		q.Add("response_type", "code")
 | 
			
		||||
		req.URL.RawQuery = q.Encode()
 | 
			
		||||
		return req
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	recorder := httptest.NewRecorder()
 | 
			
		||||
	s.ServeHTTP(recorder, createConnectorLoginRequest(connector1.ID))
 | 
			
		||||
	if recorder.Code != 302 {
 | 
			
		||||
		t.Fatal("failed to process request")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	recorder2 := httptest.NewRecorder()
 | 
			
		||||
	s.ServeHTTP(recorder2, createConnectorLoginRequest(connector2.ID))
 | 
			
		||||
	if recorder2.Code != 500 {
 | 
			
		||||
		t.Error("attempt to overwrite connector on auth request should fail")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TestHandleAuthCode checks that it is forbidden to use same code twice
 | 
			
		||||
func TestHandleAuthCode(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user