server: cache signing keys

This commit is contained in:
Eric Chiang 2016-08-10 20:51:58 -07:00
parent d313e5d493
commit 4cbe9bbc82
2 changed files with 112 additions and 3 deletions

View File

@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"path" "path"
"sync/atomic"
"time" "time"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -93,9 +94,14 @@ func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
} }
s := &Server{ s := &Server{
issuerURL: *issuerURL, issuerURL: *issuerURL,
connectors: make(map[string]Connector), connectors: make(map[string]Connector),
storage: storageWithKeyRotation(c.Storage, rotationStrategy, now), storage: newKeyCacher(
storageWithKeyRotation(
c.Storage, rotationStrategy, now,
),
now,
),
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
now: now, now: now,
} }
@ -139,3 +145,35 @@ func (s *Server) absURL(pathItems ...string) string {
u.Path = s.absPath(pathItems...) u.Path = s.absPath(pathItems...)
return u.String() return u.String()
} }
// newKeyCacher returns a storage which caches keys so long as the next
func newKeyCacher(s storage.Storage, now func() time.Time) storage.Storage {
if now == nil {
now = time.Now
}
return &keyCacher{Storage: s, now: now}
}
type keyCacher struct {
storage.Storage
now func() time.Time
keys atomic.Value // Always holds nil or type *storage.Keys.
}
func (k *keyCacher) GetKeys() (storage.Keys, error) {
keys, ok := k.keys.Load().(*storage.Keys)
if ok && keys != nil && k.now().Before(keys.NextRotation) {
return *keys, nil
}
storageKeys, err := k.Storage.GetKeys()
if err != nil {
return storageKeys, err
}
if k.now().Before(storageKeys.NextRotation) {
k.keys.Store(&storageKeys)
}
return storageKeys, nil
}

View File

@ -219,3 +219,74 @@ func TestOAuth2Flow(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
type storageWithKeysTrigger struct {
storage.Storage
f func()
}
func (s storageWithKeysTrigger) GetKeys() (storage.Keys, error) {
s.f()
return s.Storage.GetKeys()
}
func TestKeyCacher(t *testing.T) {
tNow := time.Now()
now := func() time.Time { return tNow }
s := memory.New()
tests := []struct {
before func()
wantCallToStorage bool
}{
{
before: func() {},
wantCallToStorage: true,
},
{
before: func() {
s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
old.NextRotation = tNow.Add(time.Minute)
return old, nil
})
},
wantCallToStorage: true,
},
{
before: func() {},
wantCallToStorage: false,
},
{
before: func() {
tNow = tNow.Add(time.Hour)
},
wantCallToStorage: true,
},
{
before: func() {
tNow = tNow.Add(time.Hour)
s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
old.NextRotation = tNow.Add(time.Minute)
return old, nil
})
},
wantCallToStorage: true,
},
{
before: func() {},
wantCallToStorage: false,
},
}
gotCall := false
s = newKeyCacher(storageWithKeysTrigger{s, func() { gotCall = true }}, now)
for i, tc := range tests {
gotCall = false
tc.before()
s.GetKeys()
if gotCall != tc.wantCallToStorage {
t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall)
}
}
}