*: switch dex to the ported templates

This commit is contained in:
Eric Chiang 2016-08-25 13:10:19 -07:00
parent 027e3d366c
commit 608d8ba984
7 changed files with 216 additions and 92 deletions

View File

@ -8,6 +8,7 @@ import (
"github.com/coreos/dex/connector/ldap" "github.com/coreos/dex/connector/ldap"
"github.com/coreos/dex/connector/mock" "github.com/coreos/dex/connector/mock"
"github.com/coreos/dex/connector/oidc" "github.com/coreos/dex/connector/oidc"
"github.com/coreos/dex/server"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
"github.com/coreos/dex/storage/kubernetes" "github.com/coreos/dex/storage/kubernetes"
"github.com/coreos/dex/storage/memory" "github.com/coreos/dex/storage/memory"
@ -21,6 +22,8 @@ type Config struct {
Web Web `yaml:"web"` Web Web `yaml:"web"`
OAuth2 OAuth2 `yaml:"oauth2"` OAuth2 OAuth2 `yaml:"oauth2"`
Templates server.TemplateConfig `yaml:"templates"`
StaticClients []storage.Client `yaml:"staticClients"` StaticClients []storage.Client `yaml:"staticClients"`
} }

View File

@ -89,11 +89,11 @@ func serve(cmd *cobra.Command, args []string) error {
} }
serverConfig := server.Config{ serverConfig := server.Config{
SupportedResponseTypes: c.OAuth2.ResponseTypes,
Issuer: c.Issuer, Issuer: c.Issuer,
Connectors: connectors, Connectors: connectors,
Storage: s, Storage: s,
TemplateConfig: c.Templates,
SupportedResponseTypes: c.OAuth2.ResponseTypes,
} }
serv, err := server.New(serverConfig) serv, err := server.New(serverConfig)

View File

@ -129,15 +129,16 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
connectorInfos := make([]connectorInfo, len(s.connectors)) connectorInfos := make([]connectorInfo, len(s.connectors))
i := 0 i := 0
for id := range s.connectors { for id, conn := range s.connectors {
connectorInfos[i] = connectorInfo{ connectorInfos[i] = connectorInfo{
DisplayName: id, ID: id,
Name: conn.DisplayName,
URL: s.absPath("/auth", id), URL: s.absPath("/auth", id),
} }
i++ i++
} }
renderLoginOptions(w, connectorInfos, state) s.templates.login(w, connectorInfos, state)
} }
func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
@ -163,7 +164,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
} }
http.Redirect(w, r, callbackURL, http.StatusFound) http.Redirect(w, r, callbackURL, http.StatusFound)
case connector.PasswordConnector: case connector.PasswordConnector:
renderPasswordTmpl(w, state, r.URL.String(), "") s.templates.password(w, state, r.URL.String(), "", false)
default: default:
s.notFound(w, r) s.notFound(w, r)
} }
@ -174,7 +175,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
return return
} }
username := r.FormValue("username") username := r.FormValue("login")
password := r.FormValue("password") password := r.FormValue("password")
identity, ok, err := passwordConnector.Login(username, password) identity, ok, err := passwordConnector.Login(username, password)
@ -184,7 +185,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
return return
} }
if !ok { if !ok {
renderPasswordTmpl(w, state, r.URL.String(), "Invalid credentials") s.templates.password(w, state, r.URL.String(), username, true)
return return
} }
redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector) redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector)
@ -299,7 +300,7 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
s.renderError(w, http.StatusInternalServerError, errServerError, "") s.renderError(w, http.StatusInternalServerError, errServerError, "")
return return
} }
renderApprovalTmpl(w, authReq.ID, *authReq.Claims, client, authReq.Scopes) s.templates.approval(w, authReq.ID, authReq.Claims.Username, client.Name, authReq.Scopes)
case "POST": case "POST":
if r.FormValue("approval") != "approve" { if r.FormValue("approval") != "approve" {
s.renderError(w, http.StatusInternalServerError, "approval rejected", "") s.renderError(w, http.StatusInternalServerError, "approval rejected", "")

View File

@ -43,6 +43,8 @@ type Config struct {
// If specified, the server will use this function for determining time. // If specified, the server will use this function for determining time.
Now func() time.Time Now func() time.Time
TemplateConfig TemplateConfig
} }
func value(val, defaultValue time.Duration) time.Duration { func value(val, defaultValue time.Duration) time.Duration {
@ -63,6 +65,8 @@ type Server struct {
mux http.Handler mux http.Handler
templates *templates
// If enabled, don't prompt user for approval after logging in through connector. // If enabled, don't prompt user for approval after logging in through connector.
// No package level API to set this, only used in tests. // No package level API to set this, only used in tests.
skipApproval bool skipApproval bool
@ -107,6 +111,11 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
supported[respType] = true supported[respType] = true
} }
tmpls, err := loadTemplates(c.TemplateConfig)
if err != nil {
return nil, fmt.Errorf("server: failed to load templates: %v", err)
}
now := c.Now now := c.Now
if now == nil { if now == nil {
now = time.Now now = time.Now
@ -124,6 +133,7 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
supportedResponseTypes: supported, supportedResponseTypes: supported,
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
now: now, now: now,
templates: tmpls,
} }
for _, conn := range c.Connectors { for _, conn := range c.Connectors {

View File

@ -64,7 +64,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
-----END RSA PRIVATE KEY-----`) -----END RSA PRIVATE KEY-----`)
func newTestServer(updateConfig func(c *Config)) (*httptest.Server, *Server) { func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r) server.ServeHTTP(w, r)
@ -76,7 +76,7 @@ func newTestServer(updateConfig func(c *Config)) (*httptest.Server, *Server) {
{ {
ID: "mock", ID: "mock",
DisplayName: "Mock", DisplayName: "Mock",
Connector: mock.New(), Connector: mock.NewCallbackConnector(),
}, },
}, },
} }
@ -87,21 +87,21 @@ func newTestServer(updateConfig func(c *Config)) (*httptest.Server, *Server) {
var err error var err error
if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil { if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil {
panic(err) t.Fatal(err)
} }
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
return s, server return s, server
} }
func TestNewTestServer(t *testing.T) { func TestNewTestServer(t *testing.T) {
newTestServer(nil) newTestServer(t, nil)
} }
func TestDiscovery(t *testing.T) { func TestDiscovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, _ := newTestServer(func(c *Config) { httpServer, _ := newTestServer(t, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path" c.Issuer = c.Issuer + "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()
@ -129,7 +129,7 @@ func TestOAuth2CodeFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, s := newTestServer(func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer = c.Issuer + "/non-root-path" c.Issuer = c.Issuer + "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()
@ -255,7 +255,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
httpServer, s := newTestServer(func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
// Enable support for the implicit flow. // Enable support for the implicit flow.
c.SupportedResponseTypes = []string{"code", "token"} c.SupportedResponseTypes = []string{"code", "token"}
}) })

View File

@ -1,101 +1,196 @@
package server package server
import ( import (
"fmt"
"io"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"path/filepath"
"sort"
"text/template" "text/template"
"github.com/coreos/dex/storage"
) )
const (
tmplApproval = "approval.html"
tmplLogin = "login.html"
tmplPassword = "password.html"
)
const coreOSLogoURL = "https://coreos.com/assets/images/brand/coreos-wordmark-135x40px.png"
var requiredTmpls = []string{
tmplApproval,
tmplLogin,
tmplPassword,
}
// TemplateConfig describes.
type TemplateConfig struct {
// Directory of the templates. If empty, these will be loaded from memory.
Dir string `yaml:"dir"`
// Defaults to the CoreOS logo and "dex".
LogoURL string `yaml:"logoURL"`
Issuer string `yaml:"issuerName"`
}
type globalData struct {
LogoURL string
Issuer string
}
func loadTemplates(config TemplateConfig) (*templates, error) {
var tmpls *template.Template
if config.Dir != "" {
files, err := ioutil.ReadDir(config.Dir)
if err != nil {
return nil, fmt.Errorf("read dir: %v", err)
}
filenames := []string{}
for _, file := range files {
if file.IsDir() {
continue
}
filenames = append(filenames, filepath.Join(config.Dir, file.Name()))
}
if len(filenames) == 0 {
return nil, fmt.Errorf("no files in template dir %s", config.Dir)
}
if tmpls, err = template.ParseFiles(filenames...); err != nil {
return nil, fmt.Errorf("parse files: %v", err)
}
} else {
// Load templates from memory. This code is largely copied from the standard library's
// ParseFiles source code.
// See: https://goo.gl/6Wm4mN
for name, data := range defaultTemplates {
var t *template.Template
if tmpls == nil {
tmpls = template.New(name)
}
if name == tmpls.Name() {
t = tmpls
} else {
t = tmpls.New(name)
}
if _, err := t.Parse(data); err != nil {
return nil, fmt.Errorf("parsing %s: %v", name, err)
}
}
}
missingTmpls := []string{}
for _, tmplName := range requiredTmpls {
if tmpls.Lookup(tmplName) == nil {
missingTmpls = append(missingTmpls, tmplName)
}
}
if len(missingTmpls) > 0 {
return nil, fmt.Errorf("missing template(s): %s", missingTmpls)
}
if config.LogoURL == "" {
config.LogoURL = coreOSLogoURL
}
if config.Issuer == "" {
config.Issuer = "dex"
}
return &templates{
globalData: config,
loginTmpl: tmpls.Lookup(tmplLogin),
approvalTmpl: tmpls.Lookup(tmplApproval),
passwordTmpl: tmpls.Lookup(tmplPassword),
}, nil
}
var scopeDescriptions = map[string]string{
"offline_access": "Have offline access",
"profile": "View basic profile information",
"email": "View your email",
}
type templates struct {
globalData TemplateConfig
loginTmpl *template.Template
approvalTmpl *template.Template
passwordTmpl *template.Template
}
type connectorInfo struct { type connectorInfo struct {
DisplayName string ID string
Name string
URL string URL string
} }
var loginTmpl = template.Must(template.New("login-template").Parse(`<html> type byName []connectorInfo
<head></head>
<body> func (n byName) Len() int { return len(n) }
<p>Login options</p> func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
{{ range $i, $connector := .Connectors }} func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
<a href="{{ $connector.URL }}?state={{ $.State }}">{{ $connector.DisplayName }}</a>
{{ end }} func (t *templates) login(w http.ResponseWriter, connectors []connectorInfo, state string) {
</body> sort.Sort(byName(connectors))
</html>`))
func renderLoginOptions(w http.ResponseWriter, connectors []connectorInfo, state string) {
data := struct { data := struct {
TemplateConfig
Connectors []connectorInfo Connectors []connectorInfo
State string State string
}{connectors, state} }{t.globalData, connectors, state}
renderTemplate(w, loginTmpl, data) renderTemplate(w, t.loginTmpl, data)
} }
var passwordTmpl = template.Must(template.New("password-template").Parse(`<html> func (t *templates) password(w http.ResponseWriter, state, callback, lastUsername string, lastWasInvalid bool) {
<body>
<p>Login</p>
<form action="{{ .Callback }}" method="POST">
Login: <input type="text" name="login"/><br/>
Password: <input type="password" name="password"/><br/>
<input type="hidden" name="state" value="{{ .State }}"/>
<input type="submit"/>
{{ if .Message }}
<p>Error: {{ .Message }}</p>
{{ end }}
</form>
</body>
</html>`))
func renderPasswordTmpl(w http.ResponseWriter, state, callback, message string) {
data := struct { data := struct {
TemplateConfig
State string State string
Callback string PostURL string
Message string Username string
}{state, callback, message} Invalid bool
renderTemplate(w, passwordTmpl, data) }{t.globalData, state, callback, lastUsername, lastWasInvalid}
renderTemplate(w, t.passwordTmpl, data)
} }
var approvalTmpl = template.Must(template.New("approval-template").Parse(`<html> func (t *templates) approval(w http.ResponseWriter, state, username, clientName string, scopes []string) {
<body> accesses := []string{}
<p>User: {{ .User }}</p> for _, scope := range scopes {
<p>Client: {{ .ClientName }}</p> access, ok := scopeDescriptions[scope]
<form method="post"> if ok {
<input type="hidden" name="state" value="{{ .State }}"/> accesses = append(accesses, access)
<input type="hidden" name="approval" value="approve"> }
<button type="submit">Approve</button> }
</form> sort.Strings(accesses)
<form method="post">
<input type="hidden" name="state" value="{{ .State }}"/>
<input type="hidden" name="approval" value="reject">
<button type="submit">Reject</button>
</form>
</body>
</html>`))
func renderApprovalTmpl(w http.ResponseWriter, state string, identity storage.Claims, client storage.Client, scopes []string) {
data := struct { data := struct {
TemplateConfig
User string User string
ClientName string Client string
State string State string
}{identity.Email, client.Name, state} Scopes []string
renderTemplate(w, approvalTmpl, data) }{t.globalData, username, clientName, state, accesses}
renderTemplate(w, t.approvalTmpl, data)
}
// small io.Writer utilitiy to determine if executing the template wrote to the underlying response writer.
type writeRecorder struct {
wrote bool
w io.Writer
}
func (w *writeRecorder) Write(p []byte) (n int, err error) {
w.wrote = true
return w.w.Write(p)
} }
func renderTemplate(w http.ResponseWriter, tmpl *template.Template, data interface{}) { func renderTemplate(w http.ResponseWriter, tmpl *template.Template, data interface{}) {
err := tmpl.Execute(w, data) wr := &writeRecorder{w: w}
if err == nil { if err := tmpl.Execute(wr, data); err != nil {
return
}
switch err := err.(type) {
case template.ExecError:
// An ExecError guarentees that Execute has not written to the underlying reader.
log.Printf("Error rendering template %s: %s", tmpl.Name(), err) log.Printf("Error rendering template %s: %s", tmpl.Name(), err)
if !wr.wrote {
// TODO(ericchiang): replace with better internal server error. // TODO(ericchiang): replace with better internal server error.
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)
default:
// An error with the underlying write, such as the connection being
// dropped. Ignore for now.
} }
} }
return
}

View File

@ -1 +1,16 @@
package server package server
import "testing"
func TestNewTemplates(t *testing.T) {
var config TemplateConfig
if _, err := loadTemplates(config); err != nil {
t.Fatal(err)
}
}
func TestLoadTemplates(t *testing.T) {
var config TemplateConfig
config.Dir = "../web/templates"
}