connector: add RefreshConnector interface
This commit is contained in:
parent
27fb7c523e
commit
952e0f81f5
@ -1,14 +1,25 @@
|
|||||||
// Package connector defines interfaces for federated identity strategies.
|
// Package connector defines interfaces for federated identity strategies.
|
||||||
package connector
|
package connector
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
// Connector is a mechanism for federating login to a remote identity service.
|
// Connector is a mechanism for federating login to a remote identity service.
|
||||||
//
|
//
|
||||||
// Implementations are expected to implement either the PasswordConnector or
|
// Implementations are expected to implement either the PasswordConnector or
|
||||||
// CallbackConnector interface.
|
// CallbackConnector interface.
|
||||||
type Connector interface {
|
type Connector interface{}
|
||||||
Close() error
|
|
||||||
|
// Scopes represents additional data requested by the clients about the end user.
|
||||||
|
type Scopes struct {
|
||||||
|
// The client has requested a refresh token from the server.
|
||||||
|
OfflineAccess bool
|
||||||
|
|
||||||
|
// The client has requested group information about the end user.
|
||||||
|
Groups bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Identity represents the ID Token claims supported by the server.
|
// Identity represents the ID Token claims supported by the server.
|
||||||
@ -18,6 +29,8 @@ type Identity struct {
|
|||||||
Email string
|
Email string
|
||||||
EmailVerified bool
|
EmailVerified bool
|
||||||
|
|
||||||
|
Groups []string
|
||||||
|
|
||||||
// ConnectorData holds data used by the connector for subsequent requests after initial
|
// ConnectorData holds data used by the connector for subsequent requests after initial
|
||||||
// authentication, such as access tokens for upstream provides.
|
// authentication, such as access tokens for upstream provides.
|
||||||
//
|
//
|
||||||
@ -25,18 +38,38 @@ type Identity struct {
|
|||||||
ConnectorData []byte
|
ConnectorData []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// PasswordConnector is an optional interface for password based connectors.
|
// PasswordConnector is an interface implemented by connectors which take a
|
||||||
|
// username and password.
|
||||||
type PasswordConnector interface {
|
type PasswordConnector interface {
|
||||||
Login(username, password string) (identity Identity, validPassword bool, err error)
|
Login(ctx context.Context, s Scopes, username, password string) (identity Identity, validPassword bool, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CallbackConnector is an optional interface for callback based connectors.
|
// CallbackConnector is an interface implemented by connectors which use an OAuth
|
||||||
|
// style redirect flow to determine user information.
|
||||||
type CallbackConnector interface {
|
type CallbackConnector interface {
|
||||||
LoginURL(callbackURL, state string) (string, error)
|
// The initial URL to redirect the user to.
|
||||||
HandleCallback(r *http.Request) (identity Identity, err error)
|
//
|
||||||
|
// OAuth2 implementations should request different scopes from the upstream
|
||||||
|
// identity provider based on the scopes requested by the downstream client.
|
||||||
|
// For example, if the downstream client requests a refresh token from the
|
||||||
|
// server, the connector should also request a token from the provider.
|
||||||
|
//
|
||||||
|
// Many identity providers have arbitrary restrictions on refresh tokens. For
|
||||||
|
// example Google only allows a single refresh token per client/user/scopes
|
||||||
|
// combination, and wont return a refresh token even if offline access is
|
||||||
|
// requested if one has already been issues. There's no good general answer
|
||||||
|
// for these kind of restrictions, and may require this package to become more
|
||||||
|
// aware of the global set of user/connector interactions.
|
||||||
|
LoginURL(s Scopes, callbackURL, state string) (string, error)
|
||||||
|
|
||||||
|
// Handle the callback to the server and return an identity.
|
||||||
|
HandleCallback(s Scopes, r *http.Request) (identity Identity, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GroupsConnector is an optional interface for connectors which can map a user to groups.
|
// RefreshConnector is a connector that can update the client claims.
|
||||||
type GroupsConnector interface {
|
type RefreshConnector interface {
|
||||||
Groups(identity Identity) ([]string, error)
|
// Refresh is called when a client attempts to claim a refresh token. The
|
||||||
|
// connector should attempt to update the identity object to reflect any
|
||||||
|
// changes since the token was last refreshed.
|
||||||
|
Refresh(ctx context.Context, s Scopes, identity Identity) (Identity, error)
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package github
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -15,7 +16,11 @@ import (
|
|||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
)
|
)
|
||||||
|
|
||||||
const baseURL = "https://api.github.com"
|
const (
|
||||||
|
baseURL = "https://api.github.com"
|
||||||
|
scopeEmail = "user:email"
|
||||||
|
scopeOrgs = "read:org"
|
||||||
|
)
|
||||||
|
|
||||||
// Config holds configuration options for github logins.
|
// Config holds configuration options for github logins.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@ -30,15 +35,8 @@ func (c *Config) Open() (connector.Connector, error) {
|
|||||||
return &githubConnector{
|
return &githubConnector{
|
||||||
redirectURI: c.RedirectURI,
|
redirectURI: c.RedirectURI,
|
||||||
org: c.Org,
|
org: c.Org,
|
||||||
oauth2Config: &oauth2.Config{
|
clientID: c.ClientID,
|
||||||
ClientID: c.ClientID,
|
clientSecret: c.ClientSecret,
|
||||||
ClientSecret: c.ClientSecret,
|
|
||||||
Endpoint: github.Endpoint,
|
|
||||||
Scopes: []string{
|
|
||||||
"user:email", // View user's email
|
|
||||||
"read:org", // View user's org teams.
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -49,26 +47,36 @@ type connectorData struct {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
_ connector.CallbackConnector = (*githubConnector)(nil)
|
_ connector.CallbackConnector = (*githubConnector)(nil)
|
||||||
_ connector.GroupsConnector = (*githubConnector)(nil)
|
_ connector.RefreshConnector = (*githubConnector)(nil)
|
||||||
)
|
)
|
||||||
|
|
||||||
type githubConnector struct {
|
type githubConnector struct {
|
||||||
redirectURI string
|
redirectURI string
|
||||||
org string
|
org string
|
||||||
oauth2Config *oauth2.Config
|
clientID string
|
||||||
ctx context.Context
|
clientSecret string
|
||||||
cancel context.CancelFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *githubConnector) Close() error {
|
func (c *githubConnector) oauth2Config(scopes connector.Scopes) *oauth2.Config {
|
||||||
return nil
|
var githubScopes []string
|
||||||
|
if scopes.Groups {
|
||||||
|
githubScopes = []string{scopeEmail, scopeOrgs}
|
||||||
|
} else {
|
||||||
|
githubScopes = []string{scopeEmail}
|
||||||
|
}
|
||||||
|
return &oauth2.Config{
|
||||||
|
ClientID: c.clientID,
|
||||||
|
ClientSecret: c.clientSecret,
|
||||||
|
Endpoint: github.Endpoint,
|
||||||
|
Scopes: githubScopes,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *githubConnector) LoginURL(callbackURL, state string) (string, error) {
|
func (c *githubConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) {
|
||||||
if c.redirectURI != callbackURL {
|
if c.redirectURI != callbackURL {
|
||||||
return "", fmt.Errorf("expected callback URL did not match the URL in the config")
|
return "", fmt.Errorf("expected callback URL did not match the URL in the config")
|
||||||
}
|
}
|
||||||
return c.oauth2Config.AuthCodeURL(state), nil
|
return c.oauth2Config(scopes).AuthCodeURL(state), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauth2Error struct {
|
type oauth2Error struct {
|
||||||
@ -83,43 +91,25 @@ func (e *oauth2Error) Error() string {
|
|||||||
return e.error + ": " + e.errorDescription
|
return e.error + ": " + e.errorDescription
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) {
|
func (c *githubConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
if errType := q.Get("error"); errType != "" {
|
if errType := q.Get("error"); errType != "" {
|
||||||
return identity, &oauth2Error{errType, q.Get("error_description")}
|
return identity, &oauth2Error{errType, q.Get("error_description")}
|
||||||
}
|
}
|
||||||
token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code"))
|
|
||||||
|
oauth2Config := c.oauth2Config(s)
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
token, err := oauth2Config.Exchange(ctx, q.Get("code"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, fmt.Errorf("github: failed to get token: %v", err)
|
return identity, fmt.Errorf("github: failed to get token: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user")
|
client := oauth2Config.Client(ctx, token)
|
||||||
if err != nil {
|
|
||||||
return identity, fmt.Errorf("github: get URL %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
user, err := c.user(ctx, client)
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return identity, fmt.Errorf("github: read body: %v", err)
|
return identity, fmt.Errorf("github: get user: %v", err)
|
||||||
}
|
|
||||||
return identity, fmt.Errorf("%s: %s", resp.Status, body)
|
|
||||||
}
|
|
||||||
var user struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Login string `json:"login"`
|
|
||||||
ID int `json:"id"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
}
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
|
|
||||||
return identity, fmt.Errorf("failed to decode response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data := connectorData{AccessToken: token.AccessToken}
|
|
||||||
connData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
return identity, fmt.Errorf("marshal connector data: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
username := user.Name
|
username := user.Name
|
||||||
@ -131,22 +121,114 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id
|
|||||||
Username: username,
|
Username: username,
|
||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
EmailVerified: true,
|
EmailVerified: true,
|
||||||
ConnectorData: connData,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.Groups && c.org != "" {
|
||||||
|
groups, err := c.teams(ctx, client, c.org)
|
||||||
|
if err != nil {
|
||||||
|
return identity, fmt.Errorf("github: get teams: %v", err)
|
||||||
|
}
|
||||||
|
identity.Groups = groups
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.OfflineAccess {
|
||||||
|
data := connectorData{AccessToken: token.AccessToken}
|
||||||
|
connData, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return identity, fmt.Errorf("marshal connector data: %v", err)
|
||||||
|
}
|
||||||
|
identity.ConnectorData = connData
|
||||||
|
}
|
||||||
|
|
||||||
return identity, nil
|
return identity, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) {
|
func (c *githubConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
|
||||||
var data connectorData
|
if len(ident.ConnectorData) == 0 {
|
||||||
if err := json.Unmarshal(identity.ConnectorData, &data); err != nil {
|
return ident, errors.New("no upstream access token found")
|
||||||
return nil, fmt.Errorf("decode connector data: %v", err)
|
|
||||||
}
|
}
|
||||||
token := &oauth2.Token{AccessToken: data.AccessToken}
|
|
||||||
resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user/teams")
|
var data connectorData
|
||||||
|
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
|
||||||
|
return ident, fmt.Errorf("github: unmarshal access token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := c.oauth2Config(s).Client(ctx, &oauth2.Token{AccessToken: data.AccessToken})
|
||||||
|
user, err := c.user(ctx, client)
|
||||||
|
if err != nil {
|
||||||
|
return ident, fmt.Errorf("github: get user: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
username := user.Name
|
||||||
|
if username == "" {
|
||||||
|
username = user.Login
|
||||||
|
}
|
||||||
|
ident.Username = username
|
||||||
|
ident.Email = user.Email
|
||||||
|
|
||||||
|
if s.Groups && c.org != "" {
|
||||||
|
groups, err := c.teams(ctx, client, c.org)
|
||||||
|
if err != nil {
|
||||||
|
return ident, fmt.Errorf("github: get teams: %v", err)
|
||||||
|
}
|
||||||
|
ident.Groups = groups
|
||||||
|
}
|
||||||
|
return ident, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type user struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Login string `json:"login"`
|
||||||
|
ID int `json:"id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// user queries the GitHub API for profile information using the provided client. The HTTP
|
||||||
|
// client is expected to be constructed by the golang.org/x/oauth2 package, which inserts
|
||||||
|
// a bearer token as part of the request.
|
||||||
|
func (c *githubConnector) user(ctx context.Context, client *http.Client) (user, error) {
|
||||||
|
var u user
|
||||||
|
req, err := http.NewRequest("GET", baseURL+"/user", nil)
|
||||||
|
if err != nil {
|
||||||
|
return u, fmt.Errorf("github: new req: %v", err)
|
||||||
|
}
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return u, fmt.Errorf("github: get URL %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return u, fmt.Errorf("github: read body: %v", err)
|
||||||
|
}
|
||||||
|
return u, fmt.Errorf("%s: %s", resp.Status, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&u); err != nil {
|
||||||
|
return u, fmt.Errorf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// teams queries the GitHub API for team membership within a specific organization.
|
||||||
|
//
|
||||||
|
// The HTTP passed client is expected to be constructed by the golang.org/x/oauth2 package,
|
||||||
|
// which inserts a bearer token as part of the request.
|
||||||
|
func (c *githubConnector) teams(ctx context.Context, client *http.Client, org string) ([]string, error) {
|
||||||
|
req, err := http.NewRequest("GET", baseURL+"/user/teams", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("github: new req: %v", err)
|
||||||
|
}
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("github: get teams: %v", err)
|
return nil, fmt.Errorf("github: get teams: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -167,7 +249,7 @@ func (c *githubConnector) Groups(identity connector.Identity) ([]string, error)
|
|||||||
}
|
}
|
||||||
groups := []string{}
|
groups := []string{}
|
||||||
for _, team := range teams {
|
for _, team := range teams {
|
||||||
if team.Org.Login == c.org {
|
if team.Org.Login == org {
|
||||||
groups = append(groups, team.Name)
|
groups = append(groups, team.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
"gopkg.in/ldap.v2"
|
"gopkg.in/ldap.v2"
|
||||||
|
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
@ -57,6 +58,9 @@ type Config struct {
|
|||||||
// Required if LDAP host does not use TLS.
|
// Required if LDAP host does not use TLS.
|
||||||
InsecureNoSSL bool `json:"insecureNoSSL"`
|
InsecureNoSSL bool `json:"insecureNoSSL"`
|
||||||
|
|
||||||
|
// Don't verify the CA.
|
||||||
|
InsecureSkipVerify bool `json:"insecureSkipVerify"`
|
||||||
|
|
||||||
// Path to a trusted root certificate file.
|
// Path to a trusted root certificate file.
|
||||||
RootCA string `json:"rootCA"`
|
RootCA string `json:"rootCA"`
|
||||||
|
|
||||||
@ -139,11 +143,16 @@ func (c *Config) Open() (connector.Connector, error) {
|
|||||||
return connector.Connector(conn), nil
|
return connector.Connector(conn), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type refreshData struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Entry ldap.Entry `json:"entry"`
|
||||||
|
}
|
||||||
|
|
||||||
// OpenConnector is the same as Open but returns a type with all implemented connector interfaces.
|
// OpenConnector is the same as Open but returns a type with all implemented connector interfaces.
|
||||||
func (c *Config) OpenConnector() (interface {
|
func (c *Config) OpenConnector() (interface {
|
||||||
connector.Connector
|
connector.Connector
|
||||||
connector.PasswordConnector
|
connector.PasswordConnector
|
||||||
connector.GroupsConnector
|
connector.RefreshConnector
|
||||||
}, error) {
|
}, error) {
|
||||||
|
|
||||||
requiredFields := []struct {
|
requiredFields := []struct {
|
||||||
@ -174,7 +183,7 @@ func (c *Config) OpenConnector() (interface {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig := &tls.Config{ServerName: host}
|
tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: c.InsecureSkipVerify}
|
||||||
if c.RootCA != "" || len(c.RootCAData) != 0 {
|
if c.RootCA != "" || len(c.RootCAData) != 0 {
|
||||||
data := c.RootCAData
|
data := c.RootCAData
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
@ -209,12 +218,16 @@ type ldapConnector struct {
|
|||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ connector.PasswordConnector = (*ldapConnector)(nil)
|
var (
|
||||||
|
_ connector.PasswordConnector = (*ldapConnector)(nil)
|
||||||
|
_ connector.RefreshConnector = (*ldapConnector)(nil)
|
||||||
|
)
|
||||||
|
|
||||||
// do initializes a connection to the LDAP directory and passes it to the
|
// do initializes a connection to the LDAP directory and passes it to the
|
||||||
// provided function. It then performs appropriate teardown or reuse before
|
// provided function. It then performs appropriate teardown or reuse before
|
||||||
// returning.
|
// returning.
|
||||||
func (c *ldapConnector) do(f func(c *ldap.Conn) error) error {
|
func (c *ldapConnector) do(ctx context.Context, f func(c *ldap.Conn) error) error {
|
||||||
|
// TODO(ericchiang): support context here
|
||||||
var (
|
var (
|
||||||
conn *ldap.Conn
|
conn *ldap.Conn
|
||||||
err error
|
err error
|
||||||
@ -253,13 +266,32 @@ func getAttr(e ldap.Entry, name string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ldapConnector) Login(username, password string) (ident connector.Identity, validPass bool, err error) {
|
func (c *ldapConnector) identityFromEntry(user ldap.Entry) (ident connector.Identity, err error) {
|
||||||
var (
|
// If we're missing any attributes, such as email or ID, we want to report
|
||||||
// We want to return a different error if the user's password is incorrect vs
|
// an error rather than continuing.
|
||||||
// if there was an error.
|
missing := []string{}
|
||||||
incorrectPass = false
|
|
||||||
user ldap.Entry
|
// Fill the identity struct using the attributes from the user entry.
|
||||||
)
|
if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" {
|
||||||
|
missing = append(missing, c.UserSearch.IDAttr)
|
||||||
|
}
|
||||||
|
if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" {
|
||||||
|
missing = append(missing, c.UserSearch.EmailAttr)
|
||||||
|
}
|
||||||
|
if c.UserSearch.NameAttr != "" {
|
||||||
|
if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" {
|
||||||
|
missing = append(missing, c.UserSearch.NameAttr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(missing) != 0 {
|
||||||
|
err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing)
|
||||||
|
return connector.Identity{}, err
|
||||||
|
}
|
||||||
|
return ident, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ldapConnector) userEntry(conn *ldap.Conn, username string) (user ldap.Entry, found bool, err error) {
|
||||||
|
|
||||||
filter := fmt.Sprintf("(%s=%s)", c.UserSearch.Username, ldap.EscapeFilter(username))
|
filter := fmt.Sprintf("(%s=%s)", c.UserSearch.Username, ldap.EscapeFilter(username))
|
||||||
if c.UserSearch.Filter != "" {
|
if c.UserSearch.Filter != "" {
|
||||||
@ -283,24 +315,40 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi
|
|||||||
if c.UserSearch.NameAttr != "" {
|
if c.UserSearch.NameAttr != "" {
|
||||||
req.Attributes = append(req.Attributes, c.UserSearch.NameAttr)
|
req.Attributes = append(req.Attributes, c.UserSearch.NameAttr)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.do(func(conn *ldap.Conn) error {
|
|
||||||
resp, err := conn.Search(req)
|
resp, err := conn.Search(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err)
|
return ldap.Entry{}, false, fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch n := len(resp.Entries); n {
|
switch n := len(resp.Entries); n {
|
||||||
case 0:
|
case 0:
|
||||||
log.Printf("ldap: no results returned for filter: %q", filter)
|
log.Printf("ldap: no results returned for filter: %q", filter)
|
||||||
incorrectPass = true
|
return ldap.Entry{}, false, nil
|
||||||
return nil
|
|
||||||
case 1:
|
case 1:
|
||||||
|
return *resp.Entries[0], true, nil
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter)
|
return ldap.Entry{}, false, fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
user = *resp.Entries[0]
|
func (c *ldapConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (ident connector.Identity, validPass bool, err error) {
|
||||||
|
var (
|
||||||
|
// We want to return a different error if the user's password is incorrect vs
|
||||||
|
// if there was an error.
|
||||||
|
incorrectPass = false
|
||||||
|
user ldap.Entry
|
||||||
|
)
|
||||||
|
|
||||||
|
err = c.do(ctx, func(conn *ldap.Conn) error {
|
||||||
|
entry, found, err := c.userEntry(conn, username)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
incorrectPass = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
user = entry
|
||||||
|
|
||||||
// Try to authenticate as the distinguished name.
|
// Try to authenticate as the distinguished name.
|
||||||
if err := conn.Bind(user.DN, password); err != nil {
|
if err := conn.Bind(user.DN, password); err != nil {
|
||||||
@ -323,44 +371,75 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi
|
|||||||
return connector.Identity{}, false, nil
|
return connector.Identity{}, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if ident, err = c.identityFromEntry(user); err != nil {
|
||||||
|
return connector.Identity{}, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Groups {
|
||||||
|
groups, err := c.groups(ctx, user)
|
||||||
|
if err != nil {
|
||||||
|
return connector.Identity{}, false, fmt.Errorf("ldap: failed to query groups: %v", err)
|
||||||
|
}
|
||||||
|
ident.Groups = groups
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.OfflineAccess {
|
||||||
|
refresh := refreshData{
|
||||||
|
Username: username,
|
||||||
|
Entry: user,
|
||||||
|
}
|
||||||
// Encode entry for follow up requests such as the groups query and
|
// Encode entry for follow up requests such as the groups query and
|
||||||
// refresh attempts.
|
// refresh attempts.
|
||||||
if ident.ConnectorData, err = json.Marshal(user); err != nil {
|
if ident.ConnectorData, err = json.Marshal(refresh); err != nil {
|
||||||
return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err)
|
return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we're missing any attributes, such as email or ID, we want to report
|
|
||||||
// an error rather than continuing.
|
|
||||||
missing := []string{}
|
|
||||||
|
|
||||||
// Fill the identity struct using the attributes from the user entry.
|
|
||||||
if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" {
|
|
||||||
missing = append(missing, c.UserSearch.IDAttr)
|
|
||||||
}
|
|
||||||
if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" {
|
|
||||||
missing = append(missing, c.UserSearch.EmailAttr)
|
|
||||||
}
|
|
||||||
if c.UserSearch.NameAttr != "" {
|
|
||||||
if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" {
|
|
||||||
missing = append(missing, c.UserSearch.NameAttr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(missing) != 0 {
|
|
||||||
err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing)
|
|
||||||
return connector.Identity{}, false, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ident, true, nil
|
return ident, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
|
func (c *ldapConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
|
||||||
// Decode the user entry from the identity.
|
var data refreshData
|
||||||
var user ldap.Entry
|
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
|
||||||
if err := json.Unmarshal(ident.ConnectorData, &user); err != nil {
|
return ident, fmt.Errorf("ldap: failed to unamrshal internal data: %v", err)
|
||||||
return nil, fmt.Errorf("ldap: failed to unmarshal connector data: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var user ldap.Entry
|
||||||
|
err := c.do(ctx, func(conn *ldap.Conn) error {
|
||||||
|
entry, found, err := c.userEntry(conn, data.Username)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("ldap: user not found %q", data.Username)
|
||||||
|
}
|
||||||
|
user = entry
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return ident, err
|
||||||
|
}
|
||||||
|
if user.DN != data.Entry.DN {
|
||||||
|
return ident, fmt.Errorf("ldap: refresh for username %q expected DN %q got %q", data.Username, data.Entry.DN, user.DN)
|
||||||
|
}
|
||||||
|
|
||||||
|
newIdent, err := c.identityFromEntry(user)
|
||||||
|
if err != nil {
|
||||||
|
return ident, err
|
||||||
|
}
|
||||||
|
newIdent.ConnectorData = ident.ConnectorData
|
||||||
|
|
||||||
|
if s.Groups {
|
||||||
|
groups, err := c.groups(ctx, user)
|
||||||
|
if err != nil {
|
||||||
|
return connector.Identity{}, fmt.Errorf("ldap: failed to query groups: %v", err)
|
||||||
|
}
|
||||||
|
newIdent.Groups = groups
|
||||||
|
}
|
||||||
|
return newIdent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string, error) {
|
||||||
filter := fmt.Sprintf("(%s=%s)", c.GroupSearch.GroupAttr, ldap.EscapeFilter(getAttr(user, c.GroupSearch.UserAttr)))
|
filter := fmt.Sprintf("(%s=%s)", c.GroupSearch.GroupAttr, ldap.EscapeFilter(getAttr(user, c.GroupSearch.UserAttr)))
|
||||||
if c.GroupSearch.Filter != "" {
|
if c.GroupSearch.Filter != "" {
|
||||||
filter = fmt.Sprintf("(&%s%s)", c.GroupSearch.Filter, filter)
|
filter = fmt.Sprintf("(&%s%s)", c.GroupSearch.Filter, filter)
|
||||||
@ -374,7 +453,7 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var groups []*ldap.Entry
|
var groups []*ldap.Entry
|
||||||
if err := c.do(func(conn *ldap.Conn) error {
|
if err := c.do(ctx, func(conn *ldap.Conn) error {
|
||||||
resp, err := conn.Search(req)
|
resp, err := conn.Search(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("ldap: search failed: %v", err)
|
return fmt.Errorf("ldap: search failed: %v", err)
|
||||||
@ -406,7 +485,3 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
|
|||||||
}
|
}
|
||||||
return groupNames, nil
|
return groupNames, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ldapConnector) Close() error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -2,12 +2,13 @@
|
|||||||
package mock
|
package mock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
|
||||||
"github.com/coreos/dex/connector"
|
"github.com/coreos/dex/connector"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,7 +20,6 @@ func NewCallbackConnector() connector.Connector {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
_ connector.CallbackConnector = callbackConnector{}
|
_ connector.CallbackConnector = callbackConnector{}
|
||||||
_ connector.GroupsConnector = callbackConnector{}
|
|
||||||
|
|
||||||
_ connector.PasswordConnector = passwordConnector{}
|
_ connector.PasswordConnector = passwordConnector{}
|
||||||
)
|
)
|
||||||
@ -28,7 +28,7 @@ type callbackConnector struct{}
|
|||||||
|
|
||||||
func (m callbackConnector) Close() error { return nil }
|
func (m callbackConnector) Close() error { return nil }
|
||||||
|
|
||||||
func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) {
|
func (m callbackConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
|
||||||
u, err := url.Parse(callbackURL)
|
u, err := url.Parse(callbackURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
|
return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
|
||||||
@ -41,23 +41,22 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) {
|
|||||||
|
|
||||||
var connectorData = []byte("foobar")
|
var connectorData = []byte("foobar")
|
||||||
|
|
||||||
func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, error) {
|
func (m callbackConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) {
|
||||||
|
var groups []string
|
||||||
|
if s.Groups {
|
||||||
|
groups = []string{"authors"}
|
||||||
|
}
|
||||||
|
|
||||||
return connector.Identity{
|
return connector.Identity{
|
||||||
UserID: "0-385-28089-0",
|
UserID: "0-385-28089-0",
|
||||||
Username: "Kilgore Trout",
|
Username: "Kilgore Trout",
|
||||||
Email: "kilgore@kilgore.trout",
|
Email: "kilgore@kilgore.trout",
|
||||||
EmailVerified: true,
|
EmailVerified: true,
|
||||||
|
Groups: groups,
|
||||||
ConnectorData: connectorData,
|
ConnectorData: connectorData,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) {
|
|
||||||
if !bytes.Equal(identity.ConnectorData, connectorData) {
|
|
||||||
return nil, errors.New("connector data mismatch")
|
|
||||||
}
|
|
||||||
return []string{"authors"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CallbackConfig holds the configuration parameters for a connector which requires no interaction.
|
// CallbackConfig holds the configuration parameters for a connector which requires no interaction.
|
||||||
type CallbackConfig struct{}
|
type CallbackConfig struct{}
|
||||||
|
|
||||||
@ -91,7 +90,7 @@ type passwordConnector struct {
|
|||||||
|
|
||||||
func (p passwordConnector) Close() error { return nil }
|
func (p passwordConnector) Close() error { return nil }
|
||||||
|
|
||||||
func (p passwordConnector) Login(username, password string) (identity connector.Identity, validPassword bool, err error) {
|
func (p passwordConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (identity connector.Identity, validPassword bool, err error) {
|
||||||
if username == p.username && password == p.password {
|
if username == p.username && password == p.password {
|
||||||
return connector.Identity{
|
return connector.Identity{
|
||||||
UserID: "0-385-28089-0",
|
UserID: "0-385-28089-0",
|
||||||
|
@ -75,7 +75,7 @@ func (c *oidcConnector) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *oidcConnector) LoginURL(callbackURL, state string) (string, error) {
|
func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
|
||||||
if c.redirectURI != callbackURL {
|
if c.redirectURI != callbackURL {
|
||||||
return "", fmt.Errorf("expected callback URL did not match the URL in the config")
|
return "", fmt.Errorf("expected callback URL did not match the URL in the config")
|
||||||
}
|
}
|
||||||
@ -94,7 +94,7 @@ func (e *oauth2Error) Error() string {
|
|||||||
return e.error + ": " + e.errorDescription
|
return e.error + ": " + e.errorDescription
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) {
|
func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
|
||||||
q := r.URL.Query()
|
q := r.URL.Query()
|
||||||
if errType := q.Get("error"); errType != "" {
|
if errType := q.Get("error"); errType != "" {
|
||||||
return identity, &oauth2Error{errType, q.Get("error_description")}
|
return identity, &oauth2Error{errType, q.Get("error_description")}
|
||||||
|
@ -179,7 +179,13 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
authReqID := r.FormValue("req")
|
authReqID := r.FormValue("req")
|
||||||
|
|
||||||
// TODO(ericchiang): cache user identity.
|
authReq, err := s.storage.GetAuthRequest(authReqID)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to get auth request: %v", err)
|
||||||
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
scopes := parseScopes(authReq.Scopes)
|
||||||
|
|
||||||
switch r.Method {
|
switch r.Method {
|
||||||
case "GET":
|
case "GET":
|
||||||
@ -199,7 +205,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
// Use the auth request ID as the "state" token.
|
// Use the auth request ID as the "state" token.
|
||||||
//
|
//
|
||||||
// TODO(ericchiang): Is this appropriate or should we also be using a nonce?
|
// TODO(ericchiang): Is this appropriate or should we also be using a nonce?
|
||||||
callbackURL, err := conn.LoginURL(s.absURL("/callback"), authReqID)
|
callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReqID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Connector %q returned error when creating callback: %v", connID, err)
|
log.Printf("Connector %q returned error when creating callback: %v", connID, err)
|
||||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
@ -221,7 +227,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
username := r.FormValue("login")
|
username := r.FormValue("login")
|
||||||
password := r.FormValue("password")
|
password := r.FormValue("password")
|
||||||
|
|
||||||
identity, ok, err := passwordConnector.Login(username, password)
|
identity, ok, err := passwordConnector.Login(r.Context(), scopes, username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to login user: %v", err)
|
log.Printf("Failed to login user: %v", err)
|
||||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
@ -231,12 +237,6 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
s.templates.password(w, authReqID, r.URL.String(), username, true)
|
s.templates.password(w, authReqID, r.URL.String(), username, true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
authReq, err := s.storage.GetAuthRequest(authReqID)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to get auth request: %v", err)
|
|
||||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
|
redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to finalize login: %v", err)
|
log.Printf("Failed to finalize login: %v", err)
|
||||||
@ -286,7 +286,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
identity, err := callbackConnector.HandleCallback(r)
|
identity, err := callbackConnector.HandleCallback(parseScopes(authReq.Scopes), r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to authenticate: %v", err)
|
log.Printf("Failed to authenticate: %v", err)
|
||||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||||
@ -304,34 +304,12 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) {
|
func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) {
|
||||||
if authReq.ConnectorID == "" {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := storage.Claims{
|
claims := storage.Claims{
|
||||||
UserID: identity.UserID,
|
UserID: identity.UserID,
|
||||||
Username: identity.Username,
|
Username: identity.Username,
|
||||||
Email: identity.Email,
|
Email: identity.Email,
|
||||||
EmailVerified: identity.EmailVerified,
|
EmailVerified: identity.EmailVerified,
|
||||||
}
|
Groups: identity.Groups,
|
||||||
|
|
||||||
groupsConn, ok := conn.(connector.GroupsConnector)
|
|
||||||
if ok {
|
|
||||||
reqGroups := func() bool {
|
|
||||||
for _, scope := range authReq.Scopes {
|
|
||||||
if scope == scopeGroups {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}()
|
|
||||||
if reqGroups {
|
|
||||||
groups, err := groupsConn.Groups(identity)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("getting groups: %v", err)
|
|
||||||
}
|
|
||||||
claims.Groups = groups
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
|
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
|
||||||
@ -415,6 +393,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
|
|||||||
Claims: authReq.Claims,
|
Claims: authReq.Claims,
|
||||||
Expiry: s.now().Add(time.Minute * 30),
|
Expiry: s.now().Add(time.Minute * 30),
|
||||||
RedirectURI: authReq.RedirectURI,
|
RedirectURI: authReq.RedirectURI,
|
||||||
|
ConnectorData: authReq.ConnectorData,
|
||||||
}
|
}
|
||||||
if err := s.storage.CreateAuthCode(code); err != nil {
|
if err := s.storage.CreateAuthCode(code); err != nil {
|
||||||
log.Printf("Failed to create auth code: %v", err)
|
log.Printf("Failed to create auth code: %v", err)
|
||||||
@ -543,6 +522,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||||||
Scopes: authCode.Scopes,
|
Scopes: authCode.Scopes,
|
||||||
Claims: authCode.Claims,
|
Claims: authCode.Claims,
|
||||||
Nonce: authCode.Nonce,
|
Nonce: authCode.Nonce,
|
||||||
|
ConnectorData: authCode.ConnectorData,
|
||||||
}
|
}
|
||||||
if err := s.storage.CreateRefresh(refresh); err != nil {
|
if err := s.storage.CreateRefresh(refresh); err != nil {
|
||||||
log.Printf("failed to create refresh token: %v", err)
|
log.Printf("failed to create refresh token: %v", err)
|
||||||
@ -574,6 +554,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Per the OAuth2 spec, if the client has omitted the scopes, default to the original
|
||||||
|
// authorized scopes.
|
||||||
|
//
|
||||||
|
// https://tools.ietf.org/html/rfc6749#section-6
|
||||||
scopes := refresh.Scopes
|
scopes := refresh.Scopes
|
||||||
if scope != "" {
|
if scope != "" {
|
||||||
requestedScopes := strings.Fields(scope)
|
requestedScopes := strings.Fields(scope)
|
||||||
@ -601,7 +585,43 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||||||
scopes = requestedScopes
|
scopes = requestedScopes
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(ericchiang): re-auth with backends
|
conn, ok := s.connectors[refresh.ConnectorID]
|
||||||
|
if !ok {
|
||||||
|
log.Printf("connector ID not found: %q", refresh.ConnectorID)
|
||||||
|
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can the connector refresh the identity? If so, attempt to refresh the data
|
||||||
|
// in the connector.
|
||||||
|
//
|
||||||
|
// TODO(ericchiang): We may want a strict mode where connectors that don't implement
|
||||||
|
// this interface can't perform refreshing.
|
||||||
|
if refreshConn, ok := conn.Connector.(connector.RefreshConnector); ok {
|
||||||
|
ident := connector.Identity{
|
||||||
|
UserID: refresh.Claims.UserID,
|
||||||
|
Username: refresh.Claims.Username,
|
||||||
|
Email: refresh.Claims.Email,
|
||||||
|
EmailVerified: refresh.Claims.EmailVerified,
|
||||||
|
Groups: refresh.Claims.Groups,
|
||||||
|
ConnectorData: refresh.ConnectorData,
|
||||||
|
}
|
||||||
|
ident, err := refreshConn.Refresh(r.Context(), parseScopes(scopes), ident)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("failed to refresh identity: %v", err)
|
||||||
|
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the claims of the refresh token.
|
||||||
|
//
|
||||||
|
// UserID intentionally ignored for now.
|
||||||
|
refresh.Claims.Username = ident.Username
|
||||||
|
refresh.Claims.Email = ident.Email
|
||||||
|
refresh.Claims.EmailVerified = ident.EmailVerified
|
||||||
|
refresh.Claims.Groups = ident.Groups
|
||||||
|
refresh.ConnectorData = ident.ConnectorData
|
||||||
|
}
|
||||||
|
|
||||||
idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce)
|
idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -610,6 +630,8 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Refresh tokens are claimed exactly once. Delete the current token and
|
||||||
|
// create a new one.
|
||||||
if err := s.storage.DeleteRefresh(code); err != nil {
|
if err := s.storage.DeleteRefresh(code); err != nil {
|
||||||
log.Printf("failed to delete auth code: %v", err)
|
log.Printf("failed to delete auth code: %v", err)
|
||||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/dex/connector"
|
||||||
"github.com/coreos/dex/storage"
|
"github.com/coreos/dex/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -93,6 +94,19 @@ const (
|
|||||||
responseTypeIDToken = "id_token" // ID Token in url fragment
|
responseTypeIDToken = "id_token" // ID Token in url fragment
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func parseScopes(scopes []string) connector.Scopes {
|
||||||
|
var s connector.Scopes
|
||||||
|
for _, scope := range scopes {
|
||||||
|
switch scope {
|
||||||
|
case scopeOfflineAccess:
|
||||||
|
s.OfflineAccess = true
|
||||||
|
case scopeGroups:
|
||||||
|
s.Groups = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
type audience []string
|
type audience []string
|
||||||
|
|
||||||
func (a audience) MarshalJSON() ([]byte, error) {
|
func (a audience) MarshalJSON() ([]byte, error) {
|
||||||
|
@ -211,9 +211,7 @@ type passwordDB struct {
|
|||||||
s storage.Storage
|
s storage.Storage
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db passwordDB) Close() error { return nil }
|
func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, password string) (connector.Identity, bool, error) {
|
||||||
|
|
||||||
func (db passwordDB) Login(email, password string) (connector.Identity, bool, error) {
|
|
||||||
p, err := db.s.GetPassword(email)
|
p, err := db.s.GetPassword(email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != storage.ErrNotFound {
|
if err != storage.ErrNotFound {
|
||||||
@ -233,6 +231,31 @@ func (db passwordDB) Login(email, password string) (connector.Identity, bool, er
|
|||||||
}, true, nil
|
}, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (db passwordDB) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
|
||||||
|
// If the user has been deleted, the refresh token will be rejected.
|
||||||
|
p, err := db.s.GetPassword(identity.Email)
|
||||||
|
if err != nil {
|
||||||
|
if err == storage.ErrNotFound {
|
||||||
|
return connector.Identity{}, errors.New("user not found")
|
||||||
|
}
|
||||||
|
return connector.Identity{}, fmt.Errorf("get password: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// User removed but a new user with the same email exists.
|
||||||
|
if p.UserID != identity.UserID {
|
||||||
|
return connector.Identity{}, errors.New("user not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a user has updated their username, that will be reflected in the
|
||||||
|
// refreshed token.
|
||||||
|
//
|
||||||
|
// No other fields are expected to be refreshable as email is effectively used
|
||||||
|
// as an ID and this implementation doesn't deal with groups.
|
||||||
|
identity.Username = p.Username
|
||||||
|
|
||||||
|
return identity, nil
|
||||||
|
}
|
||||||
|
|
||||||
// newKeyCacher returns a storage which caches keys so long as the next
|
// newKeyCacher returns a storage which caches keys so long as the next
|
||||||
func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage {
|
func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage {
|
||||||
if now == nil {
|
if now == nil {
|
||||||
|
@ -662,7 +662,6 @@ func TestCrossClientScopes(t *testing.T) {
|
|||||||
func TestPasswordDB(t *testing.T) {
|
func TestPasswordDB(t *testing.T) {
|
||||||
s := memory.New()
|
s := memory.New()
|
||||||
conn := newPasswordDB(s)
|
conn := newPasswordDB(s)
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
pw := "hi"
|
pw := "hi"
|
||||||
|
|
||||||
@ -712,7 +711,7 @@ func TestPasswordDB(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
ident, valid, err := conn.Login(tc.username, tc.password)
|
ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.username, tc.password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !tc.wantErr {
|
if !tc.wantErr {
|
||||||
t.Errorf("%s: %v", tc.name, err)
|
t.Errorf("%s: %v", tc.name, err)
|
||||||
|
Reference in New Issue
Block a user