connector: add RefreshConnector interface

This commit is contained in:
Eric Chiang
2016-11-18 13:40:41 -08:00
parent 27fb7c523e
commit 952e0f81f5
9 changed files with 438 additions and 191 deletions

View File

@@ -1,14 +1,25 @@
// Package connector defines interfaces for federated identity strategies.
package connector
import "net/http"
import (
"net/http"
"golang.org/x/net/context"
)
// Connector is a mechanism for federating login to a remote identity service.
//
// Implementations are expected to implement either the PasswordConnector or
// CallbackConnector interface.
type Connector interface {
Close() error
type Connector interface{}
// Scopes represents additional data requested by the clients about the end user.
type Scopes struct {
// The client has requested a refresh token from the server.
OfflineAccess bool
// The client has requested group information about the end user.
Groups bool
}
// Identity represents the ID Token claims supported by the server.
@@ -18,6 +29,8 @@ type Identity struct {
Email string
EmailVerified bool
Groups []string
// ConnectorData holds data used by the connector for subsequent requests after initial
// authentication, such as access tokens for upstream provides.
//
@@ -25,18 +38,38 @@ type Identity struct {
ConnectorData []byte
}
// PasswordConnector is an optional interface for password based connectors.
// PasswordConnector is an interface implemented by connectors which take a
// username and password.
type PasswordConnector interface {
Login(username, password string) (identity Identity, validPassword bool, err error)
Login(ctx context.Context, s Scopes, username, password string) (identity Identity, validPassword bool, err error)
}
// CallbackConnector is an optional interface for callback based connectors.
// CallbackConnector is an interface implemented by connectors which use an OAuth
// style redirect flow to determine user information.
type CallbackConnector interface {
LoginURL(callbackURL, state string) (string, error)
HandleCallback(r *http.Request) (identity Identity, err error)
// The initial URL to redirect the user to.
//
// OAuth2 implementations should request different scopes from the upstream
// identity provider based on the scopes requested by the downstream client.
// For example, if the downstream client requests a refresh token from the
// server, the connector should also request a token from the provider.
//
// Many identity providers have arbitrary restrictions on refresh tokens. For
// example Google only allows a single refresh token per client/user/scopes
// combination, and wont return a refresh token even if offline access is
// requested if one has already been issues. There's no good general answer
// for these kind of restrictions, and may require this package to become more
// aware of the global set of user/connector interactions.
LoginURL(s Scopes, callbackURL, state string) (string, error)
// Handle the callback to the server and return an identity.
HandleCallback(s Scopes, r *http.Request) (identity Identity, err error)
}
// GroupsConnector is an optional interface for connectors which can map a user to groups.
type GroupsConnector interface {
Groups(identity Identity) ([]string, error)
// RefreshConnector is a connector that can update the client claims.
type RefreshConnector interface {
// Refresh is called when a client attempts to claim a refresh token. The
// connector should attempt to update the identity object to reflect any
// changes since the token was last refreshed.
Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error)
}

View File

@@ -3,6 +3,7 @@ package github
import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
@@ -15,7 +16,11 @@ import (
"github.com/coreos/dex/connector"
)
const baseURL = "https://api.github.com"
const (
baseURL = "https://api.github.com"
scopeEmail = "user:email"
scopeOrgs = "read:org"
)
// Config holds configuration options for github logins.
type Config struct {
@@ -28,17 +33,10 @@ type Config struct {
// Open returns a strategy for logging in through GitHub.
func (c *Config) Open() (connector.Connector, error) {
return &githubConnector{
redirectURI: c.RedirectURI,
org: c.Org,
oauth2Config: &oauth2.Config{
ClientID: c.ClientID,
ClientSecret: c.ClientSecret,
Endpoint: github.Endpoint,
Scopes: []string{
"user:email", // View user's email
"read:org", // View user's org teams.
},
},
redirectURI: c.RedirectURI,
org: c.Org,
clientID: c.ClientID,
clientSecret: c.ClientSecret,
}, nil
}
@@ -49,26 +47,36 @@ type connectorData struct {
var (
_ connector.CallbackConnector = (*githubConnector)(nil)
_ connector.GroupsConnector = (*githubConnector)(nil)
_ connector.RefreshConnector = (*githubConnector)(nil)
)
type githubConnector struct {
redirectURI string
org string
oauth2Config *oauth2.Config
ctx context.Context
cancel context.CancelFunc
clientID string
clientSecret string
}
func (c *githubConnector) Close() error {
return nil
func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
var githubScopes []string
if scopes.Groups {
githubScopes = []string{scopeEmail, scopeOrgs}
} else {
githubScopes = []string{scopeEmail}
}
return &oauth2.Config{
ClientID: c.clientID,
ClientSecret: c.clientSecret,
Endpoint: github.Endpoint,
Scopes: githubScopes,
}
}
func (c *githubConnector) LoginURL(callbackURL, state string) (string, error) {
func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL did not match the URL in the config")
}
return c.oauth2Config.AuthCodeURL(state), nil
return c.oauth2Config(scopes).AuthCodeURL(state), nil
}
type oauth2Error struct {
@@ -83,43 +91,25 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}
func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) {
func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}
}
token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code"))
oauth2Config := c.oauth2Config(s)
ctx := r.Context()
token, err := oauth2Config.Exchange(ctx, q.Get("code"))
if err != nil {
return identity, fmt.Errorf("github: failed to get token: %v", err)
}
resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user")
if err != nil {
return identity, fmt.Errorf("github: get URL %v", err)
}
defer resp.Body.Close()
client := oauth2Config.Client(ctx, token)
if resp.StatusCode != http.StatusOK {
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return identity, fmt.Errorf("github: read body: %v", err)
}
return identity, fmt.Errorf("%s: %s", resp.Status, body)
}
var user struct {
Name string `json:"name"`
Login string `json:"login"`
ID int `json:"id"`
Email string `json:"email"`
}
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return identity, fmt.Errorf("failed to decode response: %v", err)
}
data := connectorData{AccessToken: token.AccessToken}
connData, err := json.Marshal(data)
user, err := c.user(ctx, client)
if err != nil {
return identity, fmt.Errorf("marshal connector data: %v", err)
return identity, fmt.Errorf("github: get user: %v", err)
}
username := user.Name
@@ -131,22 +121,114 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id
Username: username,
Email: user.Email,
EmailVerified: true,
ConnectorData: connData,
}
if s.Groups && c.org != "" {
groups, err := c.teams(ctx, client, c.org)
if err != nil {
return identity, fmt.Errorf("github: get teams: %v", err)
}
identity.Groups = groups
}
if s.OfflineAccess {
data := connectorData{AccessToken: token.AccessToken}
connData, err := json.Marshal(data)
if err != nil {
return identity, fmt.Errorf("marshal connector data: %v", err)
}
identity.ConnectorData = connData
}
return identity, nil
}
func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) {
var data connectorData
if err := json.Unmarshal(identity.ConnectorData, &data); err != nil {
return nil, fmt.Errorf("decode connector data: %v", err)
func (c *githubConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
if len(ident.ConnectorData) == 0 {
return ident, errors.New("no upstream access token found")
}
token := &oauth2.Token{AccessToken: data.AccessToken}
resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user/teams")
var data connectorData
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
return ident, fmt.Errorf("github: unmarshal access token: %v", err)
}
client := c.oauth2Config(s).Client(ctx, &oauth2.Token{AccessToken: data.AccessToken})
user, err := c.user(ctx, client)
if err != nil {
return ident, fmt.Errorf("github: get user: %v", err)
}
username := user.Name
if username == "" {
username = user.Login
}
ident.Username = username
ident.Email = user.Email
if s.Groups && c.org != "" {
groups, err := c.teams(ctx, client, c.org)
if err != nil {
return ident, fmt.Errorf("github: get teams: %v", err)
}
ident.Groups = groups
}
return ident, nil
}
type user struct {
Name string `json:"name"`
Login string `json:"login"`
ID int `json:"id"`
Email string `json:"email"`
}
// user queries the GitHub API for profile information using the provided client. The HTTP
// client is expected to be constructed by the golang.org/x/oauth2 package, which inserts
// a bearer token as part of the request.
func (c *githubConnector) user(ctx context.Context, client *http.Client) (user, error) {
var u user
req, err := http.NewRequest("GET", baseURL+"/user", nil)
if err != nil {
return u, fmt.Errorf("github: new req: %v", err)
}
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
return u, fmt.Errorf("github: get URL %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return u, fmt.Errorf("github: read body: %v", err)
}
return u, fmt.Errorf("%s: %s", resp.Status, body)
}
if err := json.NewDecoder(resp.Body).Decode(&u); err != nil {
return u, fmt.Errorf("failed to decode response: %v", err)
}
return u, nil
}
// teams queries the GitHub API for team membership within a specific organization.
//
// The HTTP passed client is expected to be constructed by the golang.org/x/oauth2 package,
// which inserts a bearer token as part of the request.
func (c *githubConnector) teams(ctx context.Context, client *http.Client, org string) ([]string, error) {
req, err := http.NewRequest("GET", baseURL+"/user/teams", nil)
if err != nil {
return nil, fmt.Errorf("github: new req: %v", err)
}
req = req.WithContext(ctx)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("github: get teams: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
@@ -167,7 +249,7 @@ func (c *githubConnector) Groups(identity connector.Identity) ([]string, error)
}
groups := []string{}
for _, team := range teams {
if team.Org.Login == c.org {
if team.Org.Login == org {
groups = append(groups, team.Name)
}
}

View File

@@ -10,6 +10,7 @@ import (
"log"
"net"
"golang.org/x/net/context"
"gopkg.in/ldap.v2"
"github.com/coreos/dex/connector"
@@ -57,6 +58,9 @@ type Config struct {
// Required if LDAP host does not use TLS.
InsecureNoSSL bool `json:"insecureNoSSL"`
// Don't verify the CA.
InsecureSkipVerify bool `json:"insecureSkipVerify"`
// Path to a trusted root certificate file.
RootCA string `json:"rootCA"`
@@ -139,11 +143,16 @@ func (c *Config) Open() (connector.Connector, error) {
return connector.Connector(conn), nil
}
type refreshData struct {
Username string `json:"username"`
Entry ldap.Entry `json:"entry"`
}
// OpenConnector is the same as Open but returns a type with all implemented connector interfaces.
func (c *Config) OpenConnector() (interface {
connector.Connector
connector.PasswordConnector
connector.GroupsConnector
connector.RefreshConnector
}, error) {
requiredFields := []struct {
@@ -174,7 +183,7 @@ func (c *Config) OpenConnector() (interface {
}
}
tlsConfig := &tls.Config{ServerName: host}
tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: c.InsecureSkipVerify}
if c.RootCA != "" || len(c.RootCAData) != 0 {
data := c.RootCAData
if len(data) == 0 {
@@ -209,12 +218,16 @@ type ldapConnector struct {
tlsConfig *tls.Config
}
var _ connector.PasswordConnector = (*ldapConnector)(nil)
var (
_ connector.PasswordConnector = (*ldapConnector)(nil)
_ connector.RefreshConnector = (*ldapConnector)(nil)
)
// do initializes a connection to the LDAP directory and passes it to the
// provided function. It then performs appropriate teardown or reuse before
// returning.
func (c *ldapConnector) do(f func(c *ldap.Conn) error) error {
func (c *ldapConnector) do(ctx context.Context, f func(c *ldap.Conn) error) error {
// TODO(ericchiang): support context here
var (
conn *ldap.Conn
err error
@@ -253,13 +266,32 @@ func getAttr(e ldap.Entry, name string) string {
return ""
}
func (c *ldapConnector) Login(username, password string) (ident connector.Identity, validPass bool, err error) {
var (
// We want to return a different error if the user's password is incorrect vs
// if there was an error.
incorrectPass = false
user ldap.Entry
)
func (c *ldapConnector) identityFromEntry(user ldap.Entry) (ident connector.Identity, err error) {
// If we're missing any attributes, such as email or ID, we want to report
// an error rather than continuing.
missing := []string{}
// Fill the identity struct using the attributes from the user entry.
if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" {
missing = append(missing, c.UserSearch.IDAttr)
}
if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" {
missing = append(missing, c.UserSearch.EmailAttr)
}
if c.UserSearch.NameAttr != "" {
if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" {
missing = append(missing, c.UserSearch.NameAttr)
}
}
if len(missing) != 0 {
err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing)
return connector.Identity{}, err
}
return ident, nil
}
func (c *ldapConnector) userEntry(conn *ldap.Conn, username string) (user ldap.Entry, found bool, err error) {
filter := fmt.Sprintf("(%s=%s)", c.UserSearch.Username, ldap.EscapeFilter(username))
if c.UserSearch.Filter != "" {
@@ -283,24 +315,40 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi
if c.UserSearch.NameAttr != "" {
req.Attributes = append(req.Attributes, c.UserSearch.NameAttr)
}
resp, err := conn.Search(req)
if err != nil {
return ldap.Entry{}, false, fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err)
}
err = c.do(func(conn *ldap.Conn) error {
resp, err := conn.Search(req)
switch n := len(resp.Entries); n {
case 0:
log.Printf("ldap: no results returned for filter: %q", filter)
return ldap.Entry{}, false, nil
case 1:
return *resp.Entries[0], true, nil
default:
return ldap.Entry{}, false, fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter)
}
}
func (c *ldapConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (ident connector.Identity, validPass bool, err error) {
var (
// We want to return a different error if the user's password is incorrect vs
// if there was an error.
incorrectPass = false
user ldap.Entry
)
err = c.do(ctx, func(conn *ldap.Conn) error {
entry, found, err := c.userEntry(conn, username)
if err != nil {
return fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err)
return err
}
switch n := len(resp.Entries); n {
case 0:
log.Printf("ldap: no results returned for filter: %q", filter)
if !found {
incorrectPass = true
return nil
case 1:
default:
return fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter)
}
user = *resp.Entries[0]
user = entry
// Try to authenticate as the distinguished name.
if err := conn.Bind(user.DN, password); err != nil {
@@ -323,44 +371,75 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi
return connector.Identity{}, false, nil
}
// Encode entry for follow up requests such as the groups query and
// refresh attempts.
if ident.ConnectorData, err = json.Marshal(user); err != nil {
return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err)
}
// If we're missing any attributes, such as email or ID, we want to report
// an error rather than continuing.
missing := []string{}
// Fill the identity struct using the attributes from the user entry.
if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" {
missing = append(missing, c.UserSearch.IDAttr)
}
if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" {
missing = append(missing, c.UserSearch.EmailAttr)
}
if c.UserSearch.NameAttr != "" {
if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" {
missing = append(missing, c.UserSearch.NameAttr)
}
}
if len(missing) != 0 {
err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing)
if ident, err = c.identityFromEntry(user); err != nil {
return connector.Identity{}, false, err
}
if s.Groups {
groups, err := c.groups(ctx, user)
if err != nil {
return connector.Identity{}, false, fmt.Errorf("ldap: failed to query groups: %v", err)
}
ident.Groups = groups
}
if s.OfflineAccess {
refresh := refreshData{
Username: username,
Entry: user,
}
// Encode entry for follow up requests such as the groups query and
// refresh attempts.
if ident.ConnectorData, err = json.Marshal(refresh); err != nil {
return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err)
}
}
return ident, true, nil
}
func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
// Decode the user entry from the identity.
var user ldap.Entry
if err := json.Unmarshal(ident.ConnectorData, &user); err != nil {
return nil, fmt.Errorf("ldap: failed to unmarshal connector data: %v", err)
func (c *ldapConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
var data refreshData
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
return ident, fmt.Errorf("ldap: failed to unamrshal internal data: %v", err)
}
var user ldap.Entry
err := c.do(ctx, func(conn *ldap.Conn) error {
entry, found, err := c.userEntry(conn, data.Username)
if err != nil {
return err
}
if !found {
return fmt.Errorf("ldap: user not found %q", data.Username)
}
user = entry
return nil
})
if err != nil {
return ident, err
}
if user.DN != data.Entry.DN {
return ident, fmt.Errorf("ldap: refresh for username %q expected DN %q got %q", data.Username, data.Entry.DN, user.DN)
}
newIdent, err := c.identityFromEntry(user)
if err != nil {
return ident, err
}
newIdent.ConnectorData = ident.ConnectorData
if s.Groups {
groups, err := c.groups(ctx, user)
if err != nil {
return connector.Identity{}, fmt.Errorf("ldap: failed to query groups: %v", err)
}
newIdent.Groups = groups
}
return newIdent, nil
}
func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string, error) {
filter := fmt.Sprintf("(%s=%s)", c.GroupSearch.GroupAttr, ldap.EscapeFilter(getAttr(user, c.GroupSearch.UserAttr)))
if c.GroupSearch.Filter != "" {
filter = fmt.Sprintf("(&%s%s)", c.GroupSearch.Filter, filter)
@@ -374,7 +453,7 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
}
var groups []*ldap.Entry
if err := c.do(func(conn *ldap.Conn) error {
if err := c.do(ctx, func(conn *ldap.Conn) error {
resp, err := conn.Search(req)
if err != nil {
return fmt.Errorf("ldap: search failed: %v", err)
@@ -406,7 +485,3 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
}
return groupNames, nil
}
func (c *ldapConnector) Close() error {
return nil
}

View File

@@ -2,12 +2,13 @@
package mock
import (
"bytes"
"errors"
"fmt"
"net/http"
"net/url"
"golang.org/x/net/context"
"github.com/coreos/dex/connector"
)
@@ -19,7 +20,6 @@ func NewCallbackConnector() connector.Connector {
var (
_ connector.CallbackConnector = callbackConnector{}
_ connector.GroupsConnector = callbackConnector{}
_ connector.PasswordConnector = passwordConnector{}
)
@@ -28,7 +28,7 @@ type callbackConnector struct{}
func (m callbackConnector) Close() error { return nil }
func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) {
func (m callbackConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
u, err := url.Parse(callbackURL)
if err != nil {
return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
@@ -41,23 +41,22 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) {
var connectorData = []byte("foobar")
func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, error) {
func (m callbackConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) {
var groups []string
if s.Groups {
groups = []string{"authors"}
}
return connector.Identity{
UserID: "0-385-28089-0",
Username: "Kilgore Trout",
Email: "kilgore@kilgore.trout",
EmailVerified: true,
Groups: groups,
ConnectorData: connectorData,
}, nil
}
func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) {
if !bytes.Equal(identity.ConnectorData, connectorData) {
return nil, errors.New("connector data mismatch")
}
return []string{"authors"}, nil
}
// CallbackConfig holds the configuration parameters for a connector which requires no interaction.
type CallbackConfig struct{}
@@ -91,7 +90,7 @@ type passwordConnector struct {
func (p passwordConnector) Close() error { return nil }
func (p passwordConnector) Login(username, password string) (identity connector.Identity, validPassword bool, err error) {
func (p passwordConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (identity connector.Identity, validPassword bool, err error) {
if username == p.username && password == p.password {
return connector.Identity{
UserID: "0-385-28089-0",

View File

@@ -75,7 +75,7 @@ func (c *oidcConnector) Close() error {
return nil
}
func (c *oidcConnector) LoginURL(callbackURL, state string) (string, error) {
func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
if c.redirectURI != callbackURL {
return "", fmt.Errorf("expected callback URL did not match the URL in the config")
}
@@ -94,7 +94,7 @@ func (e *oauth2Error) Error() string {
return e.error + ": " + e.errorDescription
}
func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) {
func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
return identity, &oauth2Error{errType, q.Get("error_description")}