initial commit

This commit is contained in:
Eric Chiang
2016-07-25 13:00:28 -07:00
commit cab271f304
1438 changed files with 335968 additions and 0 deletions

2
server/doc.go Normal file
View File

@@ -0,0 +1,2 @@
// Package server implements an OpenID Connect server with federated logins.
package server

556
server/handlers.go Normal file
View File

@@ -0,0 +1,556 @@
package server
import (
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"
"github.com/coreos/poke/connector"
"github.com/coreos/poke/storage"
)
func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
// TODO(ericchiang): Cache this.
keys, err := s.storage.GetKeys()
if err != nil {
log.Printf("failed to get keys: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
if keys.SigningKeyPub == nil {
log.Printf("No public keys found.")
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
jwks := jose.JSONWebKeySet{
Keys: make([]jose.JSONWebKey, len(keys.VerificationKeys)+1),
}
jwks.Keys[0] = *keys.SigningKeyPub
for i, verificationKey := range keys.VerificationKeys {
jwks.Keys[i+1] = *verificationKey.PublicKey
}
data, err := json.MarshalIndent(jwks, "", " ")
if err != nil {
log.Printf("failed to marshal discovery data: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
maxAge := keys.NextRotation.Sub(s.now())
if maxAge < (time.Minute * 2) {
maxAge = time.Minute * 2
}
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, must-revalidate", maxAge))
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.Write(data)
}
type discovery struct {
Issuer string `json:"issuer"`
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"`
ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
Scopes []string `json:"scopes_supported"`
AuthMethods []string `json:"token_endpoint_auth_methods_supported"`
Claims []string `json:"claims_supported"`
}
func (s *Server) handleDiscovery(w http.ResponseWriter, r *http.Request) {
// TODO(ericchiang): Cache this
d := discovery{
Issuer: s.issuerURL.String(),
Auth: s.absURL("/auth"),
Token: s.absURL("/token"),
Keys: s.absURL("/keys"),
ResponseTypes: []string{"code"},
Subjects: []string{"public"},
IDTokenAlgs: []string{string(jose.RS256)},
Scopes: []string{"openid", "email", "profile"},
AuthMethods: []string{"client_secret_basic"},
Claims: []string{
"aud", "email", "email_verified", "exp", "family_name", "given_name",
"iat", "iss", "locale", "name", "sub",
},
}
data, err := json.MarshalIndent(d, "", " ")
if err != nil {
log.Printf("failed to marshal discovery data: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.Write(data)
}
// handleAuthorization handles the OAuth2 auth endpoint.
func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
authReq, err := parseAuthorizationRequest(s.storage, r)
if err != nil {
s.renderError(w, http.StatusInternalServerError, err.Type, err.Description)
return
}
if err := s.storage.CreateAuthRequest(authReq); err != nil {
log.Printf("Failed to create authorization request: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
state := authReq.ID
if len(s.connectors) == 1 {
for id := range s.connectors {
http.Redirect(w, r, s.absPath("/auth", id)+"?state="+state, http.StatusFound)
return
}
}
connectorInfos := make([]connectorInfo, len(s.connectors))
i := 0
for id := range s.connectors {
connectorInfos[i] = connectorInfo{
DisplayName: id,
URL: s.absPath("/auth", id) + "?state=" + state,
}
i++
}
renderLoginOptions(w, connectorInfos, state)
}
func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
connID := mux.Vars(r)["connector"]
conn, ok := s.connectors[connID]
if !ok {
s.notFound(w, r)
return
}
// TODO(ericchiang): cache user identity.
state := r.FormValue("state")
switch r.Method {
case "GET":
switch conn := conn.Connector.(type) {
case connector.CallbackConnector:
callbackURL, err := conn.LoginURL(s.absURL("/callback", connID), state)
if err != nil {
log.Printf("Connector %q returned error when creating callback: %v", connID, err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
http.Redirect(w, r, callbackURL, http.StatusFound)
case connector.PasswordConnector:
renderPasswordTmpl(w, state, r.URL.String(), "")
default:
s.notFound(w, r)
}
case "POST":
passwordConnector, ok := conn.Connector.(connector.PasswordConnector)
if !ok {
s.notFound(w, r)
return
}
username := r.FormValue("username")
password := r.FormValue("password")
identity, ok, err := passwordConnector.Login(username, password)
if err != nil {
log.Printf("Failed to login user: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
if !ok {
renderPasswordTmpl(w, state, r.URL.String(), "Invalid credentials")
return
}
groups, ok, err := s.groups(identity, state, conn.Connector)
if err != nil {
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
if ok {
identity.Groups = groups
}
s.redirectToApproval(w, r, identity, connID, state)
default:
s.notFound(w, r)
}
}
func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) {
connID := mux.Vars(r)["connector"]
conn, ok := s.connectors[connID]
if !ok {
s.notFound(w, r)
return
}
callbackConnector, ok := conn.Connector.(connector.CallbackConnector)
if !ok {
s.notFound(w, r)
return
}
identity, state, err := callbackConnector.HandleCallback(r)
if err != nil {
log.Printf("Failed to authenticate: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
groups, ok, err := s.groups(identity, state, conn.Connector)
if err != nil {
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
if ok {
identity.Groups = groups
}
s.redirectToApproval(w, r, identity, connID, state)
}
func (s *Server) redirectToApproval(w http.ResponseWriter, r *http.Request, identity storage.Identity, connectorID, state string) {
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
a.Identity = &identity
a.ConnectorID = connectorID
return a, nil
}
if err := s.storage.UpdateAuthRequest(state, updater); err != nil {
log.Printf("Failed to updated auth request with identity: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
http.Redirect(w, r, path.Join(s.issuerURL.Path, "/approval")+"?state="+state, http.StatusSeeOther)
}
func (s *Server) groups(identity storage.Identity, authReqID string, conn connector.Connector) ([]string, bool, error) {
groupsConn, ok := conn.(connector.GroupsConnector)
if !ok {
return nil, false, nil
}
authReq, err := s.storage.GetAuthRequest(authReqID)
if err != nil {
log.Printf("get auth request: %v", err)
return nil, false, err
}
reqGroups := func() bool {
for _, scope := range authReq.Scopes {
if scope == scopeGroups {
return true
}
}
return false
}()
if !reqGroups {
return nil, false, nil
}
groups, err := groupsConn.Groups(identity)
return groups, true, err
}
func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
authReq, err := s.storage.GetAuthRequest(r.FormValue("state"))
if err != nil {
log.Printf("Failed to get auth request: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
if authReq.Identity == nil {
log.Printf("Auth request does not have an identity for approval")
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
switch r.Method {
case "GET":
if s.skipApproval {
s.sendCodeResponse(w, r, authReq, *authReq.Identity)
return
}
client, err := s.storage.GetClient(authReq.ClientID)
if err != nil {
log.Printf("Failed to get client %q: %v", authReq.ClientID, err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
renderApprovalTmpl(w, authReq.ID, *authReq.Identity, client, authReq.Scopes)
case "POST":
if r.FormValue("approval") != "approve" {
s.renderError(w, http.StatusInternalServerError, "approval rejected", "")
return
}
s.sendCodeResponse(w, r, authReq, *authReq.Identity)
}
}
func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest, identity storage.Identity) {
if authReq.Expiry.After(s.now()) {
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.")
return
}
if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil {
if err != storage.ErrNotFound {
log.Printf("Failed to delete authorization request: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
} else {
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request has already been completed.")
}
return
}
code := storage.AuthCode{
ID: storage.NewNonce(),
ClientID: authReq.ClientID,
ConnectorID: authReq.ConnectorID,
Nonce: authReq.Nonce,
Scopes: authReq.Scopes,
Identity: *authReq.Identity,
Expiry: s.now().Add(time.Minute * 5),
RedirectURI: authReq.RedirectURI,
}
if err := s.storage.CreateAuthCode(code); err != nil {
log.Printf("Failed to create auth code: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
if authReq.RedirectURI == "urn:ietf:wg:oauth:2.0:oob" {
// TODO(ericchiang): Add a proper template.
fmt.Fprintf(w, "Code: %s", code.ID)
return
}
u, err := url.Parse(authReq.RedirectURI)
if err != nil {
s.renderError(w, http.StatusInternalServerError, errServerError, "Invalid redirect URI.")
return
}
q := u.Query()
q.Set("code", code.ID)
q.Set("state", authReq.State)
u.RawQuery = q.Encode()
http.Redirect(w, r, u.String(), http.StatusSeeOther)
}
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
clientID, clientSecret, ok := r.BasicAuth()
if ok {
var err error
if clientID, err = url.QueryUnescape(clientID); err != nil {
tokenErr(w, errInvalidRequest, "client_id improperly encoded", http.StatusBadRequest)
return
}
if clientSecret, err = url.QueryUnescape(clientSecret); err != nil {
tokenErr(w, errInvalidRequest, "client_secret improperly encoded", http.StatusBadRequest)
return
}
} else {
clientID = r.PostFormValue("client_id")
clientSecret = r.PostFormValue("client_secret")
}
client, err := s.storage.GetClient(clientID)
if err != nil {
if err != storage.ErrNotFound {
log.Printf("failed to get client: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
} else {
tokenErr(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
}
return
}
if client.Secret != clientSecret {
tokenErr(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return
}
grantType := r.PostFormValue("grant_type")
switch grantType {
case "authorization_code":
s.handleAuthCode(w, r, client)
case "refresh_token":
s.handleRefreshToken(w, r, client)
default:
tokenErr(w, errInvalidGrant, "", http.StatusBadRequest)
}
}
// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
code := r.PostFormValue("code")
redirectURI := r.PostFormValue("redirect_uri")
authCode, err := s.storage.GetAuthCode(code)
if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID {
if err != storage.ErrNotFound {
log.Printf("failed to get auth code: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
} else {
tokenErr(w, errInvalidRequest, "Invalid or expired code parameter.", http.StatusBadRequest)
}
return
}
if authCode.RedirectURI != redirectURI {
tokenErr(w, errInvalidRequest, "redirect_uri did not match URI from initial request.", http.StatusBadRequest)
return
}
idToken, expiry, err := s.newIDToken(client.ID, authCode.Identity, authCode.Scopes, authCode.Nonce)
if err != nil {
log.Printf("failed to create ID token: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
if err := s.storage.DeleteAuthCode(code); err != nil {
log.Printf("failed to delete auth code: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
reqRefresh := func() bool {
for _, scope := range authCode.Scopes {
if scope == scopeOfflineAccess {
return true
}
}
return false
}()
var refreshToken string
if reqRefresh {
refresh := storage.Refresh{
RefreshToken: storage.NewNonce(),
ClientID: authCode.ClientID,
ConnectorID: authCode.ConnectorID,
Scopes: authCode.Scopes,
Identity: authCode.Identity,
Nonce: authCode.Nonce,
}
if err := s.storage.CreateRefresh(refresh); err != nil {
log.Printf("failed to create refresh token: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
refreshToken = refresh.RefreshToken
}
s.writeAccessToken(w, idToken, refreshToken, expiry)
}
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
code := r.PostFormValue("refresh_token")
scope := r.PostFormValue("scope")
if code == "" {
tokenErr(w, errInvalidRequest, "No refresh token in request.", http.StatusBadRequest)
return
}
refresh, err := s.storage.GetRefresh(code)
if err != nil || refresh.ClientID != client.ID {
if err != storage.ErrNotFound {
log.Printf("failed to get auth code: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
} else {
tokenErr(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
}
return
}
scopes := refresh.Scopes
if scope != "" {
requestedScopes := strings.Split(scope, " ")
contains := func() bool {
Loop:
for _, s := range requestedScopes {
for _, scope := range refresh.Scopes {
if s == scope {
continue Loop
}
}
return false
}
return true
}()
if !contains {
tokenErr(w, errInvalidRequest, "Requested scopes did not contain authorized scopes.", http.StatusBadRequest)
return
}
scopes = requestedScopes
}
// TODO(ericchiang): re-auth with backends
idToken, expiry, err := s.newIDToken(client.ID, refresh.Identity, scopes, refresh.Nonce)
if err != nil {
log.Printf("failed to create ID token: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
if err := s.storage.DeleteRefresh(code); err != nil {
log.Printf("failed to delete auth code: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
refresh.RefreshToken = storage.NewNonce()
if err := s.storage.CreateRefresh(refresh); err != nil {
log.Printf("failed to create refresh token: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
s.writeAccessToken(w, idToken, refresh.RefreshToken, expiry)
}
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) {
// TODO(ericchiang): figure out an access token story and support the user info
// endpoint. For now use a random value so no one depends on the access_token
// holding a specific structure.
resp := struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token"`
}{
storage.NewNonce(),
"bearer",
int(expiry.Sub(s.now())),
refreshToken,
idToken,
}
data, err := json.Marshal(resp)
if err != nil {
log.Printf("failed to marshal access token response: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.Write(data)
}
func (s *Server) renderError(w http.ResponseWriter, status int, err, description string) {
http.Error(w, fmt.Sprintf("%s: %s", err, description), status)
}
func (s *Server) notFound(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}

1
server/handlers_test.go Normal file
View File

@@ -0,0 +1 @@
package server

339
server/oauth2.go Normal file
View File

@@ -0,0 +1,339 @@
package server
import (
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/coreos/poke/storage"
)
// TODO(ericchiang): clean this file up and figure out more idiomatic error handling.
// authErr is an error response to an authorization request.
// See: https://tools.ietf.org/html/rfc6749#section-4.1.2.1
type authErr struct {
State string
RedirectURI string
Type string
Description string
}
func (err *authErr) ServeHTTP(w http.ResponseWriter, r *http.Request) {
v := url.Values{}
v.Add("state", err.State)
v.Add("error", err.Type)
if err.Description != "" {
v.Add("error_description", err.Description)
}
var redirectURI string
if strings.Contains(err.RedirectURI, "?") {
redirectURI = err.RedirectURI + "&" + v.Encode()
} else {
redirectURI = err.RedirectURI + "?" + v.Encode()
}
http.Redirect(w, r, redirectURI, http.StatusSeeOther)
}
func tokenErr(w http.ResponseWriter, typ, description string, statusCode int) {
data := struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}{typ, description}
body, err := json.Marshal(data)
if err != nil {
log.Printf("failed to marshal token error response: %v", err)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(body)))
w.Write(body)
}
const (
errInvalidRequest = "invalid_request"
errUnauthorizedClient = "unauthorized_client"
errAccessDenied = "access_denied"
errUnsupportedResponseType = "unsupported_response_type"
errInvalidScope = "invalid_scope"
errServerError = "server_error"
errTemporarilyUnavailable = "temporarily_unavailable"
errUnsupportedGrantType = "unsupported_grant_type"
errInvalidGrant = "invalid_grant"
errInvalidClient = "invalid_client"
)
const (
scopeOfflineAccess = "offline_access" // Request a refresh token.
scopeOpenID = "openid"
scopeGroups = "groups"
scopeEmail = "email"
scopeProfile = "profile"
scopeCrossClientPrefix = "oauth2:server:client_id:"
)
const (
grantTypeAuthorizationCode = "code"
grantTypeRefreshToken = "refresh_token"
)
const (
responseTypeCode = "code" // "Regular" flow
responseTypeToken = "token" // Implicit flow for frontend apps.
responseTypeIDToken = "id_token" // ID Token in url fragment
)
var validResponseTypes = map[string]bool{
"code": true,
"token": true,
"id_token": true,
}
type audience []string
func (a audience) MarshalJSON() ([]byte, error) {
if len(a) == 1 {
return json.Marshal(a[0])
}
return json.Marshal(a)
}
type idTokenClaims struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Audience audience `json:"aud"`
Expiry int64 `json:"exp"`
IssuedAt int64 `json:"iat"`
AuthorizingParty string `json:"azp,omitempty"`
Nonce string `json:"nonce,omitempty"`
Email string `json:"email,omitempty"`
EmailVerified *bool `json:"email_verified,omitempty"`
Groups []string `json:"groups,omitempty"`
Name string `json:"name,omitempty"`
}
func (s *Server) newIDToken(clientID string, claims storage.Identity, scopes []string, nonce string) (idToken string, expiry time.Time, err error) {
issuedAt := s.now()
expiry = issuedAt.Add(s.idTokensValidFor)
tok := idTokenClaims{
Issuer: s.issuerURL.String(),
Subject: claims.UserID,
Nonce: nonce,
Expiry: expiry.Unix(),
IssuedAt: issuedAt.Unix(),
}
for _, scope := range scopes {
switch {
case scope == scopeEmail:
tok.Email = claims.Email
tok.EmailVerified = &claims.EmailVerified
case scope == scopeGroups:
tok.Groups = claims.Groups
case scope == scopeProfile:
tok.Name = claims.Username
default:
peerID, ok := parseCrossClientScope(scope)
if !ok {
continue
}
isTrusted, err := validateCrossClientTrust(s.storage, clientID, peerID)
if err != nil {
return "", expiry, err
}
if !isTrusted {
// TODO(ericchiang): propagate this error to the client.
return "", expiry, fmt.Errorf("peer (%s) does not trust client", peerID)
}
tok.Audience = append(tok.Audience, peerID)
}
}
if len(tok.Audience) == 0 {
tok.Audience = audience{clientID}
} else {
tok.AuthorizingParty = clientID
}
payload, err := json.Marshal(tok)
if err != nil {
return "", expiry, fmt.Errorf("could not serialize claims: %v", err)
}
keys, err := s.storage.GetKeys()
if err != nil {
log.Printf("Failed to get keys: %v", err)
return "", expiry, err
}
if idToken, err = keys.Sign(payload); err != nil {
return "", expiry, fmt.Errorf("failed to sign payload: %v", err)
}
return idToken, expiry, nil
}
// parse the initial request from the OAuth2 client.
//
// For correctness the logic is largely copied from https://github.com/RangelReale/osin.
func parseAuthorizationRequest(s storage.Storage, r *http.Request) (req storage.AuthRequest, oauth2Err *authErr) {
if err := r.ParseForm(); err != nil {
return req, &authErr{"", "", errInvalidRequest, "Failed to parse request."}
}
redirectURI, err := url.QueryUnescape(r.Form.Get("redirect_uri"))
if err != nil {
return req, &authErr{"", "", errInvalidRequest, "No redirect_uri provided."}
}
state := r.FormValue("state")
clientID := r.Form.Get("client_id")
client, err := s.GetClient(clientID)
if err != nil {
if err == storage.ErrNotFound {
description := fmt.Sprintf("Invalid client_id (%q).", clientID)
return req, &authErr{"", "", errUnauthorizedClient, description}
}
log.Printf("Failed to get client: %v", err)
return req, &authErr{"", "", errServerError, ""}
}
if !validateRedirectURI(client, redirectURI) {
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
return req, &authErr{"", "", errInvalidRequest, description}
}
newErr := func(typ, format string, a ...interface{}) *authErr {
return &authErr{state, redirectURI, typ, fmt.Sprintf(format, a...)}
}
scopes := strings.Split(r.Form.Get("scope"), " ")
var (
unrecognized []string
invalidScopes []string
)
hasOpenIDScope := false
for _, scope := range scopes {
switch scope {
case scopeOpenID:
hasOpenIDScope = true
case scopeOfflineAccess, scopeEmail, scopeProfile, scopeGroups:
default:
peerID, ok := parseCrossClientScope(scope)
if !ok {
unrecognized = append(unrecognized, scope)
continue
}
isTrusted, err := validateCrossClientTrust(s, clientID, peerID)
if err != nil {
return req, newErr(errServerError, "")
}
if !isTrusted {
invalidScopes = append(invalidScopes, scope)
}
}
}
if !hasOpenIDScope {
return req, newErr("invalid_scope", `Missing required scope(s) ["openid"].`)
}
if len(unrecognized) > 0 {
return req, newErr("invalid_scope", "Unrecognized scope(s) %q", unrecognized)
}
if len(invalidScopes) > 0 {
return req, newErr("invalid_scope", "Client can't request scope(s) %q", invalidScopes)
}
responseTypes := strings.Split(r.Form.Get("response_type"), " ")
for _, responseType := range responseTypes {
if !validResponseTypes[responseType] {
return req, newErr("invalid_request", "Invalid response type %q", responseType)
}
}
return storage.AuthRequest{
ID: storage.NewNonce(),
ClientID: client.ID,
State: r.Form.Get("state"),
Nonce: r.Form.Get("nonce"),
ForceApprovalPrompt: r.Form.Get("approval_prompt") == "force",
Scopes: scopes,
RedirectURI: redirectURI,
ResponseTypes: responseTypes,
}, nil
}
func parseCrossClientScope(scope string) (peerID string, ok bool) {
if ok = strings.HasPrefix(scope, scopeCrossClientPrefix); ok {
peerID = scope[len(scopeCrossClientPrefix):]
}
return
}
func validateCrossClientTrust(s storage.Storage, clientID, peerID string) (trusted bool, err error) {
if peerID == clientID {
return true, nil
}
peer, err := s.GetClient(peerID)
if err != nil {
if err != storage.ErrNotFound {
log.Printf("Failed to get client: %v", err)
return false, err
}
return false, nil
}
for _, id := range peer.TrustedPeers {
if id == clientID {
return true, nil
}
}
return false, nil
}
func validateRedirectURI(client storage.Client, redirectURI string) bool {
if !client.Public {
for _, uri := range client.RedirectURIs {
if redirectURI == uri {
return true
}
}
return false
}
if redirectURI == "urn:ietf:wg:oauth:2.0:oob" {
return true
}
if !strings.HasPrefix(redirectURI, "http://localhost:") {
return false
}
n, err := strconv.Atoi(strings.TrimPrefix(redirectURI, "https://localhost:"))
return err == nil && n <= 0
}
type tokenRequest struct {
Client storage.Client
IsRefresh bool
Token string
RedirectURI string
Scopes []string
}
func handleTokenRequest(s storage.Storage, w http.ResponseWriter, r *http.Request) *authErr {
return nil
}
func handleRefreshRequest(s storage.Storage, w http.ResponseWriter, r *http.Request, client storage.Client) *authErr {
return nil
}
func handleCodeRequest(s storage.Storage, w http.ResponseWriter, r *http.Request, client storage.Client) *authErr {
return nil
}

1
server/oauth2_test.go Normal file
View File

@@ -0,0 +1 @@
package server

165
server/rotation.go Normal file
View File

@@ -0,0 +1,165 @@
package server
import (
"crypto/rand"
"crypto/rsa"
"encoding/hex"
"errors"
"fmt"
"io"
"log"
"time"
"golang.org/x/net/context"
"gopkg.in/square/go-jose.v2"
"github.com/coreos/poke/storage"
)
// rotationStrategy describes a strategy for generating cryptographic keys, how
// often to rotate them, and how long they can validate signatures after rotation.
type rotationStrategy struct {
// Time between rotations.
period time.Duration
// After being rotated how long can a key validate signatues?
verifyFor time.Duration
// Keys are always RSA keys. Though cryptopasta recommends ECDSA keys, not every
// client may support these (e.g. github.com/coreos/go-oidc/oidc).
key func() (*rsa.PrivateKey, error)
}
// staticRotationStrategy returns a strategy which never rotates keys.
func staticRotationStrategy(key *rsa.PrivateKey) rotationStrategy {
return rotationStrategy{
// Setting these values to 100 years is easier than having a flag indicating no rotation.
period: time.Hour * 8760 * 100,
verifyFor: time.Hour * 8760 * 100,
key: func() (*rsa.PrivateKey, error) { return key, nil },
}
}
// defaultRotationStrategy returns a strategy which rotates keys every provided period,
// holding onto the public parts for some specified amount of time.
func defaultRotationStrategy(rotationPeriod, verifyFor time.Duration) rotationStrategy {
return rotationStrategy{
period: rotationPeriod,
verifyFor: verifyFor,
key: func() (*rsa.PrivateKey, error) {
return rsa.GenerateKey(rand.Reader, 2048)
},
}
}
type keyRotater struct {
storage.Storage
strategy rotationStrategy
cancel context.CancelFunc
now func() time.Time
}
func storageWithKeyRotation(s storage.Storage, strategy rotationStrategy, now func() time.Time) storage.Storage {
if now == nil {
now = time.Now
}
ctx, cancel := context.WithCancel(context.Background())
rotater := keyRotater{s, strategy, cancel, now}
// Try to rotate immediately so properly configured storages will return a
// storage with keys.
if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err)
}
go func() {
select {
case <-ctx.Done():
return
case <-time.After(time.Second * 30):
if err := rotater.rotate(); err != nil {
log.Printf("failed to rotate keys: %v", err)
}
}
}()
return rotater
}
func (k keyRotater) Close() error {
k.cancel()
return k.Storage.Close()
}
func (k keyRotater) rotate() error {
keys, err := k.GetKeys()
if err != nil && err != storage.ErrNotFound {
return fmt.Errorf("get keys: %v", err)
}
if k.now().Before(keys.NextRotation) {
return nil
}
log.Println("keys expired, rotating")
// Generate the key outside of a storage transaction.
key, err := k.strategy.key()
if err != nil {
return fmt.Errorf("generate key: %v", err)
}
b := make([]byte, 20)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
panic(err)
}
keyID := hex.EncodeToString(b)
priv := &jose.JSONWebKey{
Key: key,
KeyID: keyID,
Algorithm: "RS256",
Use: "sig",
}
pub := &jose.JSONWebKey{
Key: key.Public(),
KeyID: keyID,
Algorithm: "RS256",
Use: "sig",
}
var nextRotation time.Time
err = k.Storage.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
tNow := k.now()
if tNow.Before(keys.NextRotation) {
return storage.Keys{}, errors.New("keys already rotated")
}
// Remove expired verification keys.
i := 0
for _, key := range keys.VerificationKeys {
if !key.Expiry.After(tNow) {
keys.VerificationKeys[i] = key
i++
}
}
keys.VerificationKeys = keys.VerificationKeys[:i]
if keys.SigningKeyPub != nil {
// Move current signing key to a verification only key.
verificationKey := storage.VerificationKey{
PublicKey: keys.SigningKeyPub,
Expiry: tNow.Add(k.strategy.verifyFor),
}
keys.VerificationKeys = append(keys.VerificationKeys, verificationKey)
}
nextRotation = k.now().Add(k.strategy.period)
keys.SigningKey = priv
keys.SigningKeyPub = pub
keys.NextRotation = nextRotation
return keys, nil
})
if err != nil {
return err
}
log.Printf("keys rotated, next rotation: %s", nextRotation)
return nil
}

1
server/rotation_test.go Normal file
View File

@@ -0,0 +1 @@
package server

141
server/server.go Normal file
View File

@@ -0,0 +1,141 @@
package server
import (
"errors"
"fmt"
"net/http"
"net/url"
"path"
"time"
"github.com/gorilla/mux"
"github.com/coreos/poke/connector"
"github.com/coreos/poke/storage"
)
// Connector is a connector with metadata.
type Connector struct {
ID string
DisplayName string
Connector connector.Connector
}
// Config holds the server's configuration options.
type Config struct {
Issuer string
// The backing persistence layer.
Storage storage.Storage
// Strategies for federated identity.
Connectors []Connector
// NOTE: Multiple servers using the same storage are expected to set rotation and
// validity periods to the same values.
RotateKeysAfter time.Duration // Defaults to 6 hours.
IDTokensValidFor time.Duration // Defaults to 24 hours
// If specified, the server will use this function for determining time.
Now func() time.Time
}
func value(val, defaultValue time.Duration) time.Duration {
if val == 0 {
return defaultValue
}
return val
}
// Server is the top level object.
type Server struct {
issuerURL url.URL
// Read-only map of connector IDs to connectors.
connectors map[string]Connector
storage storage.Storage
mux http.Handler
// If enabled, don't prompt user for approval after logging in through connector.
// No package level API to set this, only used in tests.
skipApproval bool
now func() time.Time
idTokensValidFor time.Duration
}
// New constructs a server from the provided config.
func New(c Config) (*Server, error) {
return newServer(c, defaultRotationStrategy(
value(c.RotateKeysAfter, 6*time.Hour),
value(c.IDTokensValidFor, 24*time.Hour),
))
}
func newServer(c Config, rotationStrategy rotationStrategy) (*Server, error) {
issuerURL, err := url.Parse(c.Issuer)
if err != nil {
return nil, fmt.Errorf("server: can't parse issuer URL")
}
if len(c.Connectors) == 0 {
return nil, errors.New("server: no connectors specified")
}
if c.Storage == nil {
return nil, errors.New("server: storage cannot be nil")
}
now := c.Now
if now == nil {
now = time.Now
}
s := &Server{
issuerURL: *issuerURL,
connectors: make(map[string]Connector),
storage: storageWithKeyRotation(c.Storage, rotationStrategy, now),
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
now: now,
}
for _, conn := range c.Connectors {
s.connectors[conn.ID] = conn
}
r := mux.NewRouter()
handleFunc := func(p string, h http.HandlerFunc) {
r.HandleFunc(path.Join(issuerURL.Path, p), h)
}
r.NotFoundHandler = http.HandlerFunc(s.notFound)
// TODO(ericchiang): rate limit certain paths based on IP.
handleFunc("/.well-known/openid-configuration", s.handleDiscovery)
handleFunc("/token", s.handleToken)
handleFunc("/keys", s.handlePublicKeys)
handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/callback/{connector}", s.handleConnectorCallback)
handleFunc("/approval", s.handleApproval)
s.mux = r
return s, nil
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.mux.ServeHTTP(w, r)
}
func (s *Server) absPath(pathItems ...string) string {
paths := make([]string, len(pathItems)+1)
paths[0] = s.issuerURL.Path
copy(paths[1:], pathItems)
return path.Join(paths...)
}
func (s *Server) absURL(pathItems ...string) string {
u := s.issuerURL
u.Path = s.absPath(pathItems...)
return u.String()
}

221
server/server_test.go Normal file
View File

@@ -0,0 +1,221 @@
package server
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"net/http"
"net/http/httptest"
"net/http/httputil"
"testing"
"time"
"github.com/ericchiang/oidc"
"golang.org/x/net/context"
"golang.org/x/oauth2"
"github.com/coreos/poke/connector/mock"
"github.com/coreos/poke/storage"
"github.com/coreos/poke/storage/memory"
)
func mustLoad(s string) *rsa.PrivateKey {
block, _ := pem.Decode([]byte(s))
if block == nil {
panic("no pem data found")
}
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
panic(err)
}
return key
}
var testKey = mustLoad(`-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEArmoiX5G36MKPiVGS1sicruEaGRrbhPbIKOf97aGGQRjXVngo
Knwd2L4T9CRyABgQm3tLHHcT5crODoy46wX2g9onTZWViWWuhJ5wxXNmUbCAPWHb
j9SunW53WuLYZ/IJLNZt5XYCAFPjAakWp8uMuuDwWo5EyFaw85X3FSMhVmmaYDd0
cn+1H4+NS/52wX7tWmyvGUNJ8lzjFAnnOtBJByvkyIC7HDphkLQV4j//sMNY1mPX
HbsYgFv2J/LIJtkjdYO2UoDhZG3Gvj16fMy2JE2owA8IX4/s+XAmA2PiTfd0J5b4
drAKEcdDl83G6L3depEkTkfvp0ZLsh9xupAvIwIDAQABAoIBABKGgWonPyKA7+AF
AxS/MC0/CZebC6/+ylnV8lm4K1tkuRKdJp8EmeL4pYPsDxPFepYZLWwzlbB1rxdK
iSWld36fwEb0WXLDkxrQ/Wdrj3Wjyqs6ZqjLTVS5dAH6UEQSKDlT+U5DD4lbX6RA
goCGFUeQNtdXfyTMWHU2+4yKM7NKzUpczFky+0d10Mg0ANj3/4IILdr3hqkmMSI9
1TB9ksWBXJxt3nGxAjzSFihQFUlc231cey/HhYbvAX5fN0xhLxOk88adDcdXE7br
3Ser1q6XaaFQSMj4oi1+h3RAT9MUjJ6johEqjw0PbEZtOqXvA1x5vfFdei6SqgKn
Am3BspkCgYEA2lIiKEkT/Je6ZH4Omhv9atbGoBdETAstL3FnNQjkyVau9f6bxQkl
4/sz985JpaiasORQBiTGY8JDT/hXjROkut91agi2Vafhr29L/mto7KZglfDsT4b2
9z/EZH8wHw7eYhvdoBbMbqNDSI8RrGa4mpLpuN+E0wsFTzSZEL+QMQUCgYEAzIQh
xnreQvDAhNradMqLmxRpayn1ORaPReD4/off+mi7hZRLKtP0iNgEVEWHJ6HEqqi1
r38XAc8ap/lfOVMar2MLyCFOhYspdHZ+TGLZfr8gg/Fzeq9IRGKYadmIKVwjMeyH
REPqg1tyrvMOE0HI5oqkko8JTDJ0OyVC0Vc6+AcCgYAqCzkywugLc/jcU35iZVOH
WLdFq1Vmw5w/D7rNdtoAgCYPj6nV5y4Z2o2mgl6ifXbU7BMRK9Hc8lNeOjg6HfdS
WahV9DmRA1SuIWPkKjE5qczd81i+9AHpmakrpWbSBF4FTNKAewOBpwVVGuBPcDTK
59IE3V7J+cxa9YkotYuCNQKBgCwGla7AbHBEm2z+H+DcaUktD7R+B8gOTzFfyLoi
Tdj+CsAquDO0BQQgXG43uWySql+CifoJhc5h4v8d853HggsXa0XdxaWB256yk2Wm
MePTCRDePVm/ufLetqiyp1kf+IOaw1Oyux0j5oA62mDS3Iikd+EE4Z+BjPvefY/L
E2qpAoGAZo5Wwwk7q8b1n9n/ACh4LpE+QgbFdlJxlfFLJCKstl37atzS8UewOSZj
FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
-----END RSA PRIVATE KEY-----`)
func newTestServer() (*httptest.Server, *Server) {
var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r)
}))
config := Config{
Issuer: s.URL,
Storage: memory.New(),
Connectors: []Connector{
{
ID: "mock",
DisplayName: "Mock",
Connector: mock.New(),
},
},
}
var err error
if server, err = newServer(config, staticRotationStrategy(testKey)); err != nil {
panic(err)
}
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
return s, server
}
func TestNewTestServer(t *testing.T) {
newTestServer()
}
func TestDiscovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, _ := newTestServer()
defer httpServer.Close()
p, err := oidc.NewProvider(ctx, httpServer.URL)
if err != nil {
t.Fatalf("failed to get provider: %v", err)
}
required := []struct {
name, val string
}{
{"issuer", p.Issuer},
{"authorization_endpoint", p.AuthURL},
{"token_endpoint", p.TokenURL},
{"jwks_uri", p.JWKSURL},
}
for _, field := range required {
if field.val == "" {
t.Errorf("server discovery is missing required field %q", field.name)
}
}
}
func TestOAuth2Flow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer()
defer httpServer.Close()
p, err := oidc.NewProvider(ctx, httpServer.URL)
if err != nil {
t.Fatalf("failed to get provider: %v", err)
}
var (
reqDump, respDump []byte
gotCode bool
state = "a_state"
)
defer func() {
if !gotCode {
t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump)
}
}()
var oauth2Config *oauth2.Config
oauth2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/callback" {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
if desc := q.Get("error_description"); desc != "" {
t.Errorf("got error from server %s: %s", errType, desc)
} else {
t.Errorf("got error from server %s", errType)
}
w.WriteHeader(http.StatusInternalServerError)
return
}
if code := q.Get("code"); code != "" {
gotCode = true
token, err := oauth2Config.Exchange(ctx, code)
if err != nil {
t.Errorf("failed to exchange code for token: %v", err)
return
}
idToken, ok := token.Extra("id_token").(string)
if !ok {
t.Errorf("no id token found: %v", err)
return
}
// TODO(ericchiang): validate id token
_ = idToken
token.Expiry = time.Now().Add(time.Second * -10)
if token.Valid() {
t.Errorf("token shouldn't be valid")
}
newToken, err := oauth2Config.TokenSource(ctx, token).Token()
if err != nil {
t.Errorf("failed to refresh token: %v", err)
return
}
if token.RefreshToken == newToken.RefreshToken {
t.Errorf("old refresh token was the same as the new token %q", token.RefreshToken)
}
}
if gotState := q.Get("state"); gotState != state {
t.Errorf("state did not match, want=%q got=%q", state, gotState)
}
w.WriteHeader(http.StatusOK)
return
}
http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther)
}))
defer oauth2Server.Close()
redirectURL := oauth2Server.URL + "/callback"
client := storage.Client{
ID: "testclient",
Secret: "testclientsecret",
RedirectURIs: []string{redirectURL},
}
if err := s.storage.CreateClient(client); err != nil {
t.Fatalf("failed to create client: %v", err)
}
oauth2Config = &oauth2.Config{
ClientID: client.ID,
ClientSecret: client.Secret,
Endpoint: p.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email", "offline_access"},
RedirectURL: redirectURL,
}
resp, err := http.Get(oauth2Server.URL + "/login")
if err != nil {
t.Fatalf("get failed: %v", err)
}
if reqDump, err = httputil.DumpRequest(resp.Request, false); err != nil {
t.Fatal(err)
}
if respDump, err = httputil.DumpResponse(resp, true); err != nil {
t.Fatal(err)
}
}

101
server/templates.go Normal file
View File

@@ -0,0 +1,101 @@
package server
import (
"log"
"net/http"
"text/template"
"github.com/coreos/poke/storage"
)
type connectorInfo struct {
DisplayName string
URL string
}
var loginTmpl = template.Must(template.New("login-template").Parse(`<html>
<head></head>
<body>
<p>Login options</p>
{{ range $i, $connector := .Connectors }}
<a href="{{ $connector.URL }}?state={{ $.State }}">{{ $connector.DisplayName }}</a>
{{ end }}
</body>
</html>`))
func renderLoginOptions(w http.ResponseWriter, connectors []connectorInfo, state string) {
data := struct {
Connectors []connectorInfo
State string
}{connectors, state}
renderTemplate(w, loginTmpl, data)
}
var passwordTmpl = template.Must(template.New("password-template").Parse(`<html>
<body>
<p>Login</p>
<form action="{{ .Callback }}" method="POST">
Login: <input type="text" name="login"/><br/>
Password: <input type="password" name="password"/><br/>
<input type="hidden" name="state" value="{{ .State }}"/>
<input type="submit"/>
{{ if .Message }}
<p>Error: {{ .Message }}</p>
{{ end }}
</form>
</body>
</html>`))
func renderPasswordTmpl(w http.ResponseWriter, state, callback, message string) {
data := struct {
State string
Callback string
Message string
}{state, callback, message}
renderTemplate(w, passwordTmpl, data)
}
var approvalTmpl = template.Must(template.New("approval-template").Parse(`<html>
<body>
<p>User: {{ .User }}</p>
<p>Client: {{ .ClientName }}</p>
<form method="post">
<input type="hidden" name="state" value="{{ .State }}"/>
<input type="hidden" name="approval" value="approve">
<button type="submit">Approve</button>
</form>
<form method="post">
<input type="hidden" name="state" value="{{ .State }}"/>
<input type="hidden" name="approval" value="reject">
<button type="submit">Reject</button>
</form>
</body>
</html>`))
func renderApprovalTmpl(w http.ResponseWriter, state string, identity storage.Identity, client storage.Client, scopes []string) {
data := struct {
User string
ClientName string
State string
}{identity.Email, client.Name, state}
renderTemplate(w, approvalTmpl, data)
}
func renderTemplate(w http.ResponseWriter, tmpl *template.Template, data interface{}) {
err := tmpl.Execute(w, data)
if err == nil {
return
}
switch err := err.(type) {
case template.ExecError:
// An ExecError guarentees that Execute has not written to the underlying reader.
log.Printf("Error rendering template %s: %s", tmpl.Name(), err)
// TODO(ericchiang): replace with better internal server error.
http.Error(w, "Internal server error", http.StatusInternalServerError)
default:
// An error with the underlying write, such as the connection being
// dropped. Ignore for now.
}
}

1
server/templates_test.go Normal file
View File

@@ -0,0 +1 @@
package server