Merge pull request #760 from xeonx/master
Allow CORS on discovery endpoint
This commit is contained in:
		@@ -98,10 +98,11 @@ type OAuth2 struct {
 | 
			
		||||
 | 
			
		||||
// Web is the config format for the HTTP server.
 | 
			
		||||
type Web struct {
 | 
			
		||||
	HTTP    string `json:"http"`
 | 
			
		||||
	HTTPS   string `json:"https"`
 | 
			
		||||
	TLSCert string `json:"tlsCert"`
 | 
			
		||||
	TLSKey  string `json:"tlsKey"`
 | 
			
		||||
	HTTP                    string   `json:"http"`
 | 
			
		||||
	HTTPS                   string   `json:"https"`
 | 
			
		||||
	TLSCert                 string   `json:"tlsCert"`
 | 
			
		||||
	TLSKey                  string   `json:"tlsKey"`
 | 
			
		||||
	DiscoveryAllowedOrigins []string `json:"discoveryAllowedOrigins"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GRPC is the config for the gRPC API.
 | 
			
		||||
 
 | 
			
		||||
@@ -179,20 +179,24 @@ func serve(cmd *cobra.Command, args []string) error {
 | 
			
		||||
	if c.OAuth2.SkipApprovalScreen {
 | 
			
		||||
		logger.Infof("config skipping approval screen")
 | 
			
		||||
	}
 | 
			
		||||
	if len(c.Web.DiscoveryAllowedOrigins) > 0 {
 | 
			
		||||
		logger.Infof("config discovery allowed origins: %s", c.Web.DiscoveryAllowedOrigins)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// explicitly convert to UTC.
 | 
			
		||||
	now := func() time.Time { return time.Now().UTC() }
 | 
			
		||||
 | 
			
		||||
	serverConfig := server.Config{
 | 
			
		||||
		SupportedResponseTypes: c.OAuth2.ResponseTypes,
 | 
			
		||||
		SkipApprovalScreen:     c.OAuth2.SkipApprovalScreen,
 | 
			
		||||
		Issuer:                 c.Issuer,
 | 
			
		||||
		Connectors:             connectors,
 | 
			
		||||
		Storage:                s,
 | 
			
		||||
		Web:                    c.Frontend,
 | 
			
		||||
		EnablePasswordDB:       c.EnablePasswordDB,
 | 
			
		||||
		Logger:                 logger,
 | 
			
		||||
		Now:                    now,
 | 
			
		||||
		SupportedResponseTypes:  c.OAuth2.ResponseTypes,
 | 
			
		||||
		SkipApprovalScreen:      c.OAuth2.SkipApprovalScreen,
 | 
			
		||||
		DiscoveryAllowedOrigins: c.Web.DiscoveryAllowedOrigins,
 | 
			
		||||
		Issuer:                  c.Issuer,
 | 
			
		||||
		Connectors:              connectors,
 | 
			
		||||
		Storage:                 s,
 | 
			
		||||
		Web:                     c.Frontend,
 | 
			
		||||
		EnablePasswordDB:        c.EnablePasswordDB,
 | 
			
		||||
		Logger:                  logger,
 | 
			
		||||
		Now:                     now,
 | 
			
		||||
	}
 | 
			
		||||
	if c.Expiry.SigningKeys != "" {
 | 
			
		||||
		signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								glide.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										6
									
								
								glide.lock
									
									
									
										generated
									
									
									
								
							@@ -1,5 +1,5 @@
 | 
			
		||||
hash: 22c01c4265c210fbf3cd2d55e9451b924a3301e35053c82ebe35171ab7286c83
 | 
			
		||||
updated: 2017-01-06T15:38:29.812891187-08:00
 | 
			
		||||
hash: 4d7d84f09a330d27458fb821ae7ada243cfa825808dc7ab116db28a08f9166a2
 | 
			
		||||
updated: 2017-01-08T19:23:40.352046548+01:00
 | 
			
		||||
imports:
 | 
			
		||||
- name: github.com/cockroachdb/cockroach-go
 | 
			
		||||
  version: 31611c0501c812f437d4861d87d117053967c955
 | 
			
		||||
@@ -18,6 +18,8 @@ imports:
 | 
			
		||||
  - protoc-gen-go
 | 
			
		||||
- name: github.com/gorilla/context
 | 
			
		||||
  version: aed02d124ae4a0e94fea4541c8effd05bf0c8296
 | 
			
		||||
- name: github.com/gorilla/handlers
 | 
			
		||||
  version: 3a5767ca75ece5f7f1440b1d16975247f8d8b221
 | 
			
		||||
- name: github.com/gorilla/mux
 | 
			
		||||
  version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e
 | 
			
		||||
- name: github.com/gtank/cryptopasta
 | 
			
		||||
 
 | 
			
		||||
@@ -53,6 +53,8 @@ import:
 | 
			
		||||
  version: 9fa818a44c2bf1396a17f9d5a3c0f6dd39d2ff8e
 | 
			
		||||
- package: github.com/gorilla/context
 | 
			
		||||
  version: aed02d124ae4a0e94fea4541c8effd05bf0c8296
 | 
			
		||||
- package: github.com/gorilla/handlers
 | 
			
		||||
  version: 3a5767ca75ece5f7f1440b1d16975247f8d8b221
 | 
			
		||||
 | 
			
		||||
# Package with a bunch of sane crypto defaults. Consider just
 | 
			
		||||
# copy the code (as recommended by the repo itself) instead
 | 
			
		||||
 
 | 
			
		||||
@@ -11,6 +11,7 @@ import (
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gorilla/handlers"
 | 
			
		||||
	"github.com/gorilla/mux"
 | 
			
		||||
	jose "gopkg.in/square/go-jose.v2"
 | 
			
		||||
 | 
			
		||||
@@ -101,7 +102,7 @@ type discovery struct {
 | 
			
		||||
	Claims        []string `json:"claims_supported"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
 | 
			
		||||
func (s *Server) discoveryHandler() (http.Handler, error) {
 | 
			
		||||
	d := discovery{
 | 
			
		||||
		Issuer:      s.issuerURL.String(),
 | 
			
		||||
		Auth:        s.absURL("/auth"),
 | 
			
		||||
@@ -127,11 +128,18 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
 | 
			
		||||
		return nil, fmt.Errorf("failed to marshal discovery data: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	var discoveryHandler http.Handler
 | 
			
		||||
	discoveryHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		w.Header().Set("Content-Type", "application/json")
 | 
			
		||||
		w.Header().Set("Content-Length", strconv.Itoa(len(data)))
 | 
			
		||||
		w.Write(data)
 | 
			
		||||
	}, nil
 | 
			
		||||
	})
 | 
			
		||||
	if len(s.discoveryAllowedOrigins) > 0 {
 | 
			
		||||
		corsOption := handlers.AllowedOrigins(s.discoveryAllowedOrigins)
 | 
			
		||||
		discoveryHandler = handlers.CORS(corsOption)(discoveryHandler)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return discoveryHandler, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// handleAuthorization handles the OAuth2 auth endpoint.
 | 
			
		||||
 
 | 
			
		||||
@@ -22,3 +22,61 @@ func TestHandleHealth(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var discoveryHandlerCORSTests = []struct {
 | 
			
		||||
	DiscoveryAllowedOrigins []string
 | 
			
		||||
	Origin                  string
 | 
			
		||||
	ResponseAllowOrigin     string //The expected response: same as Origin in case of valid CORS flow
 | 
			
		||||
}{
 | 
			
		||||
	{nil, "http://foo.example", ""}, //Default behavior: cross origin requests not allowed
 | 
			
		||||
	{[]string{}, "http://foo.example", ""},
 | 
			
		||||
	{[]string{"http://foo.example"}, "http://foo.example", "http://foo.example"},
 | 
			
		||||
	{[]string{"http://bar.example", "http://foo.example"}, "http://foo.example", "http://foo.example"},
 | 
			
		||||
	{[]string{"*"}, "http://foo.example", "http://foo.example"},
 | 
			
		||||
	{[]string{"http://bar.example"}, "http://foo.example", ""},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestDiscoveryHandlerCORS(t *testing.T) {
 | 
			
		||||
	ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	for _, testcase := range discoveryHandlerCORSTests {
 | 
			
		||||
 | 
			
		||||
		httpServer, server := newTestServer(ctx, t, func(c *Config) {
 | 
			
		||||
			c.DiscoveryAllowedOrigins = testcase.DiscoveryAllowedOrigins
 | 
			
		||||
		})
 | 
			
		||||
		defer httpServer.Close()
 | 
			
		||||
 | 
			
		||||
		discoveryHandler, err := server.discoveryHandler()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Fatalf("failed to get discovery handler: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		//Perform preflight request
 | 
			
		||||
		rrPreflight := httptest.NewRecorder()
 | 
			
		||||
		reqPreflight := httptest.NewRequest("OPTIONS", "/.well-kown/openid-configuration", nil)
 | 
			
		||||
		reqPreflight.Header.Set("Origin", testcase.Origin)
 | 
			
		||||
		reqPreflight.Header.Set("Access-Control-Request-Method", "GET")
 | 
			
		||||
		discoveryHandler.ServeHTTP(rrPreflight, reqPreflight)
 | 
			
		||||
		if rrPreflight.Code != http.StatusOK {
 | 
			
		||||
			t.Errorf("expected 200 got %d", rrPreflight.Code)
 | 
			
		||||
		}
 | 
			
		||||
		headerAccessControlPreflight := rrPreflight.HeaderMap.Get("Access-Control-Allow-Origin")
 | 
			
		||||
		if headerAccessControlPreflight != testcase.ResponseAllowOrigin {
 | 
			
		||||
			t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControlPreflight)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		//Perform request
 | 
			
		||||
		rr := httptest.NewRecorder()
 | 
			
		||||
		req := httptest.NewRequest("GET", "/.well-kown/openid-configuration", nil)
 | 
			
		||||
		req.Header.Set("Origin", testcase.Origin)
 | 
			
		||||
		discoveryHandler.ServeHTTP(rr, req)
 | 
			
		||||
		if rr.Code != http.StatusOK {
 | 
			
		||||
			t.Errorf("expected 200 got %d", rr.Code)
 | 
			
		||||
		}
 | 
			
		||||
		headerAccessControl := rr.HeaderMap.Get("Access-Control-Allow-Origin")
 | 
			
		||||
		if headerAccessControl != testcase.ResponseAllowOrigin {
 | 
			
		||||
			t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControl)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -42,6 +42,11 @@ type Config struct {
 | 
			
		||||
	// flow. If no response types are supplied this value defaults to "code".
 | 
			
		||||
	SupportedResponseTypes []string
 | 
			
		||||
 | 
			
		||||
	// List of allowed origins for CORS requests on discovery endpoint.
 | 
			
		||||
	// If none are indicated, CORS requests are disabled. Passing in "*" will allow any
 | 
			
		||||
	// domain.
 | 
			
		||||
	DiscoveryAllowedOrigins []string
 | 
			
		||||
 | 
			
		||||
	// If enabled, the server won't prompt the user to approve authorization requests.
 | 
			
		||||
	// Logging in implies approval.
 | 
			
		||||
	SkipApprovalScreen bool
 | 
			
		||||
@@ -111,6 +116,8 @@ type Server struct {
 | 
			
		||||
 | 
			
		||||
	supportedResponseTypes map[string]bool
 | 
			
		||||
 | 
			
		||||
	discoveryAllowedOrigins []string
 | 
			
		||||
 | 
			
		||||
	now func() time.Time
 | 
			
		||||
 | 
			
		||||
	idTokensValidFor time.Duration
 | 
			
		||||
@@ -178,15 +185,16 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	s := &Server{
 | 
			
		||||
		issuerURL:              *issuerURL,
 | 
			
		||||
		connectors:             make(map[string]Connector),
 | 
			
		||||
		storage:                newKeyCacher(c.Storage, now),
 | 
			
		||||
		supportedResponseTypes: supported,
 | 
			
		||||
		idTokensValidFor:       value(c.IDTokensValidFor, 24*time.Hour),
 | 
			
		||||
		skipApproval:           c.SkipApprovalScreen,
 | 
			
		||||
		now:                    now,
 | 
			
		||||
		templates:              tmpls,
 | 
			
		||||
		logger:                 c.Logger,
 | 
			
		||||
		issuerURL:               *issuerURL,
 | 
			
		||||
		connectors:              make(map[string]Connector),
 | 
			
		||||
		storage:                 newKeyCacher(c.Storage, now),
 | 
			
		||||
		supportedResponseTypes:  supported,
 | 
			
		||||
		discoveryAllowedOrigins: c.DiscoveryAllowedOrigins,
 | 
			
		||||
		idTokensValidFor:        value(c.IDTokensValidFor, 24*time.Hour),
 | 
			
		||||
		skipApproval:            c.SkipApprovalScreen,
 | 
			
		||||
		now:                     now,
 | 
			
		||||
		templates:               tmpls,
 | 
			
		||||
		logger:                  c.Logger,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, conn := range c.Connectors {
 | 
			
		||||
@@ -197,6 +205,9 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
 | 
			
		||||
	handleFunc := func(p string, h http.HandlerFunc) {
 | 
			
		||||
		r.HandleFunc(path.Join(issuerURL.Path, p), h)
 | 
			
		||||
	}
 | 
			
		||||
	handle := func(p string, h http.Handler) {
 | 
			
		||||
		r.Handle(path.Join(issuerURL.Path, p), h)
 | 
			
		||||
	}
 | 
			
		||||
	handlePrefix := func(p string, h http.Handler) {
 | 
			
		||||
		prefix := path.Join(issuerURL.Path, p)
 | 
			
		||||
		r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h))
 | 
			
		||||
@@ -207,7 +218,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	handleFunc("/.well-known/openid-configuration", discoveryHandler)
 | 
			
		||||
	handle("/.well-known/openid-configuration", discoveryHandler)
 | 
			
		||||
 | 
			
		||||
	// TODO(ericchiang): rate limit certain paths based on IP.
 | 
			
		||||
	handleFunc("/token", s.handleToken)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										22
									
								
								vendor/github.com/gorilla/handlers/LICENSE
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								vendor/github.com/gorilla/handlers/LICENSE
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
			
		||||
Copyright (c) 2013 The Gorilla Handlers Authors. All rights reserved.
 | 
			
		||||
 | 
			
		||||
Redistribution and use in source and binary forms, with or without
 | 
			
		||||
modification, are permitted provided that the following conditions are met:
 | 
			
		||||
 | 
			
		||||
  Redistributions of source code must retain the above copyright notice, this
 | 
			
		||||
  list of conditions and the following disclaimer.
 | 
			
		||||
 | 
			
		||||
  Redistributions in binary form must reproduce the above copyright notice,
 | 
			
		||||
  this list of conditions and the following disclaimer in the documentation
 | 
			
		||||
  and/or other materials provided with the distribution.
 | 
			
		||||
 | 
			
		||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 | 
			
		||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 | 
			
		||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | 
			
		||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | 
			
		||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | 
			
		||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | 
			
		||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | 
			
		||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | 
			
		||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | 
			
		||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | 
			
		||||
							
								
								
									
										74
									
								
								vendor/github.com/gorilla/handlers/canonical.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								vendor/github.com/gorilla/handlers/canonical.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,74 @@
 | 
			
		||||
package handlers
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type canonical struct {
 | 
			
		||||
	h      http.Handler
 | 
			
		||||
	domain string
 | 
			
		||||
	code   int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CanonicalHost is HTTP middleware that re-directs requests to the canonical
 | 
			
		||||
// domain. It accepts a domain and a status code (e.g. 301 or 302) and
 | 
			
		||||
// re-directs clients to this domain. The existing request path is maintained.
 | 
			
		||||
//
 | 
			
		||||
// Note: If the provided domain is considered invalid by url.Parse or otherwise
 | 
			
		||||
// returns an empty scheme or host, clients are not re-directed.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//
 | 
			
		||||
//  r := mux.NewRouter()
 | 
			
		||||
//  canonical := handlers.CanonicalHost("http://www.gorillatoolkit.org", 302)
 | 
			
		||||
//  r.HandleFunc("/route", YourHandler)
 | 
			
		||||
//
 | 
			
		||||
//  log.Fatal(http.ListenAndServe(":7000", canonical(r)))
 | 
			
		||||
//
 | 
			
		||||
func CanonicalHost(domain string, code int) func(h http.Handler) http.Handler {
 | 
			
		||||
	fn := func(h http.Handler) http.Handler {
 | 
			
		||||
		return canonical{h, domain, code}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c canonical) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	dest, err := url.Parse(c.domain)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		// Call the next handler if the provided domain fails to parse.
 | 
			
		||||
		c.h.ServeHTTP(w, r)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if dest.Scheme == "" || dest.Host == "" {
 | 
			
		||||
		// Call the next handler if the scheme or host are empty.
 | 
			
		||||
		// Note that url.Parse won't fail on in this case.
 | 
			
		||||
		c.h.ServeHTTP(w, r)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !strings.EqualFold(cleanHost(r.Host), dest.Host) {
 | 
			
		||||
		// Re-build the destination URL
 | 
			
		||||
		dest := dest.Scheme + "://" + dest.Host + r.URL.Path
 | 
			
		||||
		if r.URL.RawQuery != "" {
 | 
			
		||||
			dest += "?" + r.URL.RawQuery
 | 
			
		||||
		}
 | 
			
		||||
		http.Redirect(w, r, dest, c.code)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.h.ServeHTTP(w, r)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// cleanHost cleans invalid Host headers by stripping anything after '/' or ' '.
 | 
			
		||||
// This is backported from Go 1.5 (in response to issue #11206) and attempts to
 | 
			
		||||
// mitigate malformed Host headers that do not match the format in RFC7230.
 | 
			
		||||
func cleanHost(in string) string {
 | 
			
		||||
	if i := strings.IndexAny(in, " /"); i != -1 {
 | 
			
		||||
		return in[:i]
 | 
			
		||||
	}
 | 
			
		||||
	return in
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										148
									
								
								vendor/github.com/gorilla/handlers/compress.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								vendor/github.com/gorilla/handlers/compress.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,148 @@
 | 
			
		||||
// Copyright 2013 The Gorilla Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a BSD-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package handlers
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"compress/flate"
 | 
			
		||||
	"compress/gzip"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type compressResponseWriter struct {
 | 
			
		||||
	io.Writer
 | 
			
		||||
	http.ResponseWriter
 | 
			
		||||
	http.Hijacker
 | 
			
		||||
	http.Flusher
 | 
			
		||||
	http.CloseNotifier
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (w *compressResponseWriter) WriteHeader(c int) {
 | 
			
		||||
	w.ResponseWriter.Header().Del("Content-Length")
 | 
			
		||||
	w.ResponseWriter.WriteHeader(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (w *compressResponseWriter) Header() http.Header {
 | 
			
		||||
	return w.ResponseWriter.Header()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (w *compressResponseWriter) Write(b []byte) (int, error) {
 | 
			
		||||
	h := w.ResponseWriter.Header()
 | 
			
		||||
	if h.Get("Content-Type") == "" {
 | 
			
		||||
		h.Set("Content-Type", http.DetectContentType(b))
 | 
			
		||||
	}
 | 
			
		||||
	h.Del("Content-Length")
 | 
			
		||||
 | 
			
		||||
	return w.Writer.Write(b)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type flusher interface {
 | 
			
		||||
	Flush() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (w *compressResponseWriter) Flush() {
 | 
			
		||||
	// Flush compressed data if compressor supports it.
 | 
			
		||||
	if f, ok := w.Writer.(flusher); ok {
 | 
			
		||||
		f.Flush()
 | 
			
		||||
	}
 | 
			
		||||
	// Flush HTTP response.
 | 
			
		||||
	if w.Flusher != nil {
 | 
			
		||||
		w.Flusher.Flush()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CompressHandler gzip compresses HTTP responses for clients that support it
 | 
			
		||||
// via the 'Accept-Encoding' header.
 | 
			
		||||
//
 | 
			
		||||
// Compressing TLS traffic may leak the page contents to an attacker if the
 | 
			
		||||
// page contains user input: http://security.stackexchange.com/a/102015/12208
 | 
			
		||||
func CompressHandler(h http.Handler) http.Handler {
 | 
			
		||||
	return CompressHandlerLevel(h, gzip.DefaultCompression)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CompressHandlerLevel gzip compresses HTTP responses with specified compression level
 | 
			
		||||
// for clients that support it via the 'Accept-Encoding' header.
 | 
			
		||||
//
 | 
			
		||||
// The compression level should be gzip.DefaultCompression, gzip.NoCompression,
 | 
			
		||||
// or any integer value between gzip.BestSpeed and gzip.BestCompression inclusive.
 | 
			
		||||
// gzip.DefaultCompression is used in case of invalid compression level.
 | 
			
		||||
func CompressHandlerLevel(h http.Handler, level int) http.Handler {
 | 
			
		||||
	if level < gzip.DefaultCompression || level > gzip.BestCompression {
 | 
			
		||||
		level = gzip.DefaultCompression
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	L:
 | 
			
		||||
		for _, enc := range strings.Split(r.Header.Get("Accept-Encoding"), ",") {
 | 
			
		||||
			switch strings.TrimSpace(enc) {
 | 
			
		||||
			case "gzip":
 | 
			
		||||
				w.Header().Set("Content-Encoding", "gzip")
 | 
			
		||||
				w.Header().Add("Vary", "Accept-Encoding")
 | 
			
		||||
 | 
			
		||||
				gw, _ := gzip.NewWriterLevel(w, level)
 | 
			
		||||
				defer gw.Close()
 | 
			
		||||
 | 
			
		||||
				h, hok := w.(http.Hijacker)
 | 
			
		||||
				if !hok { /* w is not Hijacker... oh well... */
 | 
			
		||||
					h = nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				f, fok := w.(http.Flusher)
 | 
			
		||||
				if !fok {
 | 
			
		||||
					f = nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				cn, cnok := w.(http.CloseNotifier)
 | 
			
		||||
				if !cnok {
 | 
			
		||||
					cn = nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				w = &compressResponseWriter{
 | 
			
		||||
					Writer:         gw,
 | 
			
		||||
					ResponseWriter: w,
 | 
			
		||||
					Hijacker:       h,
 | 
			
		||||
					Flusher:        f,
 | 
			
		||||
					CloseNotifier:  cn,
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				break L
 | 
			
		||||
			case "deflate":
 | 
			
		||||
				w.Header().Set("Content-Encoding", "deflate")
 | 
			
		||||
				w.Header().Add("Vary", "Accept-Encoding")
 | 
			
		||||
 | 
			
		||||
				fw, _ := flate.NewWriter(w, level)
 | 
			
		||||
				defer fw.Close()
 | 
			
		||||
 | 
			
		||||
				h, hok := w.(http.Hijacker)
 | 
			
		||||
				if !hok { /* w is not Hijacker... oh well... */
 | 
			
		||||
					h = nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				f, fok := w.(http.Flusher)
 | 
			
		||||
				if !fok {
 | 
			
		||||
					f = nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				cn, cnok := w.(http.CloseNotifier)
 | 
			
		||||
				if !cnok {
 | 
			
		||||
					cn = nil
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				w = &compressResponseWriter{
 | 
			
		||||
					Writer:         fw,
 | 
			
		||||
					ResponseWriter: w,
 | 
			
		||||
					Hijacker:       h,
 | 
			
		||||
					Flusher:        f,
 | 
			
		||||
					CloseNotifier:  cn,
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				break L
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		h.ServeHTTP(w, r)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										317
									
								
								vendor/github.com/gorilla/handlers/cors.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										317
									
								
								vendor/github.com/gorilla/handlers/cors.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,317 @@
 | 
			
		||||
package handlers
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CORSOption represents a functional option for configuring the CORS middleware.
 | 
			
		||||
type CORSOption func(*cors) error
 | 
			
		||||
 | 
			
		||||
type cors struct {
 | 
			
		||||
	h                      http.Handler
 | 
			
		||||
	allowedHeaders         []string
 | 
			
		||||
	allowedMethods         []string
 | 
			
		||||
	allowedOrigins         []string
 | 
			
		||||
	allowedOriginValidator OriginValidator
 | 
			
		||||
	exposedHeaders         []string
 | 
			
		||||
	maxAge                 int
 | 
			
		||||
	ignoreOptions          bool
 | 
			
		||||
	allowCredentials       bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OriginValidator takes an origin string and returns whether or not that origin is allowed.
 | 
			
		||||
type OriginValidator func(string) bool
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	defaultCorsMethods = []string{"GET", "HEAD", "POST"}
 | 
			
		||||
	defaultCorsHeaders = []string{"Accept", "Accept-Language", "Content-Language", "Origin"}
 | 
			
		||||
	// (WebKit/Safari v9 sends the Origin header by default in AJAX requests)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	corsOptionMethod           string = "OPTIONS"
 | 
			
		||||
	corsAllowOriginHeader      string = "Access-Control-Allow-Origin"
 | 
			
		||||
	corsExposeHeadersHeader    string = "Access-Control-Expose-Headers"
 | 
			
		||||
	corsMaxAgeHeader           string = "Access-Control-Max-Age"
 | 
			
		||||
	corsAllowMethodsHeader     string = "Access-Control-Allow-Methods"
 | 
			
		||||
	corsAllowHeadersHeader     string = "Access-Control-Allow-Headers"
 | 
			
		||||
	corsAllowCredentialsHeader string = "Access-Control-Allow-Credentials"
 | 
			
		||||
	corsRequestMethodHeader    string = "Access-Control-Request-Method"
 | 
			
		||||
	corsRequestHeadersHeader   string = "Access-Control-Request-Headers"
 | 
			
		||||
	corsOriginHeader           string = "Origin"
 | 
			
		||||
	corsVaryHeader             string = "Vary"
 | 
			
		||||
	corsOriginMatchAll         string = "*"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
	origin := r.Header.Get(corsOriginHeader)
 | 
			
		||||
	if !ch.isOriginAllowed(origin) {
 | 
			
		||||
		ch.h.ServeHTTP(w, r)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.Method == corsOptionMethod {
 | 
			
		||||
		if ch.ignoreOptions {
 | 
			
		||||
			ch.h.ServeHTTP(w, r)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if _, ok := r.Header[corsRequestMethodHeader]; !ok {
 | 
			
		||||
			w.WriteHeader(http.StatusBadRequest)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		method := r.Header.Get(corsRequestMethodHeader)
 | 
			
		||||
		if !ch.isMatch(method, ch.allowedMethods) {
 | 
			
		||||
			w.WriteHeader(http.StatusMethodNotAllowed)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		requestHeaders := strings.Split(r.Header.Get(corsRequestHeadersHeader), ",")
 | 
			
		||||
		allowedHeaders := []string{}
 | 
			
		||||
		for _, v := range requestHeaders {
 | 
			
		||||
			canonicalHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
 | 
			
		||||
			if canonicalHeader == "" || ch.isMatch(canonicalHeader, defaultCorsHeaders) {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !ch.isMatch(canonicalHeader, ch.allowedHeaders) {
 | 
			
		||||
				w.WriteHeader(http.StatusForbidden)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			allowedHeaders = append(allowedHeaders, canonicalHeader)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if len(allowedHeaders) > 0 {
 | 
			
		||||
			w.Header().Set(corsAllowHeadersHeader, strings.Join(allowedHeaders, ","))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if ch.maxAge > 0 {
 | 
			
		||||
			w.Header().Set(corsMaxAgeHeader, strconv.Itoa(ch.maxAge))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !ch.isMatch(method, defaultCorsMethods) {
 | 
			
		||||
			w.Header().Set(corsAllowMethodsHeader, method)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		if len(ch.exposedHeaders) > 0 {
 | 
			
		||||
			w.Header().Set(corsExposeHeadersHeader, strings.Join(ch.exposedHeaders, ","))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ch.allowCredentials {
 | 
			
		||||
		w.Header().Set(corsAllowCredentialsHeader, "true")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(ch.allowedOrigins) > 1 {
 | 
			
		||||
		w.Header().Set(corsVaryHeader, corsOriginHeader)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	w.Header().Set(corsAllowOriginHeader, origin)
 | 
			
		||||
 | 
			
		||||
	if r.Method == corsOptionMethod {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	ch.h.ServeHTTP(w, r)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CORS provides Cross-Origin Resource Sharing middleware.
 | 
			
		||||
// Example:
 | 
			
		||||
//
 | 
			
		||||
//  import (
 | 
			
		||||
//      "net/http"
 | 
			
		||||
//
 | 
			
		||||
//      "github.com/gorilla/handlers"
 | 
			
		||||
//      "github.com/gorilla/mux"
 | 
			
		||||
//  )
 | 
			
		||||
//
 | 
			
		||||
//  func main() {
 | 
			
		||||
//      r := mux.NewRouter()
 | 
			
		||||
//      r.HandleFunc("/users", UserEndpoint)
 | 
			
		||||
//      r.HandleFunc("/projects", ProjectEndpoint)
 | 
			
		||||
//
 | 
			
		||||
//      // Apply the CORS middleware to our top-level router, with the defaults.
 | 
			
		||||
//      http.ListenAndServe(":8000", handlers.CORS()(r))
 | 
			
		||||
//  }
 | 
			
		||||
//
 | 
			
		||||
func CORS(opts ...CORSOption) func(http.Handler) http.Handler {
 | 
			
		||||
	return func(h http.Handler) http.Handler {
 | 
			
		||||
		ch := parseCORSOptions(opts...)
 | 
			
		||||
		ch.h = h
 | 
			
		||||
		return ch
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func parseCORSOptions(opts ...CORSOption) *cors {
 | 
			
		||||
	ch := &cors{
 | 
			
		||||
		allowedMethods: defaultCorsMethods,
 | 
			
		||||
		allowedHeaders: defaultCorsHeaders,
 | 
			
		||||
		allowedOrigins: []string{corsOriginMatchAll},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, option := range opts {
 | 
			
		||||
		option(ch)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ch
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
// Functional options for configuring CORS.
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
// AllowedHeaders adds the provided headers to the list of allowed headers in a
 | 
			
		||||
// CORS request.
 | 
			
		||||
// This is an append operation so the headers Accept, Accept-Language,
 | 
			
		||||
// and Content-Language are always allowed.
 | 
			
		||||
// Content-Type must be explicitly declared if accepting Content-Types other than
 | 
			
		||||
// application/x-www-form-urlencoded, multipart/form-data, or text/plain.
 | 
			
		||||
func AllowedHeaders(headers []string) CORSOption {
 | 
			
		||||
	return func(ch *cors) error {
 | 
			
		||||
		for _, v := range headers {
 | 
			
		||||
			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
 | 
			
		||||
			if normalizedHeader == "" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !ch.isMatch(normalizedHeader, ch.allowedHeaders) {
 | 
			
		||||
				ch.allowedHeaders = append(ch.allowedHeaders, normalizedHeader)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllowedMethods can be used to explicitly allow methods in the
 | 
			
		||||
// Access-Control-Allow-Methods header.
 | 
			
		||||
// This is a replacement operation so you must also
 | 
			
		||||
// pass GET, HEAD, and POST if you wish to support those methods.
 | 
			
		||||
func AllowedMethods(methods []string) CORSOption {
 | 
			
		||||
	return func(ch *cors) error {
 | 
			
		||||
		ch.allowedMethods = []string{}
 | 
			
		||||
		for _, v := range methods {
 | 
			
		||||
			normalizedMethod := strings.ToUpper(strings.TrimSpace(v))
 | 
			
		||||
			if normalizedMethod == "" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !ch.isMatch(normalizedMethod, ch.allowedMethods) {
 | 
			
		||||
				ch.allowedMethods = append(ch.allowedMethods, normalizedMethod)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllowedOrigins sets the allowed origins for CORS requests, as used in the
 | 
			
		||||
// 'Allow-Access-Control-Origin' HTTP header.
 | 
			
		||||
// Note: Passing in a []string{"*"} will allow any domain.
 | 
			
		||||
func AllowedOrigins(origins []string) CORSOption {
 | 
			
		||||
	return func(ch *cors) error {
 | 
			
		||||
		for _, v := range origins {
 | 
			
		||||
			if v == corsOriginMatchAll {
 | 
			
		||||
				ch.allowedOrigins = []string{corsOriginMatchAll}
 | 
			
		||||
				return nil
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ch.allowedOrigins = origins
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllowedOriginValidator sets a function for evaluating allowed origins in CORS requests, represented by the
 | 
			
		||||
// 'Allow-Access-Control-Origin' HTTP header.
 | 
			
		||||
func AllowedOriginValidator(fn OriginValidator) CORSOption {
 | 
			
		||||
	return func(ch *cors) error {
 | 
			
		||||
		ch.allowedOriginValidator = fn
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ExposeHeaders can be used to specify headers that are available
 | 
			
		||||
// and will not be stripped out by the user-agent.
 | 
			
		||||
func ExposedHeaders(headers []string) CORSOption {
 | 
			
		||||
	return func(ch *cors) error {
 | 
			
		||||
		ch.exposedHeaders = []string{}
 | 
			
		||||
		for _, v := range headers {
 | 
			
		||||
			normalizedHeader := http.CanonicalHeaderKey(strings.TrimSpace(v))
 | 
			
		||||
			if normalizedHeader == "" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !ch.isMatch(normalizedHeader, ch.exposedHeaders) {
 | 
			
		||||
				ch.exposedHeaders = append(ch.exposedHeaders, normalizedHeader)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MaxAge determines the maximum age (in seconds) between preflight requests. A
 | 
			
		||||
// maximum of 10 minutes is allowed. An age above this value will default to 10
 | 
			
		||||
// minutes.
 | 
			
		||||
func MaxAge(age int) CORSOption {
 | 
			
		||||
	return func(ch *cors) error {
 | 
			
		||||
		// Maximum of 10 minutes.
 | 
			
		||||
		if age > 600 {
 | 
			
		||||
			age = 600
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ch.maxAge = age
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// IgnoreOptions causes the CORS middleware to ignore OPTIONS requests, instead
 | 
			
		||||
// passing them through to the next handler. This is useful when your application
 | 
			
		||||
// or framework has a pre-existing mechanism for responding to OPTIONS requests.
 | 
			
		||||
func IgnoreOptions() CORSOption {
 | 
			
		||||
	return func(ch *cors) error {
 | 
			
		||||
		ch.ignoreOptions = true
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AllowCredentials can be used to specify that the user agent may pass
 | 
			
		||||
// authentication details along with the request.
 | 
			
		||||
func AllowCredentials() CORSOption {
 | 
			
		||||
	return func(ch *cors) error {
 | 
			
		||||
		ch.allowCredentials = true
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ch *cors) isOriginAllowed(origin string) bool {
 | 
			
		||||
	if origin == "" {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ch.allowedOriginValidator != nil {
 | 
			
		||||
		return ch.allowedOriginValidator(origin)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, allowedOrigin := range ch.allowedOrigins {
 | 
			
		||||
		if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (ch *cors) isMatch(needle string, haystack []string) bool {
 | 
			
		||||
	for _, v := range haystack {
 | 
			
		||||
		if v == needle {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										9
									
								
								vendor/github.com/gorilla/handlers/doc.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								vendor/github.com/gorilla/handlers/doc.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
			
		||||
/*
 | 
			
		||||
Package handlers is a collection of handlers (aka "HTTP middleware") for use
 | 
			
		||||
with Go's net/http package (or any framework supporting http.Handler).
 | 
			
		||||
 | 
			
		||||
The package includes handlers for logging in standardised formats, compressing
 | 
			
		||||
HTTP responses, validating content types and other useful tools for manipulating
 | 
			
		||||
requests and responses.
 | 
			
		||||
*/
 | 
			
		||||
package handlers
 | 
			
		||||
							
								
								
									
										403
									
								
								vendor/github.com/gorilla/handlers/handlers.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										403
									
								
								vendor/github.com/gorilla/handlers/handlers.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,403 @@
 | 
			
		||||
// Copyright 2013 The Gorilla Authors. All rights reserved.
 | 
			
		||||
// Use of this source code is governed by a BSD-style
 | 
			
		||||
// license that can be found in the LICENSE file.
 | 
			
		||||
 | 
			
		||||
package handlers
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MethodHandler is an http.Handler that dispatches to a handler whose key in the
 | 
			
		||||
// MethodHandler's map matches the name of the HTTP request's method, eg: GET
 | 
			
		||||
//
 | 
			
		||||
// If the request's method is OPTIONS and OPTIONS is not a key in the map then
 | 
			
		||||
// the handler responds with a status of 200 and sets the Allow header to a
 | 
			
		||||
// comma-separated list of available methods.
 | 
			
		||||
//
 | 
			
		||||
// If the request's method doesn't match any of its keys the handler responds
 | 
			
		||||
// with a status of HTTP 405 "Method Not Allowed" and sets the Allow header to a
 | 
			
		||||
// comma-separated list of available methods.
 | 
			
		||||
type MethodHandler map[string]http.Handler
 | 
			
		||||
 | 
			
		||||
func (h MethodHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
	if handler, ok := h[req.Method]; ok {
 | 
			
		||||
		handler.ServeHTTP(w, req)
 | 
			
		||||
	} else {
 | 
			
		||||
		allow := []string{}
 | 
			
		||||
		for k := range h {
 | 
			
		||||
			allow = append(allow, k)
 | 
			
		||||
		}
 | 
			
		||||
		sort.Strings(allow)
 | 
			
		||||
		w.Header().Set("Allow", strings.Join(allow, ", "))
 | 
			
		||||
		if req.Method == "OPTIONS" {
 | 
			
		||||
			w.WriteHeader(http.StatusOK)
 | 
			
		||||
		} else {
 | 
			
		||||
			http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// loggingHandler is the http.Handler implementation for LoggingHandlerTo and its
 | 
			
		||||
// friends
 | 
			
		||||
type loggingHandler struct {
 | 
			
		||||
	writer  io.Writer
 | 
			
		||||
	handler http.Handler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// combinedLoggingHandler is the http.Handler implementation for LoggingHandlerTo
 | 
			
		||||
// and its friends
 | 
			
		||||
type combinedLoggingHandler struct {
 | 
			
		||||
	writer  io.Writer
 | 
			
		||||
	handler http.Handler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	logger := makeLogger(w)
 | 
			
		||||
	url := *req.URL
 | 
			
		||||
	h.handler.ServeHTTP(logger, req)
 | 
			
		||||
	writeLog(h.writer, req, url, t, logger.Status(), logger.Size())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h combinedLoggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
	t := time.Now()
 | 
			
		||||
	logger := makeLogger(w)
 | 
			
		||||
	url := *req.URL
 | 
			
		||||
	h.handler.ServeHTTP(logger, req)
 | 
			
		||||
	writeCombinedLog(h.writer, req, url, t, logger.Status(), logger.Size())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func makeLogger(w http.ResponseWriter) loggingResponseWriter {
 | 
			
		||||
	var logger loggingResponseWriter = &responseLogger{w: w}
 | 
			
		||||
	if _, ok := w.(http.Hijacker); ok {
 | 
			
		||||
		logger = &hijackLogger{responseLogger{w: w}}
 | 
			
		||||
	}
 | 
			
		||||
	h, ok1 := logger.(http.Hijacker)
 | 
			
		||||
	c, ok2 := w.(http.CloseNotifier)
 | 
			
		||||
	if ok1 && ok2 {
 | 
			
		||||
		return hijackCloseNotifier{logger, h, c}
 | 
			
		||||
	}
 | 
			
		||||
	if ok2 {
 | 
			
		||||
		return &closeNotifyWriter{logger, c}
 | 
			
		||||
	}
 | 
			
		||||
	return logger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type loggingResponseWriter interface {
 | 
			
		||||
	http.ResponseWriter
 | 
			
		||||
	http.Flusher
 | 
			
		||||
	Status() int
 | 
			
		||||
	Size() int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// responseLogger is wrapper of http.ResponseWriter that keeps track of its HTTP
 | 
			
		||||
// status code and body size
 | 
			
		||||
type responseLogger struct {
 | 
			
		||||
	w      http.ResponseWriter
 | 
			
		||||
	status int
 | 
			
		||||
	size   int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *responseLogger) Header() http.Header {
 | 
			
		||||
	return l.w.Header()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *responseLogger) Write(b []byte) (int, error) {
 | 
			
		||||
	if l.status == 0 {
 | 
			
		||||
		// The status will be StatusOK if WriteHeader has not been called yet
 | 
			
		||||
		l.status = http.StatusOK
 | 
			
		||||
	}
 | 
			
		||||
	size, err := l.w.Write(b)
 | 
			
		||||
	l.size += size
 | 
			
		||||
	return size, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *responseLogger) WriteHeader(s int) {
 | 
			
		||||
	l.w.WriteHeader(s)
 | 
			
		||||
	l.status = s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *responseLogger) Status() int {
 | 
			
		||||
	return l.status
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *responseLogger) Size() int {
 | 
			
		||||
	return l.size
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *responseLogger) Flush() {
 | 
			
		||||
	f, ok := l.w.(http.Flusher)
 | 
			
		||||
	if ok {
 | 
			
		||||
		f.Flush()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type hijackLogger struct {
 | 
			
		||||
	responseLogger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (l *hijackLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
 | 
			
		||||
	h := l.responseLogger.w.(http.Hijacker)
 | 
			
		||||
	conn, rw, err := h.Hijack()
 | 
			
		||||
	if err == nil && l.responseLogger.status == 0 {
 | 
			
		||||
		// The status will be StatusSwitchingProtocols if there was no error and
 | 
			
		||||
		// WriteHeader has not been called yet
 | 
			
		||||
		l.responseLogger.status = http.StatusSwitchingProtocols
 | 
			
		||||
	}
 | 
			
		||||
	return conn, rw, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type closeNotifyWriter struct {
 | 
			
		||||
	loggingResponseWriter
 | 
			
		||||
	http.CloseNotifier
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type hijackCloseNotifier struct {
 | 
			
		||||
	loggingResponseWriter
 | 
			
		||||
	http.Hijacker
 | 
			
		||||
	http.CloseNotifier
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const lowerhex = "0123456789abcdef"
 | 
			
		||||
 | 
			
		||||
func appendQuoted(buf []byte, s string) []byte {
 | 
			
		||||
	var runeTmp [utf8.UTFMax]byte
 | 
			
		||||
	for width := 0; len(s) > 0; s = s[width:] {
 | 
			
		||||
		r := rune(s[0])
 | 
			
		||||
		width = 1
 | 
			
		||||
		if r >= utf8.RuneSelf {
 | 
			
		||||
			r, width = utf8.DecodeRuneInString(s)
 | 
			
		||||
		}
 | 
			
		||||
		if width == 1 && r == utf8.RuneError {
 | 
			
		||||
			buf = append(buf, `\x`...)
 | 
			
		||||
			buf = append(buf, lowerhex[s[0]>>4])
 | 
			
		||||
			buf = append(buf, lowerhex[s[0]&0xF])
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if r == rune('"') || r == '\\' { // always backslashed
 | 
			
		||||
			buf = append(buf, '\\')
 | 
			
		||||
			buf = append(buf, byte(r))
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if strconv.IsPrint(r) {
 | 
			
		||||
			n := utf8.EncodeRune(runeTmp[:], r)
 | 
			
		||||
			buf = append(buf, runeTmp[:n]...)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		switch r {
 | 
			
		||||
		case '\a':
 | 
			
		||||
			buf = append(buf, `\a`...)
 | 
			
		||||
		case '\b':
 | 
			
		||||
			buf = append(buf, `\b`...)
 | 
			
		||||
		case '\f':
 | 
			
		||||
			buf = append(buf, `\f`...)
 | 
			
		||||
		case '\n':
 | 
			
		||||
			buf = append(buf, `\n`...)
 | 
			
		||||
		case '\r':
 | 
			
		||||
			buf = append(buf, `\r`...)
 | 
			
		||||
		case '\t':
 | 
			
		||||
			buf = append(buf, `\t`...)
 | 
			
		||||
		case '\v':
 | 
			
		||||
			buf = append(buf, `\v`...)
 | 
			
		||||
		default:
 | 
			
		||||
			switch {
 | 
			
		||||
			case r < ' ':
 | 
			
		||||
				buf = append(buf, `\x`...)
 | 
			
		||||
				buf = append(buf, lowerhex[s[0]>>4])
 | 
			
		||||
				buf = append(buf, lowerhex[s[0]&0xF])
 | 
			
		||||
			case r > utf8.MaxRune:
 | 
			
		||||
				r = 0xFFFD
 | 
			
		||||
				fallthrough
 | 
			
		||||
			case r < 0x10000:
 | 
			
		||||
				buf = append(buf, `\u`...)
 | 
			
		||||
				for s := 12; s >= 0; s -= 4 {
 | 
			
		||||
					buf = append(buf, lowerhex[r>>uint(s)&0xF])
 | 
			
		||||
				}
 | 
			
		||||
			default:
 | 
			
		||||
				buf = append(buf, `\U`...)
 | 
			
		||||
				for s := 28; s >= 0; s -= 4 {
 | 
			
		||||
					buf = append(buf, lowerhex[r>>uint(s)&0xF])
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return buf
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// buildCommonLogLine builds a log entry for req in Apache Common Log Format.
 | 
			
		||||
// ts is the timestamp with which the entry should be logged.
 | 
			
		||||
// status and size are used to provide the response HTTP status and size.
 | 
			
		||||
func buildCommonLogLine(req *http.Request, url url.URL, ts time.Time, status int, size int) []byte {
 | 
			
		||||
	username := "-"
 | 
			
		||||
	if url.User != nil {
 | 
			
		||||
		if name := url.User.Username(); name != "" {
 | 
			
		||||
			username = name
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	host, _, err := net.SplitHostPort(req.RemoteAddr)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		host = req.RemoteAddr
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	uri := req.RequestURI
 | 
			
		||||
 | 
			
		||||
	// Requests using the CONNECT method over HTTP/2.0 must use
 | 
			
		||||
	// the authority field (aka r.Host) to identify the target.
 | 
			
		||||
	// Refer: https://httpwg.github.io/specs/rfc7540.html#CONNECT
 | 
			
		||||
	if req.ProtoMajor == 2 && req.Method == "CONNECT" {
 | 
			
		||||
		uri = req.Host
 | 
			
		||||
	}
 | 
			
		||||
	if uri == "" {
 | 
			
		||||
		uri = url.RequestURI()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	buf := make([]byte, 0, 3*(len(host)+len(username)+len(req.Method)+len(uri)+len(req.Proto)+50)/2)
 | 
			
		||||
	buf = append(buf, host...)
 | 
			
		||||
	buf = append(buf, " - "...)
 | 
			
		||||
	buf = append(buf, username...)
 | 
			
		||||
	buf = append(buf, " ["...)
 | 
			
		||||
	buf = append(buf, ts.Format("02/Jan/2006:15:04:05 -0700")...)
 | 
			
		||||
	buf = append(buf, `] "`...)
 | 
			
		||||
	buf = append(buf, req.Method...)
 | 
			
		||||
	buf = append(buf, " "...)
 | 
			
		||||
	buf = appendQuoted(buf, uri)
 | 
			
		||||
	buf = append(buf, " "...)
 | 
			
		||||
	buf = append(buf, req.Proto...)
 | 
			
		||||
	buf = append(buf, `" `...)
 | 
			
		||||
	buf = append(buf, strconv.Itoa(status)...)
 | 
			
		||||
	buf = append(buf, " "...)
 | 
			
		||||
	buf = append(buf, strconv.Itoa(size)...)
 | 
			
		||||
	return buf
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// writeLog writes a log entry for req to w in Apache Common Log Format.
 | 
			
		||||
// ts is the timestamp with which the entry should be logged.
 | 
			
		||||
// status and size are used to provide the response HTTP status and size.
 | 
			
		||||
func writeLog(w io.Writer, req *http.Request, url url.URL, ts time.Time, status, size int) {
 | 
			
		||||
	buf := buildCommonLogLine(req, url, ts, status, size)
 | 
			
		||||
	buf = append(buf, '\n')
 | 
			
		||||
	w.Write(buf)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// writeCombinedLog writes a log entry for req to w in Apache Combined Log Format.
 | 
			
		||||
// ts is the timestamp with which the entry should be logged.
 | 
			
		||||
// status and size are used to provide the response HTTP status and size.
 | 
			
		||||
func writeCombinedLog(w io.Writer, req *http.Request, url url.URL, ts time.Time, status, size int) {
 | 
			
		||||
	buf := buildCommonLogLine(req, url, ts, status, size)
 | 
			
		||||
	buf = append(buf, ` "`...)
 | 
			
		||||
	buf = appendQuoted(buf, req.Referer())
 | 
			
		||||
	buf = append(buf, `" "`...)
 | 
			
		||||
	buf = appendQuoted(buf, req.UserAgent())
 | 
			
		||||
	buf = append(buf, '"', '\n')
 | 
			
		||||
	w.Write(buf)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CombinedLoggingHandler return a http.Handler that wraps h and logs requests to out in
 | 
			
		||||
// Apache Combined Log Format.
 | 
			
		||||
//
 | 
			
		||||
// See http://httpd.apache.org/docs/2.2/logs.html#combined for a description of this format.
 | 
			
		||||
//
 | 
			
		||||
// LoggingHandler always sets the ident field of the log to -
 | 
			
		||||
func CombinedLoggingHandler(out io.Writer, h http.Handler) http.Handler {
 | 
			
		||||
	return combinedLoggingHandler{out, h}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LoggingHandler return a http.Handler that wraps h and logs requests to out in
 | 
			
		||||
// Apache Common Log Format (CLF).
 | 
			
		||||
//
 | 
			
		||||
// See http://httpd.apache.org/docs/2.2/logs.html#common for a description of this format.
 | 
			
		||||
//
 | 
			
		||||
// LoggingHandler always sets the ident field of the log to -
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//
 | 
			
		||||
//  r := mux.NewRouter()
 | 
			
		||||
//  r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
//  	w.Write([]byte("This is a catch-all route"))
 | 
			
		||||
//  })
 | 
			
		||||
//  loggedRouter := handlers.LoggingHandler(os.Stdout, r)
 | 
			
		||||
//  http.ListenAndServe(":1123", loggedRouter)
 | 
			
		||||
//
 | 
			
		||||
func LoggingHandler(out io.Writer, h http.Handler) http.Handler {
 | 
			
		||||
	return loggingHandler{out, h}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// isContentType validates the Content-Type header matches the supplied
 | 
			
		||||
// contentType. That is, its type and subtype match.
 | 
			
		||||
func isContentType(h http.Header, contentType string) bool {
 | 
			
		||||
	ct := h.Get("Content-Type")
 | 
			
		||||
	if i := strings.IndexRune(ct, ';'); i != -1 {
 | 
			
		||||
		ct = ct[0:i]
 | 
			
		||||
	}
 | 
			
		||||
	return ct == contentType
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ContentTypeHandler wraps and returns a http.Handler, validating the request
 | 
			
		||||
// content type is compatible with the contentTypes list. It writes a HTTP 415
 | 
			
		||||
// error if that fails.
 | 
			
		||||
//
 | 
			
		||||
// Only PUT, POST, and PATCH requests are considered.
 | 
			
		||||
func ContentTypeHandler(h http.Handler, contentTypes ...string) http.Handler {
 | 
			
		||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		if !(r.Method == "PUT" || r.Method == "POST" || r.Method == "PATCH") {
 | 
			
		||||
			h.ServeHTTP(w, r)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		for _, ct := range contentTypes {
 | 
			
		||||
			if isContentType(r.Header, ct) {
 | 
			
		||||
				h.ServeHTTP(w, r)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		http.Error(w, fmt.Sprintf("Unsupported content type %q; expected one of %q", r.Header.Get("Content-Type"), contentTypes), http.StatusUnsupportedMediaType)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	// HTTPMethodOverrideHeader is a commonly used
 | 
			
		||||
	// http header to override a request method.
 | 
			
		||||
	HTTPMethodOverrideHeader = "X-HTTP-Method-Override"
 | 
			
		||||
	// HTTPMethodOverrideFormKey is a commonly used
 | 
			
		||||
	// HTML form key to override a request method.
 | 
			
		||||
	HTTPMethodOverrideFormKey = "_method"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// HTTPMethodOverrideHandler wraps and returns a http.Handler which checks for
 | 
			
		||||
// the X-HTTP-Method-Override header or the _method form key, and overrides (if
 | 
			
		||||
// valid) request.Method with its value.
 | 
			
		||||
//
 | 
			
		||||
// This is especially useful for HTTP clients that don't support many http verbs.
 | 
			
		||||
// It isn't secure to override e.g a GET to a POST, so only POST requests are
 | 
			
		||||
// considered.  Likewise, the override method can only be a "write" method: PUT,
 | 
			
		||||
// PATCH or DELETE.
 | 
			
		||||
//
 | 
			
		||||
// Form method takes precedence over header method.
 | 
			
		||||
func HTTPMethodOverrideHandler(h http.Handler) http.Handler {
 | 
			
		||||
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		if r.Method == "POST" {
 | 
			
		||||
			om := r.FormValue(HTTPMethodOverrideFormKey)
 | 
			
		||||
			if om == "" {
 | 
			
		||||
				om = r.Header.Get(HTTPMethodOverrideHeader)
 | 
			
		||||
			}
 | 
			
		||||
			if om == "PUT" || om == "PATCH" || om == "DELETE" {
 | 
			
		||||
				r.Method = om
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		h.ServeHTTP(w, r)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										120
									
								
								vendor/github.com/gorilla/handlers/proxy_headers.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								vendor/github.com/gorilla/handlers/proxy_headers.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,120 @@
 | 
			
		||||
package handlers
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// De-facto standard header keys.
 | 
			
		||||
	xForwardedFor    = http.CanonicalHeaderKey("X-Forwarded-For")
 | 
			
		||||
	xForwardedHost   = http.CanonicalHeaderKey("X-Forwarded-Host")
 | 
			
		||||
	xForwardedProto  = http.CanonicalHeaderKey("X-Forwarded-Proto")
 | 
			
		||||
	xForwardedScheme = http.CanonicalHeaderKey("X-Forwarded-Scheme")
 | 
			
		||||
	xRealIP          = http.CanonicalHeaderKey("X-Real-IP")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	// RFC7239 defines a new "Forwarded: " header designed to replace the
 | 
			
		||||
	// existing use of X-Forwarded-* headers.
 | 
			
		||||
	// e.g. Forwarded: for=192.0.2.60;proto=https;by=203.0.113.43
 | 
			
		||||
	forwarded = http.CanonicalHeaderKey("Forwarded")
 | 
			
		||||
	// Allows for a sub-match of the first value after 'for=' to the next
 | 
			
		||||
	// comma, semi-colon or space. The match is case-insensitive.
 | 
			
		||||
	forRegex = regexp.MustCompile(`(?i)(?:for=)([^(;|,| )]+)`)
 | 
			
		||||
	// Allows for a sub-match for the first instance of scheme (http|https)
 | 
			
		||||
	// prefixed by 'proto='. The match is case-insensitive.
 | 
			
		||||
	protoRegex = regexp.MustCompile(`(?i)(?:proto=)(https|http)`)
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ProxyHeaders inspects common reverse proxy headers and sets the corresponding
 | 
			
		||||
// fields in the HTTP request struct. These are X-Forwarded-For and X-Real-IP
 | 
			
		||||
// for the remote (client) IP address, X-Forwarded-Proto or X-Forwarded-Scheme
 | 
			
		||||
// for the scheme (http|https) and the RFC7239 Forwarded header, which may
 | 
			
		||||
// include both client IPs and schemes.
 | 
			
		||||
//
 | 
			
		||||
// NOTE: This middleware should only be used when behind a reverse
 | 
			
		||||
// proxy like nginx, HAProxy or Apache. Reverse proxies that don't (or are
 | 
			
		||||
// configured not to) strip these headers from client requests, or where these
 | 
			
		||||
// headers are accepted "as is" from a remote client (e.g. when Go is not behind
 | 
			
		||||
// a proxy), can manifest as a vulnerability if your application uses these
 | 
			
		||||
// headers for validating the 'trustworthiness' of a request.
 | 
			
		||||
func ProxyHeaders(h http.Handler) http.Handler {
 | 
			
		||||
	fn := func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
		// Set the remote IP with the value passed from the proxy.
 | 
			
		||||
		if fwd := getIP(r); fwd != "" {
 | 
			
		||||
			r.RemoteAddr = fwd
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Set the scheme (proto) with the value passed from the proxy.
 | 
			
		||||
		if scheme := getScheme(r); scheme != "" {
 | 
			
		||||
			r.URL.Scheme = scheme
 | 
			
		||||
		}
 | 
			
		||||
		// Set the host with the value passed by the proxy
 | 
			
		||||
		if r.Header.Get(xForwardedHost) != "" {
 | 
			
		||||
			r.Host = r.Header.Get(xForwardedHost)
 | 
			
		||||
		}
 | 
			
		||||
		// Call the next handler in the chain.
 | 
			
		||||
		h.ServeHTTP(w, r)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return http.HandlerFunc(fn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getIP retrieves the IP from the X-Forwarded-For, X-Real-IP and RFC7239
 | 
			
		||||
// Forwarded headers (in that order).
 | 
			
		||||
func getIP(r *http.Request) string {
 | 
			
		||||
	var addr string
 | 
			
		||||
 | 
			
		||||
	if fwd := r.Header.Get(xForwardedFor); fwd != "" {
 | 
			
		||||
		// Only grab the first (client) address. Note that '192.168.0.1,
 | 
			
		||||
		// 10.1.1.1' is a valid key for X-Forwarded-For where addresses after
 | 
			
		||||
		// the first may represent forwarding proxies earlier in the chain.
 | 
			
		||||
		s := strings.Index(fwd, ", ")
 | 
			
		||||
		if s == -1 {
 | 
			
		||||
			s = len(fwd)
 | 
			
		||||
		}
 | 
			
		||||
		addr = fwd[:s]
 | 
			
		||||
	} else if fwd := r.Header.Get(xRealIP); fwd != "" {
 | 
			
		||||
		// X-Real-IP should only contain one IP address (the client making the
 | 
			
		||||
		// request).
 | 
			
		||||
		addr = fwd
 | 
			
		||||
	} else if fwd := r.Header.Get(forwarded); fwd != "" {
 | 
			
		||||
		// match should contain at least two elements if the protocol was
 | 
			
		||||
		// specified in the Forwarded header. The first element will always be
 | 
			
		||||
		// the 'for=' capture, which we ignore. In the case of multiple IP
 | 
			
		||||
		// addresses (for=8.8.8.8, 8.8.4.4,172.16.1.20 is valid) we only
 | 
			
		||||
		// extract the first, which should be the client IP.
 | 
			
		||||
		if match := forRegex.FindStringSubmatch(fwd); len(match) > 1 {
 | 
			
		||||
			// IPv6 addresses in Forwarded headers are quoted-strings. We strip
 | 
			
		||||
			// these quotes.
 | 
			
		||||
			addr = strings.Trim(match[1], `"`)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return addr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getScheme retrieves the scheme from the X-Forwarded-Proto and RFC7239
 | 
			
		||||
// Forwarded headers (in that order).
 | 
			
		||||
func getScheme(r *http.Request) string {
 | 
			
		||||
	var scheme string
 | 
			
		||||
 | 
			
		||||
	// Retrieve the scheme from X-Forwarded-Proto.
 | 
			
		||||
	if proto := r.Header.Get(xForwardedProto); proto != "" {
 | 
			
		||||
		scheme = strings.ToLower(proto)
 | 
			
		||||
	} else if proto = r.Header.Get(xForwardedScheme); proto != "" {
 | 
			
		||||
		scheme = strings.ToLower(proto)
 | 
			
		||||
	} else if proto = r.Header.Get(forwarded); proto != "" {
 | 
			
		||||
		// match should contain at least two elements if the protocol was
 | 
			
		||||
		// specified in the Forwarded header. The first element will always be
 | 
			
		||||
		// the 'proto=' capture, which we ignore. In the case of multiple proto
 | 
			
		||||
		// parameters (invalid) we only extract the first.
 | 
			
		||||
		if match := protoRegex.FindStringSubmatch(proto); len(match) > 1 {
 | 
			
		||||
			scheme = strings.ToLower(match[1])
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return scheme
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										91
									
								
								vendor/github.com/gorilla/handlers/recovery.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								vendor/github.com/gorilla/handlers/recovery.go
									
									
									
										generated
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,91 @@
 | 
			
		||||
package handlers
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"runtime/debug"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// RecoveryHandlerLogger is an interface used by the recovering handler to print logs.
 | 
			
		||||
type RecoveryHandlerLogger interface {
 | 
			
		||||
	Println(...interface{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type recoveryHandler struct {
 | 
			
		||||
	handler    http.Handler
 | 
			
		||||
	logger     RecoveryHandlerLogger
 | 
			
		||||
	printStack bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RecoveryOption provides a functional approach to define
 | 
			
		||||
// configuration for a handler; such as setting the logging
 | 
			
		||||
// whether or not to print strack traces on panic.
 | 
			
		||||
type RecoveryOption func(http.Handler)
 | 
			
		||||
 | 
			
		||||
func parseRecoveryOptions(h http.Handler, opts ...RecoveryOption) http.Handler {
 | 
			
		||||
	for _, option := range opts {
 | 
			
		||||
		option(h)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RecoveryHandler is HTTP middleware that recovers from a panic,
 | 
			
		||||
// logs the panic, writes http.StatusInternalServerError, and
 | 
			
		||||
// continues to the next handler.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//
 | 
			
		||||
//  r := mux.NewRouter()
 | 
			
		||||
//  r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
 | 
			
		||||
//  	panic("Unexpected error!")
 | 
			
		||||
//  })
 | 
			
		||||
//
 | 
			
		||||
//  http.ListenAndServe(":1123", handlers.RecoveryHandler()(r))
 | 
			
		||||
func RecoveryHandler(opts ...RecoveryOption) func(h http.Handler) http.Handler {
 | 
			
		||||
	return func(h http.Handler) http.Handler {
 | 
			
		||||
		r := &recoveryHandler{handler: h}
 | 
			
		||||
		return parseRecoveryOptions(r, opts...)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RecoveryLogger is a functional option to override
 | 
			
		||||
// the default logger
 | 
			
		||||
func RecoveryLogger(logger RecoveryHandlerLogger) RecoveryOption {
 | 
			
		||||
	return func(h http.Handler) {
 | 
			
		||||
		r := h.(*recoveryHandler)
 | 
			
		||||
		r.logger = logger
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PrintRecoveryStack is a functional option to enable
 | 
			
		||||
// or disable printing stack traces on panic.
 | 
			
		||||
func PrintRecoveryStack(print bool) RecoveryOption {
 | 
			
		||||
	return func(h http.Handler) {
 | 
			
		||||
		r := h.(*recoveryHandler)
 | 
			
		||||
		r.printStack = print
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h recoveryHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err := recover(); err != nil {
 | 
			
		||||
			w.WriteHeader(http.StatusInternalServerError)
 | 
			
		||||
			h.log(err)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	h.handler.ServeHTTP(w, req)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h recoveryHandler) log(v ...interface{}) {
 | 
			
		||||
	if h.logger != nil {
 | 
			
		||||
		h.logger.Println(v...)
 | 
			
		||||
	} else {
 | 
			
		||||
		log.Println(v...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if h.printStack {
 | 
			
		||||
		debug.PrintStack()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user