package server

import (
	"encoding/json"
	"fmt"
	"log"
	"net/http"
	"net/url"
	"path"
	"strconv"
	"strings"
	"time"

	"github.com/gorilla/mux"
	jose "gopkg.in/square/go-jose.v2"

	"github.com/coreos/poke/connector"
	"github.com/coreos/poke/storage"
)

func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
	// TODO(ericchiang): Cache this.
	keys, err := s.storage.GetKeys()
	if err != nil {
		log.Printf("failed to get keys: %v", err)
		http.Error(w, "Internal server error", http.StatusInternalServerError)
		return
	}

	if keys.SigningKeyPub == nil {
		log.Printf("No public keys found.")
		http.Error(w, "Internal server error", http.StatusInternalServerError)
		return
	}

	jwks := jose.JSONWebKeySet{
		Keys: make([]jose.JSONWebKey, len(keys.VerificationKeys)+1),
	}
	jwks.Keys[0] = *keys.SigningKeyPub
	for i, verificationKey := range keys.VerificationKeys {
		jwks.Keys[i+1] = *verificationKey.PublicKey
	}

	data, err := json.MarshalIndent(jwks, "", "  ")
	if err != nil {
		log.Printf("failed to marshal discovery data: %v", err)
		http.Error(w, "Internal server error", http.StatusInternalServerError)
		return
	}
	maxAge := keys.NextRotation.Sub(s.now())
	if maxAge < (time.Minute * 2) {
		maxAge = time.Minute * 2
	}

	w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, must-revalidate", maxAge))
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Content-Length", strconv.Itoa(len(data)))
	w.Write(data)
}

type discovery struct {
	Issuer        string   `json:"issuer"`
	Auth          string   `json:"authorization_endpoint"`
	Token         string   `json:"token_endpoint"`
	Keys          string   `json:"jwks_uri"`
	ResponseTypes []string `json:"response_types_supported"`
	Subjects      []string `json:"subject_types_supported"`
	IDTokenAlgs   []string `json:"id_token_signing_alg_values_supported"`
	Scopes        []string `json:"scopes_supported"`
	AuthMethods   []string `json:"token_endpoint_auth_methods_supported"`
	Claims        []string `json:"claims_supported"`
}

func (s *Server) handleDiscovery(w http.ResponseWriter, r *http.Request) {
	// TODO(ericchiang): Cache this
	d := discovery{
		Issuer:        s.issuerURL.String(),
		Auth:          s.absURL("/auth"),
		Token:         s.absURL("/token"),
		Keys:          s.absURL("/keys"),
		ResponseTypes: []string{"code"},
		Subjects:      []string{"public"},
		IDTokenAlgs:   []string{string(jose.RS256)},
		Scopes:        []string{"openid", "email", "profile"},
		AuthMethods:   []string{"client_secret_basic"},
		Claims: []string{
			"aud", "email", "email_verified", "exp", "family_name", "given_name",
			"iat", "iss", "locale", "name", "sub",
		},
	}
	data, err := json.MarshalIndent(d, "", "  ")
	if err != nil {
		log.Printf("failed to marshal discovery data: %v", err)
		http.Error(w, "Internal server error", http.StatusInternalServerError)
		return
	}
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Content-Length", strconv.Itoa(len(data)))
	w.Write(data)
}

// handleAuthorization handles the OAuth2 auth endpoint.
func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
	authReq, err := parseAuthorizationRequest(s.storage, r)
	if err != nil {
		s.renderError(w, http.StatusInternalServerError, err.Type, err.Description)
		return
	}
	if err := s.storage.CreateAuthRequest(authReq); err != nil {
		log.Printf("Failed to create authorization request: %v", err)
		s.renderError(w, http.StatusInternalServerError, errServerError, "")
		return
	}
	state := authReq.ID

	if len(s.connectors) == 1 {
		for id := range s.connectors {
			http.Redirect(w, r, s.absPath("/auth", id)+"?state="+state, http.StatusFound)
			return
		}
	}

	connectorInfos := make([]connectorInfo, len(s.connectors))
	i := 0
	for id := range s.connectors {
		connectorInfos[i] = connectorInfo{
			DisplayName: id,
			URL:         s.absPath("/auth", id) + "?state=" + state,
		}
		i++
	}

	renderLoginOptions(w, connectorInfos, state)
}

func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
	connID := mux.Vars(r)["connector"]
	conn, ok := s.connectors[connID]
	if !ok {
		s.notFound(w, r)
		return
	}

	// TODO(ericchiang): cache user identity.

	state := r.FormValue("state")
	switch r.Method {
	case "GET":
		switch conn := conn.Connector.(type) {
		case connector.CallbackConnector:
			callbackURL, err := conn.LoginURL(s.absURL("/callback", connID), state)
			if err != nil {
				log.Printf("Connector %q returned error when creating callback: %v", connID, err)
				s.renderError(w, http.StatusInternalServerError, errServerError, "")
				return
			}
			http.Redirect(w, r, callbackURL, http.StatusFound)
		case connector.PasswordConnector:
			renderPasswordTmpl(w, state, r.URL.String(), "")
		default:
			s.notFound(w, r)
		}
	case "POST":
		passwordConnector, ok := conn.Connector.(connector.PasswordConnector)
		if !ok {
			s.notFound(w, r)
			return
		}

		username := r.FormValue("username")
		password := r.FormValue("password")

		identity, ok, err := passwordConnector.Login(username, password)
		if err != nil {
			log.Printf("Failed to login user: %v", err)
			s.renderError(w, http.StatusInternalServerError, errServerError, "")
			return
		}
		if !ok {
			renderPasswordTmpl(w, state, r.URL.String(), "Invalid credentials")
			return
		}

		groups, ok, err := s.groups(identity, state, conn.Connector)
		if err != nil {
			s.renderError(w, http.StatusInternalServerError, errServerError, "")
			return
		}
		if ok {
			identity.Groups = groups
		}

		s.redirectToApproval(w, r, identity, connID, state)
	default:
		s.notFound(w, r)
	}
}

func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) {
	connID := mux.Vars(r)["connector"]
	conn, ok := s.connectors[connID]
	if !ok {
		s.notFound(w, r)
		return
	}
	callbackConnector, ok := conn.Connector.(connector.CallbackConnector)
	if !ok {
		s.notFound(w, r)
		return
	}

	identity, state, err := callbackConnector.HandleCallback(r)
	if err != nil {
		log.Printf("Failed to authenticate: %v", err)
		s.renderError(w, http.StatusInternalServerError, errServerError, "")
		return
	}
	groups, ok, err := s.groups(identity, state, conn.Connector)
	if err != nil {
		s.renderError(w, http.StatusInternalServerError, errServerError, "")
		return
	}
	if ok {
		identity.Groups = groups
	}
	s.redirectToApproval(w, r, identity, connID, state)
}

func (s *Server) redirectToApproval(w http.ResponseWriter, r *http.Request, identity storage.Identity, connectorID, state string) {
	updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
		a.Identity = &identity
		a.ConnectorID = connectorID
		return a, nil
	}
	if err := s.storage.UpdateAuthRequest(state, updater); err != nil {
		log.Printf("Failed to updated auth request with identity: %v", err)
		s.renderError(w, http.StatusInternalServerError, errServerError, "")
		return
	}
	http.Redirect(w, r, path.Join(s.issuerURL.Path, "/approval")+"?state="+state, http.StatusSeeOther)
}

func (s *Server) groups(identity storage.Identity, authReqID string, conn connector.Connector) ([]string, bool, error) {
	groupsConn, ok := conn.(connector.GroupsConnector)
	if !ok {
		return nil, false, nil
	}
	authReq, err := s.storage.GetAuthRequest(authReqID)
	if err != nil {
		log.Printf("get auth request: %v", err)
		return nil, false, err
	}
	reqGroups := func() bool {
		for _, scope := range authReq.Scopes {
			if scope == scopeGroups {
				return true
			}
		}
		return false
	}()
	if !reqGroups {
		return nil, false, nil
	}
	groups, err := groupsConn.Groups(identity)
	return groups, true, err
}

func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
	authReq, err := s.storage.GetAuthRequest(r.FormValue("state"))
	if err != nil {
		log.Printf("Failed to get auth request: %v", err)
		s.renderError(w, http.StatusInternalServerError, errServerError, "")
		return
	}
	if authReq.Identity == nil {
		log.Printf("Auth request does not have an identity for approval")
		s.renderError(w, http.StatusInternalServerError, errServerError, "")
		return
	}

	switch r.Method {
	case "GET":
		if s.skipApproval {
			s.sendCodeResponse(w, r, authReq, *authReq.Identity)
			return
		}
		client, err := s.storage.GetClient(authReq.ClientID)
		if err != nil {
			log.Printf("Failed to get client %q: %v", authReq.ClientID, err)
			s.renderError(w, http.StatusInternalServerError, errServerError, "")
			return
		}
		renderApprovalTmpl(w, authReq.ID, *authReq.Identity, client, authReq.Scopes)
	case "POST":
		if r.FormValue("approval") != "approve" {
			s.renderError(w, http.StatusInternalServerError, "approval rejected", "")
			return
		}
		s.sendCodeResponse(w, r, authReq, *authReq.Identity)
	}
}

func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest, identity storage.Identity) {
	if authReq.Expiry.After(s.now()) {
		s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.")
		return
	}

	if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil {
		if err != storage.ErrNotFound {
			log.Printf("Failed to delete authorization request: %v", err)
			s.renderError(w, http.StatusInternalServerError, errServerError, "")
		} else {
			s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request has already been completed.")
		}
		return
	}
	code := storage.AuthCode{
		ID:          storage.NewNonce(),
		ClientID:    authReq.ClientID,
		ConnectorID: authReq.ConnectorID,
		Nonce:       authReq.Nonce,
		Scopes:      authReq.Scopes,
		Identity:    *authReq.Identity,
		Expiry:      s.now().Add(time.Minute * 5),
		RedirectURI: authReq.RedirectURI,
	}
	if err := s.storage.CreateAuthCode(code); err != nil {
		log.Printf("Failed to create auth code: %v", err)
		s.renderError(w, http.StatusInternalServerError, errServerError, "")
		return
	}

	if authReq.RedirectURI == "urn:ietf:wg:oauth:2.0:oob" {
		// TODO(ericchiang): Add a proper template.
		fmt.Fprintf(w, "Code: %s", code.ID)
		return
	}

	u, err := url.Parse(authReq.RedirectURI)
	if err != nil {
		s.renderError(w, http.StatusInternalServerError, errServerError, "Invalid redirect URI.")
		return
	}
	q := u.Query()
	q.Set("code", code.ID)
	q.Set("state", authReq.State)
	u.RawQuery = q.Encode()
	http.Redirect(w, r, u.String(), http.StatusSeeOther)
}

func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
	clientID, clientSecret, ok := r.BasicAuth()
	if ok {
		var err error
		if clientID, err = url.QueryUnescape(clientID); err != nil {
			tokenErr(w, errInvalidRequest, "client_id improperly encoded", http.StatusBadRequest)
			return
		}
		if clientSecret, err = url.QueryUnescape(clientSecret); err != nil {
			tokenErr(w, errInvalidRequest, "client_secret improperly encoded", http.StatusBadRequest)
			return
		}
	} else {
		clientID = r.PostFormValue("client_id")
		clientSecret = r.PostFormValue("client_secret")
	}

	client, err := s.storage.GetClient(clientID)
	if err != nil {
		if err != storage.ErrNotFound {
			log.Printf("failed to get client: %v", err)
			tokenErr(w, errServerError, "", http.StatusInternalServerError)
		} else {
			tokenErr(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
		}
		return
	}
	if client.Secret != clientSecret {
		tokenErr(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
		return
	}

	grantType := r.PostFormValue("grant_type")
	switch grantType {
	case "authorization_code":
		s.handleAuthCode(w, r, client)
	case "refresh_token":
		s.handleRefreshToken(w, r, client)
	default:
		tokenErr(w, errInvalidGrant, "", http.StatusBadRequest)
	}
}

// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
	code := r.PostFormValue("code")
	redirectURI := r.PostFormValue("redirect_uri")

	authCode, err := s.storage.GetAuthCode(code)
	if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID {
		if err != storage.ErrNotFound {
			log.Printf("failed to get auth code: %v", err)
			tokenErr(w, errServerError, "", http.StatusInternalServerError)
		} else {
			tokenErr(w, errInvalidRequest, "Invalid or expired code parameter.", http.StatusBadRequest)
		}
		return
	}

	if authCode.RedirectURI != redirectURI {
		tokenErr(w, errInvalidRequest, "redirect_uri did not match URI from initial request.", http.StatusBadRequest)
		return
	}

	idToken, expiry, err := s.newIDToken(client.ID, authCode.Identity, authCode.Scopes, authCode.Nonce)
	if err != nil {
		log.Printf("failed to create ID token: %v", err)
		tokenErr(w, errServerError, "", http.StatusInternalServerError)
		return
	}

	if err := s.storage.DeleteAuthCode(code); err != nil {
		log.Printf("failed to delete auth code: %v", err)
		tokenErr(w, errServerError, "", http.StatusInternalServerError)
		return
	}

	reqRefresh := func() bool {
		for _, scope := range authCode.Scopes {
			if scope == scopeOfflineAccess {
				return true
			}
		}
		return false
	}()
	var refreshToken string
	if reqRefresh {
		refresh := storage.Refresh{
			RefreshToken: storage.NewNonce(),
			ClientID:     authCode.ClientID,
			ConnectorID:  authCode.ConnectorID,
			Scopes:       authCode.Scopes,
			Identity:     authCode.Identity,
			Nonce:        authCode.Nonce,
		}
		if err := s.storage.CreateRefresh(refresh); err != nil {
			log.Printf("failed to create refresh token: %v", err)
			tokenErr(w, errServerError, "", http.StatusInternalServerError)
			return
		}
		refreshToken = refresh.RefreshToken
	}
	s.writeAccessToken(w, idToken, refreshToken, expiry)
}

// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
	code := r.PostFormValue("refresh_token")
	scope := r.PostFormValue("scope")
	if code == "" {
		tokenErr(w, errInvalidRequest, "No refresh token in request.", http.StatusBadRequest)
		return
	}

	refresh, err := s.storage.GetRefresh(code)
	if err != nil || refresh.ClientID != client.ID {
		if err != storage.ErrNotFound {
			log.Printf("failed to get auth code: %v", err)
			tokenErr(w, errServerError, "", http.StatusInternalServerError)
		} else {
			tokenErr(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
		}
		return
	}

	scopes := refresh.Scopes
	if scope != "" {
		requestedScopes := strings.Split(scope, " ")
		contains := func() bool {
		Loop:
			for _, s := range requestedScopes {
				for _, scope := range refresh.Scopes {
					if s == scope {
						continue Loop
					}
				}
				return false
			}
			return true
		}()
		if !contains {
			tokenErr(w, errInvalidRequest, "Requested scopes did not contain authorized scopes.", http.StatusBadRequest)
			return
		}
		scopes = requestedScopes
	}

	// TODO(ericchiang): re-auth with backends

	idToken, expiry, err := s.newIDToken(client.ID, refresh.Identity, scopes, refresh.Nonce)
	if err != nil {
		log.Printf("failed to create ID token: %v", err)
		tokenErr(w, errServerError, "", http.StatusInternalServerError)
		return
	}

	if err := s.storage.DeleteRefresh(code); err != nil {
		log.Printf("failed to delete auth code: %v", err)
		tokenErr(w, errServerError, "", http.StatusInternalServerError)
		return
	}
	refresh.RefreshToken = storage.NewNonce()
	if err := s.storage.CreateRefresh(refresh); err != nil {
		log.Printf("failed to create refresh token: %v", err)
		tokenErr(w, errServerError, "", http.StatusInternalServerError)
		return
	}
	s.writeAccessToken(w, idToken, refresh.RefreshToken, expiry)
}

func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) {
	// TODO(ericchiang): figure out an access token story and support the user info
	// endpoint. For now use a random value so no one depends on the access_token
	// holding a specific structure.
	resp := struct {
		AccessToken  string `json:"access_token"`
		TokenType    string `json:"token_type"`
		ExpiresIn    int    `json:"expires_in"`
		RefreshToken string `json:"refresh_token,omitempty"`
		IDToken      string `json:"id_token"`
	}{
		storage.NewNonce(),
		"bearer",
		int(expiry.Sub(s.now())),
		refreshToken,
		idToken,
	}
	data, err := json.Marshal(resp)
	if err != nil {
		log.Printf("failed to marshal access token response: %v", err)
		tokenErr(w, errServerError, "", http.StatusInternalServerError)
		return
	}
	w.Header().Set("Content-Type", "application/json")
	w.Header().Set("Content-Length", strconv.Itoa(len(data)))
	w.Write(data)
}

func (s *Server) renderError(w http.ResponseWriter, status int, err, description string) {
	http.Error(w, fmt.Sprintf("%s: %s", err, description), status)
}

func (s *Server) notFound(w http.ResponseWriter, r *http.Request) {
	http.NotFound(w, r)
}