*: verify "state" field before passing request to callback connectors

Let the server handle the state token instead of the connector. As a
result it can throw out bad requests earlier. It can also use that
token to determine which connector was used to generate the request
allowing all connectors to share the same callback URL.

Callbacks now all look like:

    https://dex.example.com/callback

Instead of:

    https://dex.example.com/callback/(connector id)

Even when multiple connectors are being used.
This commit is contained in:
Eric Chiang
2016-10-27 10:08:08 -07:00
parent ba9f6c6cd6
commit a3235d022a
6 changed files with 79 additions and 41 deletions

View File

@@ -33,7 +33,7 @@ type PasswordConnector interface {
// CallbackConnector is an optional interface for callback based connectors.
type CallbackConnector interface {
LoginURL(callbackURL, state string) (string, error)
HandleCallback(r *http.Request) (identity Identity, state string, err error)
HandleCallback(r *http.Request) (identity Identity, err error)
}
// GroupsConnector is an optional interface for connectors which can map a user to groups.

View File

@@ -84,28 +84,28 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}
func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) {
func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, "", &oauth2Error{errType, q.Get("error_description")}
return identity, &oauth2Error{errType, q.Get("error_description")}
}
token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code"))
if err != nil {
return identity, "", fmt.Errorf("github: failed to get token: %v", err)
return identity, fmt.Errorf("github: failed to get token: %v", err)
}
resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user")
if err != nil {
return identity, "", fmt.Errorf("github: get URL %v", err)
return identity, fmt.Errorf("github: get URL %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return identity, "", fmt.Errorf("github: read body: %v", err)
return identity, fmt.Errorf("github: read body: %v", err)
}
return identity, "", fmt.Errorf("%s: %s", resp.Status, body)
return identity, fmt.Errorf("%s: %s", resp.Status, body)
}
var user struct {
Name string `json:"name"`
@@ -114,13 +114,13 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id
Email string `json:"email"`
}
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return identity, "", fmt.Errorf("failed to decode response: %v", err)
return identity, fmt.Errorf("failed to decode response: %v", err)
}
data := connectorData{AccessToken: token.AccessToken}
connData, err := json.Marshal(data)
if err != nil {
return identity, "", fmt.Errorf("marshal connector data: %v", err)
return identity, fmt.Errorf("marshal connector data: %v", err)
}
username := user.Name
@@ -134,7 +134,7 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id
EmailVerified: true,
ConnectorData: connData,
}
return identity, q.Get("state"), nil
return identity, nil
}
func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) {

View File

@@ -41,14 +41,14 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) {
var connectorData = []byte("foobar")
func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, string, error) {
func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, error) {
return connector.Identity{
UserID: "0-385-28089-0",
Username: "Kilgore Trout",
Email: "kilgore@kilgore.trout",
EmailVerified: true,
ConnectorData: connectorData,
}, r.URL.Query().Get("state"), nil
}, nil
}
func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) {

View File

@@ -95,23 +95,23 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}
func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) {
func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, "", &oauth2Error{errType, q.Get("error_description")}
return identity, &oauth2Error{errType, q.Get("error_description")}
}
token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code"))
if err != nil {
return identity, "", fmt.Errorf("oidc: failed to get token: %v", err)
return identity, fmt.Errorf("oidc: failed to get token: %v", err)
}
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return identity, "", errors.New("oidc: no id_token in token response")
return identity, errors.New("oidc: no id_token in token response")
}
idToken, err := c.verifier.Verify(rawIDToken)
if err != nil {
return identity, "", fmt.Errorf("oidc: failed to verify ID Token: %v", err)
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
}
var claims struct {
@@ -120,7 +120,7 @@ func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Iden
EmailVerified bool `json:"email_verified"`
}
if err := idToken.Claims(&claims); err != nil {
return identity, "", fmt.Errorf("oidc: failed to decode claims: %v", err)
return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
}
identity = connector.Identity{
@@ -129,5 +129,5 @@ func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Iden
Email: claims.Email,
EmailVerified: claims.EmailVerified,
}
return identity, q.Get("state"), nil
return identity, nil
}