Merge pull request #1481 from LanceH/master
Added "connector_id" to skip straight to a connector (similar to when len(connector) is 1.
This commit is contained in:
		| @@ -233,6 +233,18 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { | |||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	// Redirect if a client chooses a specific connector_id | ||||||
|  | 	if authReq.ConnectorID != "" { | ||||||
|  | 		for _, c := range connectors { | ||||||
|  | 			if c.ID == authReq.ConnectorID { | ||||||
|  | 				http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound) | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if len(connectors) == 1 { | 	if len(connectors) == 1 { | ||||||
| 		for _, c := range connectors { | 		for _, c := range connectors { | ||||||
| 			// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter | 			// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter | ||||||
|   | |||||||
| @@ -100,6 +100,7 @@ const ( | |||||||
| 	errUnsupportedGrantType    = "unsupported_grant_type" | 	errUnsupportedGrantType    = "unsupported_grant_type" | ||||||
| 	errInvalidGrant            = "invalid_grant" | 	errInvalidGrant            = "invalid_grant" | ||||||
| 	errInvalidClient           = "invalid_client" | 	errInvalidClient           = "invalid_client" | ||||||
|  | 	errInvalidConnectorID      = "invalid_connector_id" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -391,6 +392,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq | |||||||
| 	clientID := q.Get("client_id") | 	clientID := q.Get("client_id") | ||||||
| 	state := q.Get("state") | 	state := q.Get("state") | ||||||
| 	nonce := q.Get("nonce") | 	nonce := q.Get("nonce") | ||||||
|  | 	connectorID := q.Get("connector_id") | ||||||
| 	// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this. | 	// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this. | ||||||
| 	scopes := strings.Fields(q.Get("scope")) | 	scopes := strings.Fields(q.Get("scope")) | ||||||
| 	responseTypes := strings.Fields(q.Get("response_type")) | 	responseTypes := strings.Fields(q.Get("response_type")) | ||||||
| @@ -405,6 +407,16 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq | |||||||
| 		return req, &authErr{"", "", errServerError, ""} | 		return req, &authErr{"", "", errServerError, ""} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if connectorID != "" { | ||||||
|  | 		connectors, err := s.storage.ListConnectors() | ||||||
|  | 		if err != nil { | ||||||
|  | 			return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"} | ||||||
|  | 		} | ||||||
|  | 		if !validateConnectorID(connectors, connectorID) { | ||||||
|  | 			return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if !validateRedirectURI(client, redirectURI) { | 	if !validateRedirectURI(client, redirectURI) { | ||||||
| 		description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) | 		description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI) | ||||||
| 		return req, &authErr{"", "", errInvalidRequest, description} | 		return req, &authErr{"", "", errInvalidRequest, description} | ||||||
| @@ -509,6 +521,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq | |||||||
| 		Scopes:              scopes, | 		Scopes:              scopes, | ||||||
| 		RedirectURI:         redirectURI, | 		RedirectURI:         redirectURI, | ||||||
| 		ResponseTypes:       responseTypes, | 		ResponseTypes:       responseTypes, | ||||||
|  | 		ConnectorID:         connectorID, | ||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -568,6 +581,15 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool { | |||||||
| 	return err == nil && host == "localhost" | 	return err == nil && host == "localhost" | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func validateConnectorID(connectors []storage.Connector, connectorID string) bool { | ||||||
|  | 	for _, c := range connectors { | ||||||
|  | 		if c.ID == connectorID { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
| // storageKeySet implements the oidc.KeySet interface backed by Dex storage | // storageKeySet implements the oidc.KeySet interface backed by Dex storage | ||||||
| type storageKeySet struct { | type storageKeySet struct { | ||||||
| 	storage.Storage | 	storage.Storage | ||||||
|   | |||||||
| @@ -10,7 +10,7 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
| 	jose "gopkg.in/square/go-jose.v2" | 	"gopkg.in/square/go-jose.v2" | ||||||
|  |  | ||||||
| 	"github.com/dexidp/dex/storage" | 	"github.com/dexidp/dex/storage" | ||||||
| 	"github.com/dexidp/dex/storage/memory" | 	"github.com/dexidp/dex/storage/memory" | ||||||
| @@ -145,6 +145,58 @@ func TestParseAuthorizationRequest(t *testing.T) { | |||||||
| 			}, | 			}, | ||||||
| 			wantErr: true, | 			wantErr: true, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "choose connector_id", | ||||||
|  | 			clients: []storage.Client{ | ||||||
|  | 				{ | ||||||
|  | 					ID:           "bar", | ||||||
|  | 					RedirectURIs: []string{"https://example.com/bar"}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			supportedResponseTypes: []string{"code", "id_token", "token"}, | ||||||
|  | 			queryParams: map[string]string{ | ||||||
|  | 				"connector_id":  "mock", | ||||||
|  | 				"client_id":     "bar", | ||||||
|  | 				"redirect_uri":  "https://example.com/bar", | ||||||
|  | 				"response_type": "code id_token", | ||||||
|  | 				"scope":         "openid email profile", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "choose second connector_id", | ||||||
|  | 			clients: []storage.Client{ | ||||||
|  | 				{ | ||||||
|  | 					ID:           "bar", | ||||||
|  | 					RedirectURIs: []string{"https://example.com/bar"}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			supportedResponseTypes: []string{"code", "id_token", "token"}, | ||||||
|  | 			queryParams: map[string]string{ | ||||||
|  | 				"connector_id":  "mock2", | ||||||
|  | 				"client_id":     "bar", | ||||||
|  | 				"redirect_uri":  "https://example.com/bar", | ||||||
|  | 				"response_type": "code id_token", | ||||||
|  | 				"scope":         "openid email profile", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "choose invalid connector_id", | ||||||
|  | 			clients: []storage.Client{ | ||||||
|  | 				{ | ||||||
|  | 					ID:           "bar", | ||||||
|  | 					RedirectURIs: []string{"https://example.com/bar"}, | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			supportedResponseTypes: []string{"code", "id_token", "token"}, | ||||||
|  | 			queryParams: map[string]string{ | ||||||
|  | 				"connector_id":  "bogus", | ||||||
|  | 				"client_id":     "bar", | ||||||
|  | 				"redirect_uri":  "https://example.com/bar", | ||||||
|  | 				"response_type": "code id_token", | ||||||
|  | 				"scope":         "openid email profile", | ||||||
|  | 			}, | ||||||
|  | 			wantErr: true, | ||||||
|  | 		}, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, tc := range tests { | 	for _, tc := range tests { | ||||||
| @@ -152,7 +204,7 @@ func TestParseAuthorizationRequest(t *testing.T) { | |||||||
| 			ctx, cancel := context.WithCancel(context.Background()) | 			ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 			defer cancel() | 			defer cancel() | ||||||
|  |  | ||||||
| 			httpServer, server := newTestServer(ctx, t, func(c *Config) { | 			httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) { | ||||||
| 				c.SupportedResponseTypes = tc.supportedResponseTypes | 				c.SupportedResponseTypes = tc.supportedResponseTypes | ||||||
| 				c.Storage = storage.WithStaticClients(c.Storage, tc.clients) | 				c.Storage = storage.WithStaticClients(c.Storage, tc.clients) | ||||||
| 			}) | 			}) | ||||||
| @@ -162,7 +214,6 @@ func TestParseAuthorizationRequest(t *testing.T) { | |||||||
| 			for k, v := range tc.queryParams { | 			for k, v := range tc.queryParams { | ||||||
| 				params.Set(k, v) | 				params.Set(k, v) | ||||||
| 			} | 			} | ||||||
|  |  | ||||||
| 			var req *http.Request | 			var req *http.Request | ||||||
| 			if tc.usePOST { | 			if tc.usePOST { | ||||||
| 				body := strings.NewReader(params.Encode()) | 				body := strings.NewReader(params.Encode()) | ||||||
|   | |||||||
| @@ -117,6 +117,53 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi | |||||||
| 	return s, server | 	return s, server | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { | ||||||
|  | 	var server *Server | ||||||
|  | 	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 		server.ServeHTTP(w, r) | ||||||
|  | 	})) | ||||||
|  |  | ||||||
|  | 	config := Config{ | ||||||
|  | 		Issuer:  s.URL, | ||||||
|  | 		Storage: memory.New(logger), | ||||||
|  | 		Web: WebConfig{ | ||||||
|  | 			Dir: "../web", | ||||||
|  | 		}, | ||||||
|  | 		Logger:             logger, | ||||||
|  | 		PrometheusRegistry: prometheus.NewRegistry(), | ||||||
|  | 	} | ||||||
|  | 	if updateConfig != nil { | ||||||
|  | 		updateConfig(&config) | ||||||
|  | 	} | ||||||
|  | 	s.URL = config.Issuer | ||||||
|  |  | ||||||
|  | 	connector := storage.Connector{ | ||||||
|  | 		ID:              "mock", | ||||||
|  | 		Type:            "mockCallback", | ||||||
|  | 		Name:            "Mock", | ||||||
|  | 		ResourceVersion: "1", | ||||||
|  | 	} | ||||||
|  | 	connector2 := storage.Connector{ | ||||||
|  | 		ID:              "mock2", | ||||||
|  | 		Type:            "mockCallback", | ||||||
|  | 		Name:            "Mock", | ||||||
|  | 		ResourceVersion: "1", | ||||||
|  | 	} | ||||||
|  | 	if err := config.Storage.CreateConnector(connector); err != nil { | ||||||
|  | 		t.Fatalf("create connector: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if err := config.Storage.CreateConnector(connector2); err != nil { | ||||||
|  | 		t.Fatalf("create connector: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var err error | ||||||
|  | 	if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. | ||||||
|  | 	return s, server | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestNewTestServer(t *testing.T) { | func TestNewTestServer(t *testing.T) { | ||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 	defer cancel() | 	defer cancel() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user