diff --git a/Makefile b/Makefile index 44c5406e..b585092e 100644 --- a/Makefile +++ b/Makefile @@ -21,11 +21,14 @@ else endif -build: bin/poke +build: bin/poke bin/example-app bin/poke: FORCE @go install $(REPO_PATH)/cmd/poke +bin/example-app: FORCE + @go install $(REPO_PATH)/cmd/example-app + test: @go test $(shell go list ./... | grep -v '/vendor/') diff --git a/cmd/example-app/main.go b/cmd/example-app/main.go new file mode 100644 index 00000000..60223e4a --- /dev/null +++ b/cmd/example-app/main.go @@ -0,0 +1,225 @@ +package main + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "log" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/ericchiang/oidc" + "github.com/spf13/cobra" + "golang.org/x/net/context" + "golang.org/x/oauth2" +) + +type app struct { + clientID string + clientSecret string + redirectURI string + + verifier *oidc.IDTokenVerifier + provider *oidc.Provider + + ctx context.Context + cancel context.CancelFunc +} + +// return an HTTP client which trusts the provided root CAs. +func httpClientForRootCAs(rootCAs string) (*http.Client, error) { + tlsConfig := tls.Config{RootCAs: x509.NewCertPool()} + rootCABytes, err := ioutil.ReadFile(rootCAs) + if err != nil { + return nil, fmt.Errorf("failed to read root-ca: %v", err) + } + if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { + return nil, fmt.Errorf("no certs found in root CA file %q", rootCAs) + } + return &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tlsConfig, + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, + }, nil +} + +func cmd() *cobra.Command { + var ( + a app + issuerURL string + listen string + tlsCert string + tlsKey string + rootCAs string + ) + c := cobra.Command{ + Use: "example-app", + Short: "An example OpenID Connect client", + Long: "", + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) != 0 { + return errors.New("surplus arguments provided") + } + + u, err := url.Parse(a.redirectURI) + if err != nil { + return fmt.Errorf("parse redirect-uri: %v", err) + } + listenURL, err := url.Parse(listen) + if err != nil { + return fmt.Errorf("parse listen address: %v", err) + } + + a.ctx, a.cancel = context.WithCancel(context.Background()) + + if rootCAs != "" { + client, err := httpClientForRootCAs(rootCAs) + if err != nil { + return err + } + + // This sets the OAuth2 client and oidc client. + a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, &client) + } + + // TODO(ericchiang): Retry with backoff + provider, err := oidc.NewProvider(a.ctx, issuerURL) + if err != nil { + return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err) + } + a.provider = provider + a.verifier = provider.NewVerifier(a.ctx, oidc.VerifyAudience(a.clientID)) + + http.HandleFunc("/", a.handleIndex) + http.HandleFunc("/login", a.handleLogin) + http.HandleFunc(u.Path, a.handleCallback) + + switch listenURL.Scheme { + case "http": + log.Printf("listening on %s", listen) + return http.ListenAndServe(listenURL.Host, nil) + case "https": + log.Printf("listening on %s", listen) + return http.ListenAndServeTLS(listenURL.Host, tlsCert, tlsKey, nil) + default: + return fmt.Errorf("listen address %q is not using http or https", listen) + } + }, + } + c.Flags().StringVar(&a.clientID, "client-id", "example-app", "OAuth2 client ID of this application.") + c.Flags().StringVar(&a.clientSecret, "client-secret", "ZXhhbXBsZS1hcHAtc2VjcmV0", "OAuth2 client secret of this application.") + c.Flags().StringVar(&a.redirectURI, "redirect-uri", "http://127.0.0.1:5555/callback", "Callback URL for OAuth2 responses.") + c.Flags().StringVar(&issuerURL, "issuer", "http://127.0.0.1:5556", "URL of the OpenID Connect issuer.") + c.Flags().StringVar(&listen, "listen", "http://127.0.0.1:5555", "HTTP(S) address to listen at.") + c.Flags().StringVar(&tlsCert, "tls-cert", "", "X509 cert file to present when serving HTTPS.") + c.Flags().StringVar(&tlsKey, "tls-key", "", "Private key for the HTTPS cert.") + c.Flags().StringVar(&rootCAs, "issuer-root-ca", "", "Root certificate authorities for the issuer. Defaults to host certs.") + return &c +} + +func main() { + if err := cmd().Execute(); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(2) + } +} + +func (a *app) handleIndex(w http.ResponseWriter, r *http.Request) { + renderIndex(w) +} + +func (a *app) oauth2Config(scopes []string) *oauth2.Config { + return &oauth2.Config{ + ClientID: a.clientID, + ClientSecret: a.clientSecret, + Endpoint: a.provider.Endpoint(), + Scopes: scopes, + RedirectURL: a.redirectURI, + } +} + +func (a *app) handleLogin(w http.ResponseWriter, r *http.Request) { + var scopes []string + if extraScopes := r.FormValue("extra_scopes"); extraScopes != "" { + scopes = strings.Split(extraScopes, " ") + } + var clients []string + if crossClients := r.FormValue("cross_client"); crossClients != "" { + clients = strings.Split(crossClients, " ") + } + for _, client := range clients { + scopes = append(scopes, "audience:server:client_id:"+client) + } + + // TODO(ericchiang): Determine if provider does not support "offline_access" or has + // some other mechanism for requesting refresh tokens. + scopes = append(scopes, "openid", "profile", "email", "offline_access") + http.Redirect(w, r, a.oauth2Config(scopes).AuthCodeURL(""), http.StatusSeeOther) +} + +func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) { + if errMsg := r.FormValue("error"); errMsg != "" { + http.Error(w, errMsg+": "+r.FormValue("error_description"), http.StatusBadRequest) + return + } + + code := r.FormValue("code") + refresh := r.FormValue("refresh_token") + var ( + err error + token *oauth2.Token + ) + oauth2Config := a.oauth2Config(nil) + switch { + case code != "": + token, err = oauth2Config.Exchange(a.ctx, code) + case refresh != "": + t := &oauth2.Token{ + RefreshToken: refresh, + Expiry: time.Now().Add(-time.Hour), + } + token, err = oauth2Config.TokenSource(a.ctx, t).Token() + default: + http.Error(w, "no code in request", http.StatusBadRequest) + return + } + + if err != nil { + http.Error(w, fmt.Sprintf("failed to get token: %v", err), http.StatusInternalServerError) + return + } + + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + http.Error(w, "no id_token in token response", http.StatusInternalServerError) + return + } + + idToken, err := a.verifier.Verify(rawIDToken) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to verify ID token: %v", err), http.StatusInternalServerError) + return + } + var claims json.RawMessage + idToken.Claims(&claims) + + buff := new(bytes.Buffer) + json.Indent(buff, []byte(claims), "", " ") + + renderToken(w, a.redirectURI, rawIDToken, token.RefreshToken, buff.Bytes()) +} diff --git a/cmd/example-app/templates.go b/cmd/example-app/templates.go new file mode 100644 index 00000000..34bb23e4 --- /dev/null +++ b/cmd/example-app/templates.go @@ -0,0 +1,82 @@ +package main + +import ( + "html/template" + "log" + "net/http" +) + +var indexTmpl = template.Must(template.New("index.html").Parse(` +
+ + +`)) + +func renderIndex(w http.ResponseWriter) { + renderTemplate(w, indexTmpl, nil) +} + +type tokenTmplData struct { + IDToken string + RefreshToken string + RedirectURL string + Claims string +} + +var tokenTmpl = template.Must(template.New("token.html").Parse(` + + + + +Token:
{{ .IDToken }}
+ Claims:
{{ .Claims }}
+ Refresh Token:
{{ .RefreshToken }}
+ + + +`)) + +func renderToken(w http.ResponseWriter, redirectURL, idToken, refreshToken string, claims []byte) { + renderTemplate(w, tokenTmpl, tokenTmplData{ + IDToken: idToken, + RefreshToken: refreshToken, + RedirectURL: redirectURL, + Claims: string(claims), + }) +} + +func renderTemplate(w http.ResponseWriter, tmpl *template.Template, data interface{}) { + err := tmpl.Execute(w, data) + if err == nil { + return + } + + switch err := err.(type) { + case *template.Error: + // An ExecError guarentees that Execute has not written to the underlying reader. + log.Printf("Error rendering template %s: %s", tmpl.Name(), err) + + // TODO(ericchiang): replace with better internal server error. + http.Error(w, "Internal server error", http.StatusInternalServerError) + default: + // An error with the underlying write, such as the connection being + // dropped. Ignore for now. + } +}