Allow CORS on discovery endpoint
This commit is contained in:
parent
4ddc5eb061
commit
b4c47910e4
@ -98,10 +98,11 @@ type OAuth2 struct {
|
|||||||
|
|
||||||
// Web is the config format for the HTTP server.
|
// Web is the config format for the HTTP server.
|
||||||
type Web struct {
|
type Web struct {
|
||||||
HTTP string `json:"http"`
|
HTTP string `json:"http"`
|
||||||
HTTPS string `json:"https"`
|
HTTPS string `json:"https"`
|
||||||
TLSCert string `json:"tlsCert"`
|
TLSCert string `json:"tlsCert"`
|
||||||
TLSKey string `json:"tlsKey"`
|
TLSKey string `json:"tlsKey"`
|
||||||
|
DiscoveryAllowedOrigins []string `json:"discoveryAllowedOrigins"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GRPC is the config for the gRPC API.
|
// GRPC is the config for the gRPC API.
|
||||||
|
@ -179,20 +179,24 @@ func serve(cmd *cobra.Command, args []string) error {
|
|||||||
if c.OAuth2.SkipApprovalScreen {
|
if c.OAuth2.SkipApprovalScreen {
|
||||||
logger.Infof("config skipping approval screen")
|
logger.Infof("config skipping approval screen")
|
||||||
}
|
}
|
||||||
|
if len(c.Web.DiscoveryAllowedOrigins) > 0 {
|
||||||
|
logger.Infof("config discovery allowed origins: %s", c.Web.DiscoveryAllowedOrigins)
|
||||||
|
}
|
||||||
|
|
||||||
// explicitly convert to UTC.
|
// explicitly convert to UTC.
|
||||||
now := func() time.Time { return time.Now().UTC() }
|
now := func() time.Time { return time.Now().UTC() }
|
||||||
|
|
||||||
serverConfig := server.Config{
|
serverConfig := server.Config{
|
||||||
SupportedResponseTypes: c.OAuth2.ResponseTypes,
|
SupportedResponseTypes: c.OAuth2.ResponseTypes,
|
||||||
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
|
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
|
||||||
Issuer: c.Issuer,
|
DiscoveryAllowedOrigins: c.Web.DiscoveryAllowedOrigins,
|
||||||
Connectors: connectors,
|
Issuer: c.Issuer,
|
||||||
Storage: s,
|
Connectors: connectors,
|
||||||
Web: c.Frontend,
|
Storage: s,
|
||||||
EnablePasswordDB: c.EnablePasswordDB,
|
Web: c.Frontend,
|
||||||
Logger: logger,
|
EnablePasswordDB: c.EnablePasswordDB,
|
||||||
Now: now,
|
Logger: logger,
|
||||||
|
Now: now,
|
||||||
}
|
}
|
||||||
if c.Expiry.SigningKeys != "" {
|
if c.Expiry.SigningKeys != "" {
|
||||||
signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)
|
signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)
|
||||||
|
2
glide.lock
generated
2
glide.lock
generated
@ -18,6 +18,8 @@ imports:
|
|||||||
- protoc-gen-go
|
- protoc-gen-go
|
||||||
- name: github.com/gorilla/context
|
- name: github.com/gorilla/context
|
||||||
version: aed02d124ae4a0e94fea4541c8effd05bf0c8296
|
version: aed02d124ae4a0e94fea4541c8effd05bf0c8296
|
||||||
|
- name: github.com/gorilla/handlers
|
||||||
|
version: 3a5767ca75ece5f7f1440b1d16975247f8d8b221
|
||||||
- name: github.com/gorilla/mux
|
- name: github.com/gorilla/mux
|
||||||
version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e
|
version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e
|
||||||
- name: github.com/gtank/cryptopasta
|
- name: github.com/gtank/cryptopasta
|
||||||
|
@ -53,6 +53,8 @@ import:
|
|||||||
version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e
|
version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e
|
||||||
- package: github.com/gorilla/context
|
- package: github.com/gorilla/context
|
||||||
version: aed02d124ae4a0e94fea4541c8effd05bf0c8296
|
version: aed02d124ae4a0e94fea4541c8effd05bf0c8296
|
||||||
|
- package: github.com/gorilla/handlers
|
||||||
|
version: 3a5767ca75ece5f7f1440b1d16975247f8d8b221
|
||||||
|
|
||||||
# Package with a bunch of sane crypto defaults. Consider just
|
# Package with a bunch of sane crypto defaults. Consider just
|
||||||
# copy the code (as recommended by the repo itself) instead
|
# copy the code (as recommended by the repo itself) instead
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/handlers"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
jose "gopkg.in/square/go-jose.v2"
|
jose "gopkg.in/square/go-jose.v2"
|
||||||
|
|
||||||
@ -101,7 +102,7 @@ type discovery struct {
|
|||||||
Claims []string `json:"claims_supported"`
|
Claims []string `json:"claims_supported"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
func (s *Server) discoveryHandler() (http.Handler, error) {
|
||||||
d := discovery{
|
d := discovery{
|
||||||
Issuer: s.issuerURL.String(),
|
Issuer: s.issuerURL.String(),
|
||||||
Auth: s.absURL("/auth"),
|
Auth: s.absURL("/auth"),
|
||||||
@ -127,11 +128,18 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
|
|||||||
return nil, fmt.Errorf("failed to marshal discovery data: %v", err)
|
return nil, fmt.Errorf("failed to marshal discovery data: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
var discoveryHandler http.Handler
|
||||||
|
discoveryHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||||
w.Write(data)
|
w.Write(data)
|
||||||
}, nil
|
})
|
||||||
|
if len(s.discoveryAllowedOrigins) > 0 {
|
||||||
|
corsOption := handlers.AllowedOrigins(s.discoveryAllowedOrigins)
|
||||||
|
discoveryHandler = handlers.CORS(corsOption)(discoveryHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
return discoveryHandler, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAuthorization handles the OAuth2 auth endpoint.
|
// handleAuthorization handles the OAuth2 auth endpoint.
|
||||||
|
@ -22,3 +22,61 @@ func TestHandleHealth(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var discoveryHandlerCORSTests = []struct {
|
||||||
|
DiscoveryAllowedOrigins []string
|
||||||
|
Origin string
|
||||||
|
ResponseAllowOrigin string //The expected response: same as Origin in case of valid CORS flow
|
||||||
|
}{
|
||||||
|
{nil, "http://foo.example", ""}, //Default behavior: cross origin requests not allowed
|
||||||
|
{[]string{}, "http://foo.example", ""},
|
||||||
|
{[]string{"http://foo.example"}, "http://foo.example", "http://foo.example"},
|
||||||
|
{[]string{"http://bar.example", "http://foo.example"}, "http://foo.example", "http://foo.example"},
|
||||||
|
{[]string{"*"}, "http://foo.example", "http://foo.example"},
|
||||||
|
{[]string{"http://bar.example"}, "http://foo.example", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscoveryHandlerCORS(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for _, testcase := range discoveryHandlerCORSTests {
|
||||||
|
|
||||||
|
httpServer, server := newTestServer(ctx, t, func(c *Config) {
|
||||||
|
c.DiscoveryAllowedOrigins = testcase.DiscoveryAllowedOrigins
|
||||||
|
})
|
||||||
|
defer httpServer.Close()
|
||||||
|
|
||||||
|
discoveryHandler, err := server.discoveryHandler()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get discovery handler: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Perform preflight request
|
||||||
|
rrPreflight := httptest.NewRecorder()
|
||||||
|
reqPreflight := httptest.NewRequest("OPTIONS", "/.well-kown/openid-configuration", nil)
|
||||||
|
reqPreflight.Header.Set("Origin", testcase.Origin)
|
||||||
|
reqPreflight.Header.Set("Access-Control-Request-Method", "GET")
|
||||||
|
discoveryHandler.ServeHTTP(rrPreflight, reqPreflight)
|
||||||
|
if rrPreflight.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200 got %d", rrPreflight.Code)
|
||||||
|
}
|
||||||
|
headerAccessControlPreflight := rrPreflight.HeaderMap.Get("Access-Control-Allow-Origin")
|
||||||
|
if headerAccessControlPreflight != testcase.ResponseAllowOrigin {
|
||||||
|
t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControlPreflight)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Perform request
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/.well-kown/openid-configuration", nil)
|
||||||
|
req.Header.Set("Origin", testcase.Origin)
|
||||||
|
discoveryHandler.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected 200 got %d", rr.Code)
|
||||||
|
}
|
||||||
|
headerAccessControl := rr.HeaderMap.Get("Access-Control-Allow-Origin")
|
||||||
|
if headerAccessControl != testcase.ResponseAllowOrigin {
|
||||||
|
t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -42,6 +42,11 @@ type Config struct {
|
|||||||
// flow. If no response types are supplied this value defaults to "code".
|
// flow. If no response types are supplied this value defaults to "code".
|
||||||
SupportedResponseTypes []string
|
SupportedResponseTypes []string
|
||||||
|
|
||||||
|
// List of allowed origins for CORS requests on discovery endpoint.
|
||||||
|
// If none are indicated, CORS requests are disabled. Passing in "*" will allow any
|
||||||
|
// domain.
|
||||||
|
DiscoveryAllowedOrigins []string
|
||||||
|
|
||||||
// If enabled, the server won't prompt the user to approve authorization requests.
|
// If enabled, the server won't prompt the user to approve authorization requests.
|
||||||
// Logging in implies approval.
|
// Logging in implies approval.
|
||||||
SkipApprovalScreen bool
|
SkipApprovalScreen bool
|
||||||
@ -111,6 +116,8 @@ type Server struct {
|
|||||||
|
|
||||||
supportedResponseTypes map[string]bool
|
supportedResponseTypes map[string]bool
|
||||||
|
|
||||||
|
discoveryAllowedOrigins []string
|
||||||
|
|
||||||
now func() time.Time
|
now func() time.Time
|
||||||
|
|
||||||
idTokensValidFor time.Duration
|
idTokensValidFor time.Duration
|
||||||
@ -178,15 +185,16 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||||||
}
|
}
|
||||||
|
|
||||||
s := &Server{
|
s := &Server{
|
||||||
issuerURL: *issuerURL,
|
issuerURL: *issuerURL,
|
||||||
connectors: make(map[string]Connector),
|
connectors: make(map[string]Connector),
|
||||||
storage: newKeyCacher(c.Storage, now),
|
storage: newKeyCacher(c.Storage, now),
|
||||||
supportedResponseTypes: supported,
|
supportedResponseTypes: supported,
|
||||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
discoveryAllowedOrigins: c.DiscoveryAllowedOrigins,
|
||||||
skipApproval: c.SkipApprovalScreen,
|
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||||
now: now,
|
skipApproval: c.SkipApprovalScreen,
|
||||||
templates: tmpls,
|
now: now,
|
||||||
logger: c.Logger,
|
templates: tmpls,
|
||||||
|
logger: c.Logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, conn := range c.Connectors {
|
for _, conn := range c.Connectors {
|
||||||
@ -197,6 +205,9 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||||||
handleFunc := func(p string, h http.HandlerFunc) {
|
handleFunc := func(p string, h http.HandlerFunc) {
|
||||||
r.HandleFunc(path.Join(issuerURL.Path, p), h)
|
r.HandleFunc(path.Join(issuerURL.Path, p), h)
|
||||||
}
|
}
|
||||||
|
handle := func(p string, h http.Handler) {
|
||||||
|
r.Handle(path.Join(issuerURL.Path, p), h)
|
||||||
|
}
|
||||||
handlePrefix := func(p string, h http.Handler) {
|
handlePrefix := func(p string, h http.Handler) {
|
||||||
prefix := path.Join(issuerURL.Path, p)
|
prefix := path.Join(issuerURL.Path, p)
|
||||||
r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h))
|
r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h))
|
||||||
@ -207,7 +218,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
handleFunc("/.well-known/openid-configuration", discoveryHandler)
|
handle("/.well-known/openid-configuration", discoveryHandler)
|
||||||
|
|
||||||
// TODO(ericchiang): rate limit certain paths based on IP.
|
// TODO(ericchiang): rate limit certain paths based on IP.
|
||||||
handleFunc("/token", s.handleToken)
|
handleFunc("/token", s.handleToken)
|
||||||
|
Reference in New Issue
Block a user