From a3235d022a79ba199687701c9ba9ec39f5ec7047 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Thu, 27 Oct 2016 10:08:08 -0700 Subject: [PATCH] *: verify "state" field before passing request to callback connectors Let the server handle the state token instead of the connector. As a result it can throw out bad requests earlier. It can also use that token to determine which connector was used to generate the request allowing all connectors to share the same callback URL. Callbacks now all look like: https://dex.example.com/callback Instead of: https://dex.example.com/callback/(connector id) Even when multiple connectors are being used. --- connector/connector.go | 2 +- connector/github/github.go | 18 ++++---- connector/mock/connectortest.go | 4 +- connector/oidc/oidc.go | 14 +++--- server/handlers.go | 80 ++++++++++++++++++++++++--------- server/server.go | 2 +- 6 files changed, 79 insertions(+), 41 deletions(-) diff --git a/connector/connector.go b/connector/connector.go index 8235caae..9f84d3e6 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -33,7 +33,7 @@ type PasswordConnector interface { // CallbackConnector is an optional interface for callback based connectors. type CallbackConnector interface { LoginURL(callbackURL, state string) (string, error) - HandleCallback(r *http.Request) (identity Identity, state string, err error) + HandleCallback(r *http.Request) (identity Identity, err error) } // GroupsConnector is an optional interface for connectors which can map a user to groups. diff --git a/connector/github/github.go b/connector/github/github.go index b679e0ed..172b908d 100644 --- a/connector/github/github.go +++ b/connector/github/github.go @@ -84,28 +84,28 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) { +func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { - return identity, "", &oauth2Error{errType, q.Get("error_description")} + return identity, &oauth2Error{errType, q.Get("error_description")} } token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code")) if err != nil { - return identity, "", fmt.Errorf("github: failed to get token: %v", err) + return identity, fmt.Errorf("github: failed to get token: %v", err) } resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user") if err != nil { - return identity, "", fmt.Errorf("github: get URL %v", err) + return identity, fmt.Errorf("github: get URL %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, err := ioutil.ReadAll(resp.Body) if err != nil { - return identity, "", fmt.Errorf("github: read body: %v", err) + return identity, fmt.Errorf("github: read body: %v", err) } - return identity, "", fmt.Errorf("%s: %s", resp.Status, body) + return identity, fmt.Errorf("%s: %s", resp.Status, body) } var user struct { Name string `json:"name"` @@ -114,13 +114,13 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id Email string `json:"email"` } if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { - return identity, "", fmt.Errorf("failed to decode response: %v", err) + return identity, fmt.Errorf("failed to decode response: %v", err) } data := connectorData{AccessToken: token.AccessToken} connData, err := json.Marshal(data) if err != nil { - return identity, "", fmt.Errorf("marshal connector data: %v", err) + return identity, fmt.Errorf("marshal connector data: %v", err) } username := user.Name @@ -134,7 +134,7 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id EmailVerified: true, ConnectorData: connData, } - return identity, q.Get("state"), nil + return identity, nil } func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) { diff --git a/connector/mock/connectortest.go b/connector/mock/connectortest.go index 0d4b87ba..1ceeb361 100644 --- a/connector/mock/connectortest.go +++ b/connector/mock/connectortest.go @@ -41,14 +41,14 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) { var connectorData = []byte("foobar") -func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, string, error) { +func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, error) { return connector.Identity{ UserID: "0-385-28089-0", Username: "Kilgore Trout", Email: "kilgore@kilgore.trout", EmailVerified: true, ConnectorData: connectorData, - }, r.URL.Query().Get("state"), nil + }, nil } func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) { diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 6f19b7d4..d7fede04 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -95,23 +95,23 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) { +func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { - return identity, "", &oauth2Error{errType, q.Get("error_description")} + return identity, &oauth2Error{errType, q.Get("error_description")} } token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code")) if err != nil { - return identity, "", fmt.Errorf("oidc: failed to get token: %v", err) + return identity, fmt.Errorf("oidc: failed to get token: %v", err) } rawIDToken, ok := token.Extra("id_token").(string) if !ok { - return identity, "", errors.New("oidc: no id_token in token response") + return identity, errors.New("oidc: no id_token in token response") } idToken, err := c.verifier.Verify(rawIDToken) if err != nil { - return identity, "", fmt.Errorf("oidc: failed to verify ID Token: %v", err) + return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) } var claims struct { @@ -120,7 +120,7 @@ func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Iden EmailVerified bool `json:"email_verified"` } if err := idToken.Claims(&claims); err != nil { - return identity, "", fmt.Errorf("oidc: failed to decode claims: %v", err) + return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) } identity = connector.Identity{ @@ -129,5 +129,5 @@ func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Iden Email: claims.Email, EmailVerified: claims.EmailVerified, } - return identity, q.Get("state"), nil + return identity, nil } diff --git a/server/handlers.go b/server/handlers.go index 83ef5376..6017b126 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -2,7 +2,6 @@ package server import ( "encoding/json" - "errors" "fmt" "log" "net/http" @@ -180,14 +179,26 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } + authReqID := r.FormValue("state") + // TODO(ericchiang): cache user identity. - state := r.FormValue("state") switch r.Method { case "GET": + // Set the connector being used for the login. + updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { + a.ConnectorID = connID + return a, nil + } + if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil { + log.Printf("Failed to set connector ID on auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, errServerError, "") + return + } + switch conn := conn.Connector.(type) { case connector.CallbackConnector: - callbackURL, err := conn.LoginURL(s.absURL("/callback", connID), state) + callbackURL, err := conn.LoginURL(s.absURL("/callback"), authReqID) if err != nil { log.Printf("Connector %q returned error when creating callback: %v", connID, err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -195,7 +206,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: - s.templates.password(w, state, r.URL.String(), "", false) + s.templates.password(w, authReqID, r.URL.String(), "", false) default: s.notFound(w, r) } @@ -216,10 +227,16 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } if !ok { - s.templates.password(w, state, r.URL.String(), username, true) + s.templates.password(w, authReqID, r.URL.String(), username, true) return } - redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector) + authReq, err := s.storage.GetAuthRequest(authReqID) + if err != nil { + log.Printf("Failed to get auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, errServerError, "") + return + } + redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) if err != nil { log.Printf("Failed to finalize login: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -233,8 +250,31 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) { - connID := mux.Vars(r)["connector"] - conn, ok := s.connectors[connID] + // SAML redirect bindings use the "RelayState" URL query field. When we support + // SAML, we'll have to check that field too and possibly let callback connectors + // indicate which field is used to determine the state. + // + // See: + // https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf + // Section: "3.4.3 RelayState" + state := r.URL.Query().Get("state") + if state == "" { + s.renderError(w, http.StatusBadRequest, errInvalidRequest, "no 'state' parameter provided") + return + } + + authReq, err := s.storage.GetAuthRequest(state) + if err != nil { + if err == storage.ErrNotFound { + s.renderError(w, http.StatusBadRequest, errInvalidRequest, "invalid 'state' parameter provided") + return + } + log.Printf("Failed to get auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, errServerError, "") + return + } + + conn, ok := s.connectors[authReq.ConnectorID] if !ok { s.notFound(w, r) return @@ -245,14 +285,14 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) return } - identity, state, err := callbackConnector.HandleCallback(r) + identity, err := callbackConnector.HandleCallback(r) if err != nil { log.Printf("Failed to authenticate: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") return } - redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector) + redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) if err != nil { log.Printf("Failed to finalize login: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -262,10 +302,11 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) http.Redirect(w, r, redirectURL, http.StatusSeeOther) } -func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connectorID string, conn connector.Connector) (string, error) { - if authReqID == "" { - return "", errors.New("no auth request ID passed") +func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) { + if authReq.ConnectorID == "" { + } + claims := storage.Claims{ UserID: identity.UserID, Username: identity.Username, @@ -275,10 +316,6 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connector groupsConn, ok := conn.(connector.GroupsConnector) if ok { - authReq, err := s.storage.GetAuthRequest(authReqID) - if err != nil { - return "", fmt.Errorf("get auth request: %v", err) - } reqGroups := func() bool { for _, scope := range authReq.Scopes { if scope == scopeGroups { @@ -288,23 +325,24 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connector return false }() if reqGroups { - if claims.Groups, err = groupsConn.Groups(identity); err != nil { + groups, err := groupsConn.Groups(identity) + if err != nil { return "", fmt.Errorf("getting groups: %v", err) } + claims.Groups = groups } } updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { a.LoggedIn = true a.Claims = claims - a.ConnectorID = connectorID a.ConnectorData = identity.ConnectorData return a, nil } - if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil { + if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil { return "", fmt.Errorf("failed to update auth request: %v", err) } - return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReqID, nil + return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReq.ID, nil } func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { diff --git a/server/server.go b/server/server.go index 628dfb46..603a23cb 100644 --- a/server/server.go +++ b/server/server.go @@ -172,7 +172,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc("/keys", s.handlePublicKeys) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) - handleFunc("/callback/{connector}", s.handleConnectorCallback) + handleFunc("/callback", s.handleConnectorCallback) handleFunc("/approval", s.handleApproval) handleFunc("/healthz", s.handleHealth) s.mux = r