storage/sql: add garbage collection method

This commit is contained in:
Eric Chiang 2016-10-12 18:48:09 -07:00
parent c14ab3c44e
commit 9ce05ecf73
5 changed files with 24 additions and 111 deletions

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"strconv" "strconv"
"time"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
) )
@ -22,7 +21,7 @@ func (s *SQLite3) Open() (storage.Storage, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return withGC(conn, time.Now), nil return conn, nil
} }
func (s *SQLite3) open() (*conn, error) { func (s *SQLite3) open() (*conn, error) {
@ -76,7 +75,7 @@ func (p *Postgres) Open() (storage.Storage, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return withGC(conn, time.Now), nil return conn, nil
} }
func (p *Postgres) open() (*conn, error) { func (p *Postgres) open() (*conn, error) {

View File

@ -54,7 +54,7 @@ func TestSQLite3(t *testing.T) {
} }
withTimeout(time.Second*10, func() { withTimeout(time.Second*10, func() {
conformance.RunTestSuite(t, newStorage) conformance.RunTests(t, newStorage)
}) })
} }
@ -85,6 +85,6 @@ func TestPostgres(t *testing.T) {
return conn return conn
} }
withTimeout(time.Minute*1, func() { withTimeout(time.Minute*1, func() {
conformance.RunTestSuite(t, newStorage) conformance.RunTests(t, newStorage)
}) })
} }

View File

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"time"
"github.com/coreos/dex/storage" "github.com/coreos/dex/storage"
) )
@ -83,6 +84,25 @@ type scanner interface {
Scan(dest ...interface{}) error Scan(dest ...interface{}) error
} }
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) {
r, err := c.Exec(`delete from auth_request where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc auth_request: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.AuthRequests = n
}
r, err = c.Exec(`delete from auth_code where expiry < $1`, now)
if err != nil {
return result, fmt.Errorf("gc auth_code: %v", err)
}
if n, err := r.RowsAffected(); err == nil {
result.AuthCodes = n
}
return
}
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
_, err := c.Exec(` _, err := c.Exec(`
insert into auth_request ( insert into auth_request (

View File

@ -1,53 +0,0 @@
package sql
import (
"context"
"fmt"
"log"
"time"
"github.com/coreos/dex/storage"
)
type gc struct {
now func() time.Time
conn *conn
}
func (gc gc) run() error {
for _, table := range []string{"auth_request", "auth_code"} {
_, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now())
if err != nil {
return fmt.Errorf("gc %s: %v", table, err)
}
// TODO(ericchiang): when we have levelled logging print how many rows were gc'd
}
return nil
}
type withCancel struct {
storage.Storage
cancel context.CancelFunc
}
func (w withCancel) Close() error {
w.cancel()
return w.Storage.Close()
}
func withGC(conn *conn, now func() time.Time) storage.Storage {
ctx, cancel := context.WithCancel(context.Background())
run := (gc{now, conn}).run
go func() {
for {
select {
case <-time.After(time.Second * 30):
if err := run(); err != nil {
log.Printf("gc failed: %v", err)
}
case <-ctx.Done():
}
}
}()
return withCancel{conn, cancel}
}

View File

@ -1,53 +0,0 @@
package sql
import (
"testing"
"time"
"github.com/coreos/dex/storage"
)
func TestGC(t *testing.T) {
// TODO(ericchiang): Add a GarbageCollect method to the storage interface so
// we can write conformance tests instead of directly testing each implementation.
s := &SQLite3{":memory:"}
conn, err := s.open()
if err != nil {
t.Fatal(err)
}
defer conn.Close()
clock := time.Now()
now := func() time.Time { return clock }
runGC := (gc{now, conn}).run
a := storage.AuthRequest{
ID: storage.NewID(),
Expiry: now().Add(time.Second),
}
if err := conn.CreateAuthRequest(a); err != nil {
t.Fatal(err)
}
if err := runGC(); err != nil {
t.Errorf("gc failed: %v", err)
}
if _, err := conn.GetAuthRequest(a.ID); err != nil {
t.Errorf("failed to get auth request after gc: %v", err)
}
clock = clock.Add(time.Minute)
if err := runGC(); err != nil {
t.Errorf("gc failed: %v", err)
}
if _, err := conn.GetAuthRequest(a.ID); err == nil {
t.Errorf("expected error after gc'ing auth request: %v", err)
} else if err != storage.ErrNotFound {
t.Errorf("expected error storage.NotFound got: %v", err)
}
}