connector: add RefreshConnector interface
This commit is contained in:
		| @@ -1,14 +1,25 @@ | ||||
| // Package connector defines interfaces for federated identity strategies. | ||||
| package connector | ||||
|  | ||||
| import "net/http" | ||||
| import ( | ||||
| 	"net/http" | ||||
|  | ||||
| 	"golang.org/x/net/context" | ||||
| ) | ||||
|  | ||||
| // Connector is a mechanism for federating login to a remote identity service. | ||||
| // | ||||
| // Implementations are expected to implement either the PasswordConnector or | ||||
| // CallbackConnector interface. | ||||
| type Connector interface { | ||||
| 	Close() error | ||||
| type Connector interface{} | ||||
|  | ||||
| // Scopes represents additional data requested by the clients about the end user. | ||||
| type Scopes struct { | ||||
| 	// The client has requested a refresh token from the server. | ||||
| 	OfflineAccess bool | ||||
|  | ||||
| 	// The client has requested group information about the end user. | ||||
| 	Groups bool | ||||
| } | ||||
|  | ||||
| // Identity represents the ID Token claims supported by the server. | ||||
| @@ -18,6 +29,8 @@ type Identity struct { | ||||
| 	Email         string | ||||
| 	EmailVerified bool | ||||
|  | ||||
| 	Groups []string | ||||
|  | ||||
| 	// ConnectorData holds data used by the connector for subsequent requests after initial | ||||
| 	// authentication, such as access tokens for upstream provides. | ||||
| 	// | ||||
| @@ -25,18 +38,38 @@ type Identity struct { | ||||
| 	ConnectorData []byte | ||||
| } | ||||
|  | ||||
| // PasswordConnector is an optional interface for password based connectors. | ||||
| // PasswordConnector is an interface implemented by connectors which take a | ||||
| // username and password. | ||||
| type PasswordConnector interface { | ||||
| 	Login(username, password string) (identity Identity, validPassword bool, err error) | ||||
| 	Login(ctx context.Context, s Scopes, username, password string) (identity Identity, validPassword bool, err error) | ||||
| } | ||||
|  | ||||
| // CallbackConnector is an optional interface for callback based connectors. | ||||
| // CallbackConnector is an interface implemented by connectors which use an OAuth | ||||
| // style redirect flow to determine user information. | ||||
| type CallbackConnector interface { | ||||
| 	LoginURL(callbackURL, state string) (string, error) | ||||
| 	HandleCallback(r *http.Request) (identity Identity, err error) | ||||
| 	// The initial URL to redirect the user to. | ||||
| 	// | ||||
| 	// OAuth2 implementations should request different scopes from the upstream | ||||
| 	// identity provider based on the scopes requested by the downstream client. | ||||
| 	// For example, if the downstream client requests a refresh token from the | ||||
| 	// server, the connector should also request a token from the provider. | ||||
| 	// | ||||
| 	// Many identity providers have arbitrary restrictions on refresh tokens. For | ||||
| 	// example Google only allows a single refresh token per client/user/scopes | ||||
| 	// combination, and wont return a refresh token even if offline access is | ||||
| 	// requested if one has already been issues. There's no good general answer | ||||
| 	// for these kind of restrictions, and may require this package to become more | ||||
| 	// aware of the global set of user/connector interactions. | ||||
| 	LoginURL(s Scopes, callbackURL, state string) (string, error) | ||||
|  | ||||
| 	// Handle the callback to the server and return an identity. | ||||
| 	HandleCallback(s Scopes, r *http.Request) (identity Identity, err error) | ||||
| } | ||||
|  | ||||
| // GroupsConnector is an optional interface for connectors which can map a user to groups. | ||||
| type GroupsConnector interface { | ||||
| 	Groups(identity Identity) ([]string, error) | ||||
| // RefreshConnector is a connector that can update the client claims. | ||||
| type RefreshConnector interface { | ||||
| 	// Refresh is called when a client attempts to claim a refresh token. The | ||||
| 	// connector should attempt to update the identity object to reflect any | ||||
| 	// changes since the token was last refreshed. | ||||
| 	Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error) | ||||
| } | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package github | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io/ioutil" | ||||
| 	"net/http" | ||||
| @@ -15,7 +16,11 @@ import ( | ||||
| 	"github.com/coreos/dex/connector" | ||||
| ) | ||||
|  | ||||
| const baseURL = "https://api.github.com" | ||||
| const ( | ||||
| 	baseURL    = "https://api.github.com" | ||||
| 	scopeEmail = "user:email" | ||||
| 	scopeOrgs  = "read:org" | ||||
| ) | ||||
|  | ||||
| // Config holds configuration options for github logins. | ||||
| type Config struct { | ||||
| @@ -30,15 +35,8 @@ func (c *Config) Open() (connector.Connector, error) { | ||||
| 	return &githubConnector{ | ||||
| 		redirectURI:  c.RedirectURI, | ||||
| 		org:          c.Org, | ||||
| 		oauth2Config: &oauth2.Config{ | ||||
| 			ClientID:     c.ClientID, | ||||
| 			ClientSecret: c.ClientSecret, | ||||
| 			Endpoint:     github.Endpoint, | ||||
| 			Scopes: []string{ | ||||
| 				"user:email", // View user's email | ||||
| 				"read:org",   // View user's org teams. | ||||
| 			}, | ||||
| 		}, | ||||
| 		clientID:     c.ClientID, | ||||
| 		clientSecret: c.ClientSecret, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @@ -49,26 +47,36 @@ type connectorData struct { | ||||
|  | ||||
| var ( | ||||
| 	_ connector.CallbackConnector = (*githubConnector)(nil) | ||||
| 	_ connector.GroupsConnector   = (*githubConnector)(nil) | ||||
| 	_ connector.RefreshConnector  = (*githubConnector)(nil) | ||||
| ) | ||||
|  | ||||
| type githubConnector struct { | ||||
| 	redirectURI  string | ||||
| 	org          string | ||||
| 	oauth2Config *oauth2.Config | ||||
| 	ctx          context.Context | ||||
| 	cancel       context.CancelFunc | ||||
| 	clientID     string | ||||
| 	clientSecret string | ||||
| } | ||||
|  | ||||
| func (c *githubConnector) Close() error { | ||||
| 	return nil | ||||
| func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { | ||||
| 	var githubScopes []string | ||||
| 	if scopes.Groups { | ||||
| 		githubScopes = []string{scopeEmail, scopeOrgs} | ||||
| 	} else { | ||||
| 		githubScopes = []string{scopeEmail} | ||||
| 	} | ||||
| 	return &oauth2.Config{ | ||||
| 		ClientID:     c.clientID, | ||||
| 		ClientSecret: c.clientSecret, | ||||
| 		Endpoint:     github.Endpoint, | ||||
| 		Scopes:       githubScopes, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (c *githubConnector) LoginURL(callbackURL, state string) (string, error) { | ||||
| func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { | ||||
| 	if c.redirectURI != callbackURL { | ||||
| 		return "", fmt.Errorf("expected callback URL did not match the URL in the config") | ||||
| 	} | ||||
| 	return c.oauth2Config.AuthCodeURL(state), nil | ||||
| 	return c.oauth2Config(scopes).AuthCodeURL(state), nil | ||||
| } | ||||
|  | ||||
| type oauth2Error struct { | ||||
| @@ -83,43 +91,25 @@ func (e *oauth2Error) Error() string { | ||||
| 	return e.error + ": " + e.errorDescription | ||||
| } | ||||
|  | ||||
| func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { | ||||
| func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { | ||||
| 	q := r.URL.Query() | ||||
| 	if errType := q.Get("error"); errType != "" { | ||||
| 		return identity, &oauth2Error{errType, q.Get("error_description")} | ||||
| 	} | ||||
| 	token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code")) | ||||
|  | ||||
| 	oauth2Config := c.oauth2Config(s) | ||||
| 	ctx := r.Context() | ||||
|  | ||||
| 	token, err := oauth2Config.Exchange(ctx, q.Get("code")) | ||||
| 	if err != nil { | ||||
| 		return identity, fmt.Errorf("github: failed to get token: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user") | ||||
| 	if err != nil { | ||||
| 		return identity, fmt.Errorf("github: get URL %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
| 	client := oauth2Config.Client(ctx, token) | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		body, err := ioutil.ReadAll(resp.Body) | ||||
| 	user, err := c.user(ctx, client) | ||||
| 	if err != nil { | ||||
| 			return identity, fmt.Errorf("github: read body: %v", err) | ||||
| 		} | ||||
| 		return identity, fmt.Errorf("%s: %s", resp.Status, body) | ||||
| 	} | ||||
| 	var user struct { | ||||
| 		Name  string `json:"name"` | ||||
| 		Login string `json:"login"` | ||||
| 		ID    int    `json:"id"` | ||||
| 		Email string `json:"email"` | ||||
| 	} | ||||
| 	if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { | ||||
| 		return identity, fmt.Errorf("failed to decode response: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	data := connectorData{AccessToken: token.AccessToken} | ||||
| 	connData, err := json.Marshal(data) | ||||
| 	if err != nil { | ||||
| 		return identity, fmt.Errorf("marshal connector data: %v", err) | ||||
| 		return identity, fmt.Errorf("github: get user: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	username := user.Name | ||||
| @@ -131,22 +121,114 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id | ||||
| 		Username:      username, | ||||
| 		Email:         user.Email, | ||||
| 		EmailVerified: true, | ||||
| 		ConnectorData: connData, | ||||
| 	} | ||||
|  | ||||
| 	if s.Groups && c.org != "" { | ||||
| 		groups, err := c.teams(ctx, client, c.org) | ||||
| 		if err != nil { | ||||
| 			return identity, fmt.Errorf("github: get teams: %v", err) | ||||
| 		} | ||||
| 		identity.Groups = groups | ||||
| 	} | ||||
|  | ||||
| 	if s.OfflineAccess { | ||||
| 		data := connectorData{AccessToken: token.AccessToken} | ||||
| 		connData, err := json.Marshal(data) | ||||
| 		if err != nil { | ||||
| 			return identity, fmt.Errorf("marshal connector data: %v", err) | ||||
| 		} | ||||
| 		identity.ConnectorData = connData | ||||
| 	} | ||||
|  | ||||
| 	return identity, nil | ||||
| } | ||||
|  | ||||
| func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) { | ||||
| 	var data connectorData | ||||
| 	if err := json.Unmarshal(identity.ConnectorData, &data); err != nil { | ||||
| 		return nil, fmt.Errorf("decode connector data: %v", err) | ||||
| func (c *githubConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) { | ||||
| 	if len(ident.ConnectorData) == 0 { | ||||
| 		return ident, errors.New("no upstream access token found") | ||||
| 	} | ||||
| 	token := &oauth2.Token{AccessToken: data.AccessToken} | ||||
| 	resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user/teams") | ||||
|  | ||||
| 	var data connectorData | ||||
| 	if err := json.Unmarshal(ident.ConnectorData, &data); err != nil { | ||||
| 		return ident, fmt.Errorf("github: unmarshal access token: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	client := c.oauth2Config(s).Client(ctx, &oauth2.Token{AccessToken: data.AccessToken}) | ||||
| 	user, err := c.user(ctx, client) | ||||
| 	if err != nil { | ||||
| 		return ident, fmt.Errorf("github: get user: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	username := user.Name | ||||
| 	if username == "" { | ||||
| 		username = user.Login | ||||
| 	} | ||||
| 	ident.Username = username | ||||
| 	ident.Email = user.Email | ||||
|  | ||||
| 	if s.Groups && c.org != "" { | ||||
| 		groups, err := c.teams(ctx, client, c.org) | ||||
| 		if err != nil { | ||||
| 			return ident, fmt.Errorf("github: get teams: %v", err) | ||||
| 		} | ||||
| 		ident.Groups = groups | ||||
| 	} | ||||
| 	return ident, nil | ||||
| } | ||||
|  | ||||
| type user struct { | ||||
| 	Name  string `json:"name"` | ||||
| 	Login string `json:"login"` | ||||
| 	ID    int    `json:"id"` | ||||
| 	Email string `json:"email"` | ||||
| } | ||||
|  | ||||
| // user queries the GitHub API for profile information using the provided client. The HTTP | ||||
| // client is expected to be constructed by the golang.org/x/oauth2 package, which inserts | ||||
| // a bearer token as part of the request. | ||||
| func (c *githubConnector) user(ctx context.Context, client *http.Client) (user, error) { | ||||
| 	var u user | ||||
| 	req, err := http.NewRequest("GET", baseURL+"/user", nil) | ||||
| 	if err != nil { | ||||
| 		return u, fmt.Errorf("github: new req: %v", err) | ||||
| 	} | ||||
| 	req = req.WithContext(ctx) | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return u, fmt.Errorf("github: get URL %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		body, err := ioutil.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| 			return u, fmt.Errorf("github: read body: %v", err) | ||||
| 		} | ||||
| 		return u, fmt.Errorf("%s: %s", resp.Status, body) | ||||
| 	} | ||||
|  | ||||
| 	if err := json.NewDecoder(resp.Body).Decode(&u); err != nil { | ||||
| 		return u, fmt.Errorf("failed to decode response: %v", err) | ||||
| 	} | ||||
| 	return u, nil | ||||
| } | ||||
|  | ||||
| // teams queries the GitHub API for team membership within a specific organization. | ||||
| // | ||||
| // The HTTP passed client is expected to be constructed by the golang.org/x/oauth2 package, | ||||
| // which inserts a bearer token as part of the request. | ||||
| func (c *githubConnector) teams(ctx context.Context, client *http.Client, org string) ([]string, error) { | ||||
| 	req, err := http.NewRequest("GET", baseURL+"/user/teams", nil) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("github: new req: %v", err) | ||||
| 	} | ||||
| 	req = req.WithContext(ctx) | ||||
| 	resp, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("github: get teams: %v", err) | ||||
| 	} | ||||
| 	defer resp.Body.Close() | ||||
|  | ||||
| 	if resp.StatusCode != http.StatusOK { | ||||
| 		body, err := ioutil.ReadAll(resp.Body) | ||||
| 		if err != nil { | ||||
| @@ -167,7 +249,7 @@ func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) | ||||
| 	} | ||||
| 	groups := []string{} | ||||
| 	for _, team := range teams { | ||||
| 		if team.Org.Login == c.org { | ||||
| 		if team.Org.Login == org { | ||||
| 			groups = append(groups, team.Name) | ||||
| 		} | ||||
| 	} | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"log" | ||||
| 	"net" | ||||
|  | ||||
| 	"golang.org/x/net/context" | ||||
| 	"gopkg.in/ldap.v2" | ||||
|  | ||||
| 	"github.com/coreos/dex/connector" | ||||
| @@ -57,6 +58,9 @@ type Config struct { | ||||
| 	// Required if LDAP host does not use TLS. | ||||
| 	InsecureNoSSL bool `json:"insecureNoSSL"` | ||||
|  | ||||
| 	// Don't verify the CA. | ||||
| 	InsecureSkipVerify bool `json:"insecureSkipVerify"` | ||||
|  | ||||
| 	// Path to a trusted root certificate file. | ||||
| 	RootCA string `json:"rootCA"` | ||||
|  | ||||
| @@ -139,11 +143,16 @@ func (c *Config) Open() (connector.Connector, error) { | ||||
| 	return connector.Connector(conn), nil | ||||
| } | ||||
|  | ||||
| type refreshData struct { | ||||
| 	Username string     `json:"username"` | ||||
| 	Entry    ldap.Entry `json:"entry"` | ||||
| } | ||||
|  | ||||
| // OpenConnector is the same as Open but returns a type with all implemented connector interfaces. | ||||
| func (c *Config) OpenConnector() (interface { | ||||
| 	connector.Connector | ||||
| 	connector.PasswordConnector | ||||
| 	connector.GroupsConnector | ||||
| 	connector.RefreshConnector | ||||
| }, error) { | ||||
|  | ||||
| 	requiredFields := []struct { | ||||
| @@ -174,7 +183,7 @@ func (c *Config) OpenConnector() (interface { | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	tlsConfig := &tls.Config{ServerName: host} | ||||
| 	tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: c.InsecureSkipVerify} | ||||
| 	if c.RootCA != "" || len(c.RootCAData) != 0 { | ||||
| 		data := c.RootCAData | ||||
| 		if len(data) == 0 { | ||||
| @@ -209,12 +218,16 @@ type ldapConnector struct { | ||||
| 	tlsConfig *tls.Config | ||||
| } | ||||
|  | ||||
| var _ connector.PasswordConnector = (*ldapConnector)(nil) | ||||
| var ( | ||||
| 	_ connector.PasswordConnector = (*ldapConnector)(nil) | ||||
| 	_ connector.RefreshConnector  = (*ldapConnector)(nil) | ||||
| ) | ||||
|  | ||||
| // do initializes a connection to the LDAP directory and passes it to the | ||||
| // provided function. It then performs appropriate teardown or reuse before | ||||
| // returning. | ||||
| func (c *ldapConnector) do(f func(c *ldap.Conn) error) error { | ||||
| func (c *ldapConnector) do(ctx context.Context, f func(c *ldap.Conn) error) error { | ||||
| 	// TODO(ericchiang): support context here | ||||
| 	var ( | ||||
| 		conn *ldap.Conn | ||||
| 		err  error | ||||
| @@ -253,13 +266,32 @@ func getAttr(e ldap.Entry, name string) string { | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| func (c *ldapConnector) Login(username, password string) (ident connector.Identity, validPass bool, err error) { | ||||
| 	var ( | ||||
| 		// We want to return a different error if the user's password is incorrect vs | ||||
| 		// if there was an error. | ||||
| 		incorrectPass = false | ||||
| 		user          ldap.Entry | ||||
| 	) | ||||
| func (c *ldapConnector) identityFromEntry(user ldap.Entry) (ident connector.Identity, err error) { | ||||
| 	// If we're missing any attributes, such as email or ID, we want to report | ||||
| 	// an error rather than continuing. | ||||
| 	missing := []string{} | ||||
|  | ||||
| 	// Fill the identity struct using the attributes from the user entry. | ||||
| 	if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" { | ||||
| 		missing = append(missing, c.UserSearch.IDAttr) | ||||
| 	} | ||||
| 	if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" { | ||||
| 		missing = append(missing, c.UserSearch.EmailAttr) | ||||
| 	} | ||||
| 	if c.UserSearch.NameAttr != "" { | ||||
| 		if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" { | ||||
| 			missing = append(missing, c.UserSearch.NameAttr) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(missing) != 0 { | ||||
| 		err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing) | ||||
| 		return connector.Identity{}, err | ||||
| 	} | ||||
| 	return ident, nil | ||||
| } | ||||
|  | ||||
| func (c *ldapConnector) userEntry(conn *ldap.Conn, username string) (user ldap.Entry, found bool, err error) { | ||||
|  | ||||
| 	filter := fmt.Sprintf("(%s=%s)", c.UserSearch.Username, ldap.EscapeFilter(username)) | ||||
| 	if c.UserSearch.Filter != "" { | ||||
| @@ -283,24 +315,40 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi | ||||
| 	if c.UserSearch.NameAttr != "" { | ||||
| 		req.Attributes = append(req.Attributes, c.UserSearch.NameAttr) | ||||
| 	} | ||||
|  | ||||
| 	err = c.do(func(conn *ldap.Conn) error { | ||||
| 	resp, err := conn.Search(req) | ||||
| 	if err != nil { | ||||
| 			return fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err) | ||||
| 		return ldap.Entry{}, false, fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err) | ||||
| 	} | ||||
|  | ||||
| 	switch n := len(resp.Entries); n { | ||||
| 	case 0: | ||||
| 		log.Printf("ldap: no results returned for filter: %q", filter) | ||||
| 		return ldap.Entry{}, false, nil | ||||
| 	case 1: | ||||
| 		return *resp.Entries[0], true, nil | ||||
| 	default: | ||||
| 		return ldap.Entry{}, false, fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (c *ldapConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (ident connector.Identity, validPass bool, err error) { | ||||
| 	var ( | ||||
| 		// We want to return a different error if the user's password is incorrect vs | ||||
| 		// if there was an error. | ||||
| 		incorrectPass = false | ||||
| 		user          ldap.Entry | ||||
| 	) | ||||
|  | ||||
| 	err = c.do(ctx, func(conn *ldap.Conn) error { | ||||
| 		entry, found, err := c.userEntry(conn, username) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if !found { | ||||
| 			incorrectPass = true | ||||
| 			return nil | ||||
| 		case 1: | ||||
| 		default: | ||||
| 			return fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter) | ||||
| 		} | ||||
|  | ||||
| 		user = *resp.Entries[0] | ||||
| 		user = entry | ||||
|  | ||||
| 		// Try to authenticate as the distinguished name. | ||||
| 		if err := conn.Bind(user.DN, password); err != nil { | ||||
| @@ -323,44 +371,75 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi | ||||
| 		return connector.Identity{}, false, nil | ||||
| 	} | ||||
|  | ||||
| 	if ident, err = c.identityFromEntry(user); err != nil { | ||||
| 		return connector.Identity{}, false, err | ||||
| 	} | ||||
|  | ||||
| 	if s.Groups { | ||||
| 		groups, err := c.groups(ctx, user) | ||||
| 		if err != nil { | ||||
| 			return connector.Identity{}, false, fmt.Errorf("ldap: failed to query groups: %v", err) | ||||
| 		} | ||||
| 		ident.Groups = groups | ||||
| 	} | ||||
|  | ||||
| 	if s.OfflineAccess { | ||||
| 		refresh := refreshData{ | ||||
| 			Username: username, | ||||
| 			Entry:    user, | ||||
| 		} | ||||
| 		// Encode entry for follow up requests such as the groups query and | ||||
| 		// refresh attempts. | ||||
| 	if ident.ConnectorData, err = json.Marshal(user); err != nil { | ||||
| 		if ident.ConnectorData, err = json.Marshal(refresh); err != nil { | ||||
| 			return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err) | ||||
| 		} | ||||
|  | ||||
| 	// If we're missing any attributes, such as email or ID, we want to report | ||||
| 	// an error rather than continuing. | ||||
| 	missing := []string{} | ||||
|  | ||||
| 	// Fill the identity struct using the attributes from the user entry. | ||||
| 	if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" { | ||||
| 		missing = append(missing, c.UserSearch.IDAttr) | ||||
| 	} | ||||
| 	if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" { | ||||
| 		missing = append(missing, c.UserSearch.EmailAttr) | ||||
| 	} | ||||
| 	if c.UserSearch.NameAttr != "" { | ||||
| 		if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" { | ||||
| 			missing = append(missing, c.UserSearch.NameAttr) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(missing) != 0 { | ||||
| 		err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing) | ||||
| 		return connector.Identity{}, false, err | ||||
| 	} | ||||
|  | ||||
| 	return ident, true, nil | ||||
| } | ||||
|  | ||||
| func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) { | ||||
| 	// Decode the user entry from the identity. | ||||
| 	var user ldap.Entry | ||||
| 	if err := json.Unmarshal(ident.ConnectorData, &user); err != nil { | ||||
| 		return nil, fmt.Errorf("ldap: failed to unmarshal connector data: %v", err) | ||||
| func (c *ldapConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) { | ||||
| 	var data refreshData | ||||
| 	if err := json.Unmarshal(ident.ConnectorData, &data); err != nil { | ||||
| 		return ident, fmt.Errorf("ldap: failed to unamrshal internal data: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	var user ldap.Entry | ||||
| 	err := c.do(ctx, func(conn *ldap.Conn) error { | ||||
| 		entry, found, err := c.userEntry(conn, data.Username) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if !found { | ||||
| 			return fmt.Errorf("ldap: user not found %q", data.Username) | ||||
| 		} | ||||
| 		user = entry | ||||
| 		return nil | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		return ident, err | ||||
| 	} | ||||
| 	if user.DN != data.Entry.DN { | ||||
| 		return ident, fmt.Errorf("ldap: refresh for username %q expected DN %q got %q", data.Username, data.Entry.DN, user.DN) | ||||
| 	} | ||||
|  | ||||
| 	newIdent, err := c.identityFromEntry(user) | ||||
| 	if err != nil { | ||||
| 		return ident, err | ||||
| 	} | ||||
| 	newIdent.ConnectorData = ident.ConnectorData | ||||
|  | ||||
| 	if s.Groups { | ||||
| 		groups, err := c.groups(ctx, user) | ||||
| 		if err != nil { | ||||
| 			return connector.Identity{}, fmt.Errorf("ldap: failed to query groups: %v", err) | ||||
| 		} | ||||
| 		newIdent.Groups = groups | ||||
| 	} | ||||
| 	return newIdent, nil | ||||
| } | ||||
|  | ||||
| func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string, error) { | ||||
| 	filter := fmt.Sprintf("(%s=%s)", c.GroupSearch.GroupAttr, ldap.EscapeFilter(getAttr(user, c.GroupSearch.UserAttr))) | ||||
| 	if c.GroupSearch.Filter != "" { | ||||
| 		filter = fmt.Sprintf("(&%s%s)", c.GroupSearch.Filter, filter) | ||||
| @@ -374,7 +453,7 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) { | ||||
| 	} | ||||
|  | ||||
| 	var groups []*ldap.Entry | ||||
| 	if err := c.do(func(conn *ldap.Conn) error { | ||||
| 	if err := c.do(ctx, func(conn *ldap.Conn) error { | ||||
| 		resp, err := conn.Search(req) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("ldap: search failed: %v", err) | ||||
| @@ -406,7 +485,3 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) { | ||||
| 	} | ||||
| 	return groupNames, nil | ||||
| } | ||||
|  | ||||
| func (c *ldapConnector) Close() error { | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -2,12 +2,13 @@ | ||||
| package mock | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
|  | ||||
| 	"golang.org/x/net/context" | ||||
|  | ||||
| 	"github.com/coreos/dex/connector" | ||||
| ) | ||||
|  | ||||
| @@ -19,7 +20,6 @@ func NewCallbackConnector() connector.Connector { | ||||
|  | ||||
| var ( | ||||
| 	_ connector.CallbackConnector = callbackConnector{} | ||||
| 	_ connector.GroupsConnector   = callbackConnector{} | ||||
|  | ||||
| 	_ connector.PasswordConnector = passwordConnector{} | ||||
| ) | ||||
| @@ -28,7 +28,7 @@ type callbackConnector struct{} | ||||
|  | ||||
| func (m callbackConnector) Close() error { return nil } | ||||
|  | ||||
| func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) { | ||||
| func (m callbackConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { | ||||
| 	u, err := url.Parse(callbackURL) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) | ||||
| @@ -41,23 +41,22 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) { | ||||
|  | ||||
| var connectorData = []byte("foobar") | ||||
|  | ||||
| func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, error) { | ||||
| func (m callbackConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) { | ||||
| 	var groups []string | ||||
| 	if s.Groups { | ||||
| 		groups = []string{"authors"} | ||||
| 	} | ||||
|  | ||||
| 	return connector.Identity{ | ||||
| 		UserID:        "0-385-28089-0", | ||||
| 		Username:      "Kilgore Trout", | ||||
| 		Email:         "kilgore@kilgore.trout", | ||||
| 		EmailVerified: true, | ||||
| 		Groups:        groups, | ||||
| 		ConnectorData: connectorData, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) { | ||||
| 	if !bytes.Equal(identity.ConnectorData, connectorData) { | ||||
| 		return nil, errors.New("connector data mismatch") | ||||
| 	} | ||||
| 	return []string{"authors"}, nil | ||||
| } | ||||
|  | ||||
| // CallbackConfig holds the configuration parameters for a connector which requires no interaction. | ||||
| type CallbackConfig struct{} | ||||
|  | ||||
| @@ -91,7 +90,7 @@ type passwordConnector struct { | ||||
|  | ||||
| func (p passwordConnector) Close() error { return nil } | ||||
|  | ||||
| func (p passwordConnector) Login(username, password string) (identity connector.Identity, validPassword bool, err error) { | ||||
| func (p passwordConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (identity connector.Identity, validPassword bool, err error) { | ||||
| 	if username == p.username && password == p.password { | ||||
| 		return connector.Identity{ | ||||
| 			UserID:        "0-385-28089-0", | ||||
|   | ||||
| @@ -75,7 +75,7 @@ func (c *oidcConnector) Close() error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *oidcConnector) LoginURL(callbackURL, state string) (string, error) { | ||||
| func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { | ||||
| 	if c.redirectURI != callbackURL { | ||||
| 		return "", fmt.Errorf("expected callback URL did not match the URL in the config") | ||||
| 	} | ||||
| @@ -94,7 +94,7 @@ func (e *oauth2Error) Error() string { | ||||
| 	return e.error + ": " + e.errorDescription | ||||
| } | ||||
|  | ||||
| func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { | ||||
| func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { | ||||
| 	q := r.URL.Query() | ||||
| 	if errType := q.Get("error"); errType != "" { | ||||
| 		return identity, &oauth2Error{errType, q.Get("error_description")} | ||||
|   | ||||
| @@ -179,7 +179,13 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { | ||||
|  | ||||
| 	authReqID := r.FormValue("req") | ||||
|  | ||||
| 	// TODO(ericchiang): cache user identity. | ||||
| 	authReq, err := s.storage.GetAuthRequest(authReqID) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to get auth request: %v", err) | ||||
| 		s.renderError(w, http.StatusInternalServerError, errServerError, "") | ||||
| 		return | ||||
| 	} | ||||
| 	scopes := parseScopes(authReq.Scopes) | ||||
|  | ||||
| 	switch r.Method { | ||||
| 	case "GET": | ||||
| @@ -199,7 +205,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { | ||||
| 			// Use the auth request ID as the "state" token. | ||||
| 			// | ||||
| 			// TODO(ericchiang): Is this appropriate or should we also be using a nonce? | ||||
| 			callbackURL, err := conn.LoginURL(s.absURL("/callback"), authReqID) | ||||
| 			callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID) | ||||
| 			if err != nil { | ||||
| 				log.Printf("Connector %q returned error when creating callback: %v", connID, err) | ||||
| 				s.renderError(w, http.StatusInternalServerError, errServerError, "") | ||||
| @@ -221,7 +227,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { | ||||
| 		username := r.FormValue("login") | ||||
| 		password := r.FormValue("password") | ||||
|  | ||||
| 		identity, ok, err := passwordConnector.Login(username, password) | ||||
| 		identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Failed to login user: %v", err) | ||||
| 			s.renderError(w, http.StatusInternalServerError, errServerError, "") | ||||
| @@ -231,12 +237,6 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { | ||||
| 			s.templates.password(w, authReqID, r.URL.String(), username, true) | ||||
| 			return | ||||
| 		} | ||||
| 		authReq, err := s.storage.GetAuthRequest(authReqID) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Failed to get auth request: %v", err) | ||||
| 			s.renderError(w, http.StatusInternalServerError, errServerError, "") | ||||
| 			return | ||||
| 		} | ||||
| 		redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) | ||||
| 		if err != nil { | ||||
| 			log.Printf("Failed to finalize login: %v", err) | ||||
| @@ -286,7 +286,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	identity, err := callbackConnector.HandleCallback(r) | ||||
| 	identity, err := callbackConnector.HandleCallback(parseScopes(authReq.Scopes), r) | ||||
| 	if err != nil { | ||||
| 		log.Printf("Failed to authenticate: %v", err) | ||||
| 		s.renderError(w, http.StatusInternalServerError, errServerError, "") | ||||
| @@ -304,34 +304,12 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) | ||||
| } | ||||
|  | ||||
| func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) { | ||||
| 	if authReq.ConnectorID == "" { | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	claims := storage.Claims{ | ||||
| 		UserID:        identity.UserID, | ||||
| 		Username:      identity.Username, | ||||
| 		Email:         identity.Email, | ||||
| 		EmailVerified: identity.EmailVerified, | ||||
| 	} | ||||
|  | ||||
| 	groupsConn, ok := conn.(connector.GroupsConnector) | ||||
| 	if ok { | ||||
| 		reqGroups := func() bool { | ||||
| 			for _, scope := range authReq.Scopes { | ||||
| 				if scope == scopeGroups { | ||||
| 					return true | ||||
| 				} | ||||
| 			} | ||||
| 			return false | ||||
| 		}() | ||||
| 		if reqGroups { | ||||
| 			groups, err := groupsConn.Groups(identity) | ||||
| 			if err != nil { | ||||
| 				return "", fmt.Errorf("getting groups: %v", err) | ||||
| 			} | ||||
| 			claims.Groups = groups | ||||
| 		} | ||||
| 		Groups:        identity.Groups, | ||||
| 	} | ||||
|  | ||||
| 	updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { | ||||
| @@ -415,6 +393,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe | ||||
| 				Claims:        authReq.Claims, | ||||
| 				Expiry:        s.now().Add(time.Minute * 30), | ||||
| 				RedirectURI:   authReq.RedirectURI, | ||||
| 				ConnectorData: authReq.ConnectorData, | ||||
| 			} | ||||
| 			if err := s.storage.CreateAuthCode(code); err != nil { | ||||
| 				log.Printf("Failed to create auth code: %v", err) | ||||
| @@ -543,6 +522,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s | ||||
| 			Scopes:        authCode.Scopes, | ||||
| 			Claims:        authCode.Claims, | ||||
| 			Nonce:         authCode.Nonce, | ||||
| 			ConnectorData: authCode.ConnectorData, | ||||
| 		} | ||||
| 		if err := s.storage.CreateRefresh(refresh); err != nil { | ||||
| 			log.Printf("failed to create refresh token: %v", err) | ||||
| @@ -574,6 +554,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Per the OAuth2 spec, if the client has omitted the scopes, default to the original | ||||
| 	// authorized scopes. | ||||
| 	// | ||||
| 	// https://tools.ietf.org/html/rfc6749#section-6 | ||||
| 	scopes := refresh.Scopes | ||||
| 	if scope != "" { | ||||
| 		requestedScopes := strings.Fields(scope) | ||||
| @@ -601,7 +585,43 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie | ||||
| 		scopes = requestedScopes | ||||
| 	} | ||||
|  | ||||
| 	// TODO(ericchiang): re-auth with backends | ||||
| 	conn, ok := s.connectors[refresh.ConnectorID] | ||||
| 	if !ok { | ||||
| 		log.Printf("connector ID not found: %q", refresh.ConnectorID) | ||||
| 		tokenErr(w, errServerError, "", http.StatusInternalServerError) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Can the connector refresh the identity? If so, attempt to refresh the data | ||||
| 	// in the connector. | ||||
| 	// | ||||
| 	// TODO(ericchiang): We may want a strict mode where connectors that don't implement | ||||
| 	// this interface can't perform refreshing. | ||||
| 	if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { | ||||
| 		ident := connector.Identity{ | ||||
| 			UserID:        refresh.Claims.UserID, | ||||
| 			Username:      refresh.Claims.Username, | ||||
| 			Email:         refresh.Claims.Email, | ||||
| 			EmailVerified: refresh.Claims.EmailVerified, | ||||
| 			Groups:        refresh.Claims.Groups, | ||||
| 			ConnectorData: refresh.ConnectorData, | ||||
| 		} | ||||
| 		ident, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident) | ||||
| 		if err != nil { | ||||
| 			log.Printf("failed to refresh identity: %v", err) | ||||
| 			tokenErr(w, errServerError, "", http.StatusInternalServerError) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Update the claims of the refresh token. | ||||
| 		// | ||||
| 		// UserID intentionally ignored for now. | ||||
| 		refresh.Claims.Username = ident.Username | ||||
| 		refresh.Claims.Email = ident.Email | ||||
| 		refresh.Claims.EmailVerified = ident.EmailVerified | ||||
| 		refresh.Claims.Groups = ident.Groups | ||||
| 		refresh.ConnectorData = ident.ConnectorData | ||||
| 	} | ||||
|  | ||||
| 	idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce) | ||||
| 	if err != nil { | ||||
| @@ -610,6 +630,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// Refresh tokens are claimed exactly once. Delete the current token and | ||||
| 	// create a new one. | ||||
| 	if err := s.storage.DeleteRefresh(code); err != nil { | ||||
| 		log.Printf("failed to delete auth code: %v", err) | ||||
| 		tokenErr(w, errServerError, "", http.StatusInternalServerError) | ||||
|   | ||||
| @@ -10,6 +10,7 @@ import ( | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/coreos/dex/connector" | ||||
| 	"github.com/coreos/dex/storage" | ||||
| ) | ||||
|  | ||||
| @@ -93,6 +94,19 @@ const ( | ||||
| 	responseTypeIDToken = "id_token" // ID Token in url fragment | ||||
| ) | ||||
|  | ||||
| func parseScopes(scopes []string) connector.Scopes { | ||||
| 	var s connector.Scopes | ||||
| 	for _, scope := range scopes { | ||||
| 		switch scope { | ||||
| 		case scopeOfflineAccess: | ||||
| 			s.OfflineAccess = true | ||||
| 		case scopeGroups: | ||||
| 			s.Groups = true | ||||
| 		} | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| type audience []string | ||||
|  | ||||
| func (a audience) MarshalJSON() ([]byte, error) { | ||||
|   | ||||
| @@ -211,9 +211,7 @@ type passwordDB struct { | ||||
| 	s storage.Storage | ||||
| } | ||||
|  | ||||
| func (db passwordDB) Close() error { return nil } | ||||
|  | ||||
| func (db passwordDB) Login(email, password string) (connector.Identity, bool, error) { | ||||
| func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, password string) (connector.Identity, bool, error) { | ||||
| 	p, err := db.s.GetPassword(email) | ||||
| 	if err != nil { | ||||
| 		if err != storage.ErrNotFound { | ||||
| @@ -233,6 +231,31 @@ func (db passwordDB) Login(email, password string) (connector.Identity, bool, er | ||||
| 	}, true, nil | ||||
| } | ||||
|  | ||||
| func (db passwordDB) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) { | ||||
| 	// If the user has been deleted, the refresh token will be rejected. | ||||
| 	p, err := db.s.GetPassword(identity.Email) | ||||
| 	if err != nil { | ||||
| 		if err == storage.ErrNotFound { | ||||
| 			return connector.Identity{}, errors.New("user not found") | ||||
| 		} | ||||
| 		return connector.Identity{}, fmt.Errorf("get password: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// User removed but a new user with the same email exists. | ||||
| 	if p.UserID != identity.UserID { | ||||
| 		return connector.Identity{}, errors.New("user not found") | ||||
| 	} | ||||
|  | ||||
| 	// If a user has updated their username, that will be reflected in the | ||||
| 	// refreshed token. | ||||
| 	// | ||||
| 	// No other fields are expected to be refreshable as email is effectively used | ||||
| 	// as an ID and this implementation doesn't deal with groups. | ||||
| 	identity.Username = p.Username | ||||
|  | ||||
| 	return identity, nil | ||||
| } | ||||
|  | ||||
| // newKeyCacher returns a storage which caches keys so long as the next | ||||
| func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage { | ||||
| 	if now == nil { | ||||
|   | ||||
| @@ -662,7 +662,6 @@ func TestCrossClientScopes(t *testing.T) { | ||||
| func TestPasswordDB(t *testing.T) { | ||||
| 	s := memory.New() | ||||
| 	conn := newPasswordDB(s) | ||||
| 	defer conn.Close() | ||||
|  | ||||
| 	pw := "hi" | ||||
|  | ||||
| @@ -712,7 +711,7 @@ func TestPasswordDB(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range tests { | ||||
| 		ident, valid, err := conn.Login(tc.username, tc.password) | ||||
| 		ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.username, tc.password) | ||||
| 		if err != nil { | ||||
| 			if !tc.wantErr { | ||||
| 				t.Errorf("%s: %v", tc.name, err) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user