connector: add RefreshConnector interface
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"gopkg.in/ldap.v2"
|
||||
|
||||
"github.com/coreos/dex/connector"
|
||||
@@ -57,6 +58,9 @@ type Config struct {
|
||||
// Required if LDAP host does not use TLS.
|
||||
InsecureNoSSL bool `json:"insecureNoSSL"`
|
||||
|
||||
// Don't verify the CA.
|
||||
InsecureSkipVerify bool `json:"insecureSkipVerify"`
|
||||
|
||||
// Path to a trusted root certificate file.
|
||||
RootCA string `json:"rootCA"`
|
||||
|
||||
@@ -139,11 +143,16 @@ func (c *Config) Open() (connector.Connector, error) {
|
||||
return connector.Connector(conn), nil
|
||||
}
|
||||
|
||||
type refreshData struct {
|
||||
Username string `json:"username"`
|
||||
Entry ldap.Entry `json:"entry"`
|
||||
}
|
||||
|
||||
// OpenConnector is the same as Open but returns a type with all implemented connector interfaces.
|
||||
func (c *Config) OpenConnector() (interface {
|
||||
connector.Connector
|
||||
connector.PasswordConnector
|
||||
connector.GroupsConnector
|
||||
connector.RefreshConnector
|
||||
}, error) {
|
||||
|
||||
requiredFields := []struct {
|
||||
@@ -174,7 +183,7 @@ func (c *Config) OpenConnector() (interface {
|
||||
}
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{ServerName: host}
|
||||
tlsConfig := &tls.Config{ServerName: host, InsecureSkipVerify: c.InsecureSkipVerify}
|
||||
if c.RootCA != "" || len(c.RootCAData) != 0 {
|
||||
data := c.RootCAData
|
||||
if len(data) == 0 {
|
||||
@@ -209,12 +218,16 @@ type ldapConnector struct {
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
var _ connector.PasswordConnector = (*ldapConnector)(nil)
|
||||
var (
|
||||
_ connector.PasswordConnector = (*ldapConnector)(nil)
|
||||
_ connector.RefreshConnector = (*ldapConnector)(nil)
|
||||
)
|
||||
|
||||
// do initializes a connection to the LDAP directory and passes it to the
|
||||
// provided function. It then performs appropriate teardown or reuse before
|
||||
// returning.
|
||||
func (c *ldapConnector) do(f func(c *ldap.Conn) error) error {
|
||||
func (c *ldapConnector) do(ctx context.Context, f func(c *ldap.Conn) error) error {
|
||||
// TODO(ericchiang): support context here
|
||||
var (
|
||||
conn *ldap.Conn
|
||||
err error
|
||||
@@ -253,13 +266,32 @@ func getAttr(e ldap.Entry, name string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *ldapConnector) Login(username, password string) (ident connector.Identity, validPass bool, err error) {
|
||||
var (
|
||||
// We want to return a different error if the user's password is incorrect vs
|
||||
// if there was an error.
|
||||
incorrectPass = false
|
||||
user ldap.Entry
|
||||
)
|
||||
func (c *ldapConnector) identityFromEntry(user ldap.Entry) (ident connector.Identity, err error) {
|
||||
// If we're missing any attributes, such as email or ID, we want to report
|
||||
// an error rather than continuing.
|
||||
missing := []string{}
|
||||
|
||||
// Fill the identity struct using the attributes from the user entry.
|
||||
if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" {
|
||||
missing = append(missing, c.UserSearch.IDAttr)
|
||||
}
|
||||
if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" {
|
||||
missing = append(missing, c.UserSearch.EmailAttr)
|
||||
}
|
||||
if c.UserSearch.NameAttr != "" {
|
||||
if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" {
|
||||
missing = append(missing, c.UserSearch.NameAttr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) != 0 {
|
||||
err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing)
|
||||
return connector.Identity{}, err
|
||||
}
|
||||
return ident, nil
|
||||
}
|
||||
|
||||
func (c *ldapConnector) userEntry(conn *ldap.Conn, username string) (user ldap.Entry, found bool, err error) {
|
||||
|
||||
filter := fmt.Sprintf("(%s=%s)", c.UserSearch.Username, ldap.EscapeFilter(username))
|
||||
if c.UserSearch.Filter != "" {
|
||||
@@ -283,24 +315,40 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi
|
||||
if c.UserSearch.NameAttr != "" {
|
||||
req.Attributes = append(req.Attributes, c.UserSearch.NameAttr)
|
||||
}
|
||||
resp, err := conn.Search(req)
|
||||
if err != nil {
|
||||
return ldap.Entry{}, false, fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err)
|
||||
}
|
||||
|
||||
err = c.do(func(conn *ldap.Conn) error {
|
||||
resp, err := conn.Search(req)
|
||||
switch n := len(resp.Entries); n {
|
||||
case 0:
|
||||
log.Printf("ldap: no results returned for filter: %q", filter)
|
||||
return ldap.Entry{}, false, nil
|
||||
case 1:
|
||||
return *resp.Entries[0], true, nil
|
||||
default:
|
||||
return ldap.Entry{}, false, fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ldapConnector) Login(ctx context.Context, s connector.Scopes, username, password string) (ident connector.Identity, validPass bool, err error) {
|
||||
var (
|
||||
// We want to return a different error if the user's password is incorrect vs
|
||||
// if there was an error.
|
||||
incorrectPass = false
|
||||
user ldap.Entry
|
||||
)
|
||||
|
||||
err = c.do(ctx, func(conn *ldap.Conn) error {
|
||||
entry, found, err := c.userEntry(conn, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch n := len(resp.Entries); n {
|
||||
case 0:
|
||||
log.Printf("ldap: no results returned for filter: %q", filter)
|
||||
if !found {
|
||||
incorrectPass = true
|
||||
return nil
|
||||
case 1:
|
||||
default:
|
||||
return fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter)
|
||||
}
|
||||
|
||||
user = *resp.Entries[0]
|
||||
user = entry
|
||||
|
||||
// Try to authenticate as the distinguished name.
|
||||
if err := conn.Bind(user.DN, password); err != nil {
|
||||
@@ -323,44 +371,75 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi
|
||||
return connector.Identity{}, false, nil
|
||||
}
|
||||
|
||||
// Encode entry for follow up requests such as the groups query and
|
||||
// refresh attempts.
|
||||
if ident.ConnectorData, err = json.Marshal(user); err != nil {
|
||||
return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err)
|
||||
}
|
||||
|
||||
// If we're missing any attributes, such as email or ID, we want to report
|
||||
// an error rather than continuing.
|
||||
missing := []string{}
|
||||
|
||||
// Fill the identity struct using the attributes from the user entry.
|
||||
if ident.UserID = getAttr(user, c.UserSearch.IDAttr); ident.UserID == "" {
|
||||
missing = append(missing, c.UserSearch.IDAttr)
|
||||
}
|
||||
if ident.Email = getAttr(user, c.UserSearch.EmailAttr); ident.Email == "" {
|
||||
missing = append(missing, c.UserSearch.EmailAttr)
|
||||
}
|
||||
if c.UserSearch.NameAttr != "" {
|
||||
if ident.Username = getAttr(user, c.UserSearch.NameAttr); ident.Username == "" {
|
||||
missing = append(missing, c.UserSearch.NameAttr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) != 0 {
|
||||
err := fmt.Errorf("ldap: entry %q missing following required attribute(s): %q", user.DN, missing)
|
||||
if ident, err = c.identityFromEntry(user); err != nil {
|
||||
return connector.Identity{}, false, err
|
||||
}
|
||||
|
||||
if s.Groups {
|
||||
groups, err := c.groups(ctx, user)
|
||||
if err != nil {
|
||||
return connector.Identity{}, false, fmt.Errorf("ldap: failed to query groups: %v", err)
|
||||
}
|
||||
ident.Groups = groups
|
||||
}
|
||||
|
||||
if s.OfflineAccess {
|
||||
refresh := refreshData{
|
||||
Username: username,
|
||||
Entry: user,
|
||||
}
|
||||
// Encode entry for follow up requests such as the groups query and
|
||||
// refresh attempts.
|
||||
if ident.ConnectorData, err = json.Marshal(refresh); err != nil {
|
||||
return connector.Identity{}, false, fmt.Errorf("ldap: marshal entry: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return ident, true, nil
|
||||
}
|
||||
|
||||
func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
|
||||
// Decode the user entry from the identity.
|
||||
var user ldap.Entry
|
||||
if err := json.Unmarshal(ident.ConnectorData, &user); err != nil {
|
||||
return nil, fmt.Errorf("ldap: failed to unmarshal connector data: %v", err)
|
||||
func (c *ldapConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) {
|
||||
var data refreshData
|
||||
if err := json.Unmarshal(ident.ConnectorData, &data); err != nil {
|
||||
return ident, fmt.Errorf("ldap: failed to unamrshal internal data: %v", err)
|
||||
}
|
||||
|
||||
var user ldap.Entry
|
||||
err := c.do(ctx, func(conn *ldap.Conn) error {
|
||||
entry, found, err := c.userEntry(conn, data.Username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !found {
|
||||
return fmt.Errorf("ldap: user not found %q", data.Username)
|
||||
}
|
||||
user = entry
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return ident, err
|
||||
}
|
||||
if user.DN != data.Entry.DN {
|
||||
return ident, fmt.Errorf("ldap: refresh for username %q expected DN %q got %q", data.Username, data.Entry.DN, user.DN)
|
||||
}
|
||||
|
||||
newIdent, err := c.identityFromEntry(user)
|
||||
if err != nil {
|
||||
return ident, err
|
||||
}
|
||||
newIdent.ConnectorData = ident.ConnectorData
|
||||
|
||||
if s.Groups {
|
||||
groups, err := c.groups(ctx, user)
|
||||
if err != nil {
|
||||
return connector.Identity{}, fmt.Errorf("ldap: failed to query groups: %v", err)
|
||||
}
|
||||
newIdent.Groups = groups
|
||||
}
|
||||
return newIdent, nil
|
||||
}
|
||||
|
||||
func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string, error) {
|
||||
filter := fmt.Sprintf("(%s=%s)", c.GroupSearch.GroupAttr, ldap.EscapeFilter(getAttr(user, c.GroupSearch.UserAttr)))
|
||||
if c.GroupSearch.Filter != "" {
|
||||
filter = fmt.Sprintf("(&%s%s)", c.GroupSearch.Filter, filter)
|
||||
@@ -374,7 +453,7 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
|
||||
}
|
||||
|
||||
var groups []*ldap.Entry
|
||||
if err := c.do(func(conn *ldap.Conn) error {
|
||||
if err := c.do(ctx, func(conn *ldap.Conn) error {
|
||||
resp, err := conn.Search(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ldap: search failed: %v", err)
|
||||
@@ -406,7 +485,3 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) {
|
||||
}
|
||||
return groupNames, nil
|
||||
}
|
||||
|
||||
func (c *ldapConnector) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user