package oidc

import (
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"sync"
	"sync/atomic"
	"time"

	"github.com/pquerna/cachecontrol"
	"golang.org/x/net/context"
	jose "gopkg.in/square/go-jose.v1"
)

// No matter what insist on caching keys. This is so our request code can be
// asynchronous from matching keys. If the request code retrieved keys that
// expired immediately, the goroutine to match a JWT to a key would always see
// expired keys.
//
// TODO(ericchiang): Review this logic.
var minCache = 2 * time.Minute

type cachedKeys struct {
	keys   map[string]jose.JsonWebKey // immutable
	expiry time.Time
}

type remoteKeySet struct {
	client *http.Client

	// "jwks_uri" from discovery.
	keysURL string

	// The value is always of type *cachedKeys.
	//
	// To ensure consistency always call keyCache.Store when holding cond.L.
	keyCache atomic.Value

	// cond.L guards all following fields. sync.Cond is used in place of a mutex
	// so multiple processes can wait on a single request to update keys.
	cond sync.Cond
	// Is there an existing request to get the remote keys?
	inflight bool
	// If the last attempt to refresh keys failed, the error will be saved here.
	//
	// TODO(ericchiang): If a routine sets this before calling cond.Broadcast(),
	// there's no guarentee that a routine calling cond.Wait() will actual see
	// the error called by the previous routine. Since Broadcast() unlocks
	// cond.L and Wait() must reacquire the lock, other routines waiting on the
	// lock might acquire it first. Maybe just log the error?
	lastErr error
}

func newRemoteKeySet(ctx context.Context, jwksURL string) *remoteKeySet {
	r := &remoteKeySet{
		client:  contextClient(ctx),
		keysURL: jwksURL,
		cond:    sync.Cond{L: new(sync.Mutex)},
	}
	return r
}

func (r *remoteKeySet) verifyJWT(jwt string) (payload []byte, err error) {
	jws, err := jose.ParseSigned(jwt)
	if err != nil {
		return nil, fmt.Errorf("parsing jwt: %v", err)
	}
	keyIDs := make([]string, len(jws.Signatures))
	for i, signature := range jws.Signatures {
		keyIDs[i] = signature.Header.KeyID
	}
	key, err := r.getKey(keyIDs)
	if err != nil {
		return nil, fmt.Errorf("oidc: %s", err)
	}
	return jws.Verify(key)
}

func (r *remoteKeySet) getKeyFromCache(keyIDs []string) (*jose.JsonWebKey, bool) {
	cachedKeys, ok := r.keyCache.Load().(*cachedKeys)
	if !ok {
		return nil, false
	}
	if time.Now().After(cachedKeys.expiry) {
		return nil, false
	}
	for _, keyID := range keyIDs {
		if key, ok := cachedKeys.keys[keyID]; ok {
			return &key, true
		}
	}
	return nil, false
}

func (r *remoteKeySet) getKey(keyIDs []string) (*jose.JsonWebKey, error) {
	// Fast path. Just do an atomic load.
	if key, ok := r.getKeyFromCache(keyIDs); ok {
		return key, nil
	}

	// Didn't find keys, use the slow path.
	r.cond.L.Lock()
	defer r.cond.L.Unlock()

	// Check again within the mutex.
	if key, ok := r.getKeyFromCache(keyIDs); ok {
		return key, nil
	}

	// Keys have expired or we're trying to verify a JWT we don't have a key for.

	if !r.inflight {
		// There isn't currently an inflight request to update keys, start a
		// goroutine to do so.
		r.inflight = true
		go func() {
			newKeys, newExpiry, err := requestKeys(r.client, r.keysURL)

			r.cond.L.Lock()
			defer r.cond.L.Unlock()

			r.inflight = false
			if err != nil {
				r.lastErr = err
			} else {
				r.keyCache.Store(&cachedKeys{newKeys, newExpiry})
				r.lastErr = nil
			}

			r.cond.Broadcast() // Wake all r.cond.Wait() calls.
		}()
	}

	// Wait for r.cond.Broadcast() to be called. This unlocks r.cond.L and
	// reacquires it after its done waiting.
	r.cond.Wait()

	if key, ok := r.getKeyFromCache(keyIDs); ok {
		return key, nil
	}
	if r.lastErr != nil {
		return nil, r.lastErr
	}
	return nil, errors.New("no signing keys can validate the signature")
}

func requestKeys(client *http.Client, keysURL string) (map[string]jose.JsonWebKey, time.Time, error) {
	req, err := http.NewRequest("GET", keysURL, nil)
	if err != nil {
		return nil, time.Time{}, fmt.Errorf("can't create request: %v", err)
	}
	resp, err := client.Do(req)
	if err != nil {
		return nil, time.Time{}, fmt.Errorf("can't GET new keys %v", err)
	}
	defer resp.Body.Close()

	body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
	if err != nil {
		return nil, time.Time{}, fmt.Errorf("can't fetch new keys: %v", err)
	}
	if resp.StatusCode != http.StatusOK {
		return nil, time.Time{}, fmt.Errorf("can't fetch new keys: %s %s", resp.Status, body)
	}

	var keySet jose.JsonWebKeySet
	if err := json.Unmarshal(body, &keySet); err != nil {
		return nil, time.Time{}, fmt.Errorf("can't decode keys: %v %s", err, body)
	}

	keys := make(map[string]jose.JsonWebKey, len(keySet.Keys))
	for _, key := range keySet.Keys {
		keys[key.KeyID] = key
	}

	minExpiry := time.Now().Add(minCache)

	if _, expiry, err := cachecontrol.CachableResponse(req, resp, cachecontrol.Options{}); err == nil {
		if minExpiry.Before(expiry) {
			return keys, expiry, nil
		}
	}
	return keys, minExpiry, nil
}