storage/static.go: storage backend should not explicitly lower-case email ids.
This commit is contained in:
		| @@ -144,7 +144,7 @@ func serve(cmd *cobra.Command, args []string) error { | |||||||
| 		for i, p := range c.StaticPasswords { | 		for i, p := range c.StaticPasswords { | ||||||
| 			passwords[i] = storage.Password(p) | 			passwords[i] = storage.Password(p) | ||||||
| 		} | 		} | ||||||
| 		s = storage.WithStaticPasswords(s, passwords) | 		s = storage.WithStaticPasswords(s, passwords, logger) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	storageConnectors := make([]storage.Connector, len(c.StaticConnectors)) | 	storageConnectors := make([]storage.Connector, len(c.StaticConnectors)) | ||||||
|   | |||||||
| @@ -128,12 +128,12 @@ func (s *memStorage) CreateAuthRequest(a storage.AuthRequest) (err error) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (s *memStorage) CreatePassword(p storage.Password) (err error) { | func (s *memStorage) CreatePassword(p storage.Password) (err error) { | ||||||
| 	p.Email = strings.ToLower(p.Email) | 	lowerEmail := strings.ToLower(p.Email) | ||||||
| 	s.tx(func() { | 	s.tx(func() { | ||||||
| 		if _, ok := s.passwords[p.Email]; ok { | 		if _, ok := s.passwords[lowerEmail]; ok { | ||||||
| 			err = storage.ErrAlreadyExists | 			err = storage.ErrAlreadyExists | ||||||
| 		} else { | 		} else { | ||||||
| 			s.passwords[p.Email] = p | 			s.passwords[lowerEmail] = p | ||||||
| 		} | 		} | ||||||
| 	}) | 	}) | ||||||
| 	return | 	return | ||||||
|   | |||||||
| @@ -108,9 +108,10 @@ func TestStaticPasswords(t *testing.T) { | |||||||
| 	p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"} | 	p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"} | ||||||
| 	p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"} | 	p2 := storage.Password{Email: "bar@example.com", Username: "bar_secret"} | ||||||
| 	p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"} | 	p3 := storage.Password{Email: "spam@example.com", Username: "spam_secret"} | ||||||
|  | 	p4 := storage.Password{Email: "Spam@example.com", Username: "Spam_secret"} | ||||||
|  |  | ||||||
| 	backing.CreatePassword(p1) | 	backing.CreatePassword(p1) | ||||||
| 	s := storage.WithStaticPasswords(backing, []storage.Password{p2}) | 	s := storage.WithStaticPasswords(backing, []storage.Password{p2}, logger) | ||||||
|  |  | ||||||
| 	tests := []struct { | 	tests := []struct { | ||||||
| 		name    string | 		name    string | ||||||
| @@ -159,6 +160,29 @@ func TestStaticPasswords(t *testing.T) { | |||||||
| 				return s.UpdatePassword(p1.Email, updater) | 				return s.UpdatePassword(p1.Email, updater) | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "create passwords", | ||||||
|  | 			action: func() error { | ||||||
|  | 				if err := s.CreatePassword(p4); err != nil { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 				return s.CreatePassword(p3) | ||||||
|  | 			}, | ||||||
|  | 			wantErr: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "get password", | ||||||
|  | 			action: func() error { | ||||||
|  | 				p, err := s.GetPassword(p4.Email) | ||||||
|  | 				if err != nil { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 				if strings.Compare(p.Email, p4.Email) != 0 { | ||||||
|  | 					return fmt.Errorf("expected %s passwords got %s", p4.Email, p.Email) | ||||||
|  | 				} | ||||||
|  | 				return nil | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			name: "list passwords", | 			name: "list passwords", | ||||||
| 			action: func() error { | 			action: func() error { | ||||||
| @@ -166,18 +190,12 @@ func TestStaticPasswords(t *testing.T) { | |||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					return err | 					return err | ||||||
| 				} | 				} | ||||||
| 				if n := len(passwords); n != 2 { | 				if n := len(passwords); n != 3 { | ||||||
| 					return fmt.Errorf("expected 2 passwords got %d", n) | 					return fmt.Errorf("expected 3 passwords got %d", n) | ||||||
| 				} | 				} | ||||||
| 				return nil | 				return nil | ||||||
| 			}, | 			}, | ||||||
| 		}, | 		}, | ||||||
| 		{ |  | ||||||
| 			name: "create password", |  | ||||||
| 			action: func() error { |  | ||||||
| 				return s.CreatePassword(p3) |  | ||||||
| 			}, |  | ||||||
| 		}, |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	for _, tc := range tests { | 	for _, tc := range tests { | ||||||
|   | |||||||
| @@ -3,6 +3,8 @@ package storage | |||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
|  | 	"github.com/sirupsen/logrus" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| // Tests for this code are in the "memory" package, since this package doesn't | // Tests for this code are in the "memory" package, since this package doesn't | ||||||
| @@ -25,6 +27,7 @@ func WithStaticClients(s Storage, staticClients []Client) Storage { | |||||||
| 	for _, client := range staticClients { | 	for _, client := range staticClients { | ||||||
| 		clientsByID[client.ID] = client | 		clientsByID[client.ID] = client | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return staticClientsStorage{s, staticClients, clientsByID} | 	return staticClientsStorage{s, staticClients, clientsByID} | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -83,18 +86,25 @@ type staticPasswordsStorage struct { | |||||||
|  |  | ||||||
| 	// A read-only set of passwords. | 	// A read-only set of passwords. | ||||||
| 	passwords []Password | 	passwords []Password | ||||||
|  | 	// A map of passwords that is indexed by lower-case email ids | ||||||
| 	passwordsByEmail map[string]Password | 	passwordsByEmail map[string]Password | ||||||
|  |  | ||||||
|  | 	logger logrus.FieldLogger | ||||||
| } | } | ||||||
|  |  | ||||||
| // WithStaticPasswords returns a storage with a read-only set of passwords. Write actions, | // WithStaticPasswords returns a storage with a read-only set of passwords. | ||||||
| // such as creating other passwords, will fail. | func WithStaticPasswords(s Storage, staticPasswords []Password, logger logrus.FieldLogger) Storage { | ||||||
| func WithStaticPasswords(s Storage, staticPasswords []Password) Storage { |  | ||||||
| 	passwordsByEmail := make(map[string]Password, len(staticPasswords)) | 	passwordsByEmail := make(map[string]Password, len(staticPasswords)) | ||||||
| 	for _, p := range staticPasswords { | 	for _, p := range staticPasswords { | ||||||
| 		p.Email = strings.ToLower(p.Email) | 		//Enable case insensitive email comparison. | ||||||
| 		passwordsByEmail[p.Email] = p | 		lowerEmail := strings.ToLower(p.Email) | ||||||
|  | 		if _, ok := passwordsByEmail[lowerEmail]; ok { | ||||||
|  | 			logger.Errorf("Attempting to create StaticPasswords with the same email id: %s", p.Email) | ||||||
| 		} | 		} | ||||||
| 	return staticPasswordsStorage{s, staticPasswords, passwordsByEmail} | 		passwordsByEmail[lowerEmail] = p | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return staticPasswordsStorage{s, staticPasswords, passwordsByEmail, logger} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s staticPasswordsStorage) isStatic(email string) bool { | func (s staticPasswordsStorage) isStatic(email string) bool { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user