From cd0c24ec4db226d3146d8e30164372fd5f9a62f4 Mon Sep 17 00:00:00 2001 From: Alastair Houghton Date: Fri, 21 May 2021 11:03:22 +0100 Subject: [PATCH] fix: add an extra endpoint to avoid refresh generating AuthRequests. By adding an extra endpoint and a redirect, we can avoid a situation where it's trivially easy to generate a large number of AuthRequests by hitting F5/refresh in the browser. Signed-off-by: Alastair Houghton --- server/handlers.go | 80 +++++++++++++++++++++++++++++++++-------- server/handlers_test.go | 1 - server/server.go | 1 + 3 files changed, 67 insertions(+), 15 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index ff81460e..57ea2dca 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -128,7 +128,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { s.logger.Errorf("Failed to parse arguments: %v", err) - s.renderError(r, w, http.StatusBadRequest, "Bad query/form arguments") + s.renderError(r, w, http.StatusBadRequest, err.Error()) return } @@ -141,6 +141,9 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { return } + // We don't need connector_id any more + r.Form.Del("connector_id") + // Construct a URL with all of the arguments in its query connURL := url.URL{ RawQuery: r.Form.Encode(), @@ -160,11 +163,8 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { } if len(connectors) == 1 && !s.alwaysShowLogin { - for _, c := range connectors { - connURL.Path = s.absPath("/auth", c.ID) - http.Redirect(w, r, connURL.String(), http.StatusFound) - return - } + connURL.Path = s.absPath("/auth", connectors[0].ID) + http.Redirect(w, r, connURL.String(), http.StatusFound) } connectorInfos := make([]connectorInfo, len(connectors)) @@ -258,9 +258,15 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: - if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(conn), false, backLink); err != nil { - s.logger.Errorf("Server template error: %v", err) + loginURL := url.URL{ + Path: s.absPath("/auth", connID, "login"), } + q := loginURL.Query() + q.Set("state", authReq.ID) + q.Set("back", backLink) + loginURL.RawQuery = q.Encode() + + http.Redirect(w, r, loginURL.String(), http.StatusFound) case connector.SAMLConnector: action, value, err := conn.POSTData(scopes, authReq.ID) if err != nil { @@ -289,29 +295,75 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { default: s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") } - case http.MethodPost: - passwordConnector, ok := conn.Connector.(connector.PasswordConnector) - if !ok { + default: + s.renderError(r, w, http.StatusBadRequest, "Unsupported request method.") + } +} + +func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { + authID := r.URL.Query().Get("state") + if authID == "" { + s.renderError(r, w, http.StatusBadRequest, "User session error.") + return + } + + backLink := r.URL.Query().Get("back") + + authReq, err := s.storage.GetAuthRequest(authID) + if err != nil { + if err == storage.ErrNotFound { + s.logger.Errorf("Invalid 'state' parameter provided: %v", err) s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") return } + s.logger.Errorf("Failed to get auth request: %v", err) + s.renderError(r, w, http.StatusInternalServerError, "Database error.") + return + } + if connID := mux.Vars(r)["connector"]; connID != "" && connID != authReq.ConnectorID { + s.logger.Errorf("Connector mismatch: authentication started with id %q, but password login for id %q was triggered", authReq.ConnectorID, connID) + s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") + return + } + + conn, err := s.getConnector(authReq.ConnectorID) + if err != nil { + s.logger.Errorf("Failed to get connector with id %q : %v", authReq.ConnectorID, err) + s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") + return + } + + pwConn, ok := conn.Connector.(connector.PasswordConnector) + if !ok { + s.logger.Errorf("Expected password connector in handlePasswordLogin(), but got %v", pwConn) + s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") + return + } + + switch r.Method { + case http.MethodGet: + if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(pwConn), false, backLink); err != nil { + s.logger.Errorf("Server template error: %v", err) + } + case http.MethodPost: username := r.FormValue("login") password := r.FormValue("password") + scopes := parseScopes(authReq.Scopes) - identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password) + identity, ok, err := pwConn.Login(r.Context(), scopes, username, password) if err != nil { s.logger.Errorf("Failed to login user: %v", err) s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err)) return } if !ok { - if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(passwordConnector), true, backLink); err != nil { + if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(pwConn), true, backLink); err != nil { s.logger.Errorf("Server template error: %v", err) } return } - redirectURL, err := s.finalizeLogin(identity, *authReq, conn.Connector) + redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) if err != nil { s.logger.Errorf("Failed to finalize login: %v", err) s.renderError(r, w, http.StatusInternalServerError, "Login error.") diff --git a/server/handlers_test.go b/server/handlers_test.go index 83249ec8..60ea195a 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -7,7 +7,6 @@ import ( "errors" "net/http" "net/http/httptest" - "os" "testing" "time" diff --git a/server/server.go b/server/server.go index 84c3a82f..957b62dc 100644 --- a/server/server.go +++ b/server/server.go @@ -341,6 +341,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleWithCORS("/userinfo", s.handleUserInfo) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) + handleFunc("/auth/{connector}/login", s.handlePasswordLogin) handleFunc("/device", s.handleDeviceExchange) handleFunc("/device/auth/verify_code", s.verifyUserCode) handleFunc("/device/code", s.handleDeviceCode)