From 6dbb5c4de62f1b19017ca5bc4af9c022db207617 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Fri, 7 Oct 2016 11:27:18 -0700 Subject: [PATCH] server: fix cross client scope prefix --- server/oauth2.go | 4 +- server/server_test.go | 124 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 2 deletions(-) diff --git a/server/oauth2.go b/server/oauth2.go index e236e0c4..003d8947 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -74,7 +74,7 @@ const ( scopeGroups = "groups" scopeEmail = "email" scopeProfile = "profile" - scopeCrossClientPrefix = "oauth2:server:client_id:" + scopeCrossClientPrefix = "audience:server:client_id:" ) const ( @@ -98,7 +98,7 @@ func (a audience) MarshalJSON() ([]byte, error) { if len(a) == 1 { return json.Marshal(a[0]) } - return json.Marshal(a) + return json.Marshal([]string(a)) } type idTokenClaims struct { diff --git a/server/server_test.go b/server/server_test.go index 46fcc710..ea0793aa 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -10,6 +10,8 @@ import ( "net/http/httptest" "net/http/httputil" "net/url" + "reflect" + "sort" "strings" "sync" "testing" @@ -384,6 +386,128 @@ func TestOAuth2ImplicitFlow(t *testing.T) { } } +func TestCrossClientScopes(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + httpServer, s := newTestServer(t, func(c *Config) { + c.Issuer = c.Issuer + "/non-root-path" + }) + defer httpServer.Close() + + p, err := oidc.NewProvider(ctx, httpServer.URL) + if err != nil { + t.Fatalf("failed to get provider: %v", err) + } + + var ( + reqDump, respDump []byte + gotCode bool + state = "a_state" + ) + defer func() { + if !gotCode { + t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump) + } + }() + + testClientID := "testclient" + peerID := "peer" + + var oauth2Config *oauth2.Config + oauth2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/callback" { + q := r.URL.Query() + if errType := q.Get("error"); errType != "" { + if desc := q.Get("error_description"); desc != "" { + t.Errorf("got error from server %s: %s", errType, desc) + } else { + t.Errorf("got error from server %s", errType) + } + w.WriteHeader(http.StatusInternalServerError) + return + } + + if code := q.Get("code"); code != "" { + gotCode = true + token, err := oauth2Config.Exchange(ctx, code) + if err != nil { + t.Errorf("failed to exchange code for token: %v", err) + return + } + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + t.Errorf("no id token found: %v", err) + return + } + idToken, err := p.NewVerifier(ctx).Verify(rawIDToken) + if err != nil { + t.Errorf("failed to parse ID Token: %v", err) + return + } + + sort.Strings(idToken.Audience) + expAudience := []string{peerID, testClientID} + if !reflect.DeepEqual(idToken.Audience, expAudience) { + t.Errorf("expected audience %q, got %q", expAudience, idToken.Audience) + } + + } + if gotState := q.Get("state"); gotState != state { + t.Errorf("state did not match, want=%q got=%q", state, gotState) + } + w.WriteHeader(http.StatusOK) + return + } + http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther) + })) + + defer oauth2Server.Close() + + redirectURL := oauth2Server.URL + "/callback" + client := storage.Client{ + ID: testClientID, + Secret: "testclientsecret", + RedirectURIs: []string{redirectURL}, + } + if err := s.storage.CreateClient(client); err != nil { + t.Fatalf("failed to create client: %v", err) + } + + peer := storage.Client{ + ID: peerID, + Secret: "foobar", + TrustedPeers: []string{"testclient"}, + } + + if err := s.storage.CreateClient(peer); err != nil { + t.Fatalf("failed to create client: %v", err) + } + + oauth2Config = &oauth2.Config{ + ClientID: client.ID, + ClientSecret: client.Secret, + Endpoint: p.Endpoint(), + Scopes: []string{ + oidc.ScopeOpenID, "profile", "email", + "audience:server:client_id:" + client.ID, + "audience:server:client_id:" + peer.ID, + }, + RedirectURL: redirectURL, + } + + resp, err := http.Get(oauth2Server.URL + "/login") + if err != nil { + t.Fatalf("get failed: %v", err) + } + if reqDump, err = httputil.DumpRequest(resp.Request, false); err != nil { + t.Fatal(err) + } + if respDump, err = httputil.DumpResponse(resp, true); err != nil { + t.Fatal(err) + } +} + func TestPasswordDB(t *testing.T) { s := memory.New() conn := newPasswordDB(s)