{cmd,server}: move garbage collection logic to server
This commit is contained in:
		| @@ -9,6 +9,7 @@ import ( | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/spf13/cobra" | ||||
| 	"golang.org/x/net/context" | ||||
| 	"google.golang.org/grpc" | ||||
| 	"google.golang.org/grpc/credentials" | ||||
| 	yaml "gopkg.in/yaml.v2" | ||||
| @@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error { | ||||
| 		EnablePasswordDB:       c.EnablePasswordDB, | ||||
| 	} | ||||
|  | ||||
| 	serv, err := server.NewServer(serverConfig) | ||||
| 	serv, err := server.NewServer(context.Background(), serverConfig) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("initializing server: %v", err) | ||||
| 	} | ||||
|   | ||||
| @@ -4,10 +4,15 @@ import ( | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
|  | ||||
| 	"golang.org/x/net/context" | ||||
| ) | ||||
|  | ||||
| func TestHandleHealth(t *testing.T) { | ||||
| 	httpServer, server := newTestServer(t, nil) | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	httpServer, server := newTestServer(t, ctx, nil) | ||||
| 	defer httpServer.Close() | ||||
|  | ||||
| 	rr := httptest.NewRecorder() | ||||
|   | ||||
| @@ -56,40 +56,34 @@ type keyRotater struct { | ||||
| 	storage.Storage | ||||
|  | ||||
| 	strategy rotationStrategy | ||||
| 	cancel   context.CancelFunc | ||||
|  | ||||
| 	now      func() time.Time | ||||
| } | ||||
|  | ||||
| func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage { | ||||
| 	if now == nil { | ||||
| 		now = time.Now | ||||
| 	} | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	rotater := keyRotater{s, strategy, cancel, now} | ||||
| // startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled. | ||||
| // | ||||
| // The method blocks until after the first attempt to rotate keys has completed. That way | ||||
| // healthy storages will return from this call with valid keys. | ||||
| func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) { | ||||
| 	rotater := keyRotater{s, strategy, now} | ||||
|  | ||||
| 	// Try to rotate immediately so properly configured storages will return a | ||||
| 	// storage with keys. | ||||
| 	// Try to rotate immediately so properly configured storages will have keys. | ||||
| 	if err := rotater.rotate(); err != nil { | ||||
| 		log.Printf("failed to rotate keys: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			select { | ||||
| 			case <-ctx.Done(): | ||||
| 				return | ||||
| 		case <-time.After(time.Second * 30): | ||||
| 			case <-time.After(strategy.period): | ||||
| 				if err := rotater.rotate(); err != nil { | ||||
| 					log.Printf("failed to rotate keys: %v", err) | ||||
| 				} | ||||
| 			} | ||||
| 	}() | ||||
| 	return rotater | ||||
| 		} | ||||
|  | ||||
| func (k keyRotater) Close() error { | ||||
| 	k.cancel() | ||||
| 	return k.Storage.Close() | ||||
| 	}() | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (k keyRotater) rotate() error { | ||||
|   | ||||
| @@ -11,6 +11,7 @@ import ( | ||||
| 	"time" | ||||
|  | ||||
| 	"golang.org/x/crypto/bcrypt" | ||||
| 	"golang.org/x/net/context" | ||||
|  | ||||
| 	"github.com/gorilla/mux" | ||||
|  | ||||
| @@ -48,6 +49,8 @@ type Config struct { | ||||
| 	RotateKeysAfter  time.Duration // Defaults to 6 hours. | ||||
| 	IDTokensValidFor time.Duration // Defaults to 24 hours | ||||
|  | ||||
| 	GCFrequency time.Duration // Defaults to 5 minutes | ||||
|  | ||||
| 	// If specified, the server will use this function for determining time. | ||||
| 	Now func() time.Time | ||||
|  | ||||
| @@ -87,14 +90,14 @@ type Server struct { | ||||
| } | ||||
|  | ||||
| // NewServer constructs a server from the provided config. | ||||
| func NewServer(c Config) (*Server, error) { | ||||
| 	return newServer(c, defaultRotationStrategy( | ||||
| func NewServer(ctx context.Context, c Config) (*Server, error) { | ||||
| 	return newServer(ctx, c, defaultRotationStrategy( | ||||
| 		value(c.RotateKeysAfter, 6*time.Hour), | ||||
| 		value(c.IDTokensValidFor, 24*time.Hour), | ||||
| 	)) | ||||
| } | ||||
|  | ||||
| func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { | ||||
| func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) { | ||||
| 	issuerURL, err := url.Parse(c.Issuer) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("server: can't parse issuer URL") | ||||
| @@ -140,12 +143,7 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { | ||||
| 	s := &Server{ | ||||
| 		issuerURL:              *issuerURL, | ||||
| 		connectors:             make(map[string]Connector), | ||||
| 		storage: newKeyCacher( | ||||
| 			storageWithKeyRotation( | ||||
| 				c.Storage, rotationStrategy, now, | ||||
| 			), | ||||
| 			now, | ||||
| 		), | ||||
| 		storage:                newKeyCacher(c.Storage, now), | ||||
| 		supportedResponseTypes: supported, | ||||
| 		idTokensValidFor:       value(c.IDTokensValidFor, 24*time.Hour), | ||||
| 		skipApproval:           c.SkipApprovalScreen, | ||||
| @@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { | ||||
| 	handleFunc("/healthz", s.handleHealth) | ||||
| 	s.mux = r | ||||
|  | ||||
| 	startKeyRotation(ctx, c.Storage, rotationStrategy, now) | ||||
| 	startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now) | ||||
|  | ||||
| 	return s, nil | ||||
| } | ||||
|  | ||||
| @@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) { | ||||
| 	} | ||||
| 	return storageKeys, nil | ||||
| } | ||||
|  | ||||
| func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) { | ||||
| 	go func() { | ||||
| 		for { | ||||
| 			select { | ||||
| 			case <-ctx.Done(): | ||||
| 				return | ||||
| 			case <-time.After(frequency): | ||||
| 				if r, err := s.GarbageCollect(now()); err != nil { | ||||
| 					log.Printf("garbage collection failed: %v", err) | ||||
| 				} else { | ||||
| 					log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	}() | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ | ||||
| Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= | ||||
| -----END RSA PRIVATE KEY-----`) | ||||
|  | ||||
| func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { | ||||
| func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) { | ||||
| 	var server *Server | ||||
| 	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		server.ServeHTTP(w, r) | ||||
| @@ -91,7 +91,7 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server | ||||
| 	s.URL = config.Issuer | ||||
|  | ||||
| 	var err error | ||||
| 	if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil { | ||||
| 	if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. | ||||
| @@ -99,14 +99,16 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server | ||||
| } | ||||
|  | ||||
| func TestNewTestServer(t *testing.T) { | ||||
| 	newTestServer(t, nil) | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
| 	newTestServer(t, ctx, nil) | ||||
| } | ||||
|  | ||||
| func TestDiscovery(t *testing.T) { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	httpServer, _ := newTestServer(t, func(c *Config) { | ||||
| 	httpServer, _ := newTestServer(t, ctx, func(c *Config) { | ||||
| 		c.Issuer = c.Issuer + "/non-root-path" | ||||
| 	}) | ||||
| 	defer httpServer.Close() | ||||
| @@ -227,7 +229,7 @@ func TestOAuth2CodeFlow(t *testing.T) { | ||||
| 			ctx, cancel := context.WithCancel(context.Background()) | ||||
| 			defer cancel() | ||||
|  | ||||
| 			httpServer, s := newTestServer(t, func(c *Config) { | ||||
| 			httpServer, s := newTestServer(t, ctx, func(c *Config) { | ||||
| 				c.Issuer = c.Issuer + "/non-root-path" | ||||
| 			}) | ||||
| 			defer httpServer.Close() | ||||
| @@ -340,7 +342,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	httpServer, s := newTestServer(t, func(c *Config) { | ||||
| 	httpServer, s := newTestServer(t, ctx, func(c *Config) { | ||||
| 		// Enable support for the implicit flow. | ||||
| 		c.SupportedResponseTypes = []string{"code", "token"} | ||||
| 	}) | ||||
| @@ -470,7 +472,7 @@ func TestCrossClientScopes(t *testing.T) { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	httpServer, s := newTestServer(t, func(c *Config) { | ||||
| 	httpServer, s := newTestServer(t, ctx, func(c *Config) { | ||||
| 		c.Issuer = c.Issuer + "/non-root-path" | ||||
| 	}) | ||||
| 	defer httpServer.Close() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user