Merge pull request #1708 from tkleczek/fix-overwriting-connector-in-authreq
abort connector login if connector was already set #1707
This commit is contained in:
		@@ -283,7 +283,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	connID := mux.Vars(r)["connector"]
 | 
						connID := mux.Vars(r)["connector"]
 | 
				
			||||||
	conn, err := s.getConnector(connID)
 | 
						conn, err := s.getConnector(connID)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		s.logger.Errorf("Failed to create authorization request: %v", err)
 | 
							s.logger.Errorf("Failed to get connector: %v", err)
 | 
				
			||||||
		s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
 | 
							s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -304,6 +304,9 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
	// Set the connector being used for the login.
 | 
						// Set the connector being used for the login.
 | 
				
			||||||
	if authReq.ConnectorID != connID {
 | 
						if authReq.ConnectorID != connID {
 | 
				
			||||||
		updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
 | 
							updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
 | 
				
			||||||
 | 
								if a.ConnectorID != "" {
 | 
				
			||||||
 | 
									return a, fmt.Errorf("connector is already set for this auth request")
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
			a.ConnectorID = connID
 | 
								a.ConnectorID = connID
 | 
				
			||||||
			return a, nil
 | 
								return a, nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,8 +8,12 @@ import (
 | 
				
			|||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"net/http/httptest"
 | 
						"net/http/httptest"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/gorilla/mux"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/dexidp/dex/storage"
 | 
						"github.com/dexidp/dex/storage"
 | 
				
			||||||
 | 
						"github.com/dexidp/dex/storage/memory"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestHandleHealth(t *testing.T) {
 | 
					func TestHandleHealth(t *testing.T) {
 | 
				
			||||||
@@ -119,3 +123,84 @@ func TestHandleInvalidSAMLCallbacks(t *testing.T) {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestConnectorLoginDoesNotAllowToChangeConnectorForAuthRequest(t *testing.T) {
 | 
				
			||||||
 | 
						memStorage := memory.New(logger)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						templates, err := loadTemplates(webConfig{}, "../web/templates")
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal("failed to load tempalates")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s := &Server{
 | 
				
			||||||
 | 
							storage:                memStorage,
 | 
				
			||||||
 | 
							logger:                 logger,
 | 
				
			||||||
 | 
							templates:              templates,
 | 
				
			||||||
 | 
							supportedResponseTypes: map[string]bool{"code": true},
 | 
				
			||||||
 | 
							now:                    time.Now,
 | 
				
			||||||
 | 
							connectors:             make(map[string]Connector),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						r := mux.NewRouter()
 | 
				
			||||||
 | 
						r.HandleFunc("/auth/{connector}", s.handleConnectorLogin)
 | 
				
			||||||
 | 
						s.mux = r
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						clientID := "clientID"
 | 
				
			||||||
 | 
						clientSecret := "secret"
 | 
				
			||||||
 | 
						redirectURL := "localhost:5555" + "/callback"
 | 
				
			||||||
 | 
						client := storage.Client{
 | 
				
			||||||
 | 
							ID:           clientID,
 | 
				
			||||||
 | 
							Secret:       clientSecret,
 | 
				
			||||||
 | 
							RedirectURIs: []string{redirectURL},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := memStorage.CreateClient(client); err != nil {
 | 
				
			||||||
 | 
							t.Fatal("failed to create client")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						createConnector := func(t *testing.T, id string) storage.Connector {
 | 
				
			||||||
 | 
							connector := storage.Connector{
 | 
				
			||||||
 | 
								ID:              id,
 | 
				
			||||||
 | 
								Type:            "mockCallback",
 | 
				
			||||||
 | 
								Name:            "Mock",
 | 
				
			||||||
 | 
								ResourceVersion: "1",
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if err := memStorage.CreateConnector(connector); err != nil {
 | 
				
			||||||
 | 
								t.Fatalf("failed to create connector %v", id)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							return connector
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						connector1 := createConnector(t, "mock1")
 | 
				
			||||||
 | 
						connector2 := createConnector(t, "mock2")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						authReq := storage.AuthRequest{
 | 
				
			||||||
 | 
							ID: storage.NewID(),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := memStorage.CreateAuthRequest(authReq); err != nil {
 | 
				
			||||||
 | 
							t.Fatal("failed to create auth request")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						createConnectorLoginRequest := func(connID string) *http.Request {
 | 
				
			||||||
 | 
							req := httptest.NewRequest("GET", "/auth/"+connID, nil)
 | 
				
			||||||
 | 
							q := req.URL.Query()
 | 
				
			||||||
 | 
							q.Add("req", authReq.ID)
 | 
				
			||||||
 | 
							q.Add("redirect_uri", redirectURL)
 | 
				
			||||||
 | 
							q.Add("scope", "openid")
 | 
				
			||||||
 | 
							q.Add("response_type", "code")
 | 
				
			||||||
 | 
							req.URL.RawQuery = q.Encode()
 | 
				
			||||||
 | 
							return req
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						recorder := httptest.NewRecorder()
 | 
				
			||||||
 | 
						s.ServeHTTP(recorder, createConnectorLoginRequest(connector1.ID))
 | 
				
			||||||
 | 
						if recorder.Code != 302 {
 | 
				
			||||||
 | 
							t.Fatal("failed to process request")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						recorder2 := httptest.NewRecorder()
 | 
				
			||||||
 | 
						s.ServeHTTP(recorder2, createConnectorLoginRequest(connector2.ID))
 | 
				
			||||||
 | 
						if recorder2.Code != 500 {
 | 
				
			||||||
 | 
							t.Error("attempt to overwrite connector on auth request should fail")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user