Merge pull request #749 from ericchiang/postgres-timezones
storage: fix postgres timezone handling
This commit is contained in:
		| @@ -48,6 +48,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) { | ||||
| 		{"PasswordCRUD", testPasswordCRUD}, | ||||
| 		{"KeysCRUD", testKeysCRUD}, | ||||
| 		{"GarbageCollection", testGC}, | ||||
| 		{"TimezoneSupport", testTimezones}, | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| @@ -370,14 +371,23 @@ func testKeysCRUD(t *testing.T, s storage.Storage) { | ||||
| } | ||||
|  | ||||
| func testGC(t *testing.T, s storage.Storage) { | ||||
| 	n := time.Now().UTC() | ||||
| 	est, err := time.LoadLocation("America/New_York") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	pst, err := time.LoadLocation("America/Los_Angeles") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	expiry := time.Now().In(est) | ||||
| 	c := storage.AuthCode{ | ||||
| 		ID:            storage.NewID(), | ||||
| 		ClientID:      "foobar", | ||||
| 		RedirectURI:   "https://localhost:80/callback", | ||||
| 		Nonce:         "foobar", | ||||
| 		Scopes:        []string{"openid", "email"}, | ||||
| 		Expiry:        n.Add(time.Second), | ||||
| 		Expiry:        expiry, | ||||
| 		ConnectorID:   "ldap", | ||||
| 		ConnectorData: []byte(`{"some":"data"}`), | ||||
| 		Claims: storage.Claims{ | ||||
| @@ -393,14 +403,21 @@ func testGC(t *testing.T, s storage.Storage) { | ||||
| 		t.Fatalf("failed creating auth code: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if _, err := s.GarbageCollect(n); err != nil { | ||||
| 	for _, tz := range []*time.Location{time.UTC, est, pst} { | ||||
| 		result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("garbage collection failed: %v", err) | ||||
| 		} else { | ||||
| 			if result.AuthCodes != 0 || result.AuthRequests != 0 { | ||||
| 				t.Errorf("expected no garbage collection results, got %#v", result) | ||||
| 			} | ||||
| 		} | ||||
| 		if _, err := s.GetAuthCode(c.ID); err != nil { | ||||
| 			t.Errorf("expected to be able to get auth code after GC: %v", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { | ||||
| 	if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { | ||||
| 		t.Errorf("garbage collection failed: %v", err) | ||||
| 	} else if r.AuthCodes != 1 { | ||||
| 		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes) | ||||
| @@ -422,7 +439,7 @@ func testGC(t *testing.T, s storage.Storage) { | ||||
| 		State:               "bar", | ||||
| 		ForceApprovalPrompt: true, | ||||
| 		LoggedIn:            true, | ||||
| 		Expiry:              n, | ||||
| 		Expiry:              expiry, | ||||
| 		ConnectorID:         "ldap", | ||||
| 		ConnectorData:       []byte(`{"some":"data"}`), | ||||
| 		Claims: storage.Claims{ | ||||
| @@ -438,14 +455,21 @@ func testGC(t *testing.T, s storage.Storage) { | ||||
| 		t.Fatalf("failed creating auth request: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if _, err := s.GarbageCollect(n); err != nil { | ||||
| 	for _, tz := range []*time.Location{time.UTC, est, pst} { | ||||
| 		result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) | ||||
| 		if err != nil { | ||||
| 			t.Errorf("garbage collection failed: %v", err) | ||||
| 		} else { | ||||
| 			if result.AuthCodes != 0 || result.AuthRequests != 0 { | ||||
| 				t.Errorf("expected no garbage collection results, got %#v", result) | ||||
| 			} | ||||
| 		} | ||||
| 		if _, err := s.GetAuthRequest(a.ID); err != nil { | ||||
| 			t.Errorf("expected to be able to get auth code after GC: %v", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { | ||||
| 	if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { | ||||
| 		t.Errorf("garbage collection failed: %v", err) | ||||
| 	} else if r.AuthRequests != 1 { | ||||
| 		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests) | ||||
| @@ -457,3 +481,49 @@ func testGC(t *testing.T, s storage.Storage) { | ||||
| 		t.Errorf("expected storage.ErrNotFound, got %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // testTimezones tests that backends either fully support timezones or | ||||
| // do the correct standardization. | ||||
| func testTimezones(t *testing.T, s storage.Storage) { | ||||
| 	est, err := time.LoadLocation("America/New_York") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	// Create an expiry with timezone info. Only expect backends to be | ||||
| 	// accurate to the millisecond | ||||
| 	expiry := time.Now().In(est).Round(time.Millisecond) | ||||
|  | ||||
| 	c := storage.AuthCode{ | ||||
| 		ID:            storage.NewID(), | ||||
| 		ClientID:      "foobar", | ||||
| 		RedirectURI:   "https://localhost:80/callback", | ||||
| 		Nonce:         "foobar", | ||||
| 		Scopes:        []string{"openid", "email"}, | ||||
| 		Expiry:        expiry, | ||||
| 		ConnectorID:   "ldap", | ||||
| 		ConnectorData: []byte(`{"some":"data"}`), | ||||
| 		Claims: storage.Claims{ | ||||
| 			UserID:        "1", | ||||
| 			Username:      "jane", | ||||
| 			Email:         "jane.doe@example.com", | ||||
| 			EmailVerified: true, | ||||
| 			Groups:        []string{"a", "b"}, | ||||
| 		}, | ||||
| 	} | ||||
| 	if err := s.CreateAuthCode(c); err != nil { | ||||
| 		t.Fatalf("failed creating auth code: %v", err) | ||||
| 	} | ||||
| 	got, err := s.GetAuthCode(c.ID) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("failed to get auth code: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// Ensure that if the resulting time is converted to the same | ||||
| 	// timezone, it's the same value. We DO NOT expect timezones | ||||
| 	// to be preserved. | ||||
| 	gotTime := got.Expiry.In(est) | ||||
| 	wantTime := expiry | ||||
| 	if !gotTime.Equal(wantTime) { | ||||
| 		t.Fatalf("expected expiry %v got %v", wantTime, gotTime) | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -9,7 +9,7 @@ func (c *conn) migrate() (int, error) { | ||||
| 	_, err := c.Exec(` | ||||
| 		create table if not exists migrations ( | ||||
| 			num integer not null, | ||||
| 			at timestamp not null | ||||
| 			at timestamptz not null | ||||
| 		); | ||||
| 	`) | ||||
| 	if err != nil { | ||||
| @@ -100,7 +100,7 @@ var migrations = []migration{ | ||||
| 				connector_id text not null, | ||||
| 				connector_data bytea, | ||||
| 		 | ||||
| 				expiry timestamp not null | ||||
| 				expiry timestamptz not null | ||||
| 			); | ||||
| 		 | ||||
| 			create table auth_code ( | ||||
| @@ -119,7 +119,7 @@ var migrations = []migration{ | ||||
| 				connector_id text not null, | ||||
| 				connector_data bytea, | ||||
| 		 | ||||
| 				expiry timestamp not null | ||||
| 				expiry timestamptz not null | ||||
| 			); | ||||
| 		 | ||||
| 			create table refresh_token ( | ||||
| @@ -151,7 +151,7 @@ var migrations = []migration{ | ||||
| 				verification_keys bytea not null, -- JSON array | ||||
| 				signing_key bytea not null,       -- JSON object | ||||
| 				signing_key_pub bytea not null,   -- JSON object | ||||
| 				next_rotation timestamp not null | ||||
| 				next_rotation timestamptz not null | ||||
| 			); | ||||
| 		`, | ||||
| 	}, | ||||
|   | ||||
| @@ -4,6 +4,7 @@ package sql | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"regexp" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/Sirupsen/logrus" | ||||
| 	"github.com/cockroachdb/cockroach-go/crdb" | ||||
| @@ -28,6 +29,9 @@ type flavor struct { | ||||
| 	// | ||||
| 	// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44 | ||||
| 	executeTx func(db *sql.DB, fn func(*sql.Tx) error) error | ||||
|  | ||||
| 	// Does the flavor support timezones? | ||||
| 	supportsTimezones bool | ||||
| } | ||||
|  | ||||
| // A regexp with a replacement string. | ||||
| @@ -69,6 +73,8 @@ var ( | ||||
| 			} | ||||
| 			return tx.Commit() | ||||
| 		}, | ||||
|  | ||||
| 		supportsTimezones: true, | ||||
| 	} | ||||
|  | ||||
| 	flavorSQLite3 = flavor{ | ||||
| @@ -80,7 +86,7 @@ var ( | ||||
| 			{matchLiteral("boolean"), "integer"}, | ||||
| 			// Translate other types. | ||||
| 			{matchLiteral("bytea"), "blob"}, | ||||
| 			// {matchLiteral("timestamp"), "integer"}, | ||||
| 			{matchLiteral("timestamptz"), "timestamp"}, | ||||
| 			// SQLite doesn't have a "now()" method, replace with "date('now')" | ||||
| 			{regexp.MustCompile(`\bnow\(\)`), "date('now')"}, | ||||
| 		}, | ||||
| @@ -107,6 +113,22 @@ func (f flavor) translate(query string) string { | ||||
| 	return query | ||||
| } | ||||
|  | ||||
| // translateArgs translates query parameters that may be unique to | ||||
| // a specific SQL flavor. For example, standardizing "time.Time" | ||||
| // types to UTC for clients that don't provide timezone support. | ||||
| func (c *conn) translateArgs(args []interface{}) []interface{} { | ||||
| 	if c.flavor.supportsTimezones { | ||||
| 		return args | ||||
| 	} | ||||
|  | ||||
| 	for i, arg := range args { | ||||
| 		if t, ok := arg.(time.Time); ok { | ||||
| 			args[i] = t.UTC() | ||||
| 		} | ||||
| 	} | ||||
| 	return args | ||||
| } | ||||
|  | ||||
| // conn is the main database connection. | ||||
| type conn struct { | ||||
| 	db     *sql.DB | ||||
| @@ -122,17 +144,17 @@ func (c *conn) Close() error { | ||||
|  | ||||
| func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||
| 	query = c.flavor.translate(query) | ||||
| 	return c.db.Exec(query, args...) | ||||
| 	return c.db.Exec(query, c.translateArgs(args)...) | ||||
| } | ||||
|  | ||||
| func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	query = c.flavor.translate(query) | ||||
| 	return c.db.Query(query, args...) | ||||
| 	return c.db.Query(query, c.translateArgs(args)...) | ||||
| } | ||||
|  | ||||
| func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row { | ||||
| 	query = c.flavor.translate(query) | ||||
| 	return c.db.QueryRow(query, args...) | ||||
| 	return c.db.QueryRow(query, c.translateArgs(args)...) | ||||
| } | ||||
|  | ||||
| // ExecTx runs a method which operates on a transaction. | ||||
| @@ -163,15 +185,15 @@ type trans struct { | ||||
|  | ||||
| func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||
| 	query = t.c.flavor.translate(query) | ||||
| 	return t.tx.Exec(query, args...) | ||||
| 	return t.tx.Exec(query, t.c.translateArgs(args)...) | ||||
| } | ||||
|  | ||||
| func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	query = t.c.flavor.translate(query) | ||||
| 	return t.tx.Query(query, args...) | ||||
| 	return t.tx.Query(query, t.c.translateArgs(args)...) | ||||
| } | ||||
|  | ||||
| func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row { | ||||
| 	query = t.c.flavor.translate(query) | ||||
| 	return t.tx.QueryRow(query, args...) | ||||
| 	return t.tx.QueryRow(query, t.c.translateArgs(args)...) | ||||
| } | ||||
|   | ||||
| @@ -44,11 +44,9 @@ type GCResult struct { | ||||
| 	AuthCodes    int64 | ||||
| } | ||||
|  | ||||
| // Storage is the storage interface used by the server. Implementations, at minimum | ||||
| // require compare-and-swap atomic actions. | ||||
| // | ||||
| // Implementations are expected to perform their own garbage collection of | ||||
| // expired objects (expect keys, which are handled by the server). | ||||
| // Storage is the storage interface used by the server. Implementations are | ||||
| // required to be able to perform atomic compare-and-swap updates and either | ||||
| // support timezones or standardize on UTC. | ||||
| type Storage interface { | ||||
| 	Close() error | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user