{cmd,server}: move garbage collection logic to server
This commit is contained in:
		| @@ -9,6 +9,7 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
|  |  | ||||||
| 	"github.com/spf13/cobra" | 	"github.com/spf13/cobra" | ||||||
|  | 	"golang.org/x/net/context" | ||||||
| 	"google.golang.org/grpc" | 	"google.golang.org/grpc" | ||||||
| 	"google.golang.org/grpc/credentials" | 	"google.golang.org/grpc/credentials" | ||||||
| 	yaml "gopkg.in/yaml.v2" | 	yaml "gopkg.in/yaml.v2" | ||||||
| @@ -124,7 +125,7 @@ func serve(cmd *cobra.Command, args []string) error { | |||||||
| 		EnablePasswordDB:       c.EnablePasswordDB, | 		EnablePasswordDB:       c.EnablePasswordDB, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	serv, err := server.NewServer(serverConfig) | 	serv, err := server.NewServer(context.Background(), serverConfig) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return fmt.Errorf("initializing server: %v", err) | 		return fmt.Errorf("initializing server: %v", err) | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -4,10 +4,15 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/http/httptest" | 	"net/http/httptest" | ||||||
| 	"testing" | 	"testing" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/context" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func TestHandleHealth(t *testing.T) { | func TestHandleHealth(t *testing.T) { | ||||||
| 	httpServer, server := newTestServer(t, nil) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
|  | 	defer cancel() | ||||||
|  |  | ||||||
|  | 	httpServer, server := newTestServer(t, ctx, nil) | ||||||
| 	defer httpServer.Close() | 	defer httpServer.Close() | ||||||
|  |  | ||||||
| 	rr := httptest.NewRecorder() | 	rr := httptest.NewRecorder() | ||||||
|   | |||||||
| @@ -56,40 +56,34 @@ type keyRotater struct { | |||||||
| 	storage.Storage | 	storage.Storage | ||||||
|  |  | ||||||
| 	strategy rotationStrategy | 	strategy rotationStrategy | ||||||
| 	cancel   context.CancelFunc |  | ||||||
|  |  | ||||||
| 	now      func() time.Time | 	now      func() time.Time | ||||||
| } | } | ||||||
|  |  | ||||||
| func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage { | // startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled. | ||||||
| 	if now == nil { | // | ||||||
| 		now = time.Now | // The method blocks until after the first attempt to rotate keys has completed. That way | ||||||
| 	} | // healthy storages will return from this call with valid keys. | ||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | func startKeyRotation(ctx context.Context, s storage.Storage, strategy rotationStrategy, now func() time.Time) { | ||||||
| 	rotater := keyRotater{s, strategy, cancel, now} | 	rotater := keyRotater{s, strategy, now} | ||||||
|  |  | ||||||
| 	// Try to rotate immediately so properly configured storages will return a | 	// Try to rotate immediately so properly configured storages will have keys. | ||||||
| 	// storage with keys. |  | ||||||
| 	if err := rotater.rotate(); err != nil { | 	if err := rotater.rotate(); err != nil { | ||||||
| 		log.Printf("failed to rotate keys: %v", err) | 		log.Printf("failed to rotate keys: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	go func() { | 	go func() { | ||||||
|  | 		for { | ||||||
| 			select { | 			select { | ||||||
| 			case <-ctx.Done(): | 			case <-ctx.Done(): | ||||||
| 				return | 				return | ||||||
| 		case <-time.After(time.Second * 30): | 			case <-time.After(strategy.period): | ||||||
| 				if err := rotater.rotate(); err != nil { | 				if err := rotater.rotate(); err != nil { | ||||||
| 					log.Printf("failed to rotate keys: %v", err) | 					log.Printf("failed to rotate keys: %v", err) | ||||||
| 				} | 				} | ||||||
| 			} | 			} | ||||||
|  | 		} | ||||||
| 	}() | 	}() | ||||||
| 	return rotater | 	return | ||||||
| } |  | ||||||
|  |  | ||||||
| func (k keyRotater) Close() error { |  | ||||||
| 	k.cancel() |  | ||||||
| 	return k.Storage.Close() |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (k keyRotater) rotate() error { | func (k keyRotater) rotate() error { | ||||||
|   | |||||||
| @@ -11,6 +11,7 @@ import ( | |||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
|  | 	"golang.org/x/net/context" | ||||||
|  |  | ||||||
| 	"github.com/gorilla/mux" | 	"github.com/gorilla/mux" | ||||||
|  |  | ||||||
| @@ -48,6 +49,8 @@ type Config struct { | |||||||
| 	RotateKeysAfter  time.Duration // Defaults to 6 hours. | 	RotateKeysAfter  time.Duration // Defaults to 6 hours. | ||||||
| 	IDTokensValidFor time.Duration // Defaults to 24 hours | 	IDTokensValidFor time.Duration // Defaults to 24 hours | ||||||
|  |  | ||||||
|  | 	GCFrequency time.Duration // Defaults to 5 minutes | ||||||
|  |  | ||||||
| 	// If specified, the server will use this function for determining time. | 	// If specified, the server will use this function for determining time. | ||||||
| 	Now func() time.Time | 	Now func() time.Time | ||||||
|  |  | ||||||
| @@ -87,14 +90,14 @@ type Server struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| // NewServer constructs a server from the provided config. | // NewServer constructs a server from the provided config. | ||||||
| func NewServer(c Config) (*Server, error) { | func NewServer(ctx context.Context, c Config) (*Server, error) { | ||||||
| 	return newServer(c, defaultRotationStrategy( | 	return newServer(ctx, c, defaultRotationStrategy( | ||||||
| 		value(c.RotateKeysAfter, 6*time.Hour), | 		value(c.RotateKeysAfter, 6*time.Hour), | ||||||
| 		value(c.IDTokensValidFor, 24*time.Hour), | 		value(c.IDTokensValidFor, 24*time.Hour), | ||||||
| 	)) | 	)) | ||||||
| } | } | ||||||
|  |  | ||||||
| func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { | func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) (*Server, error) { | ||||||
| 	issuerURL, err := url.Parse(c.Issuer) | 	issuerURL, err := url.Parse(c.Issuer) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("server: can't parse issuer URL") | 		return nil, fmt.Errorf("server: can't parse issuer URL") | ||||||
| @@ -140,12 +143,7 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { | |||||||
| 	s := &Server{ | 	s := &Server{ | ||||||
| 		issuerURL:              *issuerURL, | 		issuerURL:              *issuerURL, | ||||||
| 		connectors:             make(map[string]Connector), | 		connectors:             make(map[string]Connector), | ||||||
| 		storage: newKeyCacher( | 		storage:                newKeyCacher(c.Storage, now), | ||||||
| 			storageWithKeyRotation( |  | ||||||
| 				c.Storage, rotationStrategy, now, |  | ||||||
| 			), |  | ||||||
| 			now, |  | ||||||
| 		), |  | ||||||
| 		supportedResponseTypes: supported, | 		supportedResponseTypes: supported, | ||||||
| 		idTokensValidFor:       value(c.IDTokensValidFor, 24*time.Hour), | 		idTokensValidFor:       value(c.IDTokensValidFor, 24*time.Hour), | ||||||
| 		skipApproval:           c.SkipApprovalScreen, | 		skipApproval:           c.SkipApprovalScreen, | ||||||
| @@ -179,6 +177,9 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { | |||||||
| 	handleFunc("/healthz", s.handleHealth) | 	handleFunc("/healthz", s.handleHealth) | ||||||
| 	s.mux = r | 	s.mux = r | ||||||
|  |  | ||||||
|  | 	startKeyRotation(ctx, c.Storage, rotationStrategy, now) | ||||||
|  | 	startGarbageCollection(ctx, c.Storage, value(c.GCFrequency, 5*time.Minute), now) | ||||||
|  |  | ||||||
| 	return s, nil | 	return s, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -262,3 +263,21 @@ func (k *keyCacher) GetKeys() (storage.Keys, error) { | |||||||
| 	} | 	} | ||||||
| 	return storageKeys, nil | 	return storageKeys, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func startGarbageCollection(ctx context.Context, s storage.Storage, frequency time.Duration, now func() time.Time) { | ||||||
|  | 	go func() { | ||||||
|  | 		for { | ||||||
|  | 			select { | ||||||
|  | 			case <-ctx.Done(): | ||||||
|  | 				return | ||||||
|  | 			case <-time.After(frequency): | ||||||
|  | 				if r, err := s.GarbageCollect(now()); err != nil { | ||||||
|  | 					log.Printf("garbage collection failed: %v", err) | ||||||
|  | 				} else { | ||||||
|  | 					log.Printf("garbage collection run, delete auth requests=%d, auth codes=%d", r.AuthRequests, r.AuthCodes) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  | 	return | ||||||
|  | } | ||||||
|   | |||||||
| @@ -69,7 +69,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ | |||||||
| Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= | Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= | ||||||
| -----END RSA PRIVATE KEY-----`) | -----END RSA PRIVATE KEY-----`) | ||||||
|  |  | ||||||
| func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { | func newTestServer(t *testing.T, ctx context.Context, updateConfig func(c *Config)) (*httptest.Server, *Server) { | ||||||
| 	var server *Server | 	var server *Server | ||||||
| 	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		server.ServeHTTP(w, r) | 		server.ServeHTTP(w, r) | ||||||
| @@ -91,7 +91,7 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server | |||||||
| 	s.URL = config.Issuer | 	s.URL = config.Issuer | ||||||
|  |  | ||||||
| 	var err error | 	var err error | ||||||
| 	if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil { | 	if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil { | ||||||
| 		t.Fatal(err) | 		t.Fatal(err) | ||||||
| 	} | 	} | ||||||
| 	server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. | 	server.skipApproval = true // Don't prompt for approval, just immediately redirect with code. | ||||||
| @@ -99,14 +99,16 @@ func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server | |||||||
| } | } | ||||||
|  |  | ||||||
| func TestNewTestServer(t *testing.T) { | func TestNewTestServer(t *testing.T) { | ||||||
| 	newTestServer(t, nil) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
|  | 	defer cancel() | ||||||
|  | 	newTestServer(t, ctx, nil) | ||||||
| } | } | ||||||
|  |  | ||||||
| func TestDiscovery(t *testing.T) { | func TestDiscovery(t *testing.T) { | ||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 	defer cancel() | 	defer cancel() | ||||||
|  |  | ||||||
| 	httpServer, _ := newTestServer(t, func(c *Config) { | 	httpServer, _ := newTestServer(t, ctx, func(c *Config) { | ||||||
| 		c.Issuer = c.Issuer + "/non-root-path" | 		c.Issuer = c.Issuer + "/non-root-path" | ||||||
| 	}) | 	}) | ||||||
| 	defer httpServer.Close() | 	defer httpServer.Close() | ||||||
| @@ -227,7 +229,7 @@ func TestOAuth2CodeFlow(t *testing.T) { | |||||||
| 			ctx, cancel := context.WithCancel(context.Background()) | 			ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 			defer cancel() | 			defer cancel() | ||||||
|  |  | ||||||
| 			httpServer, s := newTestServer(t, func(c *Config) { | 			httpServer, s := newTestServer(t, ctx, func(c *Config) { | ||||||
| 				c.Issuer = c.Issuer + "/non-root-path" | 				c.Issuer = c.Issuer + "/non-root-path" | ||||||
| 			}) | 			}) | ||||||
| 			defer httpServer.Close() | 			defer httpServer.Close() | ||||||
| @@ -340,7 +342,7 @@ func TestOAuth2ImplicitFlow(t *testing.T) { | |||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 	defer cancel() | 	defer cancel() | ||||||
|  |  | ||||||
| 	httpServer, s := newTestServer(t, func(c *Config) { | 	httpServer, s := newTestServer(t, ctx, func(c *Config) { | ||||||
| 		// Enable support for the implicit flow. | 		// Enable support for the implicit flow. | ||||||
| 		c.SupportedResponseTypes = []string{"code", "token"} | 		c.SupportedResponseTypes = []string{"code", "token"} | ||||||
| 	}) | 	}) | ||||||
| @@ -470,7 +472,7 @@ func TestCrossClientScopes(t *testing.T) { | |||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
| 	defer cancel() | 	defer cancel() | ||||||
|  |  | ||||||
| 	httpServer, s := newTestServer(t, func(c *Config) { | 	httpServer, s := newTestServer(t, ctx, func(c *Config) { | ||||||
| 		c.Issuer = c.Issuer + "/non-root-path" | 		c.Issuer = c.Issuer + "/non-root-path" | ||||||
| 	}) | 	}) | ||||||
| 	defer httpServer.Close() | 	defer httpServer.Close() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user