From 4bcb0aaae9d0b12be0cdf4a3f00acd50142ca76c Mon Sep 17 00:00:00 2001 From: Eric Stroczynski Date: Tue, 25 Jul 2017 14:26:47 -0700 Subject: [PATCH] server: log bcrypt cost if > 12, error on runtime > 10s The bcrypt hashing algorithm runtime grows exponentially with cost, and might cause a timeout if the cost is too high. Notifying the user of high cost and of long running calculations will help with tuning and debugging. --- server/api.go | 25 +++++++--- server/api_test.go | 63 ++++++++++++++++++++++++ server/server.go | 20 +++++++- server/server_test.go | 110 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 209 insertions(+), 9 deletions(-) diff --git a/server/api.go b/server/api.go index c6a9731f..5e8fad76 100644 --- a/server/api.go +++ b/server/api.go @@ -21,6 +21,9 @@ import ( // to determine if the server supports specific features. const apiVersion = 2 +// recCost is the recommended bcrypt cost, which balances hash strength and time +const recCost = 12 + // NewAPI returns a server which implements the gRPC API interface. func NewAPI(s storage.Storage, logger logrus.FieldLogger) api.DexServer { return dexAPI{ @@ -80,16 +83,16 @@ func (d dexAPI) DeleteClient(ctx context.Context, req *api.DeleteClientReq) (*ap return &api.DeleteClientResp{}, nil } -// checkCost returns an error if the hash provided does not meet minimum cost requirement -func checkCost(hash []byte) error { +// checkCost returns an error if the hash provided does not meet minimum cost requirement, and the actual bcrypt cost +func checkCost(hash []byte) (int, error) { actual, err := bcrypt.Cost(hash) if err != nil { - return fmt.Errorf("parsing bcrypt hash: %v", err) + return 0, fmt.Errorf("parsing bcrypt hash: %v", err) } if actual < bcrypt.DefaultCost { - return fmt.Errorf("given hash cost = %d, does not meet minimum cost requirement = %d", actual, bcrypt.DefaultCost) + return actual, fmt.Errorf("given hash cost = %d, does not meet minimum cost requirement = %d", actual, bcrypt.DefaultCost) } - return nil + return actual, nil } func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) (*api.CreatePasswordResp, error) { @@ -100,9 +103,13 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq) return nil, errors.New("no user ID supplied") } if req.Password.Hash != nil { - if err := checkCost(req.Password.Hash); err != nil { + cost, err := checkCost(req.Password.Hash) + if err != nil { return nil, err } + if cost > recCost { + d.logger.Warnln("bcrypt cost = %d, password encryption might timeout. Recommended bcrypt cost is 12", cost) + } } else { return nil, errors.New("no hash of password supplied") } @@ -133,9 +140,13 @@ func (d dexAPI) UpdatePassword(ctx context.Context, req *api.UpdatePasswordReq) } if req.NewHash != nil { - if err := checkCost(req.NewHash); err != nil { + cost, err := checkCost(req.NewHash) + if err != nil { return nil, err } + if cost > recCost { + d.logger.Warnln("bcrypt cost = %d, password encryption might timeout. Recommended bcrypt cost is 12", cost) + } } updater := func(old storage.Password) (storage.Password, error) { diff --git a/server/api_test.go b/server/api_test.go index 3c35403c..e9e2ea7d 100644 --- a/server/api_test.go +++ b/server/api_test.go @@ -119,6 +119,69 @@ func TestPassword(t *testing.T) { } +// Ensures checkCost returns expected values +func TestCheckCost(t *testing.T) { + logger := &logrus.Logger{ + Out: os.Stderr, + Formatter: &logrus.TextFormatter{DisableColors: true}, + Level: logrus.DebugLevel, + } + + s := memory.New(logger) + client := newAPI(s, logger, t) + defer client.Close() + + tests := []struct { + name string + inputHash []byte + expectedCost int + wantErr bool + }{ + { + name: "valid cost", + // bcrypt hash of the value "test1" with cost 12 + inputHash: []byte("$2a$12$M2Ot95Qty1MuQdubh1acWOiYadJDzeVg3ve4n5b.dgcgPdjCseKx2"), + expectedCost: recCost, + }, + { + name: "invalid hash", + inputHash: []byte(""), + wantErr: true, + }, + { + name: "cost below default", + // bcrypt hash of the value "test1" with cost 4 + inputHash: []byte("$2a$04$8bSTbuVCLpKzaqB3BmgI7edDigG5tIQKkjYUu/mEO9gQgIkw9m7eG"), + wantErr: true, + }, + { + name: "cost above recommendation", + // bcrypt hash of the value "test1" with cost 20 + inputHash: []byte("$2a$20$yODn5quqK9MZdePqYLs6Y.Jr4cOO1P0aXsKz0eTa2rxOmu8e7ETpi"), + expectedCost: 20, + }, + } + + for _, tc := range tests { + cost, err := checkCost(tc.inputHash) + if err != nil { + if !tc.wantErr { + t.Errorf("%s: %s", tc.name, err) + } + continue + } + + if tc.wantErr { + t.Errorf("%s: expected err", tc.name) + continue + } + + if cost != tc.expectedCost { + t.Errorf("%s: exepcted cost = %d but got cost = %d", tc.name, tc.expectedCost, cost) + } + } +} + // Attempts to list and revoke an exisiting refresh token. func TestRefreshToken(t *testing.T) { logger := &logrus.Logger{ diff --git a/server/server.go b/server/server.go index e78a33c1..41c28853 100644 --- a/server/server.go +++ b/server/server.go @@ -288,9 +288,25 @@ func (db passwordDB) Login(ctx context.Context, s connector.Scopes, email, passw } return connector.Identity{}, false, nil } - if err := bcrypt.CompareHashAndPassword(p.Hash, []byte(password)); err != nil { - return connector.Identity{}, false, nil + + // Return an error if password-hash comparison takes longer than 10 seconds + errCh := make(chan error, 1) + go func() { + errCh <- bcrypt.CompareHashAndPassword(p.Hash, []byte(password)) + }() + select { + case err = <-errCh: + if err != nil { + return connector.Identity{}, false, nil + } + case <-time.After(time.Second * 10): + var cost int + if cost, err = bcrypt.Cost(p.Hash); err == nil { + err = fmt.Errorf("password-hash comparison timeout: your bcrypt cost = %d, recommended cost = %d", cost, recCost) + } + return connector.Identity{}, false, err } + return connector.Identity{ UserID: p.UserID, Username: p.Username, diff --git a/server/server_test.go b/server/server_test.go index cb7bbc0a..a77c92bf 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -896,6 +896,116 @@ func TestPasswordDB(t *testing.T) { } +// A warning message should be logged if password-hash comparison takes longer than 10s +func TestLoginTimeout(t *testing.T) { + s := memory.New(logger) + conn := newPasswordDB(s) + + pw := "test1" + + tests := []struct { + name, email, password string + pwHash []byte + wantIdentity connector.Identity + wantInvalid bool + wantedErr string + }{ + { + name: "valid password min cost", + email: "jane@example.com", + password: pw, + // bcrypt hash of the value "test1" with cost 4 + pwHash: []byte("$2a$04$lGqOe5gnlpsfreQ1OJHxGOO7f5FyyESyICkswSFATM1cnBVgCyyuG"), + wantIdentity: connector.Identity{ + Email: "jane@example.com", + Username: "jane", + UserID: "foobar", + EmailVerified: true, + }, + }, + { + name: "valid password reccomended cost", + email: "jane@example.com", + password: pw, + // bcrypt hash of the value "test1" with cost 12 + pwHash: []byte("$2a$12$VZNNjuCUGX2NG5S1ci.3Ku9mI9DmA9XeXyrr7YzJuyTxuVBGdRAbm"), + wantIdentity: connector.Identity{ + Email: "jane@example.com", + Username: "jane", + UserID: "foobar", + EmailVerified: true, + }, + }, + { + name: "valid password timeout cost", + email: "jane@example.com", + password: pw, + // bcrypt hash of the value "test1" with cost 20 + pwHash: []byte("$2a$20$yODn5quqK9MZdePqYLs6Y.Jr4cOO1P0aXsKz0eTa2rxOmu8e7ETpi"), + wantedErr: fmt.Sprintf("password-hash comparison timeout: your bcrypt cost = 20, recommended cost = %d", recCost), + }, + { + name: "invalid password min cost", + email: "jane@example.com", + password: pw, + // bcrypt hash of the value "test2" with cost 4 + pwHash: []byte("$2a$04$pX8wwwpxw8xlXrToYaEgZemK0JIibMZYXPsgau7aPDoGyHPF73br."), + wantInvalid: true, + }, + { + name: "invalid password timeout cost", + email: "jane@example.com", + password: pw, + // bcrypt hash of the value "test2" with cost 20 + pwHash: []byte("$2a$20$WBD9cs63Zf0zqS99yyrQhODoDXphWw8MlYqVYRiftJH.lRJ1stnAa"), + wantedErr: fmt.Sprintf("password-hash comparison timeout: your bcrypt cost = 20, recommended cost = %d", recCost), + }, + } + + sEmail, sUsername, sUserID := "jane@example.com", "jane", "foobar" + for _, tc := range tests { + // Clean up before new test case + s.DeletePassword(sEmail) + + s.CreatePassword(storage.Password{ + Email: sEmail, + Username: sUsername, + UserID: sUserID, + Hash: tc.pwHash, + }) + + ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.email, tc.password) + if err != nil { + if err.Error() != tc.wantedErr { + t.Errorf("%s: error was incorrect:\n%v", tc.name, err) + } + continue + } + + if tc.wantedErr != "" { + t.Errorf("%s: expected error", tc.name) + continue + } + + if !valid { + if !tc.wantInvalid { + t.Errorf("%s: expected valid response", tc.name) + } + continue + } + + if tc.wantInvalid { + t.Errorf("%s: expected invalid response", tc.name) + continue + } + + if diff := pretty.Compare(tc.wantIdentity, ident); diff != "" { + t.Errorf("%s: %s", tc.email, diff) + } + } + +} + type storageWithKeysTrigger struct { storage.Storage f func()