189 lines
4.9 KiB
Go
189 lines
4.9 KiB
Go
|
package oidc
|
||
|
|
||
|
import (
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"net/http"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/pquerna/cachecontrol"
|
||
|
"golang.org/x/net/context"
|
||
|
jose "gopkg.in/square/go-jose.v1"
|
||
|
)
|
||
|
|
||
|
// No matter what insist on caching keys. This is so our request code can be
|
||
|
// asynchronous from matching keys. If the request code retrieved keys that
|
||
|
// expired immediately, the goroutine to match a JWT to a key would always see
|
||
|
// expired keys.
|
||
|
//
|
||
|
// TODO(ericchiang): Review this logic.
|
||
|
var minCache = 2 * time.Minute
|
||
|
|
||
|
type cachedKeys struct {
|
||
|
keys map[string]jose.JsonWebKey // immutable
|
||
|
expiry time.Time
|
||
|
}
|
||
|
|
||
|
type remoteKeySet struct {
|
||
|
client *http.Client
|
||
|
|
||
|
// "jwks_uri" from discovery.
|
||
|
keysURL string
|
||
|
|
||
|
// The value is always of type *cachedKeys.
|
||
|
//
|
||
|
// To ensure consistency always call keyCache.Store when holding cond.L.
|
||
|
keyCache atomic.Value
|
||
|
|
||
|
// cond.L guards all following fields. sync.Cond is used in place of a mutex
|
||
|
// so multiple processes can wait on a single request to update keys.
|
||
|
cond sync.Cond
|
||
|
// Is there an existing request to get the remote keys?
|
||
|
inflight bool
|
||
|
// If the last attempt to refresh keys failed, the error will be saved here.
|
||
|
//
|
||
|
// TODO(ericchiang): If a routine sets this before calling cond.Broadcast(),
|
||
|
// there's no guarentee that a routine calling cond.Wait() will actual see
|
||
|
// the error called by the previous routine. Since Broadcast() unlocks
|
||
|
// cond.L and Wait() must reacquire the lock, other routines waiting on the
|
||
|
// lock might acquire it first. Maybe just log the error?
|
||
|
lastErr error
|
||
|
}
|
||
|
|
||
|
func newRemoteKeySet(ctx context.Context, jwksURL string) *remoteKeySet {
|
||
|
r := &remoteKeySet{
|
||
|
client: contextClient(ctx),
|
||
|
keysURL: jwksURL,
|
||
|
cond: sync.Cond{L: new(sync.Mutex)},
|
||
|
}
|
||
|
return r
|
||
|
}
|
||
|
|
||
|
func (r *remoteKeySet) verifyJWT(jwt string) (payload []byte, err error) {
|
||
|
jws, err := jose.ParseSigned(jwt)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("parsing jwt: %v", err)
|
||
|
}
|
||
|
keyIDs := make([]string, len(jws.Signatures))
|
||
|
for i, signature := range jws.Signatures {
|
||
|
keyIDs[i] = signature.Header.KeyID
|
||
|
}
|
||
|
key, err := r.getKey(keyIDs)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("oidc: %s", err)
|
||
|
}
|
||
|
return jws.Verify(key)
|
||
|
}
|
||
|
|
||
|
func (r *remoteKeySet) getKeyFromCache(keyIDs []string) (*jose.JsonWebKey, bool) {
|
||
|
cachedKeys, ok := r.keyCache.Load().(*cachedKeys)
|
||
|
if !ok {
|
||
|
return nil, false
|
||
|
}
|
||
|
if time.Now().After(cachedKeys.expiry) {
|
||
|
return nil, false
|
||
|
}
|
||
|
for _, keyID := range keyIDs {
|
||
|
if key, ok := cachedKeys.keys[keyID]; ok {
|
||
|
return &key, true
|
||
|
}
|
||
|
}
|
||
|
return nil, false
|
||
|
}
|
||
|
|
||
|
func (r *remoteKeySet) getKey(keyIDs []string) (*jose.JsonWebKey, error) {
|
||
|
// Fast path. Just do an atomic load.
|
||
|
if key, ok := r.getKeyFromCache(keyIDs); ok {
|
||
|
return key, nil
|
||
|
}
|
||
|
|
||
|
// Didn't find keys, use the slow path.
|
||
|
r.cond.L.Lock()
|
||
|
defer r.cond.L.Unlock()
|
||
|
|
||
|
// Check again within the mutex.
|
||
|
if key, ok := r.getKeyFromCache(keyIDs); ok {
|
||
|
return key, nil
|
||
|
}
|
||
|
|
||
|
// Keys have expired or we're trying to verify a JWT we don't have a key for.
|
||
|
|
||
|
if !r.inflight {
|
||
|
// There isn't currently an inflight request to update keys, start a
|
||
|
// goroutine to do so.
|
||
|
r.inflight = true
|
||
|
go func() {
|
||
|
newKeys, newExpiry, err := requestKeys(r.client, r.keysURL)
|
||
|
|
||
|
r.cond.L.Lock()
|
||
|
defer r.cond.L.Unlock()
|
||
|
|
||
|
r.inflight = false
|
||
|
if err != nil {
|
||
|
r.lastErr = err
|
||
|
} else {
|
||
|
r.keyCache.Store(&cachedKeys{newKeys, newExpiry})
|
||
|
r.lastErr = nil
|
||
|
}
|
||
|
|
||
|
r.cond.Broadcast() // Wake all r.cond.Wait() calls.
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
// Wait for r.cond.Broadcast() to be called. This unlocks r.cond.L and
|
||
|
// reacquires it after its done waiting.
|
||
|
r.cond.Wait()
|
||
|
|
||
|
if key, ok := r.getKeyFromCache(keyIDs); ok {
|
||
|
return key, nil
|
||
|
}
|
||
|
if r.lastErr != nil {
|
||
|
return nil, r.lastErr
|
||
|
}
|
||
|
return nil, errors.New("no signing keys can validate the signature")
|
||
|
}
|
||
|
|
||
|
func requestKeys(client *http.Client, keysURL string) (map[string]jose.JsonWebKey, time.Time, error) {
|
||
|
req, err := http.NewRequest("GET", keysURL, nil)
|
||
|
if err != nil {
|
||
|
return nil, time.Time{}, fmt.Errorf("can't create request: %v", err)
|
||
|
}
|
||
|
resp, err := client.Do(req)
|
||
|
if err != nil {
|
||
|
return nil, time.Time{}, fmt.Errorf("can't GET new keys %v", err)
|
||
|
}
|
||
|
defer resp.Body.Close()
|
||
|
|
||
|
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||
|
if err != nil {
|
||
|
return nil, time.Time{}, fmt.Errorf("can't fetch new keys: %v", err)
|
||
|
}
|
||
|
if resp.StatusCode != http.StatusOK {
|
||
|
return nil, time.Time{}, fmt.Errorf("can't fetch new keys: %s %s", resp.Status, body)
|
||
|
}
|
||
|
|
||
|
var keySet jose.JsonWebKeySet
|
||
|
if err := json.Unmarshal(body, &keySet); err != nil {
|
||
|
return nil, time.Time{}, fmt.Errorf("can't decode keys: %v %s", err, body)
|
||
|
}
|
||
|
|
||
|
keys := make(map[string]jose.JsonWebKey, len(keySet.Keys))
|
||
|
for _, key := range keySet.Keys {
|
||
|
keys[key.KeyID] = key
|
||
|
}
|
||
|
|
||
|
minExpiry := time.Now().Add(minCache)
|
||
|
|
||
|
if _, expiry, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{}); err == nil {
|
||
|
if minExpiry.Before(expiry) {
|
||
|
return keys, expiry, nil
|
||
|
}
|
||
|
}
|
||
|
return keys, minExpiry, nil
|
||
|
}
|