diff --git a/connector/connector.go b/connector/connector.go index 9f84d3e6..95a7ec13 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -1,14 +1,25 @@ // Package connector defines interfaces for federated identity strategies. package connector -import "net/http" +import ( + "net/http" + + "golang.org/x/net/context" +) // Connector is a mechanism for federating login to a remote identity service. // // Implementations are expected to implement either the PasswordConnector or // CallbackConnector interface. -type Connector interface { - Close() error +type Connector interface{} + +// Scopes represents additional data requested by the clients about the end user. +type Scopes struct { + // The client has requested a refresh token from the server. + OfflineAccess bool + + // The client has requested group information about the end user. + Groups bool } // Identity represents the ID Token claims supported by the server. @@ -18,6 +29,8 @@ type Identity struct { Email string EmailVerified bool + Groups []string + // ConnectorData holds data used by the connector for subsequent requests after initial // authentication, such as access tokens for upstream provides. // @@ -25,18 +38,38 @@ type Identity struct { ConnectorData []byte } -// PasswordConnector is an optional interface for password based connectors. +// PasswordConnector is an interface implemented by connectors which take a +// username and password. type PasswordConnector interface { - Login(username, password string) (identity Identity, validPassword bool, err error) + Login(ctx context.Context, s Scopes, username, password string) (identity Identity, validPassword bool, err error) } -// CallbackConnector is an optional interface for callback based connectors. +// CallbackConnector is an interface implemented by connectors which use an OAuth +// style redirect flow to determine user information. type CallbackConnector interface { - LoginURL(callbackURL, state string) (string, error) - HandleCallback(r *http.Request) (identity Identity, err error) + // The initial URL to redirect the user to. + // + // OAuth2 implementations should request different scopes from the upstream + // identity provider based on the scopes requested by the downstream client. + // For example, if the downstream client requests a refresh token from the + // server, the connector should also request a token from the provider. + // + // Many identity providers have arbitrary restrictions on refresh tokens. For + // example Google only allows a single refresh token per client/user/scopes + // combination, and wont return a refresh token even if offline access is + // requested if one has already been issues. There's no good general answer + // for these kind of restrictions, and may require this package to become more + // aware of the global set of user/connector interactions. + LoginURL(s Scopes, callbackURL, state string) (string, error) + + // Handle the callback to the server and return an identity. + HandleCallback(s Scopes, r *http.Request) (identity Identity, err error) } -// GroupsConnector is an optional interface for connectors which can map a user to groups. -type GroupsConnector interface { - Groups(identity Identity) ([]string, error) +// RefreshConnector is a connector that can update the client claims. +type RefreshConnector interface { + // Refresh is called when a client attempts to claim a refresh token. The + // connector should attempt to update the identity object to reflect any + // changes since the token was last refreshed. + Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error) } diff --git a/connector/github/github.go b/connector/github/github.go index 9e14f60f..149883aa 100644 --- a/connector/github/github.go +++ b/connector/github/github.go @@ -3,6 +3,7 @@ package github import ( "encoding/json" + "errors" "fmt" "io/ioutil" "net/http" @@ -15,7 +16,11 @@ import ( "github.com/coreos/dex/connector" ) -const baseURL = "https://api.github.com" +const ( + baseURL = "https://api.github.com" + scopeEmail = "user:email" + scopeOrgs = "read:org" +) // Config holds configuration options for github logins. type Config struct { @@ -28,17 +33,10 @@ type Config struct { // Open returns a strategy for logging in through GitHub. func (c *Config) Open() (connector.Connector, error) { return &githubConnector{ - redirectURI: c.RedirectURI, - org: c.Org, - oauth2Config: &oauth2.Config{ - ClientID: c.ClientID, - ClientSecret: c.ClientSecret, - Endpoint: github.Endpoint, - Scopes: []string{ - "user:email", // View user's email - "read:org", // View user's org teams. - }, - }, + redirectURI: c.RedirectURI, + org: c.Org, + clientID: c.ClientID, + clientSecret: c.ClientSecret, }, nil } @@ -49,26 +47,36 @@ type connectorData struct { var ( _ connector.CallbackConnector = (*githubConnector)(nil) - _ connector.GroupsConnector = (*githubConnector)(nil) + _ connector.RefreshConnector = (*githubConnector)(nil) ) type githubConnector struct { redirectURI string org string - oauth2Config *oauth2.Config - ctx context.Context - cancel context.CancelFunc + clientID string + clientSecret string } -func (c *githubConnector) Close() error { - return nil +func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config { + var githubScopes []string + if scopes.Groups { + githubScopes = []string{scopeEmail, scopeOrgs} + } else { + githubScopes = []string{scopeEmail} + } + return &oauth2.Config{ + ClientID: c.clientID, + ClientSecret: c.clientSecret, + Endpoint: github.Endpoint, + Scopes: githubScopes, + } } -func (c *githubConnector) LoginURL(callbackURL, state string) (string, error) { +func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { if c.redirectURI != callbackURL { return "", fmt.Errorf("expected callback URL did not match the URL in the config") } - return c.oauth2Config.AuthCodeURL(state), nil + return c.oauth2Config(scopes).AuthCodeURL(state), nil } type oauth2Error struct { @@ -83,43 +91,25 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { +func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} } - token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code")) + + oauth2Config := c.oauth2Config(s) + ctx := r.Context() + + token, err := oauth2Config.Exchange(ctx, q.Get("code")) if err != nil { return identity, fmt.Errorf("github: failed to get token: %v", err) } - resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user") - if err != nil { - return identity, fmt.Errorf("github: get URL %v", err) - } - defer resp.Body.Close() + client := oauth2Config.Client(ctx, token) - if resp.StatusCode != http.StatusOK { - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return identity, fmt.Errorf("github: read body: %v", err) - } - return identity, fmt.Errorf("%s: %s", resp.Status, body) - } - var user struct { - Name string `json:"name"` - Login string `json:"login"` - ID int `json:"id"` - Email string `json:"email"` - } - if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { - return identity, fmt.Errorf("failed to decode response: %v", err) - } - - data := connectorData{AccessToken: token.AccessToken} - connData, err := json.Marshal(data) + user, err := c.user(ctx, client) if err != nil { - return identity, fmt.Errorf("marshal connector data: %v", err) + return identity, fmt.Errorf("github: get user: %v", err) } username := user.Name @@ -131,22 +121,114 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id Username: username, Email: user.Email, EmailVerified: true, - ConnectorData: connData, } + + if s.Groups && c.org != "" { + groups, err := c.teams(ctx, client, c.org) + if err != nil { + return identity, fmt.Errorf("github: get teams: %v", err) + } + identity.Groups = groups + } + + if s.OfflineAccess { + data := connectorData{AccessToken: token.AccessToken} + connData, err := json.Marshal(data) + if err != nil { + return identity, fmt.Errorf("marshal connector data: %v", err) + } + identity.ConnectorData = connData + } + return identity, nil } -func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) { - var data connectorData - if err := json.Unmarshal(identity.ConnectorData, &data); err != nil { - return nil, fmt.Errorf("decode connector data: %v", err) +func (c *githubConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) { + if len(ident.ConnectorData) == 0 { + return ident, errors.New("no upstream access token found") } - token := &oauth2.Token{AccessToken: data.AccessToken} - resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user/teams") + + var data connectorData + if err := json.Unmarshal(ident.ConnectorData, &data); err != nil { + return ident, fmt.Errorf("github: unmarshal access token: %v", err) + } + + client := c.oauth2Config(s).Client(ctx, &oauth2.Token{AccessToken: data.AccessToken}) + user, err := c.user(ctx, client) + if err != nil { + return ident, fmt.Errorf("github: get user: %v", err) + } + + username := user.Name + if username == "" { + username = user.Login + } + ident.Username = username + ident.Email = user.Email + + if s.Groups && c.org != "" { + groups, err := c.teams(ctx, client, c.org) + if err != nil { + return ident, fmt.Errorf("github: get teams: %v", err) + } + ident.Groups = groups + } + return ident, nil +} + +type user struct { + Name string `json:"name"` + Login string `json:"login"` + ID int `json:"id"` + Email string `json:"email"` +} + +// user queries the GitHub API for profile information using the provided client. The HTTP +// client is expected to be constructed by the golang.org/x/oauth2 package, which inserts +// a bearer token as part of the request. +func (c *githubConnector) user(ctx context.Context, client *http.Client) (user, error) { + var u user + req, err := http.NewRequest("GET", baseURL+"/user", nil) + if err != nil { + return u, fmt.Errorf("github: new req: %v", err) + } + req = req.WithContext(ctx) + resp, err := client.Do(req) + if err != nil { + return u, fmt.Errorf("github: get URL %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return u, fmt.Errorf("github: read body: %v", err) + } + return u, fmt.Errorf("%s: %s", resp.Status, body) + } + + if err := json.NewDecoder(resp.Body).Decode(&u); err != nil { + return u, fmt.Errorf("failed to decode response: %v", err) + } + return u, nil +} + +// teams queries the GitHub API for team membership within a specific organization. +// +// The HTTP passed client is expected to be constructed by the golang.org/x/oauth2 package, +// which inserts a bearer token as part of the request. +func (c *githubConnector) teams(ctx context.Context, client *http.Client, org string) ([]string, error) { + req, err := http.NewRequest("GET", baseURL+"/user/teams", nil) + if err != nil { + return nil, fmt.Errorf("github: new req: %v", err) + } + req = req.WithContext(ctx) + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("github: get teams: %v", err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { body, err := ioutil.ReadAll(resp.Body) if err != nil { @@ -167,7 +249,7 @@ func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) } groups := []string{} for _, team := range teams { - if team.Org.Login == c.org { + if team.Org.Login == org { groups = append(groups, team.Name) } } diff --git a/connector/ldap/ldap.go b/connector/ldap/ldap.go index 05588626..50688c35 100644 --- a/connector/ldap/ldap.go +++ b/connector/ldap/ldap.go @@ -10,6 +10,7 @@ import ( "log" "net" + "golang.org/x/net/context" "gopkg.in/ldap.v2" "github.com/coreos/dex/connector" @@ -57,6 +58,9 @@ type Config struct { // Required if LDAP host does not use TLS. InsecureNoSSL bool `json:"insecureNoSSL"` + // Don't verify the CA. + InsecureSkipVerify bool `json:"insecureSkipVerify"` + // Path to a trusted root certificate file. RootCA string `json:"rootCA"` @@ -139,11 +143,16 @@ func (c *Config) Open() (connector.Connector, error) { return connector.Connector(conn), nil } +type refreshData struct { + Username string `json:"username"` + Entry ldap.Entry `json:"entry"` +} + // OpenConnector is the same as Open but returns a type with all implemented connector interfaces. func (c *Config) OpenConnector() (interface { connector.Connector connector.PasswordConnector - connector.GroupsConnector + connector.RefreshConnector }, error) { requiredFields := []struct { @@ -174,7 +183,7 @@ func (c *Config) OpenConnector() (interface { } } - tlsConfig := &tls.Config{ServerName: host} + tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: c.InsecureSkipVerify} if c.RootCA != "" || len(c.RootCAData) != 0 { data := c.RootCAData if len(data) == 0 { @@ -209,12 +218,16 @@ type ldapConnector struct { tlsConfig *tls.Config } -var _ connector.PasswordConnector = (*ldapConnector)(nil) +var ( + _ connector.PasswordConnector = (*ldapConnector)(nil) + _ connector.RefreshConnector = (*ldapConnector)(nil) +) // do initializes a connection to the LDAP directory and passes it to the // provided function. It then performs appropriate teardown or reuse before // returning. -func (c *ldapConnector) do(f func(c *ldap.Conn) error) error { +func (c *ldapConnector) do(ctx context.Context, f func(c *ldap.Conn) error) error { + // TODO(ericchiang): support context here var ( conn *ldap.Conn err error @@ -253,13 +266,32 @@ func getAttr(e ldap.Entry, name string) string { return "" } -func (c *ldapConnector) Login(username, password string) (ident connector.Identity, validPass bool, err error) { - var ( - // We want to return a different error if the user's password is incorrect vs - // if there was an error. - incorrectPass = false - user ldap.Entry - ) +func (c *ldapConnector) identityFromEntry(user ldap.Entry) (ident connector.Identity, err error) { + // If we're missing any attributes, such as email or ID, we want to report + // an error rather than continuing. + missing := []string{} + + // Fill the identity struct using the attributes from the user entry. + if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" { + missing = append(missing, c.UserSearch.IDAttr) + } + if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" { + missing = append(missing, c.UserSearch.EmailAttr) + } + if c.UserSearch.NameAttr != "" { + if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" { + missing = append(missing, c.UserSearch.NameAttr) + } + } + + if len(missing) != 0 { + err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing) + return connector.Identity{}, err + } + return ident, nil +} + +func (c *ldapConnector) userEntry(conn *ldap.Conn, username string) (user ldap.Entry, found bool, err error) { filter := fmt.Sprintf("(%s=%s)", c.UserSearch.Username, ldap.EscapeFilter(username)) if c.UserSearch.Filter != "" { @@ -283,24 +315,40 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi if c.UserSearch.NameAttr != "" { req.Attributes = append(req.Attributes, c.UserSearch.NameAttr) } + resp, err := conn.Search(req) + if err != nil { + return ldap.Entry{}, false, fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err) + } - err = c.do(func(conn *ldap.Conn) error { - resp, err := conn.Search(req) + switch n := len(resp.Entries); n { + case 0: + log.Printf("ldap: no results returned for filter: %q", filter) + return ldap.Entry{}, false, nil + case 1: + return *resp.Entries[0], true, nil + default: + return ldap.Entry{}, false, fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter) + } +} + +func (c *ldapConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (ident connector.Identity, validPass bool, err error) { + var ( + // We want to return a different error if the user's password is incorrect vs + // if there was an error. + incorrectPass = false + user ldap.Entry + ) + + err = c.do(ctx, func(conn *ldap.Conn) error { + entry, found, err := c.userEntry(conn, username) if err != nil { - return fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err) + return err } - - switch n := len(resp.Entries); n { - case 0: - log.Printf("ldap: no results returned for filter: %q", filter) + if !found { incorrectPass = true return nil - case 1: - default: - return fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter) } - - user = *resp.Entries[0] + user = entry // Try to authenticate as the distinguished name. if err := conn.Bind(user.DN, password); err != nil { @@ -323,44 +371,75 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi return connector.Identity{}, false, nil } - // Encode entry for follow up requests such as the groups query and - // refresh attempts. - if ident.ConnectorData, err = json.Marshal(user); err != nil { - return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err) - } - - // If we're missing any attributes, such as email or ID, we want to report - // an error rather than continuing. - missing := []string{} - - // Fill the identity struct using the attributes from the user entry. - if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" { - missing = append(missing, c.UserSearch.IDAttr) - } - if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" { - missing = append(missing, c.UserSearch.EmailAttr) - } - if c.UserSearch.NameAttr != "" { - if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" { - missing = append(missing, c.UserSearch.NameAttr) - } - } - - if len(missing) != 0 { - err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing) + if ident, err = c.identityFromEntry(user); err != nil { return connector.Identity{}, false, err } + if s.Groups { + groups, err := c.groups(ctx, user) + if err != nil { + return connector.Identity{}, false, fmt.Errorf("ldap: failed to query groups: %v", err) + } + ident.Groups = groups + } + + if s.OfflineAccess { + refresh := refreshData{ + Username: username, + Entry: user, + } + // Encode entry for follow up requests such as the groups query and + // refresh attempts. + if ident.ConnectorData, err = json.Marshal(refresh); err != nil { + return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err) + } + } + return ident, true, nil } -func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) { - // Decode the user entry from the identity. - var user ldap.Entry - if err := json.Unmarshal(ident.ConnectorData, &user); err != nil { - return nil, fmt.Errorf("ldap: failed to unmarshal connector data: %v", err) +func (c *ldapConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) { + var data refreshData + if err := json.Unmarshal(ident.ConnectorData, &data); err != nil { + return ident, fmt.Errorf("ldap: failed to unamrshal internal data: %v", err) } + var user ldap.Entry + err := c.do(ctx, func(conn *ldap.Conn) error { + entry, found, err := c.userEntry(conn, data.Username) + if err != nil { + return err + } + if !found { + return fmt.Errorf("ldap: user not found %q", data.Username) + } + user = entry + return nil + }) + if err != nil { + return ident, err + } + if user.DN != data.Entry.DN { + return ident, fmt.Errorf("ldap: refresh for username %q expected DN %q got %q", data.Username, data.Entry.DN, user.DN) + } + + newIdent, err := c.identityFromEntry(user) + if err != nil { + return ident, err + } + newIdent.ConnectorData = ident.ConnectorData + + if s.Groups { + groups, err := c.groups(ctx, user) + if err != nil { + return connector.Identity{}, fmt.Errorf("ldap: failed to query groups: %v", err) + } + newIdent.Groups = groups + } + return newIdent, nil +} + +func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string, error) { filter := fmt.Sprintf("(%s=%s)", c.GroupSearch.GroupAttr, ldap.EscapeFilter(getAttr(user, c.GroupSearch.UserAttr))) if c.GroupSearch.Filter != "" { filter = fmt.Sprintf("(&%s%s)", c.GroupSearch.Filter, filter) @@ -374,7 +453,7 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) { } var groups []*ldap.Entry - if err := c.do(func(conn *ldap.Conn) error { + if err := c.do(ctx, func(conn *ldap.Conn) error { resp, err := conn.Search(req) if err != nil { return fmt.Errorf("ldap: search failed: %v", err) @@ -406,7 +485,3 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) { } return groupNames, nil } - -func (c *ldapConnector) Close() error { - return nil -} diff --git a/connector/mock/connectortest.go b/connector/mock/connectortest.go index 58b3e10b..3f3a1ffd 100644 --- a/connector/mock/connectortest.go +++ b/connector/mock/connectortest.go @@ -2,12 +2,13 @@ package mock import ( - "bytes" "errors" "fmt" "net/http" "net/url" + "golang.org/x/net/context" + "github.com/coreos/dex/connector" ) @@ -19,7 +20,6 @@ func NewCallbackConnector() connector.Connector { var ( _ connector.CallbackConnector = callbackConnector{} - _ connector.GroupsConnector = callbackConnector{} _ connector.PasswordConnector = passwordConnector{} ) @@ -28,7 +28,7 @@ type callbackConnector struct{} func (m callbackConnector) Close() error { return nil } -func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) { +func (m callbackConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { u, err := url.Parse(callbackURL) if err != nil { return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err) @@ -41,23 +41,22 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) { var connectorData = []byte("foobar") -func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, error) { +func (m callbackConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) { + var groups []string + if s.Groups { + groups = []string{"authors"} + } + return connector.Identity{ UserID: "0-385-28089-0", Username: "Kilgore Trout", Email: "kilgore@kilgore.trout", EmailVerified: true, + Groups: groups, ConnectorData: connectorData, }, nil } -func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) { - if !bytes.Equal(identity.ConnectorData, connectorData) { - return nil, errors.New("connector data mismatch") - } - return []string{"authors"}, nil -} - // CallbackConfig holds the configuration parameters for a connector which requires no interaction. type CallbackConfig struct{} @@ -91,7 +90,7 @@ type passwordConnector struct { func (p passwordConnector) Close() error { return nil } -func (p passwordConnector) Login(username, password string) (identity connector.Identity, validPassword bool, err error) { +func (p passwordConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (identity connector.Identity, validPassword bool, err error) { if username == p.username && password == p.password { return connector.Identity{ UserID: "0-385-28089-0", diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index cd9656f1..c9d88191 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -75,7 +75,7 @@ func (c *oidcConnector) Close() error { return nil } -func (c *oidcConnector) LoginURL(callbackURL, state string) (string, error) { +func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) { if c.redirectURI != callbackURL { return "", fmt.Errorf("expected callback URL did not match the URL in the config") } @@ -94,7 +94,7 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { +func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { return identity, &oauth2Error{errType, q.Get("error_description")} diff --git a/server/handlers.go b/server/handlers.go index 4fba9304..6ec1f4ff 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -179,7 +179,13 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { authReqID := r.FormValue("req") - // TODO(ericchiang): cache user identity. + authReq, err := s.storage.GetAuthRequest(authReqID) + if err != nil { + log.Printf("Failed to get auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, errServerError, "") + return + } + scopes := parseScopes(authReq.Scopes) switch r.Method { case "GET": @@ -199,7 +205,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { // Use the auth request ID as the "state" token. // // TODO(ericchiang): Is this appropriate or should we also be using a nonce? - callbackURL, err := conn.LoginURL(s.absURL("/callback"), authReqID) + callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID) if err != nil { log.Printf("Connector %q returned error when creating callback: %v", connID, err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -221,7 +227,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { username := r.FormValue("login") password := r.FormValue("password") - identity, ok, err := passwordConnector.Login(username, password) + identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password) if err != nil { log.Printf("Failed to login user: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -231,12 +237,6 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { s.templates.password(w, authReqID, r.URL.String(), username, true) return } - authReq, err := s.storage.GetAuthRequest(authReqID) - if err != nil { - log.Printf("Failed to get auth request: %v", err) - s.renderError(w, http.StatusInternalServerError, errServerError, "") - return - } redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) if err != nil { log.Printf("Failed to finalize login: %v", err) @@ -286,7 +286,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) return } - identity, err := callbackConnector.HandleCallback(r) + identity, err := callbackConnector.HandleCallback(parseScopes(authReq.Scopes), r) if err != nil { log.Printf("Failed to authenticate: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -304,34 +304,12 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) } func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) { - if authReq.ConnectorID == "" { - - } - claims := storage.Claims{ UserID: identity.UserID, Username: identity.Username, Email: identity.Email, EmailVerified: identity.EmailVerified, - } - - groupsConn, ok := conn.(connector.GroupsConnector) - if ok { - reqGroups := func() bool { - for _, scope := range authReq.Scopes { - if scope == scopeGroups { - return true - } - } - return false - }() - if reqGroups { - groups, err := groupsConn.Groups(identity) - if err != nil { - return "", fmt.Errorf("getting groups: %v", err) - } - claims.Groups = groups - } + Groups: identity.Groups, } updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { @@ -407,14 +385,15 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe switch responseType { case responseTypeCode: code := storage.AuthCode{ - ID: storage.NewID(), - ClientID: authReq.ClientID, - ConnectorID: authReq.ConnectorID, - Nonce: authReq.Nonce, - Scopes: authReq.Scopes, - Claims: authReq.Claims, - Expiry: s.now().Add(time.Minute * 30), - RedirectURI: authReq.RedirectURI, + ID: storage.NewID(), + ClientID: authReq.ClientID, + ConnectorID: authReq.ConnectorID, + Nonce: authReq.Nonce, + Scopes: authReq.Scopes, + Claims: authReq.Claims, + Expiry: s.now().Add(time.Minute * 30), + RedirectURI: authReq.RedirectURI, + ConnectorData: authReq.ConnectorData, } if err := s.storage.CreateAuthCode(code); err != nil { log.Printf("Failed to create auth code: %v", err) @@ -537,12 +516,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s var refreshToken string if reqRefresh { refresh := storage.RefreshToken{ - RefreshToken: storage.NewID(), - ClientID: authCode.ClientID, - ConnectorID: authCode.ConnectorID, - Scopes: authCode.Scopes, - Claims: authCode.Claims, - Nonce: authCode.Nonce, + RefreshToken: storage.NewID(), + ClientID: authCode.ClientID, + ConnectorID: authCode.ConnectorID, + Scopes: authCode.Scopes, + Claims: authCode.Claims, + Nonce: authCode.Nonce, + ConnectorData: authCode.ConnectorData, } if err := s.storage.CreateRefresh(refresh); err != nil { log.Printf("failed to create refresh token: %v", err) @@ -574,6 +554,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } + // Per the OAuth2 spec, if the client has omitted the scopes, default to the original + // authorized scopes. + // + // https://tools.ietf.org/html/rfc6749#section-6 scopes := refresh.Scopes if scope != "" { requestedScopes := strings.Fields(scope) @@ -601,7 +585,43 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie scopes = requestedScopes } - // TODO(ericchiang): re-auth with backends + conn, ok := s.connectors[refresh.ConnectorID] + if !ok { + log.Printf("connector ID not found: %q", refresh.ConnectorID) + tokenErr(w, errServerError, "", http.StatusInternalServerError) + return + } + + // Can the connector refresh the identity? If so, attempt to refresh the data + // in the connector. + // + // TODO(ericchiang): We may want a strict mode where connectors that don't implement + // this interface can't perform refreshing. + if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok { + ident := connector.Identity{ + UserID: refresh.Claims.UserID, + Username: refresh.Claims.Username, + Email: refresh.Claims.Email, + EmailVerified: refresh.Claims.EmailVerified, + Groups: refresh.Claims.Groups, + ConnectorData: refresh.ConnectorData, + } + ident, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident) + if err != nil { + log.Printf("failed to refresh identity: %v", err) + tokenErr(w, errServerError, "", http.StatusInternalServerError) + return + } + + // Update the claims of the refresh token. + // + // UserID intentionally ignored for now. + refresh.Claims.Username = ident.Username + refresh.Claims.Email = ident.Email + refresh.Claims.EmailVerified = ident.EmailVerified + refresh.Claims.Groups = ident.Groups + refresh.ConnectorData = ident.ConnectorData + } idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce) if err != nil { @@ -610,6 +630,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return } + // Refresh tokens are claimed exactly once. Delete the current token and + // create a new one. if err := s.storage.DeleteRefresh(code); err != nil { log.Printf("failed to delete auth code: %v", err) tokenErr(w, errServerError, "", http.StatusInternalServerError) diff --git a/server/oauth2.go b/server/oauth2.go index e8ace97d..c9ea6273 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/coreos/dex/connector" "github.com/coreos/dex/storage" ) @@ -93,6 +94,19 @@ const ( responseTypeIDToken = "id_token" // ID Token in url fragment ) +func parseScopes(scopes []string) connector.Scopes { + var s connector.Scopes + for _, scope := range scopes { + switch scope { + case scopeOfflineAccess: + s.OfflineAccess = true + case scopeGroups: + s.Groups = true + } + } + return s +} + type audience []string func (a audience) MarshalJSON() ([]byte, error) { diff --git a/server/server.go b/server/server.go index 3f347013..b8d2c8d3 100644 --- a/server/server.go +++ b/server/server.go @@ -211,9 +211,7 @@ type passwordDB struct { s storage.Storage } -func (db passwordDB) Close() error { return nil } - -func (db passwordDB) Login(email, password string) (connector.Identity, bool, error) { +func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, password string) (connector.Identity, bool, error) { p, err := db.s.GetPassword(email) if err != nil { if err != storage.ErrNotFound { @@ -233,6 +231,31 @@ func (db passwordDB) Login(email, password string) (connector.Identity, bool, er }, true, nil } +func (db passwordDB) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) { + // If the user has been deleted, the refresh token will be rejected. + p, err := db.s.GetPassword(identity.Email) + if err != nil { + if err == storage.ErrNotFound { + return connector.Identity{}, errors.New("user not found") + } + return connector.Identity{}, fmt.Errorf("get password: %v", err) + } + + // User removed but a new user with the same email exists. + if p.UserID != identity.UserID { + return connector.Identity{}, errors.New("user not found") + } + + // If a user has updated their username, that will be reflected in the + // refreshed token. + // + // No other fields are expected to be refreshable as email is effectively used + // as an ID and this implementation doesn't deal with groups. + identity.Username = p.Username + + return identity, nil +} + // newKeyCacher returns a storage which caches keys so long as the next func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage { if now == nil { diff --git a/server/server_test.go b/server/server_test.go index 44b1fd52..b3a02fab 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -662,7 +662,6 @@ func TestCrossClientScopes(t *testing.T) { func TestPasswordDB(t *testing.T) { s := memory.New() conn := newPasswordDB(s) - defer conn.Close() pw := "hi" @@ -712,7 +711,7 @@ func TestPasswordDB(t *testing.T) { } for _, tc := range tests { - ident, valid, err := conn.Login(tc.username, tc.password) + ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.username, tc.password) if err != nil { if !tc.wantErr { t.Errorf("%s: %v", tc.name, err)