initial commit
This commit is contained in:
2
server/doc.go
Normal file
2
server/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package server implements an OpenID Connect server with federated logins.
|
||||
package server
|
556
server/handlers.go
Normal file
556
server/handlers.go
Normal file
@@ -0,0 +1,556 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/coreos/poke/connector"
|
||||
"github.com/coreos/poke/storage"
|
||||
)
|
||||
|
||||
func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO(ericchiang): Cache this.
|
||||
keys, err := s.storage.GetKeys()
|
||||
if err != nil {
|
||||
log.Printf("failed to get keys: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if keys.SigningKeyPub == nil {
|
||||
log.Printf("No public keys found.")
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
jwks := jose.JSONWebKeySet{
|
||||
Keys: make([]jose.JSONWebKey, len(keys.VerificationKeys)+1),
|
||||
}
|
||||
jwks.Keys[0] = *keys.SigningKeyPub
|
||||
for i, verificationKey := range keys.VerificationKeys {
|
||||
jwks.Keys[i+1] = *verificationKey.PublicKey
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(jwks, "", " ")
|
||||
if err != nil {
|
||||
log.Printf("failed to marshal discovery data: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
maxAge := keys.NextRotation.Sub(s.now())
|
||||
if maxAge < (time.Minute * 2) {
|
||||
maxAge = time.Minute * 2
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, must-revalidate", maxAge))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
type discovery struct {
|
||||
Issuer string `json:"issuer"`
|
||||
Auth string `json:"authorization_endpoint"`
|
||||
Token string `json:"token_endpoint"`
|
||||
Keys string `json:"jwks_uri"`
|
||||
ResponseTypes []string `json:"response_types_supported"`
|
||||
Subjects []string `json:"subject_types_supported"`
|
||||
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
|
||||
Scopes []string `json:"scopes_supported"`
|
||||
AuthMethods []string `json:"token_endpoint_auth_methods_supported"`
|
||||
Claims []string `json:"claims_supported"`
|
||||
}
|
||||
|
||||
func (s *Server) handleDiscovery(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO(ericchiang): Cache this
|
||||
d := discovery{
|
||||
Issuer: s.issuerURL.String(),
|
||||
Auth: s.absURL("/auth"),
|
||||
Token: s.absURL("/token"),
|
||||
Keys: s.absURL("/keys"),
|
||||
ResponseTypes: []string{"code"},
|
||||
Subjects: []string{"public"},
|
||||
IDTokenAlgs: []string{string(jose.RS256)},
|
||||
Scopes: []string{"openid", "email", "profile"},
|
||||
AuthMethods: []string{"client_secret_basic"},
|
||||
Claims: []string{
|
||||
"aud", "email", "email_verified", "exp", "family_name", "given_name",
|
||||
"iat", "iss", "locale", "name", "sub",
|
||||
},
|
||||
}
|
||||
data, err := json.MarshalIndent(d, "", " ")
|
||||
if err != nil {
|
||||
log.Printf("failed to marshal discovery data: %v", err)
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
// handleAuthorization handles the OAuth2 auth endpoint.
|
||||
func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
|
||||
authReq, err := parseAuthorizationRequest(s.storage, r)
|
||||
if err != nil {
|
||||
s.renderError(w, http.StatusInternalServerError, err.Type, err.Description)
|
||||
return
|
||||
}
|
||||
if err := s.storage.CreateAuthRequest(authReq); err != nil {
|
||||
log.Printf("Failed to create authorization request: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
state := authReq.ID
|
||||
|
||||
if len(s.connectors) == 1 {
|
||||
for id := range s.connectors {
|
||||
http.Redirect(w, r, s.absPath("/auth", id)+"?state="+state, http.StatusFound)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
connectorInfos := make([]connectorInfo, len(s.connectors))
|
||||
i := 0
|
||||
for id := range s.connectors {
|
||||
connectorInfos[i] = connectorInfo{
|
||||
DisplayName: id,
|
||||
URL: s.absPath("/auth", id) + "?state=" + state,
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
renderLoginOptions(w, connectorInfos, state)
|
||||
}
|
||||
|
||||
func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
|
||||
connID := mux.Vars(r)["connector"]
|
||||
conn, ok := s.connectors[connID]
|
||||
if !ok {
|
||||
s.notFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(ericchiang): cache user identity.
|
||||
|
||||
state := r.FormValue("state")
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
switch conn := conn.Connector.(type) {
|
||||
case connector.CallbackConnector:
|
||||
callbackURL, err := conn.LoginURL(s.absURL("/callback", connID), state)
|
||||
if err != nil {
|
||||
log.Printf("Connector %q returned error when creating callback: %v", connID, err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, callbackURL, http.StatusFound)
|
||||
case connector.PasswordConnector:
|
||||
renderPasswordTmpl(w, state, r.URL.String(), "")
|
||||
default:
|
||||
s.notFound(w, r)
|
||||
}
|
||||
case "POST":
|
||||
passwordConnector, ok := conn.Connector.(connector.PasswordConnector)
|
||||
if !ok {
|
||||
s.notFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
|
||||
identity, ok, err := passwordConnector.Login(username, password)
|
||||
if err != nil {
|
||||
log.Printf("Failed to login user: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
renderPasswordTmpl(w, state, r.URL.String(), "Invalid credentials")
|
||||
return
|
||||
}
|
||||
|
||||
groups, ok, err := s.groups(identity, state, conn.Connector)
|
||||
if err != nil {
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
identity.Groups = groups
|
||||
}
|
||||
|
||||
s.redirectToApproval(w, r, identity, connID, state)
|
||||
default:
|
||||
s.notFound(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) {
|
||||
connID := mux.Vars(r)["connector"]
|
||||
conn, ok := s.connectors[connID]
|
||||
if !ok {
|
||||
s.notFound(w, r)
|
||||
return
|
||||
}
|
||||
callbackConnector, ok := conn.Connector.(connector.CallbackConnector)
|
||||
if !ok {
|
||||
s.notFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
identity, state, err := callbackConnector.HandleCallback(r)
|
||||
if err != nil {
|
||||
log.Printf("Failed to authenticate: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
groups, ok, err := s.groups(identity, state, conn.Connector)
|
||||
if err != nil {
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
identity.Groups = groups
|
||||
}
|
||||
s.redirectToApproval(w, r, identity, connID, state)
|
||||
}
|
||||
|
||||
func (s *Server) redirectToApproval(w http.ResponseWriter, r *http.Request, identity storage.Identity, connectorID, state string) {
|
||||
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
|
||||
a.Identity = &identity
|
||||
a.ConnectorID = connectorID
|
||||
return a, nil
|
||||
}
|
||||
if err := s.storage.UpdateAuthRequest(state, updater); err != nil {
|
||||
log.Printf("Failed to updated auth request with identity: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, path.Join(s.issuerURL.Path, "/approval")+"?state="+state, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (s *Server) groups(identity storage.Identity, authReqID string, conn connector.Connector) ([]string, bool, error) {
|
||||
groupsConn, ok := conn.(connector.GroupsConnector)
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
authReq, err := s.storage.GetAuthRequest(authReqID)
|
||||
if err != nil {
|
||||
log.Printf("get auth request: %v", err)
|
||||
return nil, false, err
|
||||
}
|
||||
reqGroups := func() bool {
|
||||
for _, scope := range authReq.Scopes {
|
||||
if scope == scopeGroups {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}()
|
||||
if !reqGroups {
|
||||
return nil, false, nil
|
||||
}
|
||||
groups, err := groupsConn.Groups(identity)
|
||||
return groups, true, err
|
||||
}
|
||||
|
||||
func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
|
||||
authReq, err := s.storage.GetAuthRequest(r.FormValue("state"))
|
||||
if err != nil {
|
||||
log.Printf("Failed to get auth request: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
if authReq.Identity == nil {
|
||||
log.Printf("Auth request does not have an identity for approval")
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case "GET":
|
||||
if s.skipApproval {
|
||||
s.sendCodeResponse(w, r, authReq, *authReq.Identity)
|
||||
return
|
||||
}
|
||||
client, err := s.storage.GetClient(authReq.ClientID)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get client %q: %v", authReq.ClientID, err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
renderApprovalTmpl(w, authReq.ID, *authReq.Identity, client, authReq.Scopes)
|
||||
case "POST":
|
||||
if r.FormValue("approval") != "approve" {
|
||||
s.renderError(w, http.StatusInternalServerError, "approval rejected", "")
|
||||
return
|
||||
}
|
||||
s.sendCodeResponse(w, r, authReq, *authReq.Identity)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest, identity storage.Identity) {
|
||||
if authReq.Expiry.After(s.now()) {
|
||||
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
log.Printf("Failed to delete authorization request: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
} else {
|
||||
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request has already been completed.")
|
||||
}
|
||||
return
|
||||
}
|
||||
code := storage.AuthCode{
|
||||
ID: storage.NewNonce(),
|
||||
ClientID: authReq.ClientID,
|
||||
ConnectorID: authReq.ConnectorID,
|
||||
Nonce: authReq.Nonce,
|
||||
Scopes: authReq.Scopes,
|
||||
Identity: *authReq.Identity,
|
||||
Expiry: s.now().Add(time.Minute * 5),
|
||||
RedirectURI: authReq.RedirectURI,
|
||||
}
|
||||
if err := s.storage.CreateAuthCode(code); err != nil {
|
||||
log.Printf("Failed to create auth code: %v", err)
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "")
|
||||
return
|
||||
}
|
||||
|
||||
if authReq.RedirectURI == "urn:ietf:wg:oauth:2.0:oob" {
|
||||
// TODO(ericchiang): Add a proper template.
|
||||
fmt.Fprintf(w, "Code: %s", code.ID)
|
||||
return
|
||||
}
|
||||
|
||||
u, err := url.Parse(authReq.RedirectURI)
|
||||
if err != nil {
|
||||
s.renderError(w, http.StatusInternalServerError, errServerError, "Invalid redirect URI.")
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("code", code.ID)
|
||||
q.Set("state", authReq.State)
|
||||
u.RawQuery = q.Encode()
|
||||
http.Redirect(w, r, u.String(), http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||
clientID, clientSecret, ok := r.BasicAuth()
|
||||
if ok {
|
||||
var err error
|
||||
if clientID, err = url.QueryUnescape(clientID); err != nil {
|
||||
tokenErr(w, errInvalidRequest, "client_id improperly encoded", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if clientSecret, err = url.QueryUnescape(clientSecret); err != nil {
|
||||
tokenErr(w, errInvalidRequest, "client_secret improperly encoded", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
clientID = r.PostFormValue("client_id")
|
||||
clientSecret = r.PostFormValue("client_secret")
|
||||
}
|
||||
|
||||
client, err := s.storage.GetClient(clientID)
|
||||
if err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
log.Printf("failed to get client: %v", err)
|
||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||
} else {
|
||||
tokenErr(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
|
||||
}
|
||||
return
|
||||
}
|
||||
if client.Secret != clientSecret {
|
||||
tokenErr(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
grantType := r.PostFormValue("grant_type")
|
||||
switch grantType {
|
||||
case "authorization_code":
|
||||
s.handleAuthCode(w, r, client)
|
||||
case "refresh_token":
|
||||
s.handleRefreshToken(w, r, client)
|
||||
default:
|
||||
tokenErr(w, errInvalidGrant, "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
|
||||
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
|
||||
code := r.PostFormValue("code")
|
||||
redirectURI := r.PostFormValue("redirect_uri")
|
||||
|
||||
authCode, err := s.storage.GetAuthCode(code)
|
||||
if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID {
|
||||
if err != storage.ErrNotFound {
|
||||
log.Printf("failed to get auth code: %v", err)
|
||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||
} else {
|
||||
tokenErr(w, errInvalidRequest, "Invalid or expired code parameter.", http.StatusBadRequest)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if authCode.RedirectURI != redirectURI {
|
||||
tokenErr(w, errInvalidRequest, "redirect_uri did not match URI from initial request.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, expiry, err := s.newIDToken(client.ID, authCode.Identity, authCode.Scopes, authCode.Nonce)
|
||||
if err != nil {
|
||||
log.Printf("failed to create ID token: %v", err)
|
||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.storage.DeleteAuthCode(code); err != nil {
|
||||
log.Printf("failed to delete auth code: %v", err)
|
||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
reqRefresh := func() bool {
|
||||
for _, scope := range authCode.Scopes {
|
||||
if scope == scopeOfflineAccess {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}()
|
||||
var refreshToken string
|
||||
if reqRefresh {
|
||||
refresh := storage.Refresh{
|
||||
RefreshToken: storage.NewNonce(),
|
||||
ClientID: authCode.ClientID,
|
||||
ConnectorID: authCode.ConnectorID,
|
||||
Scopes: authCode.Scopes,
|
||||
Identity: authCode.Identity,
|
||||
Nonce: authCode.Nonce,
|
||||
}
|
||||
if err := s.storage.CreateRefresh(refresh); err != nil {
|
||||
log.Printf("failed to create refresh token: %v", err)
|
||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
refreshToken = refresh.RefreshToken
|
||||
}
|
||||
s.writeAccessToken(w, idToken, refreshToken, expiry)
|
||||
}
|
||||
|
||||
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
|
||||
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
|
||||
code := r.PostFormValue("refresh_token")
|
||||
scope := r.PostFormValue("scope")
|
||||
if code == "" {
|
||||
tokenErr(w, errInvalidRequest, "No refresh token in request.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
refresh, err := s.storage.GetRefresh(code)
|
||||
if err != nil || refresh.ClientID != client.ID {
|
||||
if err != storage.ErrNotFound {
|
||||
log.Printf("failed to get auth code: %v", err)
|
||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||
} else {
|
||||
tokenErr(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
scopes := refresh.Scopes
|
||||
if scope != "" {
|
||||
requestedScopes := strings.Split(scope, " ")
|
||||
contains := func() bool {
|
||||
Loop:
|
||||
for _, s := range requestedScopes {
|
||||
for _, scope := range refresh.Scopes {
|
||||
if s == scope {
|
||||
continue Loop
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}()
|
||||
if !contains {
|
||||
tokenErr(w, errInvalidRequest, "Requested scopes did not contain authorized scopes.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
scopes = requestedScopes
|
||||
}
|
||||
|
||||
// TODO(ericchiang): re-auth with backends
|
||||
|
||||
idToken, expiry, err := s.newIDToken(client.ID, refresh.Identity, scopes, refresh.Nonce)
|
||||
if err != nil {
|
||||
log.Printf("failed to create ID token: %v", err)
|
||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.storage.DeleteRefresh(code); err != nil {
|
||||
log.Printf("failed to delete auth code: %v", err)
|
||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
refresh.RefreshToken = storage.NewNonce()
|
||||
if err := s.storage.CreateRefresh(refresh); err != nil {
|
||||
log.Printf("failed to create refresh token: %v", err)
|
||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.writeAccessToken(w, idToken, refresh.RefreshToken, expiry)
|
||||
}
|
||||
|
||||
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) {
|
||||
// TODO(ericchiang): figure out an access token story and support the user info
|
||||
// endpoint. For now use a random value so no one depends on the access_token
|
||||
// holding a specific structure.
|
||||
resp := struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
IDToken string `json:"id_token"`
|
||||
}{
|
||||
storage.NewNonce(),
|
||||
"bearer",
|
||||
int(expiry.Sub(s.now())),
|
||||
refreshToken,
|
||||
idToken,
|
||||
}
|
||||
data, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
log.Printf("failed to marshal access token response: %v", err)
|
||||
tokenErr(w, errServerError, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
func (s *Server) renderError(w http.ResponseWriter, status int, err, description string) {
|
||||
http.Error(w, fmt.Sprintf("%s: %s", err, description), status)
|
||||
}
|
||||
|
||||
func (s *Server) notFound(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
}
|
1
server/handlers_test.go
Normal file
1
server/handlers_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package server
|
339
server/oauth2.go
Normal file
339
server/oauth2.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/poke/storage"
|
||||
)
|
||||
|
||||
// TODO(ericchiang): clean this file up and figure out more idiomatic error handling.
|
||||
|
||||
// authErr is an error response to an authorization request.
|
||||
// See: https://tools.ietf.org/html/rfc6749#section-4.1.2.1
|
||||
type authErr struct {
|
||||
State string
|
||||
RedirectURI string
|
||||
Type string
|
||||
Description string
|
||||
}
|
||||
|
||||
func (err *authErr) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
v := url.Values{}
|
||||
v.Add("state", err.State)
|
||||
v.Add("error", err.Type)
|
||||
if err.Description != "" {
|
||||
v.Add("error_description", err.Description)
|
||||
}
|
||||
var redirectURI string
|
||||
if strings.Contains(err.RedirectURI, "?") {
|
||||
redirectURI = err.RedirectURI + "&" + v.Encode()
|
||||
} else {
|
||||
redirectURI = err.RedirectURI + "?" + v.Encode()
|
||||
}
|
||||
http.Redirect(w, r, redirectURI, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
func tokenErr(w http.ResponseWriter, typ, description string, statusCode int) {
|
||||
data := struct {
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description,omitempty"`
|
||||
}{typ, description}
|
||||
body, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
log.Printf("failed to marshal token error response: %v", err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
|
||||
w.Write(body)
|
||||
}
|
||||
|
||||
const (
|
||||
errInvalidRequest = "invalid_request"
|
||||
errUnauthorizedClient = "unauthorized_client"
|
||||
errAccessDenied = "access_denied"
|
||||
errUnsupportedResponseType = "unsupported_response_type"
|
||||
errInvalidScope = "invalid_scope"
|
||||
errServerError = "server_error"
|
||||
errTemporarilyUnavailable = "temporarily_unavailable"
|
||||
errUnsupportedGrantType = "unsupported_grant_type"
|
||||
errInvalidGrant = "invalid_grant"
|
||||
errInvalidClient = "invalid_client"
|
||||
)
|
||||
|
||||
const (
|
||||
scopeOfflineAccess = "offline_access" // Request a refresh token.
|
||||
scopeOpenID = "openid"
|
||||
scopeGroups = "groups"
|
||||
scopeEmail = "email"
|
||||
scopeProfile = "profile"
|
||||
scopeCrossClientPrefix = "oauth2:server:client_id:"
|
||||
)
|
||||
|
||||
const (
|
||||
grantTypeAuthorizationCode = "code"
|
||||
grantTypeRefreshToken = "refresh_token"
|
||||
)
|
||||
|
||||
const (
|
||||
responseTypeCode = "code" // "Regular" flow
|
||||
responseTypeToken = "token" // Implicit flow for frontend apps.
|
||||
responseTypeIDToken = "id_token" // ID Token in url fragment
|
||||
)
|
||||
|
||||
var validResponseTypes = map[string]bool{
|
||||
"code": true,
|
||||
"token": true,
|
||||
"id_token": true,
|
||||
}
|
||||
|
||||
type audience []string
|
||||
|
||||
func (a audience) MarshalJSON() ([]byte, error) {
|
||||
if len(a) == 1 {
|
||||
return json.Marshal(a[0])
|
||||
}
|
||||
return json.Marshal(a)
|
||||
}
|
||||
|
||||
type idTokenClaims struct {
|
||||
Issuer string `json:"iss"`
|
||||
Subject string `json:"sub"`
|
||||
Audience audience `json:"aud"`
|
||||
Expiry int64 `json:"exp"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
AuthorizingParty string `json:"azp,omitempty"`
|
||||
Nonce string `json:"nonce,omitempty"`
|
||||
|
||||
Email string `json:"email,omitempty"`
|
||||
EmailVerified *bool `json:"email_verified,omitempty"`
|
||||
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
|
||||
Name string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) newIDToken(clientID string, claims storage.Identity, scopes []string, nonce string) (idToken string, expiry time.Time, err error) {
|
||||
issuedAt := s.now()
|
||||
expiry = issuedAt.Add(s.idTokensValidFor)
|
||||
|
||||
tok := idTokenClaims{
|
||||
Issuer: s.issuerURL.String(),
|
||||
Subject: claims.UserID,
|
||||
Nonce: nonce,
|
||||
Expiry: expiry.Unix(),
|
||||
IssuedAt: issuedAt.Unix(),
|
||||
}
|
||||
|
||||
for _, scope := range scopes {
|
||||
switch {
|
||||
case scope == scopeEmail:
|
||||
tok.Email = claims.Email
|
||||
tok.EmailVerified = &claims.EmailVerified
|
||||
case scope == scopeGroups:
|
||||
tok.Groups = claims.Groups
|
||||
case scope == scopeProfile:
|
||||
tok.Name = claims.Username
|
||||
default:
|
||||
peerID, ok := parseCrossClientScope(scope)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
isTrusted, err := validateCrossClientTrust(s.storage, clientID, peerID)
|
||||
if err != nil {
|
||||
return "", expiry, err
|
||||
}
|
||||
if !isTrusted {
|
||||
// TODO(ericchiang): propagate this error to the client.
|
||||
return "", expiry, fmt.Errorf("peer (%s) does not trust client", peerID)
|
||||
}
|
||||
tok.Audience = append(tok.Audience, peerID)
|
||||
}
|
||||
}
|
||||
if len(tok.Audience) == 0 {
|
||||
tok.Audience = audience{clientID}
|
||||
} else {
|
||||
tok.AuthorizingParty = clientID
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(tok)
|
||||
if err != nil {
|
||||
return "", expiry, fmt.Errorf("could not serialize claims: %v", err)
|
||||
}
|
||||
|
||||
keys, err := s.storage.GetKeys()
|
||||
if err != nil {
|
||||
log.Printf("Failed to get keys: %v", err)
|
||||
return "", expiry, err
|
||||
}
|
||||
if idToken, err = keys.Sign(payload); err != nil {
|
||||
return "", expiry, fmt.Errorf("failed to sign payload: %v", err)
|
||||
}
|
||||
return idToken, expiry, nil
|
||||
}
|
||||
|
||||
// parse the initial request from the OAuth2 client.
|
||||
//
|
||||
// For correctness the logic is largely copied from https://github.com/RangelReale/osin.
|
||||
func parseAuthorizationRequest(s storage.Storage, r *http.Request) (req storage.AuthRequest, oauth2Err *authErr) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return req, &authErr{"", "", errInvalidRequest, "Failed to parse request."}
|
||||
}
|
||||
|
||||
redirectURI, err := url.QueryUnescape(r.Form.Get("redirect_uri"))
|
||||
if err != nil {
|
||||
return req, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."}
|
||||
}
|
||||
state := r.FormValue("state")
|
||||
|
||||
clientID := r.Form.Get("client_id")
|
||||
|
||||
client, err := s.GetClient(clientID)
|
||||
if err != nil {
|
||||
if err == storage.ErrNotFound {
|
||||
description := fmt.Sprintf("Invalid client_id (%q).", clientID)
|
||||
return req, &authErr{"", "", errUnauthorizedClient, description}
|
||||
}
|
||||
log.Printf("Failed to get client: %v", err)
|
||||
return req, &authErr{"", "", errServerError, ""}
|
||||
}
|
||||
|
||||
if !validateRedirectURI(client, redirectURI) {
|
||||
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
|
||||
return req, &authErr{"", "", errInvalidRequest, description}
|
||||
}
|
||||
|
||||
newErr := func(typ, format string, a ...interface{}) *authErr {
|
||||
return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)}
|
||||
}
|
||||
|
||||
scopes := strings.Split(r.Form.Get("scope"), " ")
|
||||
|
||||
var (
|
||||
unrecognized []string
|
||||
invalidScopes []string
|
||||
)
|
||||
hasOpenIDScope := false
|
||||
for _, scope := range scopes {
|
||||
switch scope {
|
||||
case scopeOpenID:
|
||||
hasOpenIDScope = true
|
||||
case scopeOfflineAccess, scopeEmail, scopeProfile, scopeGroups:
|
||||
default:
|
||||
peerID, ok := parseCrossClientScope(scope)
|
||||
if !ok {
|
||||
unrecognized = append(unrecognized, scope)
|
||||
continue
|
||||
}
|
||||
|
||||
isTrusted, err := validateCrossClientTrust(s, clientID, peerID)
|
||||
if err != nil {
|
||||
return req, newErr(errServerError, "")
|
||||
}
|
||||
if !isTrusted {
|
||||
invalidScopes = append(invalidScopes, scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasOpenIDScope {
|
||||
return req, newErr("invalid_scope", `Missing required scope(s) ["openid"].`)
|
||||
}
|
||||
if len(unrecognized) > 0 {
|
||||
return req, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized)
|
||||
}
|
||||
if len(invalidScopes) > 0 {
|
||||
return req, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes)
|
||||
}
|
||||
|
||||
responseTypes := strings.Split(r.Form.Get("response_type"), " ")
|
||||
for _, responseType := range responseTypes {
|
||||
if !validResponseTypes[responseType] {
|
||||
return req, newErr("invalid_request", "Invalid response type %q", responseType)
|
||||
}
|
||||
}
|
||||
|
||||
return storage.AuthRequest{
|
||||
ID: storage.NewNonce(),
|
||||
ClientID: client.ID,
|
||||
State: r.Form.Get("state"),
|
||||
Nonce: r.Form.Get("nonce"),
|
||||
ForceApprovalPrompt: r.Form.Get("approval_prompt") == "force",
|
||||
Scopes: scopes,
|
||||
RedirectURI: redirectURI,
|
||||
ResponseTypes: responseTypes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseCrossClientScope(scope string) (peerID string, ok bool) {
|
||||
if ok = strings.HasPrefix(scope, scopeCrossClientPrefix); ok {
|
||||
peerID = scope[len(scopeCrossClientPrefix):]
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func validateCrossClientTrust(s storage.Storage, clientID, peerID string) (trusted bool, err error) {
|
||||
if peerID == clientID {
|
||||
return true, nil
|
||||
}
|
||||
peer, err := s.GetClient(peerID)
|
||||
if err != nil {
|
||||
if err != storage.ErrNotFound {
|
||||
log.Printf("Failed to get client: %v", err)
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
for _, id := range peer.TrustedPeers {
|
||||
if id == clientID {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func validateRedirectURI(client storage.Client, redirectURI string) bool {
|
||||
if !client.Public {
|
||||
for _, uri := range client.RedirectURIs {
|
||||
if redirectURI == uri {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if redirectURI == "urn:ietf:wg:oauth:2.0:oob" {
|
||||
return true
|
||||
}
|
||||
if !strings.HasPrefix(redirectURI, "http://localhost:") {
|
||||
return false
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimPrefix(redirectURI, "https://localhost:"))
|
||||
return err == nil && n <= 0
|
||||
}
|
||||
|
||||
type tokenRequest struct {
|
||||
Client storage.Client
|
||||
IsRefresh bool
|
||||
Token string
|
||||
RedirectURI string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
func handleTokenRequest(s storage.Storage, w http.ResponseWriter, r *http.Request) *authErr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleRefreshRequest(s storage.Storage, w http.ResponseWriter, r *http.Request, client storage.Client) *authErr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleCodeRequest(s storage.Storage, w http.ResponseWriter, r *http.Request, client storage.Client) *authErr {
|
||||
return nil
|
||||
}
|
1
server/oauth2_test.go
Normal file
1
server/oauth2_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package server
|
165
server/rotation.go
Normal file
165
server/rotation.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/coreos/poke/storage"
|
||||
)
|
||||
|
||||
// rotationStrategy describes a strategy for generating cryptographic keys, how
|
||||
// often to rotate them, and how long they can validate signatures after rotation.
|
||||
type rotationStrategy struct {
|
||||
// Time between rotations.
|
||||
period time.Duration
|
||||
|
||||
// After being rotated how long can a key validate signatues?
|
||||
verifyFor time.Duration
|
||||
|
||||
// Keys are always RSA keys. Though cryptopasta recommends ECDSA keys, not every
|
||||
// client may support these (e.g. github.com/coreos/go-oidc/oidc).
|
||||
key func() (*rsa.PrivateKey, error)
|
||||
}
|
||||
|
||||
// staticRotationStrategy returns a strategy which never rotates keys.
|
||||
func staticRotationStrategy(key *rsa.PrivateKey) rotationStrategy {
|
||||
return rotationStrategy{
|
||||
// Setting these values to 100 years is easier than having a flag indicating no rotation.
|
||||
period: time.Hour * 8760 * 100,
|
||||
verifyFor: time.Hour * 8760 * 100,
|
||||
key: func() (*rsa.PrivateKey, error) { return key, nil },
|
||||
}
|
||||
}
|
||||
|
||||
// defaultRotationStrategy returns a strategy which rotates keys every provided period,
|
||||
// holding onto the public parts for some specified amount of time.
|
||||
func defaultRotationStrategy(rotationPeriod, verifyFor time.Duration) rotationStrategy {
|
||||
return rotationStrategy{
|
||||
period: rotationPeriod,
|
||||
verifyFor: verifyFor,
|
||||
key: func() (*rsa.PrivateKey, error) {
|
||||
return rsa.GenerateKey(rand.Reader, 2048)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type keyRotater struct {
|
||||
storage.Storage
|
||||
|
||||
strategy rotationStrategy
|
||||
cancel context.CancelFunc
|
||||
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage {
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
rotater := keyRotater{s, strategy, cancel, now}
|
||||
|
||||
// Try to rotate immediately so properly configured storages will return a
|
||||
// storage with keys.
|
||||
if err := rotater.rotate(); err != nil {
|
||||
log.Printf("failed to rotate keys: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(time.Second * 30):
|
||||
if err := rotater.rotate(); err != nil {
|
||||
log.Printf("failed to rotate keys: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
return rotater
|
||||
}
|
||||
|
||||
func (k keyRotater) Close() error {
|
||||
k.cancel()
|
||||
return k.Storage.Close()
|
||||
}
|
||||
|
||||
func (k keyRotater) rotate() error {
|
||||
keys, err := k.GetKeys()
|
||||
if err != nil && err != storage.ErrNotFound {
|
||||
return fmt.Errorf("get keys: %v", err)
|
||||
}
|
||||
if k.now().Before(keys.NextRotation) {
|
||||
return nil
|
||||
}
|
||||
log.Println("keys expired, rotating")
|
||||
|
||||
// Generate the key outside of a storage transaction.
|
||||
key, err := k.strategy.key()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate key: %v", err)
|
||||
}
|
||||
b := make([]byte, 20)
|
||||
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
keyID := hex.EncodeToString(b)
|
||||
priv := &jose.JSONWebKey{
|
||||
Key: key,
|
||||
KeyID: keyID,
|
||||
Algorithm: "RS256",
|
||||
Use: "sig",
|
||||
}
|
||||
pub := &jose.JSONWebKey{
|
||||
Key: key.Public(),
|
||||
KeyID: keyID,
|
||||
Algorithm: "RS256",
|
||||
Use: "sig",
|
||||
}
|
||||
|
||||
var nextRotation time.Time
|
||||
err = k.Storage.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
|
||||
tNow := k.now()
|
||||
if tNow.Before(keys.NextRotation) {
|
||||
return storage.Keys{}, errors.New("keys already rotated")
|
||||
}
|
||||
|
||||
// Remove expired verification keys.
|
||||
i := 0
|
||||
for _, key := range keys.VerificationKeys {
|
||||
if !key.Expiry.After(tNow) {
|
||||
keys.VerificationKeys[i] = key
|
||||
i++
|
||||
}
|
||||
}
|
||||
keys.VerificationKeys = keys.VerificationKeys[:i]
|
||||
|
||||
if keys.SigningKeyPub != nil {
|
||||
// Move current signing key to a verification only key.
|
||||
verificationKey := storage.VerificationKey{
|
||||
PublicKey: keys.SigningKeyPub,
|
||||
Expiry: tNow.Add(k.strategy.verifyFor),
|
||||
}
|
||||
keys.VerificationKeys = append(keys.VerificationKeys, verificationKey)
|
||||
}
|
||||
|
||||
nextRotation = k.now().Add(k.strategy.period)
|
||||
keys.SigningKey = priv
|
||||
keys.SigningKeyPub = pub
|
||||
keys.NextRotation = nextRotation
|
||||
return keys, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("keys rotated, next rotation: %s", nextRotation)
|
||||
return nil
|
||||
}
|
1
server/rotation_test.go
Normal file
1
server/rotation_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package server
|
141
server/server.go
Normal file
141
server/server.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/coreos/poke/connector"
|
||||
"github.com/coreos/poke/storage"
|
||||
)
|
||||
|
||||
// Connector is a connector with metadata.
|
||||
type Connector struct {
|
||||
ID string
|
||||
DisplayName string
|
||||
Connector connector.Connector
|
||||
}
|
||||
|
||||
// Config holds the server's configuration options.
|
||||
type Config struct {
|
||||
Issuer string
|
||||
|
||||
// The backing persistence layer.
|
||||
Storage storage.Storage
|
||||
|
||||
// Strategies for federated identity.
|
||||
Connectors []Connector
|
||||
|
||||
// NOTE: Multiple servers using the same storage are expected to set rotation and
|
||||
// validity periods to the same values.
|
||||
RotateKeysAfter time.Duration // Defaults to 6 hours.
|
||||
IDTokensValidFor time.Duration // Defaults to 24 hours
|
||||
|
||||
// If specified, the server will use this function for determining time.
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
func value(val, defaultValue time.Duration) time.Duration {
|
||||
if val == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// Server is the top level object.
|
||||
type Server struct {
|
||||
issuerURL url.URL
|
||||
|
||||
// Read-only map of connector IDs to connectors.
|
||||
connectors map[string]Connector
|
||||
|
||||
storage storage.Storage
|
||||
|
||||
mux http.Handler
|
||||
|
||||
// If enabled, don't prompt user for approval after logging in through connector.
|
||||
// No package level API to set this, only used in tests.
|
||||
skipApproval bool
|
||||
|
||||
now func() time.Time
|
||||
|
||||
idTokensValidFor time.Duration
|
||||
}
|
||||
|
||||
// New constructs a server from the provided config.
|
||||
func New(c Config) (*Server, error) {
|
||||
return newServer(c, defaultRotationStrategy(
|
||||
value(c.RotateKeysAfter, 6*time.Hour),
|
||||
value(c.IDTokensValidFor, 24*time.Hour),
|
||||
))
|
||||
}
|
||||
|
||||
func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
|
||||
issuerURL, err := url.Parse(c.Issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("server: can't parse issuer URL")
|
||||
}
|
||||
if len(c.Connectors) == 0 {
|
||||
return nil, errors.New("server: no connectors specified")
|
||||
}
|
||||
if c.Storage == nil {
|
||||
return nil, errors.New("server: storage cannot be nil")
|
||||
}
|
||||
|
||||
now := c.Now
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
issuerURL: *issuerURL,
|
||||
connectors: make(map[string]Connector),
|
||||
storage: storageWithKeyRotation(c.Storage, rotationStrategy, now),
|
||||
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
|
||||
now: now,
|
||||
}
|
||||
|
||||
for _, conn := range c.Connectors {
|
||||
s.connectors[conn.ID] = conn
|
||||
}
|
||||
|
||||
r := mux.NewRouter()
|
||||
handleFunc := func(p string, h http.HandlerFunc) {
|
||||
r.HandleFunc(path.Join(issuerURL.Path, p), h)
|
||||
}
|
||||
r.NotFoundHandler = http.HandlerFunc(s.notFound)
|
||||
|
||||
// TODO(ericchiang): rate limit certain paths based on IP.
|
||||
handleFunc("/.well-known/openid-configuration", s.handleDiscovery)
|
||||
handleFunc("/token", s.handleToken)
|
||||
handleFunc("/keys", s.handlePublicKeys)
|
||||
handleFunc("/auth", s.handleAuthorization)
|
||||
handleFunc("/auth/{connector}", s.handleConnectorLogin)
|
||||
handleFunc("/callback/{connector}", s.handleConnectorCallback)
|
||||
handleFunc("/approval", s.handleApproval)
|
||||
s.mux = r
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
s.mux.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (s *Server) absPath(pathItems ...string) string {
|
||||
paths := make([]string, len(pathItems)+1)
|
||||
paths[0] = s.issuerURL.Path
|
||||
copy(paths[1:], pathItems)
|
||||
return path.Join(paths...)
|
||||
}
|
||||
|
||||
func (s *Server) absURL(pathItems ...string) string {
|
||||
u := s.issuerURL
|
||||
u.Path = s.absPath(pathItems...)
|
||||
return u.String()
|
||||
}
|
221
server/server_test.go
Normal file
221
server/server_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ericchiang/oidc"
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/coreos/poke/connector/mock"
|
||||
"github.com/coreos/poke/storage"
|
||||
"github.com/coreos/poke/storage/memory"
|
||||
)
|
||||
|
||||
func mustLoad(s string) *rsa.PrivateKey {
|
||||
block, _ := pem.Decode([]byte(s))
|
||||
if block == nil {
|
||||
panic("no pem data found")
|
||||
}
|
||||
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
var testKey = mustLoad(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEogIBAAKCAQEArmoiX5G36MKPiVGS1sicruEaGRrbhPbIKOf97aGGQRjXVngo
|
||||
Knwd2L4T9CRyABgQm3tLHHcT5crODoy46wX2g9onTZWViWWuhJ5wxXNmUbCAPWHb
|
||||
j9SunW53WuLYZ/IJLNZt5XYCAFPjAakWp8uMuuDwWo5EyFaw85X3FSMhVmmaYDd0
|
||||
cn+1H4+NS/52wX7tWmyvGUNJ8lzjFAnnOtBJByvkyIC7HDphkLQV4j//sMNY1mPX
|
||||
HbsYgFv2J/LIJtkjdYO2UoDhZG3Gvj16fMy2JE2owA8IX4/s+XAmA2PiTfd0J5b4
|
||||
drAKEcdDl83G6L3depEkTkfvp0ZLsh9xupAvIwIDAQABAoIBABKGgWonPyKA7+AF
|
||||
AxS/MC0/CZebC6/+ylnV8lm4K1tkuRKdJp8EmeL4pYPsDxPFepYZLWwzlbB1rxdK
|
||||
iSWld36fwEb0WXLDkxrQ/Wdrj3Wjyqs6ZqjLTVS5dAH6UEQSKDlT+U5DD4lbX6RA
|
||||
goCGFUeQNtdXfyTMWHU2+4yKM7NKzUpczFky+0d10Mg0ANj3/4IILdr3hqkmMSI9
|
||||
1TB9ksWBXJxt3nGxAjzSFihQFUlc231cey/HhYbvAX5fN0xhLxOk88adDcdXE7br
|
||||
3Ser1q6XaaFQSMj4oi1+h3RAT9MUjJ6johEqjw0PbEZtOqXvA1x5vfFdei6SqgKn
|
||||
Am3BspkCgYEA2lIiKEkT/Je6ZH4Omhv9atbGoBdETAstL3FnNQjkyVau9f6bxQkl
|
||||
4/sz985JpaiasORQBiTGY8JDT/hXjROkut91agi2Vafhr29L/mto7KZglfDsT4b2
|
||||
9z/EZH8wHw7eYhvdoBbMbqNDSI8RrGa4mpLpuN+E0wsFTzSZEL+QMQUCgYEAzIQh
|
||||
xnreQvDAhNradMqLmxRpayn1ORaPReD4/off+mi7hZRLKtP0iNgEVEWHJ6HEqqi1
|
||||
r38XAc8ap/lfOVMar2MLyCFOhYspdHZ+TGLZfr8gg/Fzeq9IRGKYadmIKVwjMeyH
|
||||
REPqg1tyrvMOE0HI5oqkko8JTDJ0OyVC0Vc6+AcCgYAqCzkywugLc/jcU35iZVOH
|
||||
WLdFq1Vmw5w/D7rNdtoAgCYPj6nV5y4Z2o2mgl6ifXbU7BMRK9Hc8lNeOjg6HfdS
|
||||
WahV9DmRA1SuIWPkKjE5qczd81i+9AHpmakrpWbSBF4FTNKAewOBpwVVGuBPcDTK
|
||||
59IE3V7J+cxa9YkotYuCNQKBgCwGla7AbHBEm2z+H+DcaUktD7R+B8gOTzFfyLoi
|
||||
Tdj+CsAquDO0BQQgXG43uWySql+CifoJhc5h4v8d853HggsXa0XdxaWB256yk2Wm
|
||||
MePTCRDePVm/ufLetqiyp1kf+IOaw1Oyux0j5oA62mDS3Iikd+EE4Z+BjPvefY/L
|
||||
E2qpAoGAZo5Wwwk7q8b1n9n/ACh4LpE+QgbFdlJxlfFLJCKstl37atzS8UewOSZj
|
||||
FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
|
||||
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
func newTestServer() (*httptest.Server, *Server) {
|
||||
var server *Server
|
||||
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server.ServeHTTP(w, r)
|
||||
}))
|
||||
config := Config{
|
||||
Issuer: s.URL,
|
||||
Storage: memory.New(),
|
||||
Connectors: []Connector{
|
||||
{
|
||||
ID: "mock",
|
||||
DisplayName: "Mock",
|
||||
Connector: mock.New(),
|
||||
},
|
||||
},
|
||||
}
|
||||
var err error
|
||||
if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
|
||||
return s, server
|
||||
}
|
||||
|
||||
func TestNewTestServer(t *testing.T) {
|
||||
newTestServer()
|
||||
}
|
||||
|
||||
func TestDiscovery(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
httpServer, _ := newTestServer()
|
||||
defer httpServer.Close()
|
||||
|
||||
p, err := oidc.NewProvider(ctx, httpServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get provider: %v", err)
|
||||
}
|
||||
required := []struct {
|
||||
name, val string
|
||||
}{
|
||||
{"issuer", p.Issuer},
|
||||
{"authorization_endpoint", p.AuthURL},
|
||||
{"token_endpoint", p.TokenURL},
|
||||
{"jwks_uri", p.JWKSURL},
|
||||
}
|
||||
for _, field := range required {
|
||||
if field.val == "" {
|
||||
t.Errorf("server discovery is missing required field %q", field.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuth2Flow(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
httpServer, s := newTestServer()
|
||||
defer httpServer.Close()
|
||||
|
||||
p, err := oidc.NewProvider(ctx, httpServer.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get provider: %v", err)
|
||||
}
|
||||
|
||||
var (
|
||||
reqDump, respDump []byte
|
||||
gotCode bool
|
||||
state = "a_state"
|
||||
)
|
||||
defer func() {
|
||||
if !gotCode {
|
||||
t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump)
|
||||
}
|
||||
}()
|
||||
|
||||
var oauth2Config *oauth2.Config
|
||||
oauth2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/callback" {
|
||||
q := r.URL.Query()
|
||||
if errType := q.Get("error"); errType != "" {
|
||||
if desc := q.Get("error_description"); desc != "" {
|
||||
t.Errorf("got error from server %s: %s", errType, desc)
|
||||
} else {
|
||||
t.Errorf("got error from server %s", errType)
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if code := q.Get("code"); code != "" {
|
||||
gotCode = true
|
||||
token, err := oauth2Config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
t.Errorf("failed to exchange code for token: %v", err)
|
||||
return
|
||||
}
|
||||
idToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
t.Errorf("no id token found: %v", err)
|
||||
return
|
||||
}
|
||||
// TODO(ericchiang): validate id token
|
||||
_ = idToken
|
||||
|
||||
token.Expiry = time.Now().Add(time.Second * -10)
|
||||
if token.Valid() {
|
||||
t.Errorf("token shouldn't be valid")
|
||||
}
|
||||
|
||||
newToken, err := oauth2Config.TokenSource(ctx, token).Token()
|
||||
if err != nil {
|
||||
t.Errorf("failed to refresh token: %v", err)
|
||||
return
|
||||
}
|
||||
if token.RefreshToken == newToken.RefreshToken {
|
||||
t.Errorf("old refresh token was the same as the new token %q", token.RefreshToken)
|
||||
}
|
||||
}
|
||||
if gotState := q.Get("state"); gotState != state {
|
||||
t.Errorf("state did not match, want=%q got=%q", state, gotState)
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther)
|
||||
}))
|
||||
|
||||
defer oauth2Server.Close()
|
||||
|
||||
redirectURL := oauth2Server.URL + "/callback"
|
||||
client := storage.Client{
|
||||
ID: "testclient",
|
||||
Secret: "testclientsecret",
|
||||
RedirectURIs: []string{redirectURL},
|
||||
}
|
||||
if err := s.storage.CreateClient(client); err != nil {
|
||||
t.Fatalf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
oauth2Config = &oauth2.Config{
|
||||
ClientID: client.ID,
|
||||
ClientSecret: client.Secret,
|
||||
Endpoint: p.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"},
|
||||
RedirectURL: redirectURL,
|
||||
}
|
||||
|
||||
resp, err := http.Get(oauth2Server.URL + "/login")
|
||||
if err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
if reqDump, err = httputil.DumpRequest(resp.Request, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if respDump, err = httputil.DumpResponse(resp, true); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
101
server/templates.go
Normal file
101
server/templates.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"text/template"
|
||||
|
||||
"github.com/coreos/poke/storage"
|
||||
)
|
||||
|
||||
type connectorInfo struct {
|
||||
DisplayName string
|
||||
URL string
|
||||
}
|
||||
|
||||
var loginTmpl = template.Must(template.New("login-template").Parse(`<html>
|
||||
<head></head>
|
||||
<body>
|
||||
<p>Login options</p>
|
||||
{{ range $i, $connector := .Connectors }}
|
||||
<a href="{{ $connector.URL }}?state={{ $.State }}">{{ $connector.DisplayName }}</a>
|
||||
{{ end }}
|
||||
</body>
|
||||
</html>`))
|
||||
|
||||
func renderLoginOptions(w http.ResponseWriter, connectors []connectorInfo, state string) {
|
||||
data := struct {
|
||||
Connectors []connectorInfo
|
||||
State string
|
||||
}{connectors, state}
|
||||
renderTemplate(w, loginTmpl, data)
|
||||
}
|
||||
|
||||
var passwordTmpl = template.Must(template.New("password-template").Parse(`<html>
|
||||
<body>
|
||||
<p>Login</p>
|
||||
<form action="{{ .Callback }}" method="POST">
|
||||
Login: <input type="text" name="login"/><br/>
|
||||
Password: <input type="password" name="password"/><br/>
|
||||
<input type="hidden" name="state" value="{{ .State }}"/>
|
||||
<input type="submit"/>
|
||||
{{ if .Message }}
|
||||
<p>Error: {{ .Message }}</p>
|
||||
{{ end }}
|
||||
</form>
|
||||
</body>
|
||||
</html>`))
|
||||
|
||||
func renderPasswordTmpl(w http.ResponseWriter, state, callback, message string) {
|
||||
data := struct {
|
||||
State string
|
||||
Callback string
|
||||
Message string
|
||||
}{state, callback, message}
|
||||
renderTemplate(w, passwordTmpl, data)
|
||||
}
|
||||
|
||||
var approvalTmpl = template.Must(template.New("approval-template").Parse(`<html>
|
||||
<body>
|
||||
<p>User: {{ .User }}</p>
|
||||
<p>Client: {{ .ClientName }}</p>
|
||||
<form method="post">
|
||||
<input type="hidden" name="state" value="{{ .State }}"/>
|
||||
<input type="hidden" name="approval" value="approve">
|
||||
<button type="submit">Approve</button>
|
||||
</form>
|
||||
<form method="post">
|
||||
<input type="hidden" name="state" value="{{ .State }}"/>
|
||||
<input type="hidden" name="approval" value="reject">
|
||||
<button type="submit">Reject</button>
|
||||
</form>
|
||||
</body>
|
||||
</html>`))
|
||||
|
||||
func renderApprovalTmpl(w http.ResponseWriter, state string, identity storage.Identity, client storage.Client, scopes []string) {
|
||||
data := struct {
|
||||
User string
|
||||
ClientName string
|
||||
State string
|
||||
}{identity.Email, client.Name, state}
|
||||
renderTemplate(w, approvalTmpl, data)
|
||||
}
|
||||
|
||||
func renderTemplate(w http.ResponseWriter, tmpl *template.Template, data interface{}) {
|
||||
err := tmpl.Execute(w, data)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch err := err.(type) {
|
||||
case template.ExecError:
|
||||
// An ExecError guarentees that Execute has not written to the underlying reader.
|
||||
log.Printf("Error rendering template %s: %s", tmpl.Name(), err)
|
||||
|
||||
// TODO(ericchiang): replace with better internal server error.
|
||||
http.Error(w, "Internal server error", http.StatusInternalServerError)
|
||||
default:
|
||||
// An error with the underlying write, such as the connection being
|
||||
// dropped. Ignore for now.
|
||||
}
|
||||
}
|
1
server/templates_test.go
Normal file
1
server/templates_test.go
Normal file
@@ -0,0 +1 @@
|
||||
package server
|
Reference in New Issue
Block a user