Merge pull request #749 from ericchiang/postgres-timezones
storage: fix postgres timezone handling
This commit is contained in:
		| @@ -48,6 +48,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) { | |||||||
| 		{"PasswordCRUD", testPasswordCRUD}, | 		{"PasswordCRUD", testPasswordCRUD}, | ||||||
| 		{"KeysCRUD", testKeysCRUD}, | 		{"KeysCRUD", testKeysCRUD}, | ||||||
| 		{"GarbageCollection", testGC}, | 		{"GarbageCollection", testGC}, | ||||||
|  | 		{"TimezoneSupport", testTimezones}, | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -370,14 +371,23 @@ func testKeysCRUD(t *testing.T, s storage.Storage) { | |||||||
| } | } | ||||||
|  |  | ||||||
| func testGC(t *testing.T, s storage.Storage) { | func testGC(t *testing.T, s storage.Storage) { | ||||||
| 	n := time.Now().UTC() | 	est, err := time.LoadLocation("America/New_York") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	pst, err := time.LoadLocation("America/Los_Angeles") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	expiry := time.Now().In(est) | ||||||
| 	c := storage.AuthCode{ | 	c := storage.AuthCode{ | ||||||
| 		ID:            storage.NewID(), | 		ID:            storage.NewID(), | ||||||
| 		ClientID:      "foobar", | 		ClientID:      "foobar", | ||||||
| 		RedirectURI:   "https://localhost:80/callback", | 		RedirectURI:   "https://localhost:80/callback", | ||||||
| 		Nonce:         "foobar", | 		Nonce:         "foobar", | ||||||
| 		Scopes:        []string{"openid", "email"}, | 		Scopes:        []string{"openid", "email"}, | ||||||
| 		Expiry:        n.Add(time.Second), | 		Expiry:        expiry, | ||||||
| 		ConnectorID:   "ldap", | 		ConnectorID:   "ldap", | ||||||
| 		ConnectorData: []byte(`{"some":"data"}`), | 		ConnectorData: []byte(`{"some":"data"}`), | ||||||
| 		Claims: storage.Claims{ | 		Claims: storage.Claims{ | ||||||
| @@ -393,14 +403,21 @@ func testGC(t *testing.T, s storage.Storage) { | |||||||
| 		t.Fatalf("failed creating auth code: %v", err) | 		t.Fatalf("failed creating auth code: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if _, err := s.GarbageCollect(n); err != nil { | 	for _, tz := range []*time.Location{time.UTC, est, pst} { | ||||||
| 		t.Errorf("garbage collection failed: %v", err) | 		result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) | ||||||
| 	} | 		if err != nil { | ||||||
| 	if _, err := s.GetAuthCode(c.ID); err != nil { | 			t.Errorf("garbage collection failed: %v", err) | ||||||
| 		t.Errorf("expected to be able to get auth code after GC: %v", err) | 		} else { | ||||||
|  | 			if result.AuthCodes != 0 || result.AuthRequests != 0 { | ||||||
|  | 				t.Errorf("expected no garbage collection results, got %#v", result) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if _, err := s.GetAuthCode(c.ID); err != nil { | ||||||
|  | 			t.Errorf("expected to be able to get auth code after GC: %v", err) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { | 	if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { | ||||||
| 		t.Errorf("garbage collection failed: %v", err) | 		t.Errorf("garbage collection failed: %v", err) | ||||||
| 	} else if r.AuthCodes != 1 { | 	} else if r.AuthCodes != 1 { | ||||||
| 		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes) | 		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes) | ||||||
| @@ -422,7 +439,7 @@ func testGC(t *testing.T, s storage.Storage) { | |||||||
| 		State:               "bar", | 		State:               "bar", | ||||||
| 		ForceApprovalPrompt: true, | 		ForceApprovalPrompt: true, | ||||||
| 		LoggedIn:            true, | 		LoggedIn:            true, | ||||||
| 		Expiry:              n, | 		Expiry:              expiry, | ||||||
| 		ConnectorID:         "ldap", | 		ConnectorID:         "ldap", | ||||||
| 		ConnectorData:       []byte(`{"some":"data"}`), | 		ConnectorData:       []byte(`{"some":"data"}`), | ||||||
| 		Claims: storage.Claims{ | 		Claims: storage.Claims{ | ||||||
| @@ -438,14 +455,21 @@ func testGC(t *testing.T, s storage.Storage) { | |||||||
| 		t.Fatalf("failed creating auth request: %v", err) | 		t.Fatalf("failed creating auth request: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if _, err := s.GarbageCollect(n); err != nil { | 	for _, tz := range []*time.Location{time.UTC, est, pst} { | ||||||
| 		t.Errorf("garbage collection failed: %v", err) | 		result, err := s.GarbageCollect(expiry.Add(-time.Hour).In(tz)) | ||||||
| 	} | 		if err != nil { | ||||||
| 	if _, err := s.GetAuthRequest(a.ID); err != nil { | 			t.Errorf("garbage collection failed: %v", err) | ||||||
| 		t.Errorf("expected to be able to get auth code after GC: %v", err) | 		} else { | ||||||
|  | 			if result.AuthCodes != 0 || result.AuthRequests != 0 { | ||||||
|  | 				t.Errorf("expected no garbage collection results, got %#v", result) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if _, err := s.GetAuthRequest(a.ID); err != nil { | ||||||
|  | 			t.Errorf("expected to be able to get auth code after GC: %v", err) | ||||||
|  | 		} | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if r, err := s.GarbageCollect(n.Add(time.Minute)); err != nil { | 	if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil { | ||||||
| 		t.Errorf("garbage collection failed: %v", err) | 		t.Errorf("garbage collection failed: %v", err) | ||||||
| 	} else if r.AuthRequests != 1 { | 	} else if r.AuthRequests != 1 { | ||||||
| 		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests) | 		t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests) | ||||||
| @@ -457,3 +481,49 @@ func testGC(t *testing.T, s storage.Storage) { | |||||||
| 		t.Errorf("expected storage.ErrNotFound, got %v", err) | 		t.Errorf("expected storage.ErrNotFound, got %v", err) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // testTimezones tests that backends either fully support timezones or | ||||||
|  | // do the correct standardization. | ||||||
|  | func testTimezones(t *testing.T, s storage.Storage) { | ||||||
|  | 	est, err := time.LoadLocation("America/New_York") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	// Create an expiry with timezone info. Only expect backends to be | ||||||
|  | 	// accurate to the millisecond | ||||||
|  | 	expiry := time.Now().In(est).Round(time.Millisecond) | ||||||
|  |  | ||||||
|  | 	c := storage.AuthCode{ | ||||||
|  | 		ID:            storage.NewID(), | ||||||
|  | 		ClientID:      "foobar", | ||||||
|  | 		RedirectURI:   "https://localhost:80/callback", | ||||||
|  | 		Nonce:         "foobar", | ||||||
|  | 		Scopes:        []string{"openid", "email"}, | ||||||
|  | 		Expiry:        expiry, | ||||||
|  | 		ConnectorID:   "ldap", | ||||||
|  | 		ConnectorData: []byte(`{"some":"data"}`), | ||||||
|  | 		Claims: storage.Claims{ | ||||||
|  | 			UserID:        "1", | ||||||
|  | 			Username:      "jane", | ||||||
|  | 			Email:         "jane.doe@example.com", | ||||||
|  | 			EmailVerified: true, | ||||||
|  | 			Groups:        []string{"a", "b"}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	if err := s.CreateAuthCode(c); err != nil { | ||||||
|  | 		t.Fatalf("failed creating auth code: %v", err) | ||||||
|  | 	} | ||||||
|  | 	got, err := s.GetAuthCode(c.ID) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("failed to get auth code: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Ensure that if the resulting time is converted to the same | ||||||
|  | 	// timezone, it's the same value. We DO NOT expect timezones | ||||||
|  | 	// to be preserved. | ||||||
|  | 	gotTime := got.Expiry.In(est) | ||||||
|  | 	wantTime := expiry | ||||||
|  | 	if !gotTime.Equal(wantTime) { | ||||||
|  | 		t.Fatalf("expected expiry %v got %v", wantTime, gotTime) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ func (c *conn) migrate() (int, error) { | |||||||
| 	_, err := c.Exec(` | 	_, err := c.Exec(` | ||||||
| 		create table if not exists migrations ( | 		create table if not exists migrations ( | ||||||
| 			num integer not null, | 			num integer not null, | ||||||
| 			at timestamp not null | 			at timestamptz not null | ||||||
| 		); | 		); | ||||||
| 	`) | 	`) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -100,7 +100,7 @@ var migrations = []migration{ | |||||||
| 				connector_id text not null, | 				connector_id text not null, | ||||||
| 				connector_data bytea, | 				connector_data bytea, | ||||||
| 		 | 		 | ||||||
| 				expiry timestamp not null | 				expiry timestamptz not null | ||||||
| 			); | 			); | ||||||
| 		 | 		 | ||||||
| 			create table auth_code ( | 			create table auth_code ( | ||||||
| @@ -119,7 +119,7 @@ var migrations = []migration{ | |||||||
| 				connector_id text not null, | 				connector_id text not null, | ||||||
| 				connector_data bytea, | 				connector_data bytea, | ||||||
| 		 | 		 | ||||||
| 				expiry timestamp not null | 				expiry timestamptz not null | ||||||
| 			); | 			); | ||||||
| 		 | 		 | ||||||
| 			create table refresh_token ( | 			create table refresh_token ( | ||||||
| @@ -151,7 +151,7 @@ var migrations = []migration{ | |||||||
| 				verification_keys bytea not null, -- JSON array | 				verification_keys bytea not null, -- JSON array | ||||||
| 				signing_key bytea not null,       -- JSON object | 				signing_key bytea not null,       -- JSON object | ||||||
| 				signing_key_pub bytea not null,   -- JSON object | 				signing_key_pub bytea not null,   -- JSON object | ||||||
| 				next_rotation timestamp not null | 				next_rotation timestamptz not null | ||||||
| 			); | 			); | ||||||
| 		`, | 		`, | ||||||
| 	}, | 	}, | ||||||
|   | |||||||
| @@ -4,6 +4,7 @@ package sql | |||||||
| import ( | import ( | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"regexp" | 	"regexp" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/Sirupsen/logrus" | 	"github.com/Sirupsen/logrus" | ||||||
| 	"github.com/cockroachdb/cockroach-go/crdb" | 	"github.com/cockroachdb/cockroach-go/crdb" | ||||||
| @@ -28,6 +29,9 @@ type flavor struct { | |||||||
| 	// | 	// | ||||||
| 	// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44 | 	// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44 | ||||||
| 	executeTx func(db *sql.DB, fn func(*sql.Tx) error) error | 	executeTx func(db *sql.DB, fn func(*sql.Tx) error) error | ||||||
|  |  | ||||||
|  | 	// Does the flavor support timezones? | ||||||
|  | 	supportsTimezones bool | ||||||
| } | } | ||||||
|  |  | ||||||
| // A regexp with a replacement string. | // A regexp with a replacement string. | ||||||
| @@ -69,6 +73,8 @@ var ( | |||||||
| 			} | 			} | ||||||
| 			return tx.Commit() | 			return tx.Commit() | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
|  | 		supportsTimezones: true, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	flavorSQLite3 = flavor{ | 	flavorSQLite3 = flavor{ | ||||||
| @@ -80,7 +86,7 @@ var ( | |||||||
| 			{matchLiteral("boolean"), "integer"}, | 			{matchLiteral("boolean"), "integer"}, | ||||||
| 			// Translate other types. | 			// Translate other types. | ||||||
| 			{matchLiteral("bytea"), "blob"}, | 			{matchLiteral("bytea"), "blob"}, | ||||||
| 			// {matchLiteral("timestamp"), "integer"}, | 			{matchLiteral("timestamptz"), "timestamp"}, | ||||||
| 			// SQLite doesn't have a "now()" method, replace with "date('now')" | 			// SQLite doesn't have a "now()" method, replace with "date('now')" | ||||||
| 			{regexp.MustCompile(`\bnow\(\)`), "date('now')"}, | 			{regexp.MustCompile(`\bnow\(\)`), "date('now')"}, | ||||||
| 		}, | 		}, | ||||||
| @@ -107,6 +113,22 @@ func (f flavor) translate(query string) string { | |||||||
| 	return query | 	return query | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // translateArgs translates query parameters that may be unique to | ||||||
|  | // a specific SQL flavor. For example, standardizing "time.Time" | ||||||
|  | // types to UTC for clients that don't provide timezone support. | ||||||
|  | func (c *conn) translateArgs(args []interface{}) []interface{} { | ||||||
|  | 	if c.flavor.supportsTimezones { | ||||||
|  | 		return args | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i, arg := range args { | ||||||
|  | 		if t, ok := arg.(time.Time); ok { | ||||||
|  | 			args[i] = t.UTC() | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return args | ||||||
|  | } | ||||||
|  |  | ||||||
| // conn is the main database connection. | // conn is the main database connection. | ||||||
| type conn struct { | type conn struct { | ||||||
| 	db     *sql.DB | 	db     *sql.DB | ||||||
| @@ -122,17 +144,17 @@ func (c *conn) Close() error { | |||||||
|  |  | ||||||
| func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) { | func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||||
| 	query = c.flavor.translate(query) | 	query = c.flavor.translate(query) | ||||||
| 	return c.db.Exec(query, args...) | 	return c.db.Exec(query, c.translateArgs(args)...) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) { | func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) { | ||||||
| 	query = c.flavor.translate(query) | 	query = c.flavor.translate(query) | ||||||
| 	return c.db.Query(query, args...) | 	return c.db.Query(query, c.translateArgs(args)...) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row { | func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row { | ||||||
| 	query = c.flavor.translate(query) | 	query = c.flavor.translate(query) | ||||||
| 	return c.db.QueryRow(query, args...) | 	return c.db.QueryRow(query, c.translateArgs(args)...) | ||||||
| } | } | ||||||
|  |  | ||||||
| // ExecTx runs a method which operates on a transaction. | // ExecTx runs a method which operates on a transaction. | ||||||
| @@ -163,15 +185,15 @@ type trans struct { | |||||||
|  |  | ||||||
| func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) { | func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||||
| 	query = t.c.flavor.translate(query) | 	query = t.c.flavor.translate(query) | ||||||
| 	return t.tx.Exec(query, args...) | 	return t.tx.Exec(query, t.c.translateArgs(args)...) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) { | func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) { | ||||||
| 	query = t.c.flavor.translate(query) | 	query = t.c.flavor.translate(query) | ||||||
| 	return t.tx.Query(query, args...) | 	return t.tx.Query(query, t.c.translateArgs(args)...) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row { | func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row { | ||||||
| 	query = t.c.flavor.translate(query) | 	query = t.c.flavor.translate(query) | ||||||
| 	return t.tx.QueryRow(query, args...) | 	return t.tx.QueryRow(query, t.c.translateArgs(args)...) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -44,11 +44,9 @@ type GCResult struct { | |||||||
| 	AuthCodes    int64 | 	AuthCodes    int64 | ||||||
| } | } | ||||||
|  |  | ||||||
| // Storage is the storage interface used by the server. Implementations, at minimum | // Storage is the storage interface used by the server. Implementations are | ||||||
| // require compare-and-swap atomic actions. | // required to be able to perform atomic compare-and-swap updates and either | ||||||
| // | // support timezones or standardize on UTC. | ||||||
| // Implementations are expected to perform their own garbage collection of |  | ||||||
| // expired objects (expect keys, which are handled by the server). |  | ||||||
| type Storage interface { | type Storage interface { | ||||||
| 	Close() error | 	Close() error | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user