diff --git a/cmd/dex/config.go b/cmd/dex/config.go index a2a653a7..40a4fe61 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -98,10 +98,11 @@ type OAuth2 struct { // Web is the config format for the HTTP server. type Web struct { - HTTP string `json:"http"` - HTTPS string `json:"https"` - TLSCert string `json:"tlsCert"` - TLSKey string `json:"tlsKey"` + HTTP string `json:"http"` + HTTPS string `json:"https"` + TLSCert string `json:"tlsCert"` + TLSKey string `json:"tlsKey"` + DiscoveryAllowedOrigins []string `json:"discoveryAllowedOrigins"` } // GRPC is the config for the gRPC API. diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 640b122a..2dbc7fe0 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -179,20 +179,24 @@ func serve(cmd *cobra.Command, args []string) error { if c.OAuth2.SkipApprovalScreen { logger.Infof("config skipping approval screen") } + if len(c.Web.DiscoveryAllowedOrigins) > 0 { + logger.Infof("config discovery allowed origins: %s", c.Web.DiscoveryAllowedOrigins) + } // explicitly convert to UTC. now := func() time.Time { return time.Now().UTC() } serverConfig := server.Config{ - SupportedResponseTypes: c.OAuth2.ResponseTypes, - SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, - Issuer: c.Issuer, - Connectors: connectors, - Storage: s, - Web: c.Frontend, - EnablePasswordDB: c.EnablePasswordDB, - Logger: logger, - Now: now, + SupportedResponseTypes: c.OAuth2.ResponseTypes, + SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, + DiscoveryAllowedOrigins: c.Web.DiscoveryAllowedOrigins, + Issuer: c.Issuer, + Connectors: connectors, + Storage: s, + Web: c.Frontend, + EnablePasswordDB: c.EnablePasswordDB, + Logger: logger, + Now: now, } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) diff --git a/glide.lock b/glide.lock index a4741bda..95a9ab21 100644 --- a/glide.lock +++ b/glide.lock @@ -18,6 +18,8 @@ imports: - protoc-gen-go - name: github.com/gorilla/context version: aed02d124ae4a0e94fea4541c8effd05bf0c8296 +- name: github.com/gorilla/handlers + version: 3a5767ca75ece5f7f1440b1d16975247f8d8b221 - name: github.com/gorilla/mux version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e - name: github.com/gtank/cryptopasta diff --git a/glide.yaml b/glide.yaml index 650bd8bf..05d6ec42 100644 --- a/glide.yaml +++ b/glide.yaml @@ -53,6 +53,8 @@ import: version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e - package: github.com/gorilla/context version: aed02d124ae4a0e94fea4541c8effd05bf0c8296 +- package: github.com/gorilla/handlers + version: 3a5767ca75ece5f7f1440b1d16975247f8d8b221 # Package with a bunch of sane crypto defaults. Consider just # copy the code (as recommended by the repo itself) instead diff --git a/server/handlers.go b/server/handlers.go index 832c262b..c962265f 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/gorilla/handlers" "github.com/gorilla/mux" jose "gopkg.in/square/go-jose.v2" @@ -101,7 +102,7 @@ type discovery struct { Claims []string `json:"claims_supported"` } -func (s *Server) discoveryHandler() (http.HandlerFunc, error) { +func (s *Server) discoveryHandler() (http.Handler, error) { d := discovery{ Issuer: s.issuerURL.String(), Auth: s.absURL("/auth"), @@ -127,11 +128,18 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) { return nil, fmt.Errorf("failed to marshal discovery data: %v", err) } - return func(w http.ResponseWriter, r *http.Request) { + var discoveryHandler http.Handler + discoveryHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Length", strconv.Itoa(len(data))) w.Write(data) - }, nil + }) + if len(s.discoveryAllowedOrigins) > 0 { + corsOption := handlers.AllowedOrigins(s.discoveryAllowedOrigins) + discoveryHandler = handlers.CORS(corsOption)(discoveryHandler) + } + + return discoveryHandler, nil } // handleAuthorization handles the OAuth2 auth endpoint. diff --git a/server/handlers_test.go b/server/handlers_test.go index 233af279..6470a54c 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -22,3 +22,61 @@ func TestHandleHealth(t *testing.T) { } } + +var discoveryHandlerCORSTests = []struct { + DiscoveryAllowedOrigins []string + Origin string + ResponseAllowOrigin string //The expected response: same as Origin in case of valid CORS flow +}{ + {nil, "http://foo.example", ""}, //Default behavior: cross origin requests not allowed + {[]string{}, "http://foo.example", ""}, + {[]string{"http://foo.example"}, "http://foo.example", "http://foo.example"}, + {[]string{"http://bar.example", "http://foo.example"}, "http://foo.example", "http://foo.example"}, + {[]string{"*"}, "http://foo.example", "http://foo.example"}, + {[]string{"http://bar.example"}, "http://foo.example", ""}, +} + +func TestDiscoveryHandlerCORS(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for _, testcase := range discoveryHandlerCORSTests { + + httpServer, server := newTestServer(ctx, t, func(c *Config) { + c.DiscoveryAllowedOrigins = testcase.DiscoveryAllowedOrigins + }) + defer httpServer.Close() + + discoveryHandler, err := server.discoveryHandler() + if err != nil { + t.Fatalf("failed to get discovery handler: %v", err) + } + + //Perform preflight request + rrPreflight := httptest.NewRecorder() + reqPreflight := httptest.NewRequest("OPTIONS", "/.well-kown/openid-configuration", nil) + reqPreflight.Header.Set("Origin", testcase.Origin) + reqPreflight.Header.Set("Access-Control-Request-Method", "GET") + discoveryHandler.ServeHTTP(rrPreflight, reqPreflight) + if rrPreflight.Code != http.StatusOK { + t.Errorf("expected 200 got %d", rrPreflight.Code) + } + headerAccessControlPreflight := rrPreflight.HeaderMap.Get("Access-Control-Allow-Origin") + if headerAccessControlPreflight != testcase.ResponseAllowOrigin { + t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControlPreflight) + } + + //Perform request + rr := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/.well-kown/openid-configuration", nil) + req.Header.Set("Origin", testcase.Origin) + discoveryHandler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Errorf("expected 200 got %d", rr.Code) + } + headerAccessControl := rr.HeaderMap.Get("Access-Control-Allow-Origin") + if headerAccessControl != testcase.ResponseAllowOrigin { + t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControl) + } + } +} diff --git a/server/server.go b/server/server.go index 0c292d13..535e21be 100644 --- a/server/server.go +++ b/server/server.go @@ -42,6 +42,11 @@ type Config struct { // flow. If no response types are supplied this value defaults to "code". SupportedResponseTypes []string + // List of allowed origins for CORS requests on discovery endpoint. + // If none are indicated, CORS requests are disabled. Passing in "*" will allow any + // domain. + DiscoveryAllowedOrigins []string + // If enabled, the server won't prompt the user to approve authorization requests. // Logging in implies approval. SkipApprovalScreen bool @@ -111,6 +116,8 @@ type Server struct { supportedResponseTypes map[string]bool + discoveryAllowedOrigins []string + now func() time.Time idTokensValidFor time.Duration @@ -178,15 +185,16 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) } s := &Server{ - issuerURL: *issuerURL, - connectors: make(map[string]Connector), - storage: newKeyCacher(c.Storage, now), - supportedResponseTypes: supported, - idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), - skipApproval: c.SkipApprovalScreen, - now: now, - templates: tmpls, - logger: c.Logger, + issuerURL: *issuerURL, + connectors: make(map[string]Connector), + storage: newKeyCacher(c.Storage, now), + supportedResponseTypes: supported, + discoveryAllowedOrigins: c.DiscoveryAllowedOrigins, + idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), + skipApproval: c.SkipApprovalScreen, + now: now, + templates: tmpls, + logger: c.Logger, } for _, conn := range c.Connectors { @@ -197,6 +205,9 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc := func(p string, h http.HandlerFunc) { r.HandleFunc(path.Join(issuerURL.Path, p), h) } + handle := func(p string, h http.Handler) { + r.Handle(path.Join(issuerURL.Path, p), h) + } handlePrefix := func(p string, h http.Handler) { prefix := path.Join(issuerURL.Path, p) r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h)) @@ -207,7 +218,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) if err != nil { return nil, err } - handleFunc("/.well-known/openid-configuration", discoveryHandler) + handle("/.well-known/openid-configuration", discoveryHandler) // TODO(ericchiang): rate limit certain paths based on IP. handleFunc("/token", s.handleToken)