From a3235d022a79ba199687701c9ba9ec39f5ec7047 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Thu, 27 Oct 2016 10:08:08 -0700 Subject: [PATCH 1/2] *: verify "state" field before passing request to callback connectors Let the server handle the state token instead of the connector. As a result it can throw out bad requests earlier. It can also use that token to determine which connector was used to generate the request allowing all connectors to share the same callback URL. Callbacks now all look like: https://dex.example.com/callback Instead of: https://dex.example.com/callback/(connector id) Even when multiple connectors are being used. --- connector/connector.go | 2 +- connector/github/github.go | 18 ++++---- connector/mock/connectortest.go | 4 +- connector/oidc/oidc.go | 14 +++--- server/handlers.go | 80 ++++++++++++++++++++++++--------- server/server.go | 2 +- 6 files changed, 79 insertions(+), 41 deletions(-) diff --git a/connector/connector.go b/connector/connector.go index 8235caae..9f84d3e6 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -33,7 +33,7 @@ type PasswordConnector interface { // CallbackConnector is an optional interface for callback based connectors. type CallbackConnector interface { LoginURL(callbackURL, state string) (string, error) - HandleCallback(r *http.Request) (identity Identity, state string, err error) + HandleCallback(r *http.Request) (identity Identity, err error) } // GroupsConnector is an optional interface for connectors which can map a user to groups. diff --git a/connector/github/github.go b/connector/github/github.go index b679e0ed..172b908d 100644 --- a/connector/github/github.go +++ b/connector/github/github.go @@ -84,28 +84,28 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) { +func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { - return identity, "", &oauth2Error{errType, q.Get("error_description")} + return identity, &oauth2Error{errType, q.Get("error_description")} } token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code")) if err != nil { - return identity, "", fmt.Errorf("github: failed to get token: %v", err) + return identity, fmt.Errorf("github: failed to get token: %v", err) } resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user") if err != nil { - return identity, "", fmt.Errorf("github: get URL %v", err) + return identity, fmt.Errorf("github: get URL %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, err := ioutil.ReadAll(resp.Body) if err != nil { - return identity, "", fmt.Errorf("github: read body: %v", err) + return identity, fmt.Errorf("github: read body: %v", err) } - return identity, "", fmt.Errorf("%s: %s", resp.Status, body) + return identity, fmt.Errorf("%s: %s", resp.Status, body) } var user struct { Name string `json:"name"` @@ -114,13 +114,13 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id Email string `json:"email"` } if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { - return identity, "", fmt.Errorf("failed to decode response: %v", err) + return identity, fmt.Errorf("failed to decode response: %v", err) } data := connectorData{AccessToken: token.AccessToken} connData, err := json.Marshal(data) if err != nil { - return identity, "", fmt.Errorf("marshal connector data: %v", err) + return identity, fmt.Errorf("marshal connector data: %v", err) } username := user.Name @@ -134,7 +134,7 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id EmailVerified: true, ConnectorData: connData, } - return identity, q.Get("state"), nil + return identity, nil } func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) { diff --git a/connector/mock/connectortest.go b/connector/mock/connectortest.go index 0d4b87ba..1ceeb361 100644 --- a/connector/mock/connectortest.go +++ b/connector/mock/connectortest.go @@ -41,14 +41,14 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) { var connectorData = []byte("foobar") -func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, string, error) { +func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, error) { return connector.Identity{ UserID: "0-385-28089-0", Username: "Kilgore Trout", Email: "kilgore@kilgore.trout", EmailVerified: true, ConnectorData: connectorData, - }, r.URL.Query().Get("state"), nil + }, nil } func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) { diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 6f19b7d4..d7fede04 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -95,23 +95,23 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) { +func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { - return identity, "", &oauth2Error{errType, q.Get("error_description")} + return identity, &oauth2Error{errType, q.Get("error_description")} } token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code")) if err != nil { - return identity, "", fmt.Errorf("oidc: failed to get token: %v", err) + return identity, fmt.Errorf("oidc: failed to get token: %v", err) } rawIDToken, ok := token.Extra("id_token").(string) if !ok { - return identity, "", errors.New("oidc: no id_token in token response") + return identity, errors.New("oidc: no id_token in token response") } idToken, err := c.verifier.Verify(rawIDToken) if err != nil { - return identity, "", fmt.Errorf("oidc: failed to verify ID Token: %v", err) + return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) } var claims struct { @@ -120,7 +120,7 @@ func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Iden EmailVerified bool `json:"email_verified"` } if err := idToken.Claims(&claims); err != nil { - return identity, "", fmt.Errorf("oidc: failed to decode claims: %v", err) + return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) } identity = connector.Identity{ @@ -129,5 +129,5 @@ func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Iden Email: claims.Email, EmailVerified: claims.EmailVerified, } - return identity, q.Get("state"), nil + return identity, nil } diff --git a/server/handlers.go b/server/handlers.go index 83ef5376..6017b126 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -2,7 +2,6 @@ package server import ( "encoding/json" - "errors" "fmt" "log" "net/http" @@ -180,14 +179,26 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } + authReqID := r.FormValue("state") + // TODO(ericchiang): cache user identity. - state := r.FormValue("state") switch r.Method { case "GET": + // Set the connector being used for the login. + updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { + a.ConnectorID = connID + return a, nil + } + if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil { + log.Printf("Failed to set connector ID on auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, errServerError, "") + return + } + switch conn := conn.Connector.(type) { case connector.CallbackConnector: - callbackURL, err := conn.LoginURL(s.absURL("/callback", connID), state) + callbackURL, err := conn.LoginURL(s.absURL("/callback"), authReqID) if err != nil { log.Printf("Connector %q returned error when creating callback: %v", connID, err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -195,7 +206,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: - s.templates.password(w, state, r.URL.String(), "", false) + s.templates.password(w, authReqID, r.URL.String(), "", false) default: s.notFound(w, r) } @@ -216,10 +227,16 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } if !ok { - s.templates.password(w, state, r.URL.String(), username, true) + s.templates.password(w, authReqID, r.URL.String(), username, true) return } - redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector) + authReq, err := s.storage.GetAuthRequest(authReqID) + if err != nil { + log.Printf("Failed to get auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, errServerError, "") + return + } + redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) if err != nil { log.Printf("Failed to finalize login: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -233,8 +250,31 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) { - connID := mux.Vars(r)["connector"] - conn, ok := s.connectors[connID] + // SAML redirect bindings use the "RelayState" URL query field. When we support + // SAML, we'll have to check that field too and possibly let callback connectors + // indicate which field is used to determine the state. + // + // See: + // https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf + // Section: "3.4.3 RelayState" + state := r.URL.Query().Get("state") + if state == "" { + s.renderError(w, http.StatusBadRequest, errInvalidRequest, "no 'state' parameter provided") + return + } + + authReq, err := s.storage.GetAuthRequest(state) + if err != nil { + if err == storage.ErrNotFound { + s.renderError(w, http.StatusBadRequest, errInvalidRequest, "invalid 'state' parameter provided") + return + } + log.Printf("Failed to get auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, errServerError, "") + return + } + + conn, ok := s.connectors[authReq.ConnectorID] if !ok { s.notFound(w, r) return @@ -245,14 +285,14 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) return } - identity, state, err := callbackConnector.HandleCallback(r) + identity, err := callbackConnector.HandleCallback(r) if err != nil { log.Printf("Failed to authenticate: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") return } - redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector) + redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) if err != nil { log.Printf("Failed to finalize login: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -262,10 +302,11 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) http.Redirect(w, r, redirectURL, http.StatusSeeOther) } -func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connectorID string, conn connector.Connector) (string, error) { - if authReqID == "" { - return "", errors.New("no auth request ID passed") +func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) { + if authReq.ConnectorID == "" { + } + claims := storage.Claims{ UserID: identity.UserID, Username: identity.Username, @@ -275,10 +316,6 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connector groupsConn, ok := conn.(connector.GroupsConnector) if ok { - authReq, err := s.storage.GetAuthRequest(authReqID) - if err != nil { - return "", fmt.Errorf("get auth request: %v", err) - } reqGroups := func() bool { for _, scope := range authReq.Scopes { if scope == scopeGroups { @@ -288,23 +325,24 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connector return false }() if reqGroups { - if claims.Groups, err = groupsConn.Groups(identity); err != nil { + groups, err := groupsConn.Groups(identity) + if err != nil { return "", fmt.Errorf("getting groups: %v", err) } + claims.Groups = groups } } updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { a.LoggedIn = true a.Claims = claims - a.ConnectorID = connectorID a.ConnectorData = identity.ConnectorData return a, nil } - if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil { + if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil { return "", fmt.Errorf("failed to update auth request: %v", err) } - return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReqID, nil + return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReq.ID, nil } func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { diff --git a/server/server.go b/server/server.go index 628dfb46..603a23cb 100644 --- a/server/server.go +++ b/server/server.go @@ -172,7 +172,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc("/keys", s.handlePublicKeys) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) - handleFunc("/callback/{connector}", s.handleConnectorCallback) + handleFunc("/callback", s.handleConnectorCallback) handleFunc("/approval", s.handleApproval) handleFunc("/healthz", s.handleHealth) s.mux = r From 7c2289e0decadb96f0d5e977feccecfa69d410bc Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Thu, 27 Oct 2016 10:20:30 -0700 Subject: [PATCH 2/2] *: rename internally used "state" form value to "req" "state" means something specific to OAuth2 and SAML so we don't want to confuse developers who are working on this. Also don't use "session" which could easily be confused with HTTP cookies. --- server/handlers.go | 15 ++++++++------- server/templates.go | 30 +++++++++++++++--------------- server/templates_default.go | 8 ++++---- web/templates/approval.html | 4 ++-- web/templates/login.html | 2 +- web/templates/password.html | 2 +- 6 files changed, 31 insertions(+), 30 deletions(-) diff --git a/server/handlers.go b/server/handlers.go index 6017b126..055c0b83 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -148,11 +148,9 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { s.renderError(w, http.StatusInternalServerError, errServerError, "") return } - state := authReq.ID - if len(s.connectors) == 1 { for id := range s.connectors { - http.Redirect(w, r, s.absPath("/auth", id)+"?state="+state, http.StatusFound) + http.Redirect(w, r, s.absPath("/auth", id)+"?req="+authReq.ID, http.StatusFound) return } } @@ -168,7 +166,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { i++ } - s.templates.login(w, connectorInfos, state) + s.templates.login(w, connectorInfos, authReq.ID) } func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { @@ -179,7 +177,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } - authReqID := r.FormValue("state") + authReqID := r.FormValue("req") // TODO(ericchiang): cache user identity. @@ -198,6 +196,9 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { switch conn := conn.Connector.(type) { case connector.CallbackConnector: + // Use the auth request ID as the "state" token. + // + // TODO(ericchiang): Is this appropriate or should we also be using a nonce? callbackURL, err := conn.LoginURL(s.absURL("/callback"), authReqID) if err != nil { log.Printf("Connector %q returned error when creating callback: %v", connID, err) @@ -342,11 +343,11 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil { return "", fmt.Errorf("failed to update auth request: %v", err) } - return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReq.ID, nil + return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil } func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { - authReq, err := s.storage.GetAuthRequest(r.FormValue("state")) + authReq, err := s.storage.GetAuthRequest(r.FormValue("req")) if err != nil { log.Printf("Failed to get auth request: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") diff --git a/server/templates.go b/server/templates.go index 117d12c5..e8285fe3 100644 --- a/server/templates.go +++ b/server/templates.go @@ -138,29 +138,29 @@ func (n byName) Len() int { return len(n) } func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name } func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] } -func (t *templates) login(w http.ResponseWriter, connectors []connectorInfo, state string) { +func (t *templates) login(w http.ResponseWriter, connectors []connectorInfo, authReqID string) { sort.Sort(byName(connectors)) data := struct { TemplateConfig Connectors []connectorInfo - State string - }{t.globalData, connectors, state} + AuthReqID string + }{t.globalData, connectors, authReqID} renderTemplate(w, t.loginTmpl, data) } -func (t *templates) password(w http.ResponseWriter, state, callback, lastUsername string, lastWasInvalid bool) { +func (t *templates) password(w http.ResponseWriter, authReqID, callback, lastUsername string, lastWasInvalid bool) { data := struct { TemplateConfig - State string - PostURL string - Username string - Invalid bool - }{t.globalData, state, callback, lastUsername, lastWasInvalid} + AuthReqID string + PostURL string + Username string + Invalid bool + }{t.globalData, authReqID, callback, lastUsername, lastWasInvalid} renderTemplate(w, t.passwordTmpl, data) } -func (t *templates) approval(w http.ResponseWriter, state, username, clientName string, scopes []string) { +func (t *templates) approval(w http.ResponseWriter, authReqID, username, clientName string, scopes []string) { accesses := []string{} for _, scope := range scopes { access, ok := scopeDescriptions[scope] @@ -171,11 +171,11 @@ func (t *templates) approval(w http.ResponseWriter, state, username, clientName sort.Strings(accesses) data := struct { TemplateConfig - User string - Client string - State string - Scopes []string - }{t.globalData, username, clientName, state, accesses} + User string + Client string + AuthReqID string + Scopes []string + }{t.globalData, username, clientName, authReqID, accesses} renderTemplate(w, t.approvalTmpl, data) } diff --git a/server/templates_default.go b/server/templates_default.go index 3a3d031a..651c7411 100644 --- a/server/templates_default.go +++ b/server/templates_default.go @@ -25,7 +25,7 @@ var defaultTemplates = map[string]string{
- +
- +
- + {{ if .Invalid }}