storage/sql: add garbage collection method
This commit is contained in:
		| @@ -5,7 +5,6 @@ import ( | |||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/coreos/dex/storage" | 	"github.com/coreos/dex/storage" | ||||||
| ) | ) | ||||||
| @@ -22,7 +21,7 @@ func (s *SQLite3) Open() (storage.Storage, error) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return withGC(conn, time.Now), nil | 	return conn, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *SQLite3) open() (*conn, error) { | func (s *SQLite3) open() (*conn, error) { | ||||||
| @@ -76,7 +75,7 @@ func (p *Postgres) Open() (storage.Storage, error) { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return withGC(conn, time.Now), nil | 	return conn, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *Postgres) open() (*conn, error) { | func (p *Postgres) open() (*conn, error) { | ||||||
|   | |||||||
| @@ -54,7 +54,7 @@ func TestSQLite3(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	withTimeout(time.Second*10, func() { | 	withTimeout(time.Second*10, func() { | ||||||
| 		conformance.RunTestSuite(t, newStorage) | 		conformance.RunTests(t, newStorage) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -85,6 +85,6 @@ func TestPostgres(t *testing.T) { | |||||||
| 		return conn | 		return conn | ||||||
| 	} | 	} | ||||||
| 	withTimeout(time.Minute*1, func() { | 	withTimeout(time.Minute*1, func() { | ||||||
| 		conformance.RunTestSuite(t, newStorage) | 		conformance.RunTests(t, newStorage) | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -7,6 +7,7 @@ import ( | |||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"strings" | 	"strings" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/coreos/dex/storage" | 	"github.com/coreos/dex/storage" | ||||||
| ) | ) | ||||||
| @@ -83,6 +84,25 @@ type scanner interface { | |||||||
| 	Scan(dest ...interface{}) error | 	Scan(dest ...interface{}) error | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) { | ||||||
|  | 	r, err := c.Exec(`delete from auth_request where expiry < $1`, now) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return result, fmt.Errorf("gc auth_request: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if n, err := r.RowsAffected(); err == nil { | ||||||
|  | 		result.AuthRequests = n | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	r, err = c.Exec(`delete from auth_code where expiry < $1`, now) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return result, fmt.Errorf("gc auth_code: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if n, err := r.RowsAffected(); err == nil { | ||||||
|  | 		result.AuthCodes = n | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
| func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { | func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { | ||||||
| 	_, err := c.Exec(` | 	_, err := c.Exec(` | ||||||
| 		insert into auth_request ( | 		insert into auth_request ( | ||||||
|   | |||||||
| @@ -1,53 +0,0 @@ | |||||||
| package sql |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"fmt" |  | ||||||
| 	"log" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/coreos/dex/storage" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type gc struct { |  | ||||||
| 	now  func() time.Time |  | ||||||
| 	conn *conn |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (gc gc) run() error { |  | ||||||
| 	for _, table := range []string{"auth_request", "auth_code"} { |  | ||||||
| 		_, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now()) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return fmt.Errorf("gc %s: %v", table, err) |  | ||||||
| 		} |  | ||||||
| 		// TODO(ericchiang): when we have levelled logging print how many rows were gc'd |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type withCancel struct { |  | ||||||
| 	storage.Storage |  | ||||||
| 	cancel context.CancelFunc |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (w withCancel) Close() error { |  | ||||||
| 	w.cancel() |  | ||||||
| 	return w.Storage.Close() |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func withGC(conn *conn, now func() time.Time) storage.Storage { |  | ||||||
| 	ctx, cancel := context.WithCancel(context.Background()) |  | ||||||
| 	run := (gc{now, conn}).run |  | ||||||
| 	go func() { |  | ||||||
| 		for { |  | ||||||
| 			select { |  | ||||||
| 			case <-time.After(time.Second * 30): |  | ||||||
| 				if err := run(); err != nil { |  | ||||||
| 					log.Printf("gc failed: %v", err) |  | ||||||
| 				} |  | ||||||
| 			case <-ctx.Done(): |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	}() |  | ||||||
| 	return withCancel{conn, cancel} |  | ||||||
| } |  | ||||||
| @@ -1,53 +0,0 @@ | |||||||
| package sql |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
|  |  | ||||||
| 	"github.com/coreos/dex/storage" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func TestGC(t *testing.T) { |  | ||||||
| 	// TODO(ericchiang): Add a GarbageCollect method to the storage interface so |  | ||||||
| 	// we can write conformance tests instead of directly testing each implementation. |  | ||||||
| 	s := &SQLite3{":memory:"} |  | ||||||
| 	conn, err := s.open() |  | ||||||
| 	if err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
| 	defer conn.Close() |  | ||||||
|  |  | ||||||
| 	clock := time.Now() |  | ||||||
| 	now := func() time.Time { return clock } |  | ||||||
|  |  | ||||||
| 	runGC := (gc{now, conn}).run |  | ||||||
|  |  | ||||||
| 	a := storage.AuthRequest{ |  | ||||||
| 		ID:     storage.NewID(), |  | ||||||
| 		Expiry: now().Add(time.Second), |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if err := conn.CreateAuthRequest(a); err != nil { |  | ||||||
| 		t.Fatal(err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if err := runGC(); err != nil { |  | ||||||
| 		t.Errorf("gc failed: %v", err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if _, err := conn.GetAuthRequest(a.ID); err != nil { |  | ||||||
| 		t.Errorf("failed to get auth request after gc: %v", err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	clock = clock.Add(time.Minute) |  | ||||||
|  |  | ||||||
| 	if err := runGC(); err != nil { |  | ||||||
| 		t.Errorf("gc failed: %v", err) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if _, err := conn.GetAuthRequest(a.ID); err == nil { |  | ||||||
| 		t.Errorf("expected error after gc'ing auth request: %v", err) |  | ||||||
| 	} else if err != storage.ErrNotFound { |  | ||||||
| 		t.Errorf("expected error storage.NotFound got: %v", err) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
		Reference in New Issue
	
	Block a user