fix: add an extra endpoint to avoid refresh generating AuthRequests.
By adding an extra endpoint and a redirect, we can avoid a situation where it's trivially easy to generate a large number of AuthRequests by hitting F5/refresh in the browser. Signed-off-by: Alastair Houghton <alastair@alastairs-place.net>
This commit is contained in:
		@@ -128,7 +128,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	if err := r.ParseForm(); err != nil {
 | 
						if err := r.ParseForm(); err != nil {
 | 
				
			||||||
		s.logger.Errorf("Failed to parse arguments: %v", err)
 | 
							s.logger.Errorf("Failed to parse arguments: %v", err)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		s.renderError(r, w, http.StatusBadRequest, "Bad query/form arguments")
 | 
							s.renderError(r, w, http.StatusBadRequest, err.Error())
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -141,6 +141,9 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// We don't need connector_id any more
 | 
				
			||||||
 | 
						r.Form.Del("connector_id")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// Construct a URL with all of the arguments in its query
 | 
						// Construct a URL with all of the arguments in its query
 | 
				
			||||||
	connURL := url.URL{
 | 
						connURL := url.URL{
 | 
				
			||||||
		RawQuery: r.Form.Encode(),
 | 
							RawQuery: r.Form.Encode(),
 | 
				
			||||||
@@ -160,11 +163,8 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(connectors) == 1 && !s.alwaysShowLogin {
 | 
						if len(connectors) == 1 && !s.alwaysShowLogin {
 | 
				
			||||||
		for _, c := range connectors {
 | 
							connURL.Path = s.absPath("/auth", connectors[0].ID)
 | 
				
			||||||
			connURL.Path = s.absPath("/auth", c.ID)
 | 
							http.Redirect(w, r, connURL.String(), http.StatusFound)
 | 
				
			||||||
			http.Redirect(w, r, connURL.String(), http.StatusFound)
 | 
					 | 
				
			||||||
			return
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	connectorInfos := make([]connectorInfo, len(connectors))
 | 
						connectorInfos := make([]connectorInfo, len(connectors))
 | 
				
			||||||
@@ -258,9 +258,15 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
			http.Redirect(w, r, callbackURL, http.StatusFound)
 | 
								http.Redirect(w, r, callbackURL, http.StatusFound)
 | 
				
			||||||
		case connector.PasswordConnector:
 | 
							case connector.PasswordConnector:
 | 
				
			||||||
			if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(conn), false, backLink); err != nil {
 | 
								loginURL := url.URL{
 | 
				
			||||||
				s.logger.Errorf("Server template error: %v", err)
 | 
									Path: s.absPath("/auth", connID, "login"),
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								q := loginURL.Query()
 | 
				
			||||||
 | 
								q.Set("state", authReq.ID)
 | 
				
			||||||
 | 
								q.Set("back", backLink)
 | 
				
			||||||
 | 
								loginURL.RawQuery = q.Encode()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								http.Redirect(w, r, loginURL.String(), http.StatusFound)
 | 
				
			||||||
		case connector.SAMLConnector:
 | 
							case connector.SAMLConnector:
 | 
				
			||||||
			action, value, err := conn.POSTData(scopes, authReq.ID)
 | 
								action, value, err := conn.POSTData(scopes, authReq.ID)
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
@@ -289,29 +295,75 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
		default:
 | 
							default:
 | 
				
			||||||
			s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
 | 
								s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	case http.MethodPost:
 | 
						default:
 | 
				
			||||||
		passwordConnector, ok := conn.Connector.(connector.PasswordConnector)
 | 
							s.renderError(r, w, http.StatusBadRequest, "Unsupported request method.")
 | 
				
			||||||
		if !ok {
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
						authID := r.URL.Query().Get("state")
 | 
				
			||||||
 | 
						if authID == "" {
 | 
				
			||||||
 | 
							s.renderError(r, w, http.StatusBadRequest, "User session error.")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						backLink := r.URL.Query().Get("back")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						authReq, err := s.storage.GetAuthRequest(authID)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							if err == storage.ErrNotFound {
 | 
				
			||||||
 | 
								s.logger.Errorf("Invalid 'state' parameter provided: %v", err)
 | 
				
			||||||
			s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
 | 
								s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							s.logger.Errorf("Failed to get auth request: %v", err)
 | 
				
			||||||
 | 
							s.renderError(r, w, http.StatusInternalServerError, "Database error.")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if connID := mux.Vars(r)["connector"]; connID != "" && connID != authReq.ConnectorID {
 | 
				
			||||||
 | 
							s.logger.Errorf("Connector mismatch: authentication started with id %q, but password login for id %q was triggered", authReq.ConnectorID, connID)
 | 
				
			||||||
 | 
							s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						conn, err := s.getConnector(authReq.ConnectorID)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							s.logger.Errorf("Failed to get connector with id %q : %v", authReq.ConnectorID, err)
 | 
				
			||||||
 | 
							s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						pwConn, ok := conn.Connector.(connector.PasswordConnector)
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							s.logger.Errorf("Expected password connector in handlePasswordLogin(), but got %v", pwConn)
 | 
				
			||||||
 | 
							s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						switch r.Method {
 | 
				
			||||||
 | 
						case http.MethodGet:
 | 
				
			||||||
 | 
							if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(pwConn), false, backLink); err != nil {
 | 
				
			||||||
 | 
								s.logger.Errorf("Server template error: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case http.MethodPost:
 | 
				
			||||||
		username := r.FormValue("login")
 | 
							username := r.FormValue("login")
 | 
				
			||||||
		password := r.FormValue("password")
 | 
							password := r.FormValue("password")
 | 
				
			||||||
 | 
							scopes := parseScopes(authReq.Scopes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password)
 | 
							identity, ok, err := pwConn.Login(r.Context(), scopes, username, password)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			s.logger.Errorf("Failed to login user: %v", err)
 | 
								s.logger.Errorf("Failed to login user: %v", err)
 | 
				
			||||||
			s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err))
 | 
								s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err))
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if !ok {
 | 
							if !ok {
 | 
				
			||||||
			if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(passwordConnector), true, backLink); err != nil {
 | 
								if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(pwConn), true, backLink); err != nil {
 | 
				
			||||||
				s.logger.Errorf("Server template error: %v", err)
 | 
									s.logger.Errorf("Server template error: %v", err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		redirectURL, err := s.finalizeLogin(identity, *authReq, conn.Connector)
 | 
							redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			s.logger.Errorf("Failed to finalize login: %v", err)
 | 
								s.logger.Errorf("Failed to finalize login: %v", err)
 | 
				
			||||||
			s.renderError(r, w, http.StatusInternalServerError, "Login error.")
 | 
								s.renderError(r, w, http.StatusInternalServerError, "Login error.")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,7 +7,6 @@ import (
 | 
				
			|||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httptest"
 | 
						"net/http/httptest"
 | 
				
			||||||
	"os"
 | 
					 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -341,6 +341,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
 | 
				
			|||||||
	handleWithCORS("/userinfo", s.handleUserInfo)
 | 
						handleWithCORS("/userinfo", s.handleUserInfo)
 | 
				
			||||||
	handleFunc("/auth", s.handleAuthorization)
 | 
						handleFunc("/auth", s.handleAuthorization)
 | 
				
			||||||
	handleFunc("/auth/{connector}", s.handleConnectorLogin)
 | 
						handleFunc("/auth/{connector}", s.handleConnectorLogin)
 | 
				
			||||||
 | 
						handleFunc("/auth/{connector}/login", s.handlePasswordLogin)
 | 
				
			||||||
	handleFunc("/device", s.handleDeviceExchange)
 | 
						handleFunc("/device", s.handleDeviceExchange)
 | 
				
			||||||
	handleFunc("/device/auth/verify_code", s.verifyUserCode)
 | 
						handleFunc("/device/auth/verify_code", s.verifyUserCode)
 | 
				
			||||||
	handleFunc("/device/code", s.handleDeviceCode)
 | 
						handleFunc("/device/code", s.handleDeviceCode)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user