diff --git a/server/server.go b/server/server.go index 778054fa..1bec7b69 100644 --- a/server/server.go +++ b/server/server.go @@ -6,6 +6,7 @@ import ( "net/http" "net/url" "path" + "sync/atomic" "time" "github.com/gorilla/mux" @@ -93,9 +94,14 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) { } s := &Server{ - issuerURL: *issuerURL, - connectors: make(map[string]Connector), - storage: storageWithKeyRotation(c.Storage, rotationStrategy, now), + issuerURL: *issuerURL, + connectors: make(map[string]Connector), + storage: newKeyCacher( + storageWithKeyRotation( + c.Storage, rotationStrategy, now, + ), + now, + ), idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), now: now, } @@ -139,3 +145,35 @@ func (s *Server) absURL(pathItems ...string) string { u.Path = s.absPath(pathItems...) return u.String() } + +// newKeyCacher returns a storage which caches keys so long as the next +func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage { + if now == nil { + now = time.Now + } + return &keyCacher{Storage: s, now: now} +} + +type keyCacher struct { + storage.Storage + + now func() time.Time + keys atomic.Value // Always holds nil or type *storage.Keys. +} + +func (k *keyCacher) GetKeys() (storage.Keys, error) { + keys, ok := k.keys.Load().(*storage.Keys) + if ok && keys != nil && k.now().Before(keys.NextRotation) { + return *keys, nil + } + + storageKeys, err := k.Storage.GetKeys() + if err != nil { + return storageKeys, err + } + + if k.now().Before(storageKeys.NextRotation) { + k.keys.Store(&storageKeys) + } + return storageKeys, nil +} diff --git a/server/server_test.go b/server/server_test.go index 6cc1d80f..a542421b 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -219,3 +219,74 @@ func TestOAuth2Flow(t *testing.T) { t.Fatal(err) } } + +type storageWithKeysTrigger struct { + storage.Storage + f func() +} + +func (s storageWithKeysTrigger) GetKeys() (storage.Keys, error) { + s.f() + return s.Storage.GetKeys() +} + +func TestKeyCacher(t *testing.T) { + tNow := time.Now() + now := func() time.Time { return tNow } + + s := memory.New() + + tests := []struct { + before func() + wantCallToStorage bool + }{ + { + before: func() {}, + wantCallToStorage: true, + }, + { + before: func() { + s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) { + old.NextRotation = tNow.Add(time.Minute) + return old, nil + }) + }, + wantCallToStorage: true, + }, + { + before: func() {}, + wantCallToStorage: false, + }, + { + before: func() { + tNow = tNow.Add(time.Hour) + }, + wantCallToStorage: true, + }, + { + before: func() { + tNow = tNow.Add(time.Hour) + s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) { + old.NextRotation = tNow.Add(time.Minute) + return old, nil + }) + }, + wantCallToStorage: true, + }, + { + before: func() {}, + wantCallToStorage: false, + }, + } + + gotCall := false + s = newKeyCacher(storageWithKeysTrigger{s, func() { gotCall = true }}, now) + for i, tc := range tests { + gotCall = false + tc.before() + s.GetKeys() + if gotCall != tc.wantCallToStorage { + t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall) + } + } +}