Merge pull request #1481 from LanceH/master
Added "connector_id" to skip straight to a connector (similar to when len(connector) is 1.
This commit is contained in:
		@@ -233,6 +233,18 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Redirect if a client chooses a specific connector_id
 | 
			
		||||
	if authReq.ConnectorID != "" {
 | 
			
		||||
		for _, c := range connectors {
 | 
			
		||||
			if c.ID == authReq.ConnectorID {
 | 
			
		||||
				http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(connectors) == 1 {
 | 
			
		||||
		for _, c := range connectors {
 | 
			
		||||
			// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter
 | 
			
		||||
 
 | 
			
		||||
@@ -100,6 +100,7 @@ const (
 | 
			
		||||
	errUnsupportedGrantType    = "unsupported_grant_type"
 | 
			
		||||
	errInvalidGrant            = "invalid_grant"
 | 
			
		||||
	errInvalidClient           = "invalid_client"
 | 
			
		||||
	errInvalidConnectorID      = "invalid_connector_id"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -391,6 +392,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
			
		||||
	clientID := q.Get("client_id")
 | 
			
		||||
	state := q.Get("state")
 | 
			
		||||
	nonce := q.Get("nonce")
 | 
			
		||||
	connectorID := q.Get("connector_id")
 | 
			
		||||
	// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
 | 
			
		||||
	scopes := strings.Fields(q.Get("scope"))
 | 
			
		||||
	responseTypes := strings.Fields(q.Get("response_type"))
 | 
			
		||||
@@ -405,6 +407,16 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
			
		||||
		return req, &authErr{"", "", errServerError, ""}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if connectorID != "" {
 | 
			
		||||
		connectors, err := s.storage.ListConnectors()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
 | 
			
		||||
		}
 | 
			
		||||
		if !validateConnectorID(connectors, connectorID) {
 | 
			
		||||
			return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !validateRedirectURI(client, redirectURI) {
 | 
			
		||||
		description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
 | 
			
		||||
		return req, &authErr{"", "", errInvalidRequest, description}
 | 
			
		||||
@@ -509,6 +521,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
 | 
			
		||||
		Scopes:              scopes,
 | 
			
		||||
		RedirectURI:         redirectURI,
 | 
			
		||||
		ResponseTypes:       responseTypes,
 | 
			
		||||
		ConnectorID:         connectorID,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -568,6 +581,15 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
 | 
			
		||||
	return err == nil && host == "localhost"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateConnectorID(connectors []storage.Connector, connectorID string) bool {
 | 
			
		||||
	for _, c := range connectors {
 | 
			
		||||
		if c.ID == connectorID {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
 | 
			
		||||
type storageKeySet struct {
 | 
			
		||||
	storage.Storage
 | 
			
		||||
 
 | 
			
		||||
@@ -10,7 +10,7 @@ import (
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	jose "gopkg.in/square/go-jose.v2"
 | 
			
		||||
	"gopkg.in/square/go-jose.v2"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/memory"
 | 
			
		||||
@@ -145,6 +145,58 @@ func TestParseAuthorizationRequest(t *testing.T) {
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "choose connector_id",
 | 
			
		||||
			clients: []storage.Client{
 | 
			
		||||
				{
 | 
			
		||||
					ID:           "bar",
 | 
			
		||||
					RedirectURIs: []string{"https://example.com/bar"},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			supportedResponseTypes: []string{"code", "id_token", "token"},
 | 
			
		||||
			queryParams: map[string]string{
 | 
			
		||||
				"connector_id":  "mock",
 | 
			
		||||
				"client_id":     "bar",
 | 
			
		||||
				"redirect_uri":  "https://example.com/bar",
 | 
			
		||||
				"response_type": "code id_token",
 | 
			
		||||
				"scope":         "openid email profile",
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "choose second connector_id",
 | 
			
		||||
			clients: []storage.Client{
 | 
			
		||||
				{
 | 
			
		||||
					ID:           "bar",
 | 
			
		||||
					RedirectURIs: []string{"https://example.com/bar"},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			supportedResponseTypes: []string{"code", "id_token", "token"},
 | 
			
		||||
			queryParams: map[string]string{
 | 
			
		||||
				"connector_id":  "mock2",
 | 
			
		||||
				"client_id":     "bar",
 | 
			
		||||
				"redirect_uri":  "https://example.com/bar",
 | 
			
		||||
				"response_type": "code id_token",
 | 
			
		||||
				"scope":         "openid email profile",
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "choose invalid connector_id",
 | 
			
		||||
			clients: []storage.Client{
 | 
			
		||||
				{
 | 
			
		||||
					ID:           "bar",
 | 
			
		||||
					RedirectURIs: []string{"https://example.com/bar"},
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			supportedResponseTypes: []string{"code", "id_token", "token"},
 | 
			
		||||
			queryParams: map[string]string{
 | 
			
		||||
				"connector_id":  "bogus",
 | 
			
		||||
				"client_id":     "bar",
 | 
			
		||||
				"redirect_uri":  "https://example.com/bar",
 | 
			
		||||
				"response_type": "code id_token",
 | 
			
		||||
				"scope":         "openid email profile",
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tc := range tests {
 | 
			
		||||
@@ -152,7 +204,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
 | 
			
		||||
			ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
			defer cancel()
 | 
			
		||||
 | 
			
		||||
			httpServer, server := newTestServer(ctx, t, func(c *Config) {
 | 
			
		||||
			httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) {
 | 
			
		||||
				c.SupportedResponseTypes = tc.supportedResponseTypes
 | 
			
		||||
				c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
 | 
			
		||||
			})
 | 
			
		||||
@@ -162,7 +214,6 @@ func TestParseAuthorizationRequest(t *testing.T) {
 | 
			
		||||
			for k, v := range tc.queryParams {
 | 
			
		||||
				params.Set(k, v)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var req *http.Request
 | 
			
		||||
			if tc.usePOST {
 | 
			
		||||
				body := strings.NewReader(params.Encode())
 | 
			
		||||
 
 | 
			
		||||
@@ -117,6 +117,53 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
 | 
			
		||||
	return s, server
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
 | 
			
		||||
	var server *Server
 | 
			
		||||
	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		server.ServeHTTP(w, r)
 | 
			
		||||
	}))
 | 
			
		||||
 | 
			
		||||
	config := Config{
 | 
			
		||||
		Issuer:  s.URL,
 | 
			
		||||
		Storage: memory.New(logger),
 | 
			
		||||
		Web: WebConfig{
 | 
			
		||||
			Dir: "../web",
 | 
			
		||||
		},
 | 
			
		||||
		Logger:             logger,
 | 
			
		||||
		PrometheusRegistry: prometheus.NewRegistry(),
 | 
			
		||||
	}
 | 
			
		||||
	if updateConfig != nil {
 | 
			
		||||
		updateConfig(&config)
 | 
			
		||||
	}
 | 
			
		||||
	s.URL = config.Issuer
 | 
			
		||||
 | 
			
		||||
	connector := storage.Connector{
 | 
			
		||||
		ID:              "mock",
 | 
			
		||||
		Type:            "mockCallback",
 | 
			
		||||
		Name:            "Mock",
 | 
			
		||||
		ResourceVersion: "1",
 | 
			
		||||
	}
 | 
			
		||||
	connector2 := storage.Connector{
 | 
			
		||||
		ID:              "mock2",
 | 
			
		||||
		Type:            "mockCallback",
 | 
			
		||||
		Name:            "Mock",
 | 
			
		||||
		ResourceVersion: "1",
 | 
			
		||||
	}
 | 
			
		||||
	if err := config.Storage.CreateConnector(connector); err != nil {
 | 
			
		||||
		t.Fatalf("create connector: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if err := config.Storage.CreateConnector(connector2); err != nil {
 | 
			
		||||
		t.Fatalf("create connector: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
	server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
 | 
			
		||||
	return s, server
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestNewTestServer(t *testing.T) {
 | 
			
		||||
	ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user