server: log bcrypt cost if > 12, error on runtime > 10s
The bcrypt hashing algorithm runtime grows exponentially with cost, and might cause a timeout if the cost is too high. Notifying the user of high cost and of long running calculations will help with tuning and debugging.
This commit is contained in:
		| @@ -21,6 +21,9 @@ import ( | |||||||
| // to determine if the server supports specific features. | // to determine if the server supports specific features. | ||||||
| const apiVersion = 2 | const apiVersion = 2 | ||||||
|  |  | ||||||
|  | // recCost is the recommended bcrypt cost, which balances hash strength and time | ||||||
|  | const recCost = 12 | ||||||
|  |  | ||||||
| // NewAPI returns a server which implements the gRPC API interface. | // NewAPI returns a server which implements the gRPC API interface. | ||||||
| func NewAPI(s storage.Storage, logger logrus.FieldLogger) api.DexServer { | func NewAPI(s storage.Storage, logger logrus.FieldLogger) api.DexServer { | ||||||
| 	return dexAPI{ | 	return dexAPI{ | ||||||
| @@ -80,16 +83,16 @@ func (d dexAPI) DeleteClient(ctx context.Context, req *api.DeleteClientReq) (*ap | |||||||
| 	return &api.DeleteClientResp{}, nil | 	return &api.DeleteClientResp{}, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // checkCost returns an error if the hash provided does not meet minimum cost requirement | // checkCost returns an error if the hash provided does not meet minimum cost requirement, and the actual bcrypt cost | ||||||
| func checkCost(hash []byte) error { | func checkCost(hash []byte) (int, error) { | ||||||
| 	actual, err := bcrypt.Cost(hash) | 	actual, err := bcrypt.Cost(hash) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("parsing bcrypt hash: %v", err) | 		return 0, fmt.Errorf("parsing bcrypt hash: %v", err) | ||||||
| 	} | 	} | ||||||
| 	if actual < bcrypt.DefaultCost { | 	if actual < bcrypt.DefaultCost { | ||||||
| 		return fmt.Errorf("given hash cost = %d, does not meet minimum cost requirement = %d", actual, bcrypt.DefaultCost) | 		return actual, fmt.Errorf("given hash cost = %d, does not meet minimum cost requirement = %d", actual, bcrypt.DefaultCost) | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return actual, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) (*api.CreatePasswordResp, error) { | func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) (*api.CreatePasswordResp, error) { | ||||||
| @@ -100,9 +103,13 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) | |||||||
| 		return nil, errors.New("no user ID supplied") | 		return nil, errors.New("no user ID supplied") | ||||||
| 	} | 	} | ||||||
| 	if req.Password.Hash != nil { | 	if req.Password.Hash != nil { | ||||||
| 		if err := checkCost(req.Password.Hash); err != nil { | 		cost, err := checkCost(req.Password.Hash) | ||||||
|  | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  | 		if cost > recCost { | ||||||
|  | 			d.logger.Warnln("bcrypt cost = %d, password encryption might timeout. Recommended bcrypt cost is 12", cost) | ||||||
|  | 		} | ||||||
| 	} else { | 	} else { | ||||||
| 		return nil, errors.New("no hash of password supplied") | 		return nil, errors.New("no hash of password supplied") | ||||||
| 	} | 	} | ||||||
| @@ -133,9 +140,13 @@ func (d dexAPI) UpdatePassword(ctx context.Context, req *api.UpdatePasswordReq) | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if req.NewHash != nil { | 	if req.NewHash != nil { | ||||||
| 		if err := checkCost(req.NewHash); err != nil { | 		cost, err := checkCost(req.NewHash) | ||||||
|  | 		if err != nil { | ||||||
| 			return nil, err | 			return nil, err | ||||||
| 		} | 		} | ||||||
|  | 		if cost > recCost { | ||||||
|  | 			d.logger.Warnln("bcrypt cost = %d, password encryption might timeout. Recommended bcrypt cost is 12", cost) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	updater := func(old storage.Password) (storage.Password, error) { | 	updater := func(old storage.Password) (storage.Password, error) { | ||||||
|   | |||||||
| @@ -119,6 +119,69 @@ func TestPassword(t *testing.T) { | |||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Ensures checkCost returns expected values | ||||||
|  | func TestCheckCost(t *testing.T) { | ||||||
|  | 	logger := &logrus.Logger{ | ||||||
|  | 		Out:       os.Stderr, | ||||||
|  | 		Formatter: &logrus.TextFormatter{DisableColors: true}, | ||||||
|  | 		Level:     logrus.DebugLevel, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	s := memory.New(logger) | ||||||
|  | 	client := newAPI(s, logger, t) | ||||||
|  | 	defer client.Close() | ||||||
|  |  | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name         string | ||||||
|  | 		inputHash    []byte | ||||||
|  | 		expectedCost int | ||||||
|  | 		wantErr      bool | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "valid cost", | ||||||
|  | 			// bcrypt hash of the value "test1" with cost 12 | ||||||
|  | 			inputHash:    []byte("$2a$12$M2Ot95Qty1MuQdubh1acWOiYadJDzeVg3ve4n5b.dgcgPdjCseKx2"), | ||||||
|  | 			expectedCost: recCost, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:      "invalid hash", | ||||||
|  | 			inputHash: []byte(""), | ||||||
|  | 			wantErr:   true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "cost below default", | ||||||
|  | 			// bcrypt hash of the value "test1" with cost 4 | ||||||
|  | 			inputHash: []byte("$2a$04$8bSTbuVCLpKzaqB3BmgI7edDigG5tIQKkjYUu/mEO9gQgIkw9m7eG"), | ||||||
|  | 			wantErr:   true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "cost above recommendation", | ||||||
|  | 			// bcrypt hash of the value "test1" with cost 20 | ||||||
|  | 			inputHash:    []byte("$2a$20$yODn5quqK9MZdePqYLs6Y.Jr4cOO1P0aXsKz0eTa2rxOmu8e7ETpi"), | ||||||
|  | 			expectedCost: 20, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		cost, err := checkCost(tc.inputHash) | ||||||
|  | 		if err != nil { | ||||||
|  | 			if !tc.wantErr { | ||||||
|  | 				t.Errorf("%s: %s", tc.name, err) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if tc.wantErr { | ||||||
|  | 			t.Errorf("%s: expected err", tc.name) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if cost != tc.expectedCost { | ||||||
|  | 			t.Errorf("%s: exepcted cost = %d but got cost = %d", tc.name, tc.expectedCost, cost) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| // Attempts to list and revoke an exisiting refresh token. | // Attempts to list and revoke an exisiting refresh token. | ||||||
| func TestRefreshToken(t *testing.T) { | func TestRefreshToken(t *testing.T) { | ||||||
| 	logger := &logrus.Logger{ | 	logger := &logrus.Logger{ | ||||||
|   | |||||||
| @@ -288,9 +288,25 @@ func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, passw | |||||||
| 		} | 		} | ||||||
| 		return connector.Identity{}, false, nil | 		return connector.Identity{}, false, nil | ||||||
| 	} | 	} | ||||||
| 	if err := bcrypt.CompareHashAndPassword(p.Hash, []byte(password)); err != nil { |  | ||||||
| 		return connector.Identity{}, false, nil | 	// Return an error if password-hash comparison takes longer than 10 seconds | ||||||
|  | 	errCh := make(chan error, 1) | ||||||
|  | 	go func() { | ||||||
|  | 		errCh <- bcrypt.CompareHashAndPassword(p.Hash, []byte(password)) | ||||||
|  | 	}() | ||||||
|  | 	select { | ||||||
|  | 	case err = <-errCh: | ||||||
|  | 		if err != nil { | ||||||
|  | 			return connector.Identity{}, false, nil | ||||||
|  | 		} | ||||||
|  | 	case <-time.After(time.Second * 10): | ||||||
|  | 		var cost int | ||||||
|  | 		if cost, err = bcrypt.Cost(p.Hash); err == nil { | ||||||
|  | 			err = fmt.Errorf("password-hash comparison timeout: your bcrypt cost = %d, recommended cost = %d", cost, recCost) | ||||||
|  | 		} | ||||||
|  | 		return connector.Identity{}, false, err | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return connector.Identity{ | 	return connector.Identity{ | ||||||
| 		UserID:        p.UserID, | 		UserID:        p.UserID, | ||||||
| 		Username:      p.Username, | 		Username:      p.Username, | ||||||
|   | |||||||
| @@ -896,6 +896,116 @@ func TestPasswordDB(t *testing.T) { | |||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // A warning message should be logged if password-hash comparison takes longer than 10s | ||||||
|  | func TestLoginTimeout(t *testing.T) { | ||||||
|  | 	s := memory.New(logger) | ||||||
|  | 	conn := newPasswordDB(s) | ||||||
|  |  | ||||||
|  | 	pw := "test1" | ||||||
|  |  | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name, email, password string | ||||||
|  | 		pwHash                []byte | ||||||
|  | 		wantIdentity          connector.Identity | ||||||
|  | 		wantInvalid           bool | ||||||
|  | 		wantedErr             string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name:     "valid password min cost", | ||||||
|  | 			email:    "jane@example.com", | ||||||
|  | 			password: pw, | ||||||
|  | 			// bcrypt hash of the value "test1" with cost 4 | ||||||
|  | 			pwHash: []byte("$2a$04$lGqOe5gnlpsfreQ1OJHxGOO7f5FyyESyICkswSFATM1cnBVgCyyuG"), | ||||||
|  | 			wantIdentity: connector.Identity{ | ||||||
|  | 				Email:         "jane@example.com", | ||||||
|  | 				Username:      "jane", | ||||||
|  | 				UserID:        "foobar", | ||||||
|  | 				EmailVerified: true, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "valid password reccomended cost", | ||||||
|  | 			email:    "jane@example.com", | ||||||
|  | 			password: pw, | ||||||
|  | 			// bcrypt hash of the value "test1" with cost 12 | ||||||
|  | 			pwHash: []byte("$2a$12$VZNNjuCUGX2NG5S1ci.3Ku9mI9DmA9XeXyrr7YzJuyTxuVBGdRAbm"), | ||||||
|  | 			wantIdentity: connector.Identity{ | ||||||
|  | 				Email:         "jane@example.com", | ||||||
|  | 				Username:      "jane", | ||||||
|  | 				UserID:        "foobar", | ||||||
|  | 				EmailVerified: true, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "valid password timeout cost", | ||||||
|  | 			email:    "jane@example.com", | ||||||
|  | 			password: pw, | ||||||
|  | 			// bcrypt hash of the value "test1" with cost 20 | ||||||
|  | 			pwHash:    []byte("$2a$20$yODn5quqK9MZdePqYLs6Y.Jr4cOO1P0aXsKz0eTa2rxOmu8e7ETpi"), | ||||||
|  | 			wantedErr: fmt.Sprintf("password-hash comparison timeout: your bcrypt cost = 20, recommended cost = %d", recCost), | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "invalid password min cost", | ||||||
|  | 			email:    "jane@example.com", | ||||||
|  | 			password: pw, | ||||||
|  | 			// bcrypt hash of the value "test2" with cost 4 | ||||||
|  | 			pwHash:      []byte("$2a$04$pX8wwwpxw8xlXrToYaEgZemK0JIibMZYXPsgau7aPDoGyHPF73br."), | ||||||
|  | 			wantInvalid: true, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name:     "invalid password timeout cost", | ||||||
|  | 			email:    "jane@example.com", | ||||||
|  | 			password: pw, | ||||||
|  | 			// bcrypt hash of the value "test2" with cost 20 | ||||||
|  | 			pwHash:    []byte("$2a$20$WBD9cs63Zf0zqS99yyrQhODoDXphWw8MlYqVYRiftJH.lRJ1stnAa"), | ||||||
|  | 			wantedErr: fmt.Sprintf("password-hash comparison timeout: your bcrypt cost = 20, recommended cost = %d", recCost), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sEmail, sUsername, sUserID := "jane@example.com", "jane", "foobar" | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		// Clean up before new test case | ||||||
|  | 		s.DeletePassword(sEmail) | ||||||
|  |  | ||||||
|  | 		s.CreatePassword(storage.Password{ | ||||||
|  | 			Email:    sEmail, | ||||||
|  | 			Username: sUsername, | ||||||
|  | 			UserID:   sUserID, | ||||||
|  | 			Hash:     tc.pwHash, | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.email, tc.password) | ||||||
|  | 		if err != nil { | ||||||
|  | 			if err.Error() != tc.wantedErr { | ||||||
|  | 				t.Errorf("%s: error was incorrect:\n%v", tc.name, err) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if tc.wantedErr != "" { | ||||||
|  | 			t.Errorf("%s: expected error", tc.name) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !valid { | ||||||
|  | 			if !tc.wantInvalid { | ||||||
|  | 				t.Errorf("%s: expected valid response", tc.name) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if tc.wantInvalid { | ||||||
|  | 			t.Errorf("%s: expected invalid response", tc.name) | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if diff := pretty.Compare(tc.wantIdentity, ident); diff != "" { | ||||||
|  | 			t.Errorf("%s: %s", tc.email, diff) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
| type storageWithKeysTrigger struct { | type storageWithKeysTrigger struct { | ||||||
| 	storage.Storage | 	storage.Storage | ||||||
| 	f func() | 	f func() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user