*: verify "state" field before passing request to callback connectors
Let the server handle the state token instead of the connector. As a
result it can throw out bad requests earlier. It can also use that
token to determine which connector was used to generate the request
allowing all connectors to share the same callback URL.
Callbacks now all look like:
https://dex.example.com/callback
Instead of:
https://dex.example.com/callback/(connector id)
Even when multiple connectors are being used.
This commit is contained in:
@@ -2,7 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
@@ -180,14 +179,26 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
authReqID := r.FormValue("state")
|
||||
|
||||
// TODO(ericchiang): cache user identity.
|
||||
|
||||
state := r.FormValue("state")
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
// Set the connector being used for the login.
|
||||
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
|
||||
a.ConnectorID = connID
|
||||
return a, nil
|
||||
}
|
||||
if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil {
|
||||
log.Printf("Failed to set connector ID on auth request: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
|
||||
switch conn := conn.Connector.(type) {
|
||||
case connector.CallbackConnector:
|
||||
callbackURL, err := conn.LoginURL(s.absURL("/callback", connID), state)
|
||||
callbackURL, err := conn.LoginURL(s.absURL("/callback"), authReqID)
|
||||
if err != nil {
|
||||
log.Printf("Connector %q returned error when creating callback: %v", connID, err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
@@ -195,7 +206,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
http.Redirect(w, r, callbackURL, http.StatusFound)
|
||||
case connector.PasswordConnector:
|
||||
s.templates.password(w, state, r.URL.String(), "", false)
|
||||
s.templates.password(w, authReqID, r.URL.String(), "", false)
|
||||
default:
|
||||
s.notFound(w, r)
|
||||
}
|
||||
@@ -216,10 +227,16 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
s.templates.password(w, state, r.URL.String(), username, true)
|
||||
s.templates.password(w, authReqID, r.URL.String(), username, true)
|
||||
return
|
||||
}
|
||||
redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector)
|
||||
authReq, err := s.storage.GetAuthRequest(authReqID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get auth request: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
|
||||
if err != nil {
|
||||
log.Printf("Failed to finalize login: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
@@ -233,8 +250,31 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) {
|
||||
connID := mux.Vars(r)["connector"]
|
||||
conn, ok := s.connectors[connID]
|
||||
// SAML redirect bindings use the "RelayState" URL query field. When we support
|
||||
// SAML, we'll have to check that field too and possibly let callback connectors
|
||||
// indicate which field is used to determine the state.
|
||||
//
|
||||
// See:
|
||||
// https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf
|
||||
// Section: "3.4.3 RelayState"
|
||||
state := r.URL.Query().Get("state")
|
||||
if state == "" {
|
||||
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "no 'state' parameter provided")
|
||||
return
|
||||
}
|
||||
|
||||
authReq, err := s.storage.GetAuthRequest(state)
|
||||
if err != nil {
|
||||
if err == storage.ErrNotFound {
|
||||
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "invalid 'state' parameter provided")
|
||||
return
|
||||
}
|
||||
log.Printf("Failed to get auth request: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := s.connectors[authReq.ConnectorID]
|
||||
if !ok {
|
||||
s.notFound(w, r)
|
||||
return
|
||||
@@ -245,14 +285,14 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
identity, state, err := callbackConnector.HandleCallback(r)
|
||||
identity, err := callbackConnector.HandleCallback(r)
|
||||
if err != nil {
|
||||
log.Printf("Failed to authenticate: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
|
||||
redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector)
|
||||
redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
|
||||
if err != nil {
|
||||
log.Printf("Failed to finalize login: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
@@ -262,10 +302,11 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
|
||||
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connectorID string, conn connector.Connector) (string, error) {
|
||||
if authReqID == "" {
|
||||
return "", errors.New("no auth request ID passed")
|
||||
func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) {
|
||||
if authReq.ConnectorID == "" {
|
||||
|
||||
}
|
||||
|
||||
claims := storage.Claims{
|
||||
UserID: identity.UserID,
|
||||
Username: identity.Username,
|
||||
@@ -275,10 +316,6 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connector
|
||||
|
||||
groupsConn, ok := conn.(connector.GroupsConnector)
|
||||
if ok {
|
||||
authReq, err := s.storage.GetAuthRequest(authReqID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get auth request: %v", err)
|
||||
}
|
||||
reqGroups := func() bool {
|
||||
for _, scope := range authReq.Scopes {
|
||||
if scope == scopeGroups {
|
||||
@@ -288,23 +325,24 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connector
|
||||
return false
|
||||
}()
|
||||
if reqGroups {
|
||||
if claims.Groups, err = groupsConn.Groups(identity); err != nil {
|
||||
groups, err := groupsConn.Groups(identity)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("getting groups: %v", err)
|
||||
}
|
||||
claims.Groups = groups
|
||||
}
|
||||
}
|
||||
|
||||
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
|
||||
a.LoggedIn = true
|
||||
a.Claims = claims
|
||||
a.ConnectorID = connectorID
|
||||
a.ConnectorData = identity.ConnectorData
|
||||
return a, nil
|
||||
}
|
||||
if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil {
|
||||
if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil {
|
||||
return "", fmt.Errorf("failed to update auth request: %v", err)
|
||||
}
|
||||
return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReqID, nil
|
||||
return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReq.ID, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -172,7 +172,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
|
||||
handleFunc("/keys", s.handlePublicKeys)
|
||||
handleFunc("/auth", s.handleAuthorization)
|
||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||
handleFunc("/callback/{connector}", s.handleConnectorCallback)
|
||||
handleFunc("/callback", s.handleConnectorCallback)
|
||||
handleFunc("/approval", s.handleApproval)
|
||||
handleFunc("/healthz", s.handleHealth)
|
||||
s.mux = r
|
||||
|
||||
Reference in New Issue
Block a user