package oidc import ( "context" "encoding/json" "fmt" "io/ioutil" "net/http" "sync" "time" "github.com/pquerna/cachecontrol" jose "gopkg.in/square/go-jose.v2" ) // keysExpiryDelta is the allowed clock skew between a client and the OpenID Connect // server. // // When keys expire, they are valid for this amount of time after. // // If the keys have not expired, and an ID Token claims it was signed by a key not in // the cache, if and only if the keys expire in this amount of time, the keys will be // updated. const keysExpiryDelta = 30 * time.Second func newRemoteKeySet(ctx context.Context, jwksURL string, now func() time.Time) *remoteKeySet { if now == nil { now = time.Now } return &remoteKeySet{jwksURL: jwksURL, ctx: ctx, now: now} } type remoteKeySet struct { jwksURL string ctx context.Context now func() time.Time // guard all other fields mu sync.Mutex // inflightCtx suppresses parallel execution of updateKeys and allows // multiple goroutines to wait for its result. // Its Err() method returns any errors encountered during updateKeys. // // If nil, there is no inflight updateKeys request. inflightCtx *inflight // A set of cached keys and their expiry. cachedKeys []jose.JSONWebKey expiry time.Time } // inflight is used to wait on some in-flight request from multiple goroutines type inflight struct { done chan struct{} err error } // Done returns a channel that is closed when the inflight request finishes. func (i *inflight) Done() <-chan struct{} { return i.done } // Err returns any error encountered during request execution. May be nil. func (i *inflight) Err() error { return i.err } // Cancel signals completion of the inflight request with error err. // Must be called only once for particular inflight instance. func (i *inflight) Cancel(err error) { i.err = err close(i.done) } func (r *remoteKeySet) keysWithIDFromCache(keyIDs []string) ([]jose.JSONWebKey, bool) { r.mu.Lock() keys, expiry := r.cachedKeys, r.expiry r.mu.Unlock() // Have the keys expired? if expiry.Add(keysExpiryDelta).Before(r.now()) { return nil, false } var signingKeys []jose.JSONWebKey for _, key := range keys { if contains(keyIDs, key.KeyID) { signingKeys = append(signingKeys, key) } } if len(signingKeys) == 0 { // Are the keys about to expire? if r.now().Add(keysExpiryDelta).After(expiry) { return nil, false } } return signingKeys, true } func (r *remoteKeySet) keysWithID(ctx context.Context, keyIDs []string) ([]jose.JSONWebKey, error) { keys, ok := r.keysWithIDFromCache(keyIDs) if ok { return keys, nil } var inflightCtx *inflight func() { r.mu.Lock() defer r.mu.Unlock() // If there's not a current inflight request, create one. if r.inflightCtx == nil { inflightCtx := &inflight{make(chan struct{}), nil} r.inflightCtx = inflightCtx go func() { // TODO(ericchiang): Upstream Kubernetes request that we recover every time // we spawn a goroutine, because panics in a goroutine will bring down the // entire program. There's no way to recover from another goroutine's panic. // // Most users actually want to let the panic propagate and bring down the // program because it implies some unrecoverable state. // // Add a context key to allow the recover behavior. // // See: https://github.com/coreos/go-oidc/issues/89 // Sync keys and close inflightCtx when that's done. // Use the remoteKeySet's context instead of the requests context // because a re-sync is unique to the keys set and will span multiple // requests. inflightCtx.Cancel(r.updateKeys(r.ctx)) r.mu.Lock() defer r.mu.Unlock() r.inflightCtx = nil }() } inflightCtx = r.inflightCtx }() select { case <-ctx.Done(): return nil, ctx.Err() case <-inflightCtx.Done(): if err := inflightCtx.Err(); err != nil { return nil, err } } // Since we've just updated keys, we don't care about the cache miss. keys, _ = r.keysWithIDFromCache(keyIDs) return keys, nil } func (r *remoteKeySet) updateKeys(ctx context.Context) error { req, err := http.NewRequest("GET", r.jwksURL, nil) if err != nil { return fmt.Errorf("oidc: can't create request: %v", err) } resp, err := doRequest(ctx, req) if err != nil { return fmt.Errorf("oidc: get keys failed %v", err) } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { return fmt.Errorf("oidc: read response body: %v", err) } if resp.StatusCode != http.StatusOK { return fmt.Errorf("oidc: get keys failed: %s %s", resp.Status, body) } var keySet jose.JSONWebKeySet if err := json.Unmarshal(body, &keySet); err != nil { return fmt.Errorf("oidc: failed to decode keys: %v %s", err, body) } // If the server doesn't provide cache control headers, assume the // keys expire immediately. expiry := r.now() _, e, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{}) if err == nil && e.After(expiry) { expiry = e } r.mu.Lock() defer r.mu.Unlock() r.cachedKeys = keySet.Keys r.expiry = expiry return nil }