Use connector_id param to skip directly to a specific connector
This commit is contained in:
		@@ -233,6 +233,18 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Redirect if a client chooses a specific connector_id
 | 
				
			||||||
 | 
						if authReq.ConnectorID != "" {
 | 
				
			||||||
 | 
							for _, c := range connectors {
 | 
				
			||||||
 | 
								if c.ID == authReq.ConnectorID {
 | 
				
			||||||
 | 
									http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(connectors) == 1 {
 | 
						if len(connectors) == 1 {
 | 
				
			||||||
		for _, c := range connectors {
 | 
							for _, c := range connectors {
 | 
				
			||||||
			// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter
 | 
								// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -100,6 +100,7 @@ const (
 | 
				
			|||||||
	errUnsupportedGrantType    = "unsupported_grant_type"
 | 
						errUnsupportedGrantType    = "unsupported_grant_type"
 | 
				
			||||||
	errInvalidGrant            = "invalid_grant"
 | 
						errInvalidGrant            = "invalid_grant"
 | 
				
			||||||
	errInvalidClient           = "invalid_client"
 | 
						errInvalidClient           = "invalid_client"
 | 
				
			||||||
 | 
						errInvalidConnectorID      = "invalid_connector_id"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
@@ -391,6 +392,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
				
			|||||||
	clientID := q.Get("client_id")
 | 
						clientID := q.Get("client_id")
 | 
				
			||||||
	state := q.Get("state")
 | 
						state := q.Get("state")
 | 
				
			||||||
	nonce := q.Get("nonce")
 | 
						nonce := q.Get("nonce")
 | 
				
			||||||
 | 
						connectorID := q.Get("connector_id")
 | 
				
			||||||
	// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
 | 
						// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
 | 
				
			||||||
	scopes := strings.Fields(q.Get("scope"))
 | 
						scopes := strings.Fields(q.Get("scope"))
 | 
				
			||||||
	responseTypes := strings.Fields(q.Get("response_type"))
 | 
						responseTypes := strings.Fields(q.Get("response_type"))
 | 
				
			||||||
@@ -405,6 +407,16 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
				
			|||||||
		return req, &authErr{"", "", errServerError, ""}
 | 
							return req, &authErr{"", "", errServerError, ""}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if connectorID != "" {
 | 
				
			||||||
 | 
							connectors, err := s.storage.ListConnectors()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if !validateConnectorID(connectors, connectorID) {
 | 
				
			||||||
 | 
								return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !validateRedirectURI(client, redirectURI) {
 | 
						if !validateRedirectURI(client, redirectURI) {
 | 
				
			||||||
		description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
 | 
							description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
 | 
				
			||||||
		return req, &authErr{"", "", errInvalidRequest, description}
 | 
							return req, &authErr{"", "", errInvalidRequest, description}
 | 
				
			||||||
@@ -509,6 +521,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
				
			|||||||
		Scopes:              scopes,
 | 
							Scopes:              scopes,
 | 
				
			||||||
		RedirectURI:         redirectURI,
 | 
							RedirectURI:         redirectURI,
 | 
				
			||||||
		ResponseTypes:       responseTypes,
 | 
							ResponseTypes:       responseTypes,
 | 
				
			||||||
 | 
							ConnectorID:         connectorID,
 | 
				
			||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -568,6 +581,15 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
 | 
				
			|||||||
	return err == nil && host == "localhost"
 | 
						return err == nil && host == "localhost"
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func validateConnectorID(connectors []storage.Connector, connectorID string) bool {
 | 
				
			||||||
 | 
						for _, c := range connectors {
 | 
				
			||||||
 | 
							if c.ID == connectorID {
 | 
				
			||||||
 | 
								return true
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return false
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
 | 
					// storageKeySet implements the oidc.KeySet interface backed by Dex storage
 | 
				
			||||||
type storageKeySet struct {
 | 
					type storageKeySet struct {
 | 
				
			||||||
	storage.Storage
 | 
						storage.Storage
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,7 +10,7 @@ import (
 | 
				
			|||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	jose "gopkg.in/square/go-jose.v2"
 | 
						"gopkg.in/square/go-jose.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/dexidp/dex/storage"
 | 
						"github.com/dexidp/dex/storage"
 | 
				
			||||||
	"github.com/dexidp/dex/storage/memory"
 | 
						"github.com/dexidp/dex/storage/memory"
 | 
				
			||||||
@@ -145,6 +145,58 @@ func TestParseAuthorizationRequest(t *testing.T) {
 | 
				
			|||||||
			},
 | 
								},
 | 
				
			||||||
			wantErr: true,
 | 
								wantErr: true,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "choose connector_id",
 | 
				
			||||||
 | 
								clients: []storage.Client{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										ID:           "bar",
 | 
				
			||||||
 | 
										RedirectURIs: []string{"https://example.com/bar"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								supportedResponseTypes: []string{"code", "id_token", "token"},
 | 
				
			||||||
 | 
								queryParams: map[string]string{
 | 
				
			||||||
 | 
									"connector_id":  "mock",
 | 
				
			||||||
 | 
									"client_id":     "bar",
 | 
				
			||||||
 | 
									"redirect_uri":  "https://example.com/bar",
 | 
				
			||||||
 | 
									"response_type": "code id_token",
 | 
				
			||||||
 | 
									"scope":         "openid email profile",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "choose second connector_id",
 | 
				
			||||||
 | 
								clients: []storage.Client{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										ID:           "bar",
 | 
				
			||||||
 | 
										RedirectURIs: []string{"https://example.com/bar"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								supportedResponseTypes: []string{"code", "id_token", "token"},
 | 
				
			||||||
 | 
								queryParams: map[string]string{
 | 
				
			||||||
 | 
									"connector_id":  "mock2",
 | 
				
			||||||
 | 
									"client_id":     "bar",
 | 
				
			||||||
 | 
									"redirect_uri":  "https://example.com/bar",
 | 
				
			||||||
 | 
									"response_type": "code id_token",
 | 
				
			||||||
 | 
									"scope":         "openid email profile",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name: "choose invalid connector_id",
 | 
				
			||||||
 | 
								clients: []storage.Client{
 | 
				
			||||||
 | 
									{
 | 
				
			||||||
 | 
										ID:           "bar",
 | 
				
			||||||
 | 
										RedirectURIs: []string{"https://example.com/bar"},
 | 
				
			||||||
 | 
									},
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								supportedResponseTypes: []string{"code", "id_token", "token"},
 | 
				
			||||||
 | 
								queryParams: map[string]string{
 | 
				
			||||||
 | 
									"connector_id":  "bogus",
 | 
				
			||||||
 | 
									"client_id":     "bar",
 | 
				
			||||||
 | 
									"redirect_uri":  "https://example.com/bar",
 | 
				
			||||||
 | 
									"response_type": "code id_token",
 | 
				
			||||||
 | 
									"scope":         "openid email profile",
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
								wantErr: true,
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, tc := range tests {
 | 
						for _, tc := range tests {
 | 
				
			||||||
@@ -152,7 +204,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
 | 
				
			|||||||
			ctx, cancel := context.WithCancel(context.Background())
 | 
								ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
			defer cancel()
 | 
								defer cancel()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			httpServer, server := newTestServer(ctx, t, func(c *Config) {
 | 
								httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) {
 | 
				
			||||||
				c.SupportedResponseTypes = tc.supportedResponseTypes
 | 
									c.SupportedResponseTypes = tc.supportedResponseTypes
 | 
				
			||||||
				c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
 | 
									c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
@@ -162,7 +214,6 @@ func TestParseAuthorizationRequest(t *testing.T) {
 | 
				
			|||||||
			for k, v := range tc.queryParams {
 | 
								for k, v := range tc.queryParams {
 | 
				
			||||||
				params.Set(k, v)
 | 
									params.Set(k, v)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					 | 
				
			||||||
			var req *http.Request
 | 
								var req *http.Request
 | 
				
			||||||
			if tc.usePOST {
 | 
								if tc.usePOST {
 | 
				
			||||||
				body := strings.NewReader(params.Encode())
 | 
									body := strings.NewReader(params.Encode())
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -117,6 +117,53 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
 | 
				
			|||||||
	return s, server
 | 
						return s, server
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
 | 
				
			||||||
 | 
						var server *Server
 | 
				
			||||||
 | 
						s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
				
			||||||
 | 
							server.ServeHTTP(w, r)
 | 
				
			||||||
 | 
						}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config := Config{
 | 
				
			||||||
 | 
							Issuer:  s.URL,
 | 
				
			||||||
 | 
							Storage: memory.New(logger),
 | 
				
			||||||
 | 
							Web: WebConfig{
 | 
				
			||||||
 | 
								Dir: "../web",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							Logger:             logger,
 | 
				
			||||||
 | 
							PrometheusRegistry: prometheus.NewRegistry(),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if updateConfig != nil {
 | 
				
			||||||
 | 
							updateConfig(&config)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						s.URL = config.Issuer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						connector := storage.Connector{
 | 
				
			||||||
 | 
							ID:              "mock",
 | 
				
			||||||
 | 
							Type:            "mockCallback",
 | 
				
			||||||
 | 
							Name:            "Mock",
 | 
				
			||||||
 | 
							ResourceVersion: "1",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						connector2 := storage.Connector{
 | 
				
			||||||
 | 
							ID:              "mock2",
 | 
				
			||||||
 | 
							Type:            "mockCallback",
 | 
				
			||||||
 | 
							Name:            "Mock",
 | 
				
			||||||
 | 
							ResourceVersion: "1",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := config.Storage.CreateConnector(connector); err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("create connector: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := config.Storage.CreateConnector(connector2); err != nil {
 | 
				
			||||||
 | 
							t.Fatalf("create connector: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
 | 
				
			||||||
 | 
						return s, server
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestNewTestServer(t *testing.T) {
 | 
					func TestNewTestServer(t *testing.T) {
 | 
				
			||||||
	ctx, cancel := context.WithCancel(context.Background())
 | 
						ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
	defer cancel()
 | 
						defer cancel()
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user