initial commit
This commit is contained in:
188
vendor/github.com/ericchiang/oidc/jwks.go
generated
vendored
Normal file
188
vendor/github.com/ericchiang/oidc/jwks.go
generated
vendored
Normal file
@@ -0,0 +1,188 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pquerna/cachecontrol"
|
||||
"golang.org/x/net/context"
|
||||
jose "gopkg.in/square/go-jose.v1"
|
||||
)
|
||||
|
||||
// No matter what insist on caching keys. This is so our request code can be
|
||||
// asynchronous from matching keys. If the request code retrieved keys that
|
||||
// expired immediately, the goroutine to match a JWT to a key would always see
|
||||
// expired keys.
|
||||
//
|
||||
// TODO(ericchiang): Review this logic.
|
||||
var minCache = 2 * time.Minute
|
||||
|
||||
type cachedKeys struct {
|
||||
keys map[string]jose.JsonWebKey // immutable
|
||||
expiry time.Time
|
||||
}
|
||||
|
||||
type remoteKeySet struct {
|
||||
client *http.Client
|
||||
|
||||
// "jwks_uri" from discovery.
|
||||
keysURL string
|
||||
|
||||
// The value is always of type *cachedKeys.
|
||||
//
|
||||
// To ensure consistency always call keyCache.Store when holding cond.L.
|
||||
keyCache atomic.Value
|
||||
|
||||
// cond.L guards all following fields. sync.Cond is used in place of a mutex
|
||||
// so multiple processes can wait on a single request to update keys.
|
||||
cond sync.Cond
|
||||
// Is there an existing request to get the remote keys?
|
||||
inflight bool
|
||||
// If the last attempt to refresh keys failed, the error will be saved here.
|
||||
//
|
||||
// TODO(ericchiang): If a routine sets this before calling cond.Broadcast(),
|
||||
// there's no guarentee that a routine calling cond.Wait() will actual see
|
||||
// the error called by the previous routine. Since Broadcast() unlocks
|
||||
// cond.L and Wait() must reacquire the lock, other routines waiting on the
|
||||
// lock might acquire it first. Maybe just log the error?
|
||||
lastErr error
|
||||
}
|
||||
|
||||
func newRemoteKeySet(ctx context.Context, jwksURL string) *remoteKeySet {
|
||||
r := &remoteKeySet{
|
||||
client: contextClient(ctx),
|
||||
keysURL: jwksURL,
|
||||
cond: sync.Cond{L: new(sync.Mutex)},
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) verifyJWT(jwt string) (payload []byte, err error) {
|
||||
jws, err := jose.ParseSigned(jwt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing jwt: %v", err)
|
||||
}
|
||||
keyIDs := make([]string, len(jws.Signatures))
|
||||
for i, signature := range jws.Signatures {
|
||||
keyIDs[i] = signature.Header.KeyID
|
||||
}
|
||||
key, err := r.getKey(keyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("oidc: %s", err)
|
||||
}
|
||||
return jws.Verify(key)
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) getKeyFromCache(keyIDs []string) (*jose.JsonWebKey, bool) {
|
||||
cachedKeys, ok := r.keyCache.Load().(*cachedKeys)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if time.Now().After(cachedKeys.expiry) {
|
||||
return nil, false
|
||||
}
|
||||
for _, keyID := range keyIDs {
|
||||
if key, ok := cachedKeys.keys[keyID]; ok {
|
||||
return &key, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (r *remoteKeySet) getKey(keyIDs []string) (*jose.JsonWebKey, error) {
|
||||
// Fast path. Just do an atomic load.
|
||||
if key, ok := r.getKeyFromCache(keyIDs); ok {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Didn't find keys, use the slow path.
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
// Check again within the mutex.
|
||||
if key, ok := r.getKeyFromCache(keyIDs); ok {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Keys have expired or we're trying to verify a JWT we don't have a key for.
|
||||
|
||||
if !r.inflight {
|
||||
// There isn't currently an inflight request to update keys, start a
|
||||
// goroutine to do so.
|
||||
r.inflight = true
|
||||
go func() {
|
||||
newKeys, newExpiry, err := requestKeys(r.client, r.keysURL)
|
||||
|
||||
r.cond.L.Lock()
|
||||
defer r.cond.L.Unlock()
|
||||
|
||||
r.inflight = false
|
||||
if err != nil {
|
||||
r.lastErr = err
|
||||
} else {
|
||||
r.keyCache.Store(&cachedKeys{newKeys, newExpiry})
|
||||
r.lastErr = nil
|
||||
}
|
||||
|
||||
r.cond.Broadcast() // Wake all r.cond.Wait() calls.
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for r.cond.Broadcast() to be called. This unlocks r.cond.L and
|
||||
// reacquires it after its done waiting.
|
||||
r.cond.Wait()
|
||||
|
||||
if key, ok := r.getKeyFromCache(keyIDs); ok {
|
||||
return key, nil
|
||||
}
|
||||
if r.lastErr != nil {
|
||||
return nil, r.lastErr
|
||||
}
|
||||
return nil, errors.New("no signing keys can validate the signature")
|
||||
}
|
||||
|
||||
func requestKeys(client *http.Client, keysURL string) (map[string]jose.JsonWebKey, time.Time, error) {
|
||||
req, err := http.NewRequest("GET", keysURL, nil)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, fmt.Errorf("can't create request: %v", err)
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, time.Time{}, fmt.Errorf("can't GET new keys %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return nil, time.Time{}, fmt.Errorf("can't fetch new keys: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, time.Time{}, fmt.Errorf("can't fetch new keys: %s %s", resp.Status, body)
|
||||
}
|
||||
|
||||
var keySet jose.JsonWebKeySet
|
||||
if err := json.Unmarshal(body, &keySet); err != nil {
|
||||
return nil, time.Time{}, fmt.Errorf("can't decode keys: %v %s", err, body)
|
||||
}
|
||||
|
||||
keys := make(map[string]jose.JsonWebKey, len(keySet.Keys))
|
||||
for _, key := range keySet.Keys {
|
||||
keys[key.KeyID] = key
|
||||
}
|
||||
|
||||
minExpiry := time.Now().Add(minCache)
|
||||
|
||||
if _, expiry, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{}); err == nil {
|
||||
if minExpiry.Before(expiry) {
|
||||
return keys, expiry, nil
|
||||
}
|
||||
}
|
||||
return keys, minExpiry, nil
|
||||
}
|
||||
Reference in New Issue
Block a user