server: add an option to enable emails and passwords from the database
This commit is contained in:
		| @@ -3,12 +3,15 @@ package server | |||||||
| import ( | import ( | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"log" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"path" | 	"path" | ||||||
| 	"sync/atomic" | 	"sync/atomic" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/crypto/bcrypt" | ||||||
|  |  | ||||||
| 	"github.com/gorilla/mux" | 	"github.com/gorilla/mux" | ||||||
|  |  | ||||||
| 	"github.com/coreos/dex/connector" | 	"github.com/coreos/dex/connector" | ||||||
| @@ -44,6 +47,8 @@ type Config struct { | |||||||
| 	// If specified, the server will use this function for determining time. | 	// If specified, the server will use this function for determining time. | ||||||
| 	Now func() time.Time | 	Now func() time.Time | ||||||
|  |  | ||||||
|  | 	EnablePasswordDB bool | ||||||
|  |  | ||||||
| 	TemplateConfig TemplateConfig | 	TemplateConfig TemplateConfig | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -91,6 +96,14 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("server: can't parse issuer URL") | 		return nil, fmt.Errorf("server: can't parse issuer URL") | ||||||
| 	} | 	} | ||||||
|  | 	if c.EnablePasswordDB { | ||||||
|  | 		c.Connectors = append(c.Connectors, Connector{ | ||||||
|  | 			ID:          "local", | ||||||
|  | 			DisplayName: "Email", | ||||||
|  | 			Connector:   newPasswordDB(c.Storage), | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if len(c.Connectors) == 0 { | 	if len(c.Connectors) == 0 { | ||||||
| 		return nil, errors.New("server: no connectors specified") | 		return nil, errors.New("server: no connectors specified") | ||||||
| 	} | 	} | ||||||
| @@ -182,6 +195,38 @@ func (s *Server) absURL(pathItems ...string) string { | |||||||
| 	return u.String() | 	return u.String() | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func newPasswordDB(s storage.Storage) interface { | ||||||
|  | 	connector.Connector | ||||||
|  | 	connector.PasswordConnector | ||||||
|  | } { | ||||||
|  | 	return passwordDB{s} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type passwordDB struct { | ||||||
|  | 	s storage.Storage | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (db passwordDB) Close() error { return nil } | ||||||
|  |  | ||||||
|  | func (db passwordDB) Login(email, password string) (connector.Identity, bool, error) { | ||||||
|  | 	p, err := db.s.GetPassword(email) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if err != storage.ErrNotFound { | ||||||
|  | 			log.Printf("get password: %v", err) | ||||||
|  | 		} | ||||||
|  | 		return connector.Identity{}, false, err | ||||||
|  | 	} | ||||||
|  | 	if err := bcrypt.CompareHashAndPassword(p.Hash, []byte(password)); err != nil { | ||||||
|  | 		return connector.Identity{}, false, nil | ||||||
|  | 	} | ||||||
|  | 	return connector.Identity{ | ||||||
|  | 		UserID:        p.UserID, | ||||||
|  | 		Username:      p.Username, | ||||||
|  | 		Email:         p.Email, | ||||||
|  | 		EmailVerified: true, | ||||||
|  | 	}, true, nil | ||||||
|  | } | ||||||
|  |  | ||||||
| // newKeyCacher returns a storage which caches keys so long as the next | // newKeyCacher returns a storage which caches keys so long as the next | ||||||
| func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage { | func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage { | ||||||
| 	if now == nil { | 	if now == nil { | ||||||
|   | |||||||
| @@ -16,9 +16,12 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/ericchiang/oidc" | 	"github.com/ericchiang/oidc" | ||||||
|  | 	"github.com/kylelemons/godebug/pretty" | ||||||
|  | 	"golang.org/x/crypto/bcrypt" | ||||||
| 	"golang.org/x/net/context" | 	"golang.org/x/net/context" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
|  |  | ||||||
|  | 	"github.com/coreos/dex/connector" | ||||||
| 	"github.com/coreos/dex/connector/mock" | 	"github.com/coreos/dex/connector/mock" | ||||||
| 	"github.com/coreos/dex/storage" | 	"github.com/coreos/dex/storage" | ||||||
| 	"github.com/coreos/dex/storage/memory" | 	"github.com/coreos/dex/storage/memory" | ||||||
| @@ -381,6 +384,91 @@ func TestOAuth2ImplicitFlow(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestPasswordDB(t *testing.T) { | ||||||
|  | 	s := memory.New() | ||||||
|  | 	conn := newPasswordDB(s) | ||||||
|  | 	defer conn.Close() | ||||||
|  |  | ||||||
|  | 	pw := "hi" | ||||||
|  |  | ||||||
|  | 	h, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.MinCost) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s.CreatePassword(storage.Password{ | ||||||
|  | 		Email:    "jane@example.com", | ||||||
|  | 		Username: "jane", | ||||||
|  | 		UserID:   "foobar", | ||||||
|  | 		Hash:     h, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name         string | ||||||
|  | 		username     string | ||||||
|  | 		password     string | ||||||
|  | 		wantIdentity connector.Identity | ||||||
|  | 		wantInvalid  bool | ||||||
|  | 		wantErr      bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:     "valid password", | ||||||
|  | 			username: "jane@example.com", | ||||||
|  | 			password: pw, | ||||||
|  | 			wantIdentity: connector.Identity{ | ||||||
|  | 				Email:         "jane@example.com", | ||||||
|  | 				Username:      "jane", | ||||||
|  | 				UserID:        "foobar", | ||||||
|  | 				EmailVerified: true, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "unknown user", | ||||||
|  | 			username: "john@example.com", | ||||||
|  | 			password: pw, | ||||||
|  | 			wantErr:  true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:        "invalid password", | ||||||
|  | 			username:    "jane@example.com", | ||||||
|  | 			password:    "not the correct password", | ||||||
|  | 			wantInvalid: true, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		ident, valid, err := conn.Login(tc.username, tc.password) | ||||||
|  | 		if err != nil { | ||||||
|  | 			if !tc.wantErr { | ||||||
|  | 				t.Errorf("%s: %v", tc.name, err) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if tc.wantErr { | ||||||
|  | 			t.Errorf("%s: expected error", tc.name) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !valid { | ||||||
|  | 			if !tc.wantInvalid { | ||||||
|  | 				t.Errorf("%s: expected valid password", tc.name) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if tc.wantInvalid { | ||||||
|  | 			t.Errorf("%s: expected invalid password", tc.name) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if diff := pretty.Compare(tc.wantIdentity, ident); diff != "" { | ||||||
|  | 			t.Errorf("%s: %s", tc.name, diff) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
| type storageWithKeysTrigger struct { | type storageWithKeysTrigger struct { | ||||||
| 	storage.Storage | 	storage.Storage | ||||||
| 	f func() | 	f func() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user