server: set sane bcrypt cost upper bound
This commit is contained in:
		@@ -21,8 +21,16 @@ import (
 | 
			
		||||
// to determine if the server supports specific features.
 | 
			
		||||
const apiVersion = 2
 | 
			
		||||
 | 
			
		||||
// recCost is the recommended bcrypt cost, which balances hash strength and time
 | 
			
		||||
const recCost = 12
 | 
			
		||||
const (
 | 
			
		||||
	// recCost is the recommended bcrypt cost, which balances hash strength and
 | 
			
		||||
	// efficiency.
 | 
			
		||||
	recCost = 12
 | 
			
		||||
 | 
			
		||||
	// upBoundCost is a sane upper bound on bcrypt cost determined by benchmarking:
 | 
			
		||||
	// high enough to ensure secure encryption, low enough to not put unnecessary
 | 
			
		||||
	// load on a dex server.
 | 
			
		||||
	upBoundCost = 16
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// NewAPI returns a server which implements the gRPC API interface.
 | 
			
		||||
func NewAPI(s storage.Storage, logger logrus.FieldLogger) api.DexServer {
 | 
			
		||||
@@ -83,16 +91,20 @@ func (d dexAPI) DeleteClient(ctx context.Context, req *api.DeleteClientReq) (*ap
 | 
			
		||||
	return &api.DeleteClientResp{}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// checkCost returns an error if the hash provided does not meet minimum cost requirement, and the actual bcrypt cost
 | 
			
		||||
func checkCost(hash []byte) (int, error) {
 | 
			
		||||
// checkCost returns an error if the hash provided does not meet lower or upper
 | 
			
		||||
// bound cost requirements.
 | 
			
		||||
func checkCost(hash []byte) error {
 | 
			
		||||
	actual, err := bcrypt.Cost(hash)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0, fmt.Errorf("parsing bcrypt hash: %v", err)
 | 
			
		||||
		return fmt.Errorf("parsing bcrypt hash: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if actual < bcrypt.DefaultCost {
 | 
			
		||||
		return actual, fmt.Errorf("given hash cost = %d, does not meet minimum cost requirement = %d", actual, bcrypt.DefaultCost)
 | 
			
		||||
		return fmt.Errorf("given hash cost = %d does not meet minimum cost requirement = %d", actual, bcrypt.DefaultCost)
 | 
			
		||||
	}
 | 
			
		||||
	return actual, nil
 | 
			
		||||
	if actual > upBoundCost {
 | 
			
		||||
		return fmt.Errorf("given hash cost = %d is above upper bound cost = %d, recommended cost = %d", actual, upBoundCost, recCost)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) (*api.CreatePasswordResp, error) {
 | 
			
		||||
@@ -103,13 +115,9 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq)
 | 
			
		||||
		return nil, errors.New("no user ID supplied")
 | 
			
		||||
	}
 | 
			
		||||
	if req.Password.Hash != nil {
 | 
			
		||||
		cost, err := checkCost(req.Password.Hash)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
		if err := checkCost(req.Password.Hash); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if cost > recCost {
 | 
			
		||||
			d.logger.Warnln("bcrypt cost = %d, password encryption might timeout. Recommended bcrypt cost is 12", cost)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		return nil, errors.New("no hash of password supplied")
 | 
			
		||||
	}
 | 
			
		||||
@@ -140,13 +148,9 @@ func (d dexAPI) UpdatePassword(ctx context.Context, req *api.UpdatePasswordReq)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if req.NewHash != nil {
 | 
			
		||||
		cost, err := checkCost(req.NewHash)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
		if err := checkCost(req.NewHash); err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		if cost > recCost {
 | 
			
		||||
			d.logger.Warnln("bcrypt cost = %d, password encryption might timeout. Recommended bcrypt cost is 12", cost)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	updater := func(old storage.Password) (storage.Password, error) {
 | 
			
		||||
 
 | 
			
		||||
@@ -134,14 +134,13 @@ func TestCheckCost(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name      string
 | 
			
		||||
		inputHash []byte
 | 
			
		||||
		expectedCost int
 | 
			
		||||
 | 
			
		||||
		wantErr bool
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name: "valid cost",
 | 
			
		||||
			// bcrypt hash of the value "test1" with cost 12
 | 
			
		||||
			// bcrypt hash of the value "test1" with cost 12 (default)
 | 
			
		||||
			inputHash: []byte("$2a$12$M2Ot95Qty1MuQdubh1acWOiYadJDzeVg3ve4n5b.dgcgPdjCseKx2"),
 | 
			
		||||
			expectedCost: recCost,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:      "invalid hash",
 | 
			
		||||
@@ -156,15 +155,14 @@ func TestCheckCost(t *testing.T) {
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name: "cost above recommendation",
 | 
			
		||||
			// bcrypt hash of the value "test1" with cost 20
 | 
			
		||||
			inputHash:    []byte("$2a$20$yODn5quqK9MZdePqYLs6Y.Jr4cOO1P0aXsKz0eTa2rxOmu8e7ETpi"),
 | 
			
		||||
			expectedCost: 20,
 | 
			
		||||
			// bcrypt hash of the value "test1" with cost 17
 | 
			
		||||
			inputHash: []byte("$2a$17$tWuZkTxtSmRyWZAGWVHQE.7npdl.TgP8adjzLJD.SyjpFznKBftPe"),
 | 
			
		||||
			wantErr:   true,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tc := range tests {
 | 
			
		||||
		cost, err := checkCost(tc.inputHash)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
		if err := checkCost(tc.inputHash); err != nil {
 | 
			
		||||
			if !tc.wantErr {
 | 
			
		||||
				t.Errorf("%s: %s", tc.name, err)
 | 
			
		||||
			}
 | 
			
		||||
@@ -175,10 +173,6 @@ func TestCheckCost(t *testing.T) {
 | 
			
		||||
			t.Errorf("%s: expected err", tc.name)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if cost != tc.expectedCost {
 | 
			
		||||
			t.Errorf("%s: exepcted cost = %d but got cost = %d", tc.name, tc.expectedCost, cost)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -288,25 +288,14 @@ func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, passw
 | 
			
		||||
		}
 | 
			
		||||
		return connector.Identity{}, false, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Return an error if password-hash comparison takes longer than 10 seconds
 | 
			
		||||
	errCh := make(chan error, 1)
 | 
			
		||||
	go func() {
 | 
			
		||||
		errCh <- bcrypt.CompareHashAndPassword(p.Hash, []byte(password))
 | 
			
		||||
	}()
 | 
			
		||||
	select {
 | 
			
		||||
	case err = <-errCh:
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return connector.Identity{}, false, nil
 | 
			
		||||
		}
 | 
			
		||||
	case <-time.After(time.Second * 10):
 | 
			
		||||
		var cost int
 | 
			
		||||
		if cost, err = bcrypt.Cost(p.Hash); err == nil {
 | 
			
		||||
			err = fmt.Errorf("password-hash comparison timeout: your bcrypt cost = %d, recommended cost = %d", cost, recCost)
 | 
			
		||||
		}
 | 
			
		||||
	// This check prevents dex users from logging in using static passwords
 | 
			
		||||
	// configured with hash costs that are too high or low.
 | 
			
		||||
	if err := checkCost(p.Hash); err != nil {
 | 
			
		||||
		return connector.Identity{}, false, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := bcrypt.CompareHashAndPassword(p.Hash, []byte(password)); err != nil {
 | 
			
		||||
		return connector.Identity{}, false, nil
 | 
			
		||||
	}
 | 
			
		||||
	return connector.Identity{
 | 
			
		||||
		UserID:        p.UserID,
 | 
			
		||||
		Username:      p.Username,
 | 
			
		||||
 
 | 
			
		||||
@@ -818,7 +818,7 @@ func TestPasswordDB(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	pw := "hi"
 | 
			
		||||
 | 
			
		||||
	h, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.MinCost)
 | 
			
		||||
	h, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -896,116 +896,6 @@ func TestPasswordDB(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A warning message should be logged if password-hash comparison takes longer than 10s
 | 
			
		||||
func TestLoginTimeout(t *testing.T) {
 | 
			
		||||
	s := memory.New(logger)
 | 
			
		||||
	conn := newPasswordDB(s)
 | 
			
		||||
 | 
			
		||||
	pw := "test1"
 | 
			
		||||
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name, email, password string
 | 
			
		||||
		pwHash                []byte
 | 
			
		||||
		wantIdentity          connector.Identity
 | 
			
		||||
		wantInvalid           bool
 | 
			
		||||
		wantedErr             string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:     "valid password min cost",
 | 
			
		||||
			email:    "jane@example.com",
 | 
			
		||||
			password: pw,
 | 
			
		||||
			// bcrypt hash of the value "test1" with cost 4
 | 
			
		||||
			pwHash: []byte("$2a$04$lGqOe5gnlpsfreQ1OJHxGOO7f5FyyESyICkswSFATM1cnBVgCyyuG"),
 | 
			
		||||
			wantIdentity: connector.Identity{
 | 
			
		||||
				Email:         "jane@example.com",
 | 
			
		||||
				Username:      "jane",
 | 
			
		||||
				UserID:        "foobar",
 | 
			
		||||
				EmailVerified: true,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "valid password reccomended cost",
 | 
			
		||||
			email:    "jane@example.com",
 | 
			
		||||
			password: pw,
 | 
			
		||||
			// bcrypt hash of the value "test1" with cost 12
 | 
			
		||||
			pwHash: []byte("$2a$12$VZNNjuCUGX2NG5S1ci.3Ku9mI9DmA9XeXyrr7YzJuyTxuVBGdRAbm"),
 | 
			
		||||
			wantIdentity: connector.Identity{
 | 
			
		||||
				Email:         "jane@example.com",
 | 
			
		||||
				Username:      "jane",
 | 
			
		||||
				UserID:        "foobar",
 | 
			
		||||
				EmailVerified: true,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "valid password timeout cost",
 | 
			
		||||
			email:    "jane@example.com",
 | 
			
		||||
			password: pw,
 | 
			
		||||
			// bcrypt hash of the value "test1" with cost 20
 | 
			
		||||
			pwHash:    []byte("$2a$20$yODn5quqK9MZdePqYLs6Y.Jr4cOO1P0aXsKz0eTa2rxOmu8e7ETpi"),
 | 
			
		||||
			wantedErr: fmt.Sprintf("password-hash comparison timeout: your bcrypt cost = 20, recommended cost = %d", recCost),
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "invalid password min cost",
 | 
			
		||||
			email:    "jane@example.com",
 | 
			
		||||
			password: pw,
 | 
			
		||||
			// bcrypt hash of the value "test2" with cost 4
 | 
			
		||||
			pwHash:      []byte("$2a$04$pX8wwwpxw8xlXrToYaEgZemK0JIibMZYXPsgau7aPDoGyHPF73br."),
 | 
			
		||||
			wantInvalid: true,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:     "invalid password timeout cost",
 | 
			
		||||
			email:    "jane@example.com",
 | 
			
		||||
			password: pw,
 | 
			
		||||
			// bcrypt hash of the value "test2" with cost 20
 | 
			
		||||
			pwHash:    []byte("$2a$20$WBD9cs63Zf0zqS99yyrQhODoDXphWw8MlYqVYRiftJH.lRJ1stnAa"),
 | 
			
		||||
			wantedErr: fmt.Sprintf("password-hash comparison timeout: your bcrypt cost = 20, recommended cost = %d", recCost),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sEmail, sUsername, sUserID := "jane@example.com", "jane", "foobar"
 | 
			
		||||
	for _, tc := range tests {
 | 
			
		||||
		// Clean up before new test case
 | 
			
		||||
		s.DeletePassword(sEmail)
 | 
			
		||||
 | 
			
		||||
		s.CreatePassword(storage.Password{
 | 
			
		||||
			Email:    sEmail,
 | 
			
		||||
			Username: sUsername,
 | 
			
		||||
			UserID:   sUserID,
 | 
			
		||||
			Hash:     tc.pwHash,
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.email, tc.password)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			if err.Error() != tc.wantedErr {
 | 
			
		||||
				t.Errorf("%s: error was incorrect:\n%v", tc.name, err)
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if tc.wantedErr != "" {
 | 
			
		||||
			t.Errorf("%s: expected error", tc.name)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !valid {
 | 
			
		||||
			if !tc.wantInvalid {
 | 
			
		||||
				t.Errorf("%s: expected valid response", tc.name)
 | 
			
		||||
			}
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if tc.wantInvalid {
 | 
			
		||||
			t.Errorf("%s: expected invalid response", tc.name)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if diff := pretty.Compare(tc.wantIdentity, ident); diff != "" {
 | 
			
		||||
			t.Errorf("%s: %s", tc.email, diff)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type storageWithKeysTrigger struct {
 | 
			
		||||
	storage.Storage
 | 
			
		||||
	f func()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user