server: set sane bcrypt cost upper bound
This commit is contained in:
		@@ -21,8 +21,16 @@ import (
 | 
				
			|||||||
// to determine if the server supports specific features.
 | 
					// to determine if the server supports specific features.
 | 
				
			||||||
const apiVersion = 2
 | 
					const apiVersion = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// recCost is the recommended bcrypt cost, which balances hash strength and time
 | 
					const (
 | 
				
			||||||
const recCost = 12
 | 
						// recCost is the recommended bcrypt cost, which balances hash strength and
 | 
				
			||||||
 | 
						// efficiency.
 | 
				
			||||||
 | 
						recCost = 12
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// upBoundCost is a sane upper bound on bcrypt cost determined by benchmarking:
 | 
				
			||||||
 | 
						// high enough to ensure secure encryption, low enough to not put unnecessary
 | 
				
			||||||
 | 
						// load on a dex server.
 | 
				
			||||||
 | 
						upBoundCost = 16
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewAPI returns a server which implements the gRPC API interface.
 | 
					// NewAPI returns a server which implements the gRPC API interface.
 | 
				
			||||||
func NewAPI(s storage.Storage, logger logrus.FieldLogger) api.DexServer {
 | 
					func NewAPI(s storage.Storage, logger logrus.FieldLogger) api.DexServer {
 | 
				
			||||||
@@ -83,16 +91,20 @@ func (d dexAPI) DeleteClient(ctx context.Context, req *api.DeleteClientReq) (*ap
 | 
				
			|||||||
	return &api.DeleteClientResp{}, nil
 | 
						return &api.DeleteClientResp{}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// checkCost returns an error if the hash provided does not meet minimum cost requirement, and the actual bcrypt cost
 | 
					// checkCost returns an error if the hash provided does not meet lower or upper
 | 
				
			||||||
func checkCost(hash []byte) (int, error) {
 | 
					// bound cost requirements.
 | 
				
			||||||
 | 
					func checkCost(hash []byte) error {
 | 
				
			||||||
	actual, err := bcrypt.Cost(hash)
 | 
						actual, err := bcrypt.Cost(hash)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return 0, fmt.Errorf("parsing bcrypt hash: %v", err)
 | 
							return fmt.Errorf("parsing bcrypt hash: %v", err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if actual < bcrypt.DefaultCost {
 | 
						if actual < bcrypt.DefaultCost {
 | 
				
			||||||
		return actual, fmt.Errorf("given hash cost = %d, does not meet minimum cost requirement = %d", actual, bcrypt.DefaultCost)
 | 
							return fmt.Errorf("given hash cost = %d does not meet minimum cost requirement = %d", actual, bcrypt.DefaultCost)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return actual, nil
 | 
						if actual > upBoundCost {
 | 
				
			||||||
 | 
							return fmt.Errorf("given hash cost = %d is above upper bound cost = %d, recommended cost = %d", actual, upBoundCost, recCost)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) (*api.CreatePasswordResp, error) {
 | 
					func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) (*api.CreatePasswordResp, error) {
 | 
				
			||||||
@@ -103,13 +115,9 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq)
 | 
				
			|||||||
		return nil, errors.New("no user ID supplied")
 | 
							return nil, errors.New("no user ID supplied")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if req.Password.Hash != nil {
 | 
						if req.Password.Hash != nil {
 | 
				
			||||||
		cost, err := checkCost(req.Password.Hash)
 | 
							if err := checkCost(req.Password.Hash); err != nil {
 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if cost > recCost {
 | 
					 | 
				
			||||||
			d.logger.Warnln("bcrypt cost = %d, password encryption might timeout. Recommended bcrypt cost is 12", cost)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		return nil, errors.New("no hash of password supplied")
 | 
							return nil, errors.New("no hash of password supplied")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -140,13 +148,9 @@ func (d dexAPI) UpdatePassword(ctx context.Context, req *api.UpdatePasswordReq)
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if req.NewHash != nil {
 | 
						if req.NewHash != nil {
 | 
				
			||||||
		cost, err := checkCost(req.NewHash)
 | 
							if err := checkCost(req.NewHash); err != nil {
 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if cost > recCost {
 | 
					 | 
				
			||||||
			d.logger.Warnln("bcrypt cost = %d, password encryption might timeout. Recommended bcrypt cost is 12", cost)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	updater := func(old storage.Password) (storage.Password, error) {
 | 
						updater := func(old storage.Password) (storage.Password, error) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -132,16 +132,15 @@ func TestCheckCost(t *testing.T) {
 | 
				
			|||||||
	defer client.Close()
 | 
						defer client.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	tests := []struct {
 | 
						tests := []struct {
 | 
				
			||||||
		name         string
 | 
							name      string
 | 
				
			||||||
		inputHash    []byte
 | 
							inputHash []byte
 | 
				
			||||||
		expectedCost int
 | 
					
 | 
				
			||||||
		wantErr      bool
 | 
							wantErr bool
 | 
				
			||||||
	}{
 | 
						}{
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name: "valid cost",
 | 
								name: "valid cost",
 | 
				
			||||||
			// bcrypt hash of the value "test1" with cost 12
 | 
								// bcrypt hash of the value "test1" with cost 12 (default)
 | 
				
			||||||
			inputHash:    []byte("$2a$12$M2Ot95Qty1MuQdubh1acWOiYadJDzeVg3ve4n5b.dgcgPdjCseKx2"),
 | 
								inputHash: []byte("$2a$12$M2Ot95Qty1MuQdubh1acWOiYadJDzeVg3ve4n5b.dgcgPdjCseKx2"),
 | 
				
			||||||
			expectedCost: recCost,
 | 
					 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name:      "invalid hash",
 | 
								name:      "invalid hash",
 | 
				
			||||||
@@ -156,15 +155,14 @@ func TestCheckCost(t *testing.T) {
 | 
				
			|||||||
		},
 | 
							},
 | 
				
			||||||
		{
 | 
							{
 | 
				
			||||||
			name: "cost above recommendation",
 | 
								name: "cost above recommendation",
 | 
				
			||||||
			// bcrypt hash of the value "test1" with cost 20
 | 
								// bcrypt hash of the value "test1" with cost 17
 | 
				
			||||||
			inputHash:    []byte("$2a$20$yODn5quqK9MZdePqYLs6Y.Jr4cOO1P0aXsKz0eTa2rxOmu8e7ETpi"),
 | 
								inputHash: []byte("$2a$17$tWuZkTxtSmRyWZAGWVHQE.7npdl.TgP8adjzLJD.SyjpFznKBftPe"),
 | 
				
			||||||
			expectedCost: 20,
 | 
								wantErr:   true,
 | 
				
			||||||
		},
 | 
							},
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for _, tc := range tests {
 | 
						for _, tc := range tests {
 | 
				
			||||||
		cost, err := checkCost(tc.inputHash)
 | 
							if err := checkCost(tc.inputHash); err != nil {
 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			if !tc.wantErr {
 | 
								if !tc.wantErr {
 | 
				
			||||||
				t.Errorf("%s: %s", tc.name, err)
 | 
									t.Errorf("%s: %s", tc.name, err)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
@@ -175,10 +173,6 @@ func TestCheckCost(t *testing.T) {
 | 
				
			|||||||
			t.Errorf("%s: expected err", tc.name)
 | 
								t.Errorf("%s: expected err", tc.name)
 | 
				
			||||||
			continue
 | 
								continue
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					 | 
				
			||||||
		if cost != tc.expectedCost {
 | 
					 | 
				
			||||||
			t.Errorf("%s: exepcted cost = %d but got cost = %d", tc.name, tc.expectedCost, cost)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -288,25 +288,14 @@ func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, passw
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		return connector.Identity{}, false, nil
 | 
							return connector.Identity{}, false, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						// This check prevents dex users from logging in using static passwords
 | 
				
			||||||
	// Return an error if password-hash comparison takes longer than 10 seconds
 | 
						// configured with hash costs that are too high or low.
 | 
				
			||||||
	errCh := make(chan error, 1)
 | 
						if err := checkCost(p.Hash); err != nil {
 | 
				
			||||||
	go func() {
 | 
					 | 
				
			||||||
		errCh <- bcrypt.CompareHashAndPassword(p.Hash, []byte(password))
 | 
					 | 
				
			||||||
	}()
 | 
					 | 
				
			||||||
	select {
 | 
					 | 
				
			||||||
	case err = <-errCh:
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			return connector.Identity{}, false, nil
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	case <-time.After(time.Second * 10):
 | 
					 | 
				
			||||||
		var cost int
 | 
					 | 
				
			||||||
		if cost, err = bcrypt.Cost(p.Hash); err == nil {
 | 
					 | 
				
			||||||
			err = fmt.Errorf("password-hash comparison timeout: your bcrypt cost = %d, recommended cost = %d", cost, recCost)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		return connector.Identity{}, false, err
 | 
							return connector.Identity{}, false, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						if err := bcrypt.CompareHashAndPassword(p.Hash, []byte(password)); err != nil {
 | 
				
			||||||
 | 
							return connector.Identity{}, false, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	return connector.Identity{
 | 
						return connector.Identity{
 | 
				
			||||||
		UserID:        p.UserID,
 | 
							UserID:        p.UserID,
 | 
				
			||||||
		Username:      p.Username,
 | 
							Username:      p.Username,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -818,7 +818,7 @@ func TestPasswordDB(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	pw := "hi"
 | 
						pw := "hi"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	h, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.MinCost)
 | 
						h, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -896,116 +896,6 @@ func TestPasswordDB(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// A warning message should be logged if password-hash comparison takes longer than 10s
 | 
					 | 
				
			||||||
func TestLoginTimeout(t *testing.T) {
 | 
					 | 
				
			||||||
	s := memory.New(logger)
 | 
					 | 
				
			||||||
	conn := newPasswordDB(s)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	pw := "test1"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	tests := []struct {
 | 
					 | 
				
			||||||
		name, email, password string
 | 
					 | 
				
			||||||
		pwHash                []byte
 | 
					 | 
				
			||||||
		wantIdentity          connector.Identity
 | 
					 | 
				
			||||||
		wantInvalid           bool
 | 
					 | 
				
			||||||
		wantedErr             string
 | 
					 | 
				
			||||||
	}{
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:     "valid password min cost",
 | 
					 | 
				
			||||||
			email:    "jane@example.com",
 | 
					 | 
				
			||||||
			password: pw,
 | 
					 | 
				
			||||||
			// bcrypt hash of the value "test1" with cost 4
 | 
					 | 
				
			||||||
			pwHash: []byte("$2a$04$lGqOe5gnlpsfreQ1OJHxGOO7f5FyyESyICkswSFATM1cnBVgCyyuG"),
 | 
					 | 
				
			||||||
			wantIdentity: connector.Identity{
 | 
					 | 
				
			||||||
				Email:         "jane@example.com",
 | 
					 | 
				
			||||||
				Username:      "jane",
 | 
					 | 
				
			||||||
				UserID:        "foobar",
 | 
					 | 
				
			||||||
				EmailVerified: true,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:     "valid password reccomended cost",
 | 
					 | 
				
			||||||
			email:    "jane@example.com",
 | 
					 | 
				
			||||||
			password: pw,
 | 
					 | 
				
			||||||
			// bcrypt hash of the value "test1" with cost 12
 | 
					 | 
				
			||||||
			pwHash: []byte("$2a$12$VZNNjuCUGX2NG5S1ci.3Ku9mI9DmA9XeXyrr7YzJuyTxuVBGdRAbm"),
 | 
					 | 
				
			||||||
			wantIdentity: connector.Identity{
 | 
					 | 
				
			||||||
				Email:         "jane@example.com",
 | 
					 | 
				
			||||||
				Username:      "jane",
 | 
					 | 
				
			||||||
				UserID:        "foobar",
 | 
					 | 
				
			||||||
				EmailVerified: true,
 | 
					 | 
				
			||||||
			},
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:     "valid password timeout cost",
 | 
					 | 
				
			||||||
			email:    "jane@example.com",
 | 
					 | 
				
			||||||
			password: pw,
 | 
					 | 
				
			||||||
			// bcrypt hash of the value "test1" with cost 20
 | 
					 | 
				
			||||||
			pwHash:    []byte("$2a$20$yODn5quqK9MZdePqYLs6Y.Jr4cOO1P0aXsKz0eTa2rxOmu8e7ETpi"),
 | 
					 | 
				
			||||||
			wantedErr: fmt.Sprintf("password-hash comparison timeout: your bcrypt cost = 20, recommended cost = %d", recCost),
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:     "invalid password min cost",
 | 
					 | 
				
			||||||
			email:    "jane@example.com",
 | 
					 | 
				
			||||||
			password: pw,
 | 
					 | 
				
			||||||
			// bcrypt hash of the value "test2" with cost 4
 | 
					 | 
				
			||||||
			pwHash:      []byte("$2a$04$pX8wwwpxw8xlXrToYaEgZemK0JIibMZYXPsgau7aPDoGyHPF73br."),
 | 
					 | 
				
			||||||
			wantInvalid: true,
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
		{
 | 
					 | 
				
			||||||
			name:     "invalid password timeout cost",
 | 
					 | 
				
			||||||
			email:    "jane@example.com",
 | 
					 | 
				
			||||||
			password: pw,
 | 
					 | 
				
			||||||
			// bcrypt hash of the value "test2" with cost 20
 | 
					 | 
				
			||||||
			pwHash:    []byte("$2a$20$WBD9cs63Zf0zqS99yyrQhODoDXphWw8MlYqVYRiftJH.lRJ1stnAa"),
 | 
					 | 
				
			||||||
			wantedErr: fmt.Sprintf("password-hash comparison timeout: your bcrypt cost = 20, recommended cost = %d", recCost),
 | 
					 | 
				
			||||||
		},
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sEmail, sUsername, sUserID := "jane@example.com", "jane", "foobar"
 | 
					 | 
				
			||||||
	for _, tc := range tests {
 | 
					 | 
				
			||||||
		// Clean up before new test case
 | 
					 | 
				
			||||||
		s.DeletePassword(sEmail)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		s.CreatePassword(storage.Password{
 | 
					 | 
				
			||||||
			Email:    sEmail,
 | 
					 | 
				
			||||||
			Username: sUsername,
 | 
					 | 
				
			||||||
			UserID:   sUserID,
 | 
					 | 
				
			||||||
			Hash:     tc.pwHash,
 | 
					 | 
				
			||||||
		})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.email, tc.password)
 | 
					 | 
				
			||||||
		if err != nil {
 | 
					 | 
				
			||||||
			if err.Error() != tc.wantedErr {
 | 
					 | 
				
			||||||
				t.Errorf("%s: error was incorrect:\n%v", tc.name, err)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if tc.wantedErr != "" {
 | 
					 | 
				
			||||||
			t.Errorf("%s: expected error", tc.name)
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if !valid {
 | 
					 | 
				
			||||||
			if !tc.wantInvalid {
 | 
					 | 
				
			||||||
				t.Errorf("%s: expected valid response", tc.name)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if tc.wantInvalid {
 | 
					 | 
				
			||||||
			t.Errorf("%s: expected invalid response", tc.name)
 | 
					 | 
				
			||||||
			continue
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		if diff := pretty.Compare(tc.wantIdentity, ident); diff != "" {
 | 
					 | 
				
			||||||
			t.Errorf("%s: %s", tc.email, diff)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type storageWithKeysTrigger struct {
 | 
					type storageWithKeysTrigger struct {
 | 
				
			||||||
	storage.Storage
 | 
						storage.Storage
 | 
				
			||||||
	f func()
 | 
						f func()
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user