Allow CORS on discovery endpoint
This commit is contained in:
		| @@ -98,10 +98,11 @@ type OAuth2 struct { | ||||
|  | ||||
| // Web is the config format for the HTTP server. | ||||
| type Web struct { | ||||
| 	HTTP    string `json:"http"` | ||||
| 	HTTPS   string `json:"https"` | ||||
| 	TLSCert string `json:"tlsCert"` | ||||
| 	TLSKey  string `json:"tlsKey"` | ||||
| 	HTTP                    string   `json:"http"` | ||||
| 	HTTPS                   string   `json:"https"` | ||||
| 	TLSCert                 string   `json:"tlsCert"` | ||||
| 	TLSKey                  string   `json:"tlsKey"` | ||||
| 	DiscoveryAllowedOrigins []string `json:"discoveryAllowedOrigins"` | ||||
| } | ||||
|  | ||||
| // GRPC is the config for the gRPC API. | ||||
|   | ||||
| @@ -179,20 +179,24 @@ func serve(cmd *cobra.Command, args []string) error { | ||||
| 	if c.OAuth2.SkipApprovalScreen { | ||||
| 		logger.Infof("config skipping approval screen") | ||||
| 	} | ||||
| 	if len(c.Web.DiscoveryAllowedOrigins) > 0 { | ||||
| 		logger.Infof("config discovery allowed origins: %s", c.Web.DiscoveryAllowedOrigins) | ||||
| 	} | ||||
|  | ||||
| 	// explicitly convert to UTC. | ||||
| 	now := func() time.Time { return time.Now().UTC() } | ||||
|  | ||||
| 	serverConfig := server.Config{ | ||||
| 		SupportedResponseTypes: c.OAuth2.ResponseTypes, | ||||
| 		SkipApprovalScreen:     c.OAuth2.SkipApprovalScreen, | ||||
| 		Issuer:                 c.Issuer, | ||||
| 		Connectors:             connectors, | ||||
| 		Storage:                s, | ||||
| 		Web:                    c.Frontend, | ||||
| 		EnablePasswordDB:       c.EnablePasswordDB, | ||||
| 		Logger:                 logger, | ||||
| 		Now:                    now, | ||||
| 		SupportedResponseTypes:  c.OAuth2.ResponseTypes, | ||||
| 		SkipApprovalScreen:      c.OAuth2.SkipApprovalScreen, | ||||
| 		DiscoveryAllowedOrigins: c.Web.DiscoveryAllowedOrigins, | ||||
| 		Issuer:                  c.Issuer, | ||||
| 		Connectors:              connectors, | ||||
| 		Storage:                 s, | ||||
| 		Web:                     c.Frontend, | ||||
| 		EnablePasswordDB:        c.EnablePasswordDB, | ||||
| 		Logger:                  logger, | ||||
| 		Now:                     now, | ||||
| 	} | ||||
| 	if c.Expiry.SigningKeys != "" { | ||||
| 		signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) | ||||
|   | ||||
							
								
								
									
										2
									
								
								glide.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										2
									
								
								glide.lock
									
									
									
										generated
									
									
									
								
							| @@ -18,6 +18,8 @@ imports: | ||||
|   - protoc-gen-go | ||||
| - name: github.com/gorilla/context | ||||
|   version: aed02d124ae4a0e94fea4541c8effd05bf0c8296 | ||||
| - name: github.com/gorilla/handlers | ||||
|   version: 3a5767ca75ece5f7f1440b1d16975247f8d8b221 | ||||
| - name: github.com/gorilla/mux | ||||
|   version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e | ||||
| - name: github.com/gtank/cryptopasta | ||||
|   | ||||
| @@ -53,6 +53,8 @@ import: | ||||
|   version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e | ||||
| - package: github.com/gorilla/context | ||||
|   version: aed02d124ae4a0e94fea4541c8effd05bf0c8296 | ||||
| - package: github.com/gorilla/handlers | ||||
|   version: 3a5767ca75ece5f7f1440b1d16975247f8d8b221 | ||||
|  | ||||
| # Package with a bunch of sane crypto defaults. Consider just | ||||
| # copy the code (as recommended by the repo itself) instead | ||||
|   | ||||
| @@ -11,6 +11,7 @@ import ( | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gorilla/handlers" | ||||
| 	"github.com/gorilla/mux" | ||||
| 	jose "gopkg.in/square/go-jose.v2" | ||||
|  | ||||
| @@ -101,7 +102,7 @@ type discovery struct { | ||||
| 	Claims        []string `json:"claims_supported"` | ||||
| } | ||||
|  | ||||
| func (s *Server) discoveryHandler() (http.HandlerFunc, error) { | ||||
| func (s *Server) discoveryHandler() (http.Handler, error) { | ||||
| 	d := discovery{ | ||||
| 		Issuer:      s.issuerURL.String(), | ||||
| 		Auth:        s.absURL("/auth"), | ||||
| @@ -127,11 +128,18 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) { | ||||
| 		return nil, fmt.Errorf("failed to marshal discovery data: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return func(w http.ResponseWriter, r *http.Request) { | ||||
| 	var discoveryHandler http.Handler | ||||
| 	discoveryHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.Header().Set("Content-Type", "application/json") | ||||
| 		w.Header().Set("Content-Length", strconv.Itoa(len(data))) | ||||
| 		w.Write(data) | ||||
| 	}, nil | ||||
| 	}) | ||||
| 	if len(s.discoveryAllowedOrigins) > 0 { | ||||
| 		corsOption := handlers.AllowedOrigins(s.discoveryAllowedOrigins) | ||||
| 		discoveryHandler = handlers.CORS(corsOption)(discoveryHandler) | ||||
| 	} | ||||
|  | ||||
| 	return discoveryHandler, nil | ||||
| } | ||||
|  | ||||
| // handleAuthorization handles the OAuth2 auth endpoint. | ||||
|   | ||||
| @@ -22,3 +22,61 @@ func TestHandleHealth(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| } | ||||
|  | ||||
| var discoveryHandlerCORSTests = []struct { | ||||
| 	DiscoveryAllowedOrigins []string | ||||
| 	Origin                  string | ||||
| 	ResponseAllowOrigin     string //The expected response: same as Origin in case of valid CORS flow | ||||
| }{ | ||||
| 	{nil, "http://foo.example", ""}, //Default behavior: cross origin requests not allowed | ||||
| 	{[]string{}, "http://foo.example", ""}, | ||||
| 	{[]string{"http://foo.example"}, "http://foo.example", "http://foo.example"}, | ||||
| 	{[]string{"http://bar.example", "http://foo.example"}, "http://foo.example", "http://foo.example"}, | ||||
| 	{[]string{"*"}, "http://foo.example", "http://foo.example"}, | ||||
| 	{[]string{"http://bar.example"}, "http://foo.example", ""}, | ||||
| } | ||||
|  | ||||
| func TestDiscoveryHandlerCORS(t *testing.T) { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	for _, testcase := range discoveryHandlerCORSTests { | ||||
|  | ||||
| 		httpServer, server := newTestServer(ctx, t, func(c *Config) { | ||||
| 			c.DiscoveryAllowedOrigins = testcase.DiscoveryAllowedOrigins | ||||
| 		}) | ||||
| 		defer httpServer.Close() | ||||
|  | ||||
| 		discoveryHandler, err := server.discoveryHandler() | ||||
| 		if err != nil { | ||||
| 			t.Fatalf("failed to get discovery handler: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		//Perform preflight request | ||||
| 		rrPreflight := httptest.NewRecorder() | ||||
| 		reqPreflight := httptest.NewRequest("OPTIONS", "/.well-kown/openid-configuration", nil) | ||||
| 		reqPreflight.Header.Set("Origin", testcase.Origin) | ||||
| 		reqPreflight.Header.Set("Access-Control-Request-Method", "GET") | ||||
| 		discoveryHandler.ServeHTTP(rrPreflight, reqPreflight) | ||||
| 		if rrPreflight.Code != http.StatusOK { | ||||
| 			t.Errorf("expected 200 got %d", rrPreflight.Code) | ||||
| 		} | ||||
| 		headerAccessControlPreflight := rrPreflight.HeaderMap.Get("Access-Control-Allow-Origin") | ||||
| 		if headerAccessControlPreflight != testcase.ResponseAllowOrigin { | ||||
| 			t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControlPreflight) | ||||
| 		} | ||||
|  | ||||
| 		//Perform request | ||||
| 		rr := httptest.NewRecorder() | ||||
| 		req := httptest.NewRequest("GET", "/.well-kown/openid-configuration", nil) | ||||
| 		req.Header.Set("Origin", testcase.Origin) | ||||
| 		discoveryHandler.ServeHTTP(rr, req) | ||||
| 		if rr.Code != http.StatusOK { | ||||
| 			t.Errorf("expected 200 got %d", rr.Code) | ||||
| 		} | ||||
| 		headerAccessControl := rr.HeaderMap.Get("Access-Control-Allow-Origin") | ||||
| 		if headerAccessControl != testcase.ResponseAllowOrigin { | ||||
| 			t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControl) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -42,6 +42,11 @@ type Config struct { | ||||
| 	// flow. If no response types are supplied this value defaults to "code". | ||||
| 	SupportedResponseTypes []string | ||||
|  | ||||
| 	// List of allowed origins for CORS requests on discovery endpoint. | ||||
| 	// If none are indicated, CORS requests are disabled. Passing in "*" will allow any | ||||
| 	// domain. | ||||
| 	DiscoveryAllowedOrigins []string | ||||
|  | ||||
| 	// If enabled, the server won't prompt the user to approve authorization requests. | ||||
| 	// Logging in implies approval. | ||||
| 	SkipApprovalScreen bool | ||||
| @@ -111,6 +116,8 @@ type Server struct { | ||||
|  | ||||
| 	supportedResponseTypes map[string]bool | ||||
|  | ||||
| 	discoveryAllowedOrigins []string | ||||
|  | ||||
| 	now func() time.Time | ||||
|  | ||||
| 	idTokensValidFor time.Duration | ||||
| @@ -178,15 +185,16 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) | ||||
| 	} | ||||
|  | ||||
| 	s := &Server{ | ||||
| 		issuerURL:              *issuerURL, | ||||
| 		connectors:             make(map[string]Connector), | ||||
| 		storage:                newKeyCacher(c.Storage, now), | ||||
| 		supportedResponseTypes: supported, | ||||
| 		idTokensValidFor:       value(c.IDTokensValidFor, 24*time.Hour), | ||||
| 		skipApproval:           c.SkipApprovalScreen, | ||||
| 		now:                    now, | ||||
| 		templates:              tmpls, | ||||
| 		logger:                 c.Logger, | ||||
| 		issuerURL:               *issuerURL, | ||||
| 		connectors:              make(map[string]Connector), | ||||
| 		storage:                 newKeyCacher(c.Storage, now), | ||||
| 		supportedResponseTypes:  supported, | ||||
| 		discoveryAllowedOrigins: c.DiscoveryAllowedOrigins, | ||||
| 		idTokensValidFor:        value(c.IDTokensValidFor, 24*time.Hour), | ||||
| 		skipApproval:            c.SkipApprovalScreen, | ||||
| 		now:                     now, | ||||
| 		templates:               tmpls, | ||||
| 		logger:                  c.Logger, | ||||
| 	} | ||||
|  | ||||
| 	for _, conn := range c.Connectors { | ||||
| @@ -197,6 +205,9 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) | ||||
| 	handleFunc := func(p string, h http.HandlerFunc) { | ||||
| 		r.HandleFunc(path.Join(issuerURL.Path, p), h) | ||||
| 	} | ||||
| 	handle := func(p string, h http.Handler) { | ||||
| 		r.Handle(path.Join(issuerURL.Path, p), h) | ||||
| 	} | ||||
| 	handlePrefix := func(p string, h http.Handler) { | ||||
| 		prefix := path.Join(issuerURL.Path, p) | ||||
| 		r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h)) | ||||
| @@ -207,7 +218,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	handleFunc("/.well-known/openid-configuration", discoveryHandler) | ||||
| 	handle("/.well-known/openid-configuration", discoveryHandler) | ||||
|  | ||||
| 	// TODO(ericchiang): rate limit certain paths based on IP. | ||||
| 	handleFunc("/token", s.handleToken) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user