connector/ldap: use gopkg.in/ldap.v2's escape filter
Use the escape filter method provided by the upstream LDAP package instead of rolling our own.
This commit is contained in:
		| @@ -2,17 +2,13 @@ | |||||||
| package ldap | package ldap | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" |  | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"crypto/x509" | 	"crypto/x509" | ||||||
| 	"encoding/hex" |  | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"log" | 	"log" | ||||||
| 	"net" | 	"net" | ||||||
| 	"strings" |  | ||||||
| 	"unicode" |  | ||||||
|  |  | ||||||
| 	"gopkg.in/ldap.v2" | 	"gopkg.in/ldap.v2" | ||||||
|  |  | ||||||
| @@ -134,43 +130,6 @@ func parseScope(s string) (int, bool) { | |||||||
| 	return 0, false | 	return 0, false | ||||||
| } | } | ||||||
|  |  | ||||||
| // escapeRune maps a rune to a hex encoded value. For example 'é' would become '\\c3\\a9' |  | ||||||
| func escapeRune(buff *bytes.Buffer, r rune) { |  | ||||||
| 	// Really inefficient, but it seems correct. |  | ||||||
| 	for _, b := range []byte(string(r)) { |  | ||||||
| 		buff.WriteString("\\") |  | ||||||
| 		buff.WriteString(hex.EncodeToString([]byte{b})) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // NOTE(ericchiang): There are no good documents on how to escape an LDAP string. |  | ||||||
| // This implementation is inspired by an Oracle document, and is purposefully |  | ||||||
| // extremely restrictive. |  | ||||||
| // |  | ||||||
| // See: https://docs.oracle.com/cd/E19424-01/820-4811/gdxpo/index.html |  | ||||||
| func escapeFilter(s string) string { |  | ||||||
| 	r := strings.NewReader(s) |  | ||||||
| 	buff := new(bytes.Buffer) |  | ||||||
| 	for { |  | ||||||
| 		ru, _, err := r.ReadRune() |  | ||||||
| 		if err != nil { |  | ||||||
| 			// ignore decoding issues |  | ||||||
| 			return buff.String() |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		switch { |  | ||||||
| 		case ru > unicode.MaxASCII: // Not ASCII |  | ||||||
| 			escapeRune(buff, ru) |  | ||||||
| 		case !unicode.IsPrint(ru): // Not printable |  | ||||||
| 			escapeRune(buff, ru) |  | ||||||
| 		case strings.ContainsRune(`*\()`, ru): // Reserved characters |  | ||||||
| 			escapeRune(buff, ru) |  | ||||||
| 		default: |  | ||||||
| 			buff.WriteRune(ru) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Open returns an authentication strategy using LDAP. | // Open returns an authentication strategy using LDAP. | ||||||
| func (c *Config) Open() (connector.Connector, error) { | func (c *Config) Open() (connector.Connector, error) { | ||||||
| 	conn, err := c.OpenConnector() | 	conn, err := c.OpenConnector() | ||||||
| @@ -302,7 +261,7 @@ func (c *ldapConnector) Login(username, password string) (ident connector.Identi | |||||||
| 		user          ldap.Entry | 		user          ldap.Entry | ||||||
| 	) | 	) | ||||||
|  |  | ||||||
| 	filter := fmt.Sprintf("(%s=%s)", c.UserSearch.Username, escapeFilter(username)) | 	filter := fmt.Sprintf("(%s=%s)", c.UserSearch.Username, ldap.EscapeFilter(username)) | ||||||
| 	if c.UserSearch.Filter != "" { | 	if c.UserSearch.Filter != "" { | ||||||
| 		filter = fmt.Sprintf("(&%s%s)", c.UserSearch.Filter, filter) | 		filter = fmt.Sprintf("(&%s%s)", c.UserSearch.Filter, filter) | ||||||
| 	} | 	} | ||||||
| @@ -402,7 +361,7 @@ func (c *ldapConnector) Groups(ident connector.Identity) ([]string, error) { | |||||||
| 		return nil, fmt.Errorf("ldap: failed to unmarshal connector data: %v", err) | 		return nil, fmt.Errorf("ldap: failed to unmarshal connector data: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	filter := fmt.Sprintf("(%s=%s)", c.GroupSearch.GroupAttr, escapeFilter(getAttr(user, c.GroupSearch.UserAttr))) | 	filter := fmt.Sprintf("(%s=%s)", c.GroupSearch.GroupAttr, ldap.EscapeFilter(getAttr(user, c.GroupSearch.UserAttr))) | ||||||
| 	if c.GroupSearch.Filter != "" { | 	if c.GroupSearch.Filter != "" { | ||||||
| 		filter = fmt.Sprintf("(&%s%s)", c.GroupSearch.Filter, filter) | 		filter = fmt.Sprintf("(&%s%s)", c.GroupSearch.Filter, filter) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -1,23 +0,0 @@ | |||||||
| package ldap |  | ||||||
|  |  | ||||||
| import "testing" |  | ||||||
|  |  | ||||||
| func TestEscapeFilter(t *testing.T) { |  | ||||||
| 	tests := []struct { |  | ||||||
| 		val  string |  | ||||||
| 		want string |  | ||||||
| 	}{ |  | ||||||
| 		{"Five*Star", "Five\\2aStar"}, |  | ||||||
| 		{"c:\\File", "c:\\5cFile"}, |  | ||||||
| 		{"John (2nd)", "John \\282nd\\29"}, |  | ||||||
| 		{string([]byte{0, 0, 0, 4}), "\\00\\00\\00\\04"}, |  | ||||||
| 		{"Chloé", "Chlo\\c3\\a9"}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	for _, tc := range tests { |  | ||||||
| 		got := escapeFilter(tc.val) |  | ||||||
| 		if tc.want != got { |  | ||||||
| 			t.Errorf("value %q want=%q, got=%q", tc.val, tc.want, got) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
		Reference in New Issue
	
	Block a user