templates: add new relativeURL function
Signed-off-by: Yannis Zarkadas <yanniszark@arrikto.com>
This commit is contained in:
		@@ -6,7 +6,9 @@ import (
 | 
				
			|||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
	"os"
 | 
						"os"
 | 
				
			||||||
 | 
						"path"
 | 
				
			||||||
	"path/filepath"
 | 
						"path/filepath"
 | 
				
			||||||
	"sort"
 | 
						"sort"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
@@ -94,7 +96,7 @@ func loadWebConfig(c webConfig) (static, theme http.Handler, templates *template
 | 
				
			|||||||
		c.dir = "./web"
 | 
							c.dir = "./web"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if c.logoURL == "" {
 | 
						if c.logoURL == "" {
 | 
				
			||||||
		c.logoURL = join(c.issuerURL, "theme/logo.png")
 | 
							c.logoURL = "theme/logo.png"
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err := dirExists(c.dir); err != nil {
 | 
						if err := dirExists(c.dir); err != nil {
 | 
				
			||||||
@@ -136,10 +138,15 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
 | 
				
			|||||||
		return nil, fmt.Errorf("no files in template dir %q", templatesDir)
 | 
							return nil, fmt.Errorf("no files in template dir %q", templatesDir)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						issuerURL, err := url.Parse(c.issuerURL)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, fmt.Errorf("error parsing issuerURL: %v", err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	funcs := map[string]interface{}{
 | 
						funcs := map[string]interface{}{
 | 
				
			||||||
		"issuer": func() string { return c.issuer },
 | 
							"issuer": func() string { return c.issuer },
 | 
				
			||||||
		"logo":   func() string { return c.logoURL },
 | 
							"logo":   func() string { return c.logoURL },
 | 
				
			||||||
		"url":    func(s string) string { return join(c.issuerURL, s) },
 | 
							"url":    func(reqPath, assetPath string) string { return relativeURL(issuerURL.Path, reqPath, assetPath) },
 | 
				
			||||||
		"lower":  strings.ToLower,
 | 
							"lower":  strings.ToLower,
 | 
				
			||||||
		"extra":  func(k string) string { return c.extra[k] },
 | 
							"extra":  func(k string) string { return c.extra[k] },
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -166,6 +173,69 @@ func loadTemplates(c webConfig, templatesDir string) (*templates, error) {
 | 
				
			|||||||
	}, nil
 | 
						}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// relativeURL returns the URL of the asset relative to the URL of the request path.
 | 
				
			||||||
 | 
					// The serverPath is consulted to trim any prefix due in case it is not listening
 | 
				
			||||||
 | 
					// to the root path.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Algorithm:
 | 
				
			||||||
 | 
					// 1. Remove common prefix of serverPath and reqPath
 | 
				
			||||||
 | 
					// 2. Remove common prefix of assetPath and reqPath
 | 
				
			||||||
 | 
					// 3. For each part of reqPath remaining(minus one), go up one level (..)
 | 
				
			||||||
 | 
					// 4. For each part of assetPath remaining, append it to result
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//eg
 | 
				
			||||||
 | 
					//server listens at localhost/dex so serverPath is dex
 | 
				
			||||||
 | 
					//reqPath is /dex/auth
 | 
				
			||||||
 | 
					//assetPath is static/main.css
 | 
				
			||||||
 | 
					//relativeURL("/dex", "/dex/auth", "static/main.css") = "../static/main.css"
 | 
				
			||||||
 | 
					func relativeURL(serverPath, reqPath, assetPath string) string {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						splitPath := func(p string) []string {
 | 
				
			||||||
 | 
							res := []string{}
 | 
				
			||||||
 | 
							parts := strings.Split(path.Clean(p), "/")
 | 
				
			||||||
 | 
							for _, part := range parts {
 | 
				
			||||||
 | 
								if part != "" {
 | 
				
			||||||
 | 
									res = append(res, part)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return res
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						stripCommonParts := func(s1, s2 []string) ([]string, []string) {
 | 
				
			||||||
 | 
							min := len(s1)
 | 
				
			||||||
 | 
							if len(s2) < min {
 | 
				
			||||||
 | 
								min = len(s2)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							splitIndex := min
 | 
				
			||||||
 | 
							for i := 0; i < min; i++ {
 | 
				
			||||||
 | 
								if s1[i] != s2[i] {
 | 
				
			||||||
 | 
									splitIndex = i
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							return s1[splitIndex:], s2[splitIndex:]
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						server, req, asset := splitPath(serverPath), splitPath(reqPath), splitPath(assetPath)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Remove common prefix of request path with server path
 | 
				
			||||||
 | 
						server, req = stripCommonParts(server, req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Remove common prefix of request path with asset path
 | 
				
			||||||
 | 
						asset, req = stripCommonParts(asset, req)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// For each part of the request remaining (minus one) -> go up one level (..)
 | 
				
			||||||
 | 
						// For each part of the asset remaining               -> append it
 | 
				
			||||||
 | 
						var relativeURL string
 | 
				
			||||||
 | 
						for i := 0; i < len(req)-1; i++ {
 | 
				
			||||||
 | 
							relativeURL = path.Join("..", relativeURL)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						relativeURL = path.Join(relativeURL, path.Join(asset...))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return relativeURL
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var scopeDescriptions = map[string]string{
 | 
					var scopeDescriptions = map[string]string{
 | 
				
			||||||
	"offline_access": "Have offline access",
 | 
						"offline_access": "Have offline access",
 | 
				
			||||||
	"profile":        "View basic profile information",
 | 
						"profile":        "View basic profile information",
 | 
				
			||||||
@@ -184,26 +254,28 @@ func (n byName) Len() int           { return len(n) }
 | 
				
			|||||||
func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
 | 
					func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
 | 
				
			||||||
func (n byName) Swap(i, j int)      { n[i], n[j] = n[j], n[i] }
 | 
					func (n byName) Swap(i, j int)      { n[i], n[j] = n[j], n[i] }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *templates) login(w http.ResponseWriter, connectors []connectorInfo) error {
 | 
					func (t *templates) login(r *http.Request, w http.ResponseWriter, connectors []connectorInfo, reqPath string) error {
 | 
				
			||||||
	sort.Sort(byName(connectors))
 | 
						sort.Sort(byName(connectors))
 | 
				
			||||||
	data := struct {
 | 
						data := struct {
 | 
				
			||||||
		Connectors []connectorInfo
 | 
							Connectors []connectorInfo
 | 
				
			||||||
	}{connectors}
 | 
							ReqPath    string
 | 
				
			||||||
 | 
						}{connectors, r.URL.Path}
 | 
				
			||||||
	return renderTemplate(w, t.loginTmpl, data)
 | 
						return renderTemplate(w, t.loginTmpl, data)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *templates) password(w http.ResponseWriter, postURL, lastUsername, usernamePrompt string, lastWasInvalid, showBacklink bool) error {
 | 
					func (t *templates) password(r *http.Request, w http.ResponseWriter, postURL, lastUsername, usernamePrompt string, lastWasInvalid, showBacklink bool, reqPath string) error {
 | 
				
			||||||
	data := struct {
 | 
						data := struct {
 | 
				
			||||||
		PostURL        string
 | 
							PostURL        string
 | 
				
			||||||
		BackLink       bool
 | 
							BackLink       bool
 | 
				
			||||||
		Username       string
 | 
							Username       string
 | 
				
			||||||
		UsernamePrompt string
 | 
							UsernamePrompt string
 | 
				
			||||||
		Invalid        bool
 | 
							Invalid        bool
 | 
				
			||||||
	}{postURL, showBacklink, lastUsername, usernamePrompt, lastWasInvalid}
 | 
							ReqPath        string
 | 
				
			||||||
 | 
						}{postURL, showBacklink, lastUsername, usernamePrompt, lastWasInvalid, r.URL.Path}
 | 
				
			||||||
	return renderTemplate(w, t.passwordTmpl, data)
 | 
						return renderTemplate(w, t.passwordTmpl, data)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *templates) approval(w http.ResponseWriter, authReqID, username, clientName string, scopes []string) error {
 | 
					func (t *templates) approval(r *http.Request, w http.ResponseWriter, authReqID, username, clientName string, scopes []string, reqPath string) error {
 | 
				
			||||||
	accesses := []string{}
 | 
						accesses := []string{}
 | 
				
			||||||
	for _, scope := range scopes {
 | 
						for _, scope := range scopes {
 | 
				
			||||||
		access, ok := scopeDescriptions[scope]
 | 
							access, ok := scopeDescriptions[scope]
 | 
				
			||||||
@@ -217,23 +289,26 @@ func (t *templates) approval(w http.ResponseWriter, authReqID, username, clientN
 | 
				
			|||||||
		Client    string
 | 
							Client    string
 | 
				
			||||||
		AuthReqID string
 | 
							AuthReqID string
 | 
				
			||||||
		Scopes    []string
 | 
							Scopes    []string
 | 
				
			||||||
	}{username, clientName, authReqID, accesses}
 | 
							ReqPath   string
 | 
				
			||||||
 | 
						}{username, clientName, authReqID, accesses, r.URL.Path}
 | 
				
			||||||
	return renderTemplate(w, t.approvalTmpl, data)
 | 
						return renderTemplate(w, t.approvalTmpl, data)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *templates) oob(w http.ResponseWriter, code string) error {
 | 
					func (t *templates) oob(r *http.Request, w http.ResponseWriter, code string, reqPath string) error {
 | 
				
			||||||
	data := struct {
 | 
						data := struct {
 | 
				
			||||||
		Code    string
 | 
							Code    string
 | 
				
			||||||
	}{code}
 | 
							ReqPath string
 | 
				
			||||||
 | 
						}{code, r.URL.Path}
 | 
				
			||||||
	return renderTemplate(w, t.oobTmpl, data)
 | 
						return renderTemplate(w, t.oobTmpl, data)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (t *templates) err(w http.ResponseWriter, errCode int, errMsg string) error {
 | 
					func (t *templates) err(r *http.Request, w http.ResponseWriter, errCode int, errMsg string) error {
 | 
				
			||||||
	w.WriteHeader(errCode)
 | 
						w.WriteHeader(errCode)
 | 
				
			||||||
	data := struct {
 | 
						data := struct {
 | 
				
			||||||
		ErrType string
 | 
							ErrType string
 | 
				
			||||||
		ErrMsg  string
 | 
							ErrMsg  string
 | 
				
			||||||
	}{http.StatusText(errCode), errMsg}
 | 
							ReqPath string
 | 
				
			||||||
 | 
						}{http.StatusText(errCode), errMsg, r.URL.Path}
 | 
				
			||||||
	if err := t.errorTmpl.Execute(w, data); err != nil {
 | 
						if err := t.errorTmpl.Execute(w, data); err != nil {
 | 
				
			||||||
		return fmt.Errorf("Error rendering template %s: %s", t.errorTmpl.Name(), err)
 | 
							return fmt.Errorf("Error rendering template %s: %s", t.errorTmpl.Name(), err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1 +1,44 @@
 | 
				
			|||||||
package server
 | 
					package server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestRelativeURL(t *testing.T) {
 | 
				
			||||||
 | 
						tests := []struct {
 | 
				
			||||||
 | 
							name       string
 | 
				
			||||||
 | 
							serverPath string
 | 
				
			||||||
 | 
							reqPath    string
 | 
				
			||||||
 | 
							assetPath  string
 | 
				
			||||||
 | 
							expected   string
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:       "server-root-req-one-level-asset-two-level",
 | 
				
			||||||
 | 
								serverPath: "/",
 | 
				
			||||||
 | 
								reqPath:    "/auth",
 | 
				
			||||||
 | 
								assetPath:  "/theme/main.css",
 | 
				
			||||||
 | 
								expected:   "theme/main.css",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:       "server-one-level-req-one-level-asset-two-level",
 | 
				
			||||||
 | 
								serverPath: "/dex",
 | 
				
			||||||
 | 
								reqPath:    "/dex/auth",
 | 
				
			||||||
 | 
								assetPath:  "/theme/main.css",
 | 
				
			||||||
 | 
								expected:   "theme/main.css",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
							{
 | 
				
			||||||
 | 
								name:       "server-root-req-two-level-asset-three-level",
 | 
				
			||||||
 | 
								serverPath: "/dex",
 | 
				
			||||||
 | 
								reqPath:    "/dex/auth/connector",
 | 
				
			||||||
 | 
								assetPath:  "assets/css/main.css",
 | 
				
			||||||
 | 
								expected:   "../assets/css/main.css",
 | 
				
			||||||
 | 
							},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for _, test := range tests {
 | 
				
			||||||
 | 
							t.Run(test.name, func(t *testing.T) {
 | 
				
			||||||
 | 
								actual := relativeURL(test.serverPath, test.reqPath, test.assetPath)
 | 
				
			||||||
 | 
								if actual != test.expected {
 | 
				
			||||||
 | 
									t.Fatalf("Got '%s'. Expected '%s'", actual, test.expected)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user