vendor: revendor using glide-vc
This commit is contained in:
21
vendor/github.com/coreos/go-oidc/example/README.md
generated
vendored
21
vendor/github.com/coreos/go-oidc/example/README.md
generated
vendored
@@ -1,21 +0,0 @@
|
||||
# Examples
|
||||
|
||||
These are example uses of the oidc package. Each requires a Google account and the client ID and secret of a registered OAuth2 application. To create one:
|
||||
|
||||
1. Visit your [Google Developer Console][google-developer-console].
|
||||
2. Click "Credentials" on the left column.
|
||||
3. Click the "Create credentials" button followed by "OAuth client ID".
|
||||
4. Select "Web application" and add "http://127.0.0.1:5556/auth/google/callback" as an authorized redirect URI.
|
||||
5. Click create and add the printed client ID and secret to your environment using the following variables:
|
||||
|
||||
```
|
||||
GOOGLE_OAUTH2_CLIENT_ID
|
||||
GOOGLE_OAUTH2_CLIENT_SECRET
|
||||
```
|
||||
|
||||
Finally run the examples using the Go tool and navigate to http://127.0.0.1:5556.
|
||||
|
||||
```
|
||||
go run ./examples/idtoken/app.go
|
||||
```
|
||||
[google-developer-console]: https://console.developers.google.com/apis/dashboard
|
||||
89
vendor/github.com/coreos/go-oidc/example/idtoken/app.go
generated
vendored
89
vendor/github.com/coreos/go-oidc/example/idtoken/app.go
generated
vendored
@@ -1,89 +0,0 @@
|
||||
/*
|
||||
This is an example application to demonstrate parsing an ID Token.
|
||||
*/
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var (
|
||||
clientID = os.Getenv("GOOGLE_OAUTH2_CLIENT_ID")
|
||||
clientSecret = os.Getenv("GOOGLE_OAUTH2_CLIENT_SECRET")
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
verifier := provider.Verifier()
|
||||
|
||||
config := oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: "http://127.0.0.1:5556/auth/google/callback",
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
state := "foobar" // Don't do this in production.
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, config.AuthCodeURL(state), http.StatusFound)
|
||||
})
|
||||
|
||||
http.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("state") != state {
|
||||
http.Error(w, "state did not match", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
idToken, err := verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Token.AccessToken = "*REDACTED*"
|
||||
|
||||
resp := struct {
|
||||
OAuth2Token *oauth2.Token
|
||||
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
|
||||
}{oauth2Token, new(json.RawMessage)}
|
||||
|
||||
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
data, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(data)
|
||||
})
|
||||
|
||||
log.Printf("listening on http://%s/", "127.0.0.1:5556")
|
||||
log.Fatal(http.ListenAndServe("127.0.0.1:5556", nil))
|
||||
}
|
||||
104
vendor/github.com/coreos/go-oidc/example/nonce/app.go
generated
vendored
104
vendor/github.com/coreos/go-oidc/example/nonce/app.go
generated
vendored
@@ -1,104 +0,0 @@
|
||||
/*
|
||||
This is an example application to demonstrate verifying an ID Token with a nonce.
|
||||
*/
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var (
|
||||
clientID = os.Getenv("GOOGLE_OAUTH2_CLIENT_ID")
|
||||
clientSecret = os.Getenv("GOOGLE_OAUTH2_CLIENT_SECRET")
|
||||
)
|
||||
|
||||
const appNonce = "a super secret nonce"
|
||||
|
||||
// Create a nonce source.
|
||||
type nonceSource struct{}
|
||||
|
||||
func (n nonceSource) ClaimNonce(nonce string) error {
|
||||
if nonce != appNonce {
|
||||
return errors.New("unregonized nonce")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Use the nonce source to create a custom ID Token verifier.
|
||||
nonceEnabledVerifier := provider.Verifier(oidc.VerifyNonce(nonceSource{}))
|
||||
|
||||
config := oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: "http://127.0.0.1:5556/auth/google/callback",
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
state := "foobar" // Don't do this in production.
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, config.AuthCodeURL(state, oidc.Nonce(appNonce)), http.StatusFound)
|
||||
})
|
||||
|
||||
http.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("state") != state {
|
||||
http.Error(w, "state did not match", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// Verify the ID Token signature and nonce.
|
||||
idToken, err := nonceEnabledVerifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
OAuth2Token *oauth2.Token
|
||||
IDTokenClaims *json.RawMessage // ID Token payload is just JSON.
|
||||
}{oauth2Token, new(json.RawMessage)}
|
||||
|
||||
if err := idToken.Claims(&resp.IDTokenClaims); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
data, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(data)
|
||||
})
|
||||
|
||||
log.Printf("listening on http://%s/", "127.0.0.1:5556")
|
||||
log.Fatal(http.ListenAndServe("127.0.0.1:5556", nil))
|
||||
}
|
||||
76
vendor/github.com/coreos/go-oidc/example/userinfo/app.go
generated
vendored
76
vendor/github.com/coreos/go-oidc/example/userinfo/app.go
generated
vendored
@@ -1,76 +0,0 @@
|
||||
/*
|
||||
This is an example application to demonstrate querying the user info endpoint.
|
||||
*/
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var (
|
||||
clientID = os.Getenv("GOOGLE_OAUTH2_CLIENT_ID")
|
||||
clientSecret = os.Getenv("GOOGLE_OAUTH2_CLIENT_SECRET")
|
||||
)
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
|
||||
provider, err := oidc.NewProvider(ctx, "https://accounts.google.com")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
config := oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: "http://127.0.0.1:5556/auth/google/callback",
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
|
||||
state := "foobar" // Don't do this in production.
|
||||
|
||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, config.AuthCodeURL(state), http.StatusFound)
|
||||
})
|
||||
|
||||
http.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("state") != state {
|
||||
http.Error(w, "state did not match", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp := struct {
|
||||
OAuth2Token *oauth2.Token
|
||||
UserInfo *oidc.UserInfo
|
||||
}{oauth2Token, userInfo}
|
||||
data, err := json.MarshalIndent(resp, "", " ")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(data)
|
||||
})
|
||||
|
||||
log.Printf("listening on http://%s/", "127.0.0.1:5556")
|
||||
log.Fatal(http.ListenAndServe("127.0.0.1:5556", nil))
|
||||
}
|
||||
7
vendor/github.com/coreos/go-oidc/http/client.go
generated
vendored
7
vendor/github.com/coreos/go-oidc/http/client.go
generated
vendored
@@ -1,7 +0,0 @@
|
||||
package http
|
||||
|
||||
import "net/http"
|
||||
|
||||
type Client interface {
|
||||
Do(*http.Request) (*http.Response, error)
|
||||
}
|
||||
2
vendor/github.com/coreos/go-oidc/http/doc.go
generated
vendored
2
vendor/github.com/coreos/go-oidc/http/doc.go
generated
vendored
@@ -1,2 +0,0 @@
|
||||
// Package http is DEPRECATED. Use net/http instead.
|
||||
package http
|
||||
156
vendor/github.com/coreos/go-oidc/http/http.go
generated
vendored
156
vendor/github.com/coreos/go-oidc/http/http.go
generated
vendored
@@ -1,156 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func WriteError(w http.ResponseWriter, code int, msg string) {
|
||||
e := struct {
|
||||
Error string `json:"error"`
|
||||
}{
|
||||
Error: msg,
|
||||
}
|
||||
b, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
log.Printf("go-oidc: failed to marshal %#v: %v", e, err)
|
||||
code = http.StatusInternalServerError
|
||||
b = []byte(`{"error":"server_error"}`)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(code)
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
// BasicAuth parses a username and password from the request's
|
||||
// Authorization header. This was pulled from golang master:
|
||||
// https://codereview.appspot.com/76540043
|
||||
func BasicAuth(r *http.Request) (username, password string, ok bool) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth == "" {
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(auth, "Basic ") {
|
||||
return
|
||||
}
|
||||
c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(auth, "Basic "))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, ':')
|
||||
if s < 0 {
|
||||
return
|
||||
}
|
||||
return cs[:s], cs[s+1:], true
|
||||
}
|
||||
|
||||
func cacheControlMaxAge(hdr string) (time.Duration, bool, error) {
|
||||
for _, field := range strings.Split(hdr, ",") {
|
||||
parts := strings.SplitN(strings.TrimSpace(field), "=", 2)
|
||||
k := strings.ToLower(strings.TrimSpace(parts[0]))
|
||||
if k != "max-age" {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(parts) == 1 {
|
||||
return 0, false, errors.New("max-age has no value")
|
||||
}
|
||||
|
||||
v := strings.TrimSpace(parts[1])
|
||||
if v == "" {
|
||||
return 0, false, errors.New("max-age has empty value")
|
||||
}
|
||||
|
||||
age, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
if age <= 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
return time.Duration(age) * time.Second, true, nil
|
||||
}
|
||||
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
func expires(date, expires string) (time.Duration, bool, error) {
|
||||
if date == "" || expires == "" {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
te, err := time.Parse(time.RFC1123, expires)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
td, err := time.Parse(time.RFC1123, date)
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
|
||||
ttl := te.Sub(td)
|
||||
|
||||
// headers indicate data already expired, caller should not
|
||||
// have to care about this case
|
||||
if ttl <= 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
return ttl, true, nil
|
||||
}
|
||||
|
||||
func Cacheable(hdr http.Header) (time.Duration, bool, error) {
|
||||
ttl, ok, err := cacheControlMaxAge(hdr.Get("Cache-Control"))
|
||||
if err != nil || ok {
|
||||
return ttl, ok, err
|
||||
}
|
||||
|
||||
return expires(hdr.Get("Date"), hdr.Get("Expires"))
|
||||
}
|
||||
|
||||
// MergeQuery appends additional query values to an existing URL.
|
||||
func MergeQuery(u url.URL, q url.Values) url.URL {
|
||||
uv := u.Query()
|
||||
for k, vs := range q {
|
||||
for _, v := range vs {
|
||||
uv.Add(k, v)
|
||||
}
|
||||
}
|
||||
u.RawQuery = uv.Encode()
|
||||
return u
|
||||
}
|
||||
|
||||
// NewResourceLocation appends a resource id to the end of the requested URL path.
|
||||
func NewResourceLocation(reqURL *url.URL, id string) string {
|
||||
var u url.URL
|
||||
u = *reqURL
|
||||
u.Path = path.Join(u.Path, id)
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// CopyRequest returns a clone of the provided *http.Request.
|
||||
// The returned object is a shallow copy of the struct and a
|
||||
// deep copy of its Header field.
|
||||
func CopyRequest(r *http.Request) *http.Request {
|
||||
r2 := *r
|
||||
r2.Header = make(http.Header)
|
||||
for k, s := range r.Header {
|
||||
r2.Header[k] = s
|
||||
}
|
||||
return &r2
|
||||
}
|
||||
380
vendor/github.com/coreos/go-oidc/http/http_test.go
generated
vendored
380
vendor/github.com/coreos/go-oidc/http/http_test.go
generated
vendored
@@ -1,380 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCacheControlMaxAgeSuccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
hdr string
|
||||
wantAge time.Duration
|
||||
wantOK bool
|
||||
}{
|
||||
{"max-age=12", 12 * time.Second, true},
|
||||
{"max-age=-12", 0, false},
|
||||
{"max-age=0", 0, false},
|
||||
{"public, max-age=12", 12 * time.Second, true},
|
||||
{"public, max-age=40192, must-revalidate", 40192 * time.Second, true},
|
||||
{"public, not-max-age=12, must-revalidate", time.Duration(0), false},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
maxAge, ok, err := cacheControlMaxAge(tt.hdr)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: err=%v", i, err)
|
||||
}
|
||||
if tt.wantAge != maxAge {
|
||||
t.Errorf("case %d: want=%d got=%d", i, tt.wantAge, maxAge)
|
||||
}
|
||||
if tt.wantOK != ok {
|
||||
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheControlMaxAgeFail(t *testing.T) {
|
||||
tests := []string{
|
||||
"max-age=aasdf",
|
||||
"max-age=",
|
||||
"max-age",
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
_, ok, err := cacheControlMaxAge(tt)
|
||||
if ok {
|
||||
t.Errorf("case %d: want ok=false, got true", i)
|
||||
}
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
u string
|
||||
q url.Values
|
||||
w string
|
||||
}{
|
||||
// No values
|
||||
{
|
||||
u: "http://example.com",
|
||||
q: nil,
|
||||
w: "http://example.com",
|
||||
},
|
||||
// No additional values
|
||||
{
|
||||
u: "http://example.com?foo=bar",
|
||||
q: nil,
|
||||
w: "http://example.com?foo=bar",
|
||||
},
|
||||
// Simple addition
|
||||
{
|
||||
u: "http://example.com",
|
||||
q: url.Values{
|
||||
"foo": []string{"bar"},
|
||||
},
|
||||
w: "http://example.com?foo=bar",
|
||||
},
|
||||
// Addition with existing values
|
||||
{
|
||||
u: "http://example.com?dog=boo",
|
||||
q: url.Values{
|
||||
"foo": []string{"bar"},
|
||||
},
|
||||
w: "http://example.com?dog=boo&foo=bar",
|
||||
},
|
||||
// Merge
|
||||
{
|
||||
u: "http://example.com?dog=boo",
|
||||
q: url.Values{
|
||||
"dog": []string{"elroy"},
|
||||
},
|
||||
w: "http://example.com?dog=boo&dog=elroy",
|
||||
},
|
||||
// Add and merge
|
||||
{
|
||||
u: "http://example.com?dog=boo",
|
||||
q: url.Values{
|
||||
"dog": []string{"elroy"},
|
||||
"foo": []string{"bar"},
|
||||
},
|
||||
w: "http://example.com?dog=boo&dog=elroy&foo=bar",
|
||||
},
|
||||
// Multivalue merge
|
||||
{
|
||||
u: "http://example.com?dog=boo",
|
||||
q: url.Values{
|
||||
"dog": []string{"elroy", "penny"},
|
||||
},
|
||||
w: "http://example.com?dog=boo&dog=elroy&dog=penny",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
ur, err := url.Parse(tt.u)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed parsing test url: %v, error: %v", i, tt.u, err)
|
||||
}
|
||||
|
||||
got := MergeQuery(*ur, tt.q)
|
||||
want, err := url.Parse(tt.w)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed parsing want url: %v, error: %v", i, tt.w, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(*want, got) {
|
||||
t.Errorf("case %d: want: %v, got: %v", i, *want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiresPass(t *testing.T) {
|
||||
tests := []struct {
|
||||
date string
|
||||
exp string
|
||||
wantTTL time.Duration
|
||||
wantOK bool
|
||||
}{
|
||||
// Expires and Date properly set
|
||||
{
|
||||
date: "Thu, 01 Dec 1983 22:00:00 GMT",
|
||||
exp: "Fri, 02 Dec 1983 01:00:00 GMT",
|
||||
wantTTL: 10800 * time.Second,
|
||||
wantOK: true,
|
||||
},
|
||||
// empty headers
|
||||
{
|
||||
date: "",
|
||||
exp: "",
|
||||
wantOK: false,
|
||||
},
|
||||
// lack of Expirs short-ciruits Date parsing
|
||||
{
|
||||
date: "foo",
|
||||
exp: "",
|
||||
wantOK: false,
|
||||
},
|
||||
// lack of Date short-ciruits Expires parsing
|
||||
{
|
||||
date: "",
|
||||
exp: "foo",
|
||||
wantOK: false,
|
||||
},
|
||||
// no Date
|
||||
{
|
||||
exp: "Thu, 01 Dec 1983 22:00:00 GMT",
|
||||
wantTTL: 0,
|
||||
wantOK: false,
|
||||
},
|
||||
// no Expires
|
||||
{
|
||||
date: "Thu, 01 Dec 1983 22:00:00 GMT",
|
||||
wantTTL: 0,
|
||||
wantOK: false,
|
||||
},
|
||||
// Expires < Date
|
||||
{
|
||||
date: "Fri, 02 Dec 1983 01:00:00 GMT",
|
||||
exp: "Thu, 01 Dec 1983 22:00:00 GMT",
|
||||
wantTTL: 0,
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
ttl, ok, err := expires(tt.date, tt.exp)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: err=%v", i, err)
|
||||
}
|
||||
if tt.wantTTL != ttl {
|
||||
t.Errorf("case %d: want=%d got=%d", i, tt.wantTTL, ttl)
|
||||
}
|
||||
if tt.wantOK != ok {
|
||||
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpiresFail(t *testing.T) {
|
||||
tests := []struct {
|
||||
date string
|
||||
exp string
|
||||
}{
|
||||
// malformed Date header
|
||||
{
|
||||
date: "foo",
|
||||
exp: "Fri, 02 Dec 1983 01:00:00 GMT",
|
||||
},
|
||||
// malformed exp header
|
||||
{
|
||||
date: "Fri, 02 Dec 1983 01:00:00 GMT",
|
||||
exp: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
_, _, err := expires(tt.date, tt.exp)
|
||||
if err == nil {
|
||||
t.Errorf("case %d: expected non-nil error", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheablePass(t *testing.T) {
|
||||
tests := []struct {
|
||||
headers http.Header
|
||||
wantTTL time.Duration
|
||||
wantOK bool
|
||||
}{
|
||||
// valid Cache-Control
|
||||
{
|
||||
headers: http.Header{
|
||||
"Cache-Control": []string{"max-age=100"},
|
||||
},
|
||||
wantTTL: 100 * time.Second,
|
||||
wantOK: true,
|
||||
},
|
||||
// valid Date/Expires
|
||||
{
|
||||
headers: http.Header{
|
||||
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
|
||||
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
|
||||
},
|
||||
wantTTL: 10800 * time.Second,
|
||||
wantOK: true,
|
||||
},
|
||||
// Cache-Control supersedes Date/Expires
|
||||
{
|
||||
headers: http.Header{
|
||||
"Cache-Control": []string{"max-age=100"},
|
||||
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
|
||||
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
|
||||
},
|
||||
wantTTL: 100 * time.Second,
|
||||
wantOK: true,
|
||||
},
|
||||
// no caching headers
|
||||
{
|
||||
headers: http.Header{},
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
ttl, ok, err := Cacheable(tt.headers)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: err=%v", i, err)
|
||||
continue
|
||||
}
|
||||
if tt.wantTTL != ttl {
|
||||
t.Errorf("case %d: want=%d got=%d", i, tt.wantTTL, ttl)
|
||||
}
|
||||
if tt.wantOK != ok {
|
||||
t.Errorf("case %d: incorrect ok value: want=%t got=%t", i, tt.wantOK, ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheableFail(t *testing.T) {
|
||||
tests := []http.Header{
|
||||
// invalid Cache-Control short-circuits
|
||||
http.Header{
|
||||
"Cache-Control": []string{"max-age"},
|
||||
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
|
||||
"Expires": []string{"Fri, 02 Dec 1983 01:00:00 GMT"},
|
||||
},
|
||||
// no Cache-Control, invalid Expires
|
||||
http.Header{
|
||||
"Date": []string{"Thu, 01 Dec 1983 22:00:00 GMT"},
|
||||
"Expires": []string{"boo"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
_, _, err := Cacheable(tt)
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want non-nil err", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResourceLocation(t *testing.T) {
|
||||
tests := []struct {
|
||||
ru *url.URL
|
||||
id string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
ru: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
},
|
||||
id: "foo",
|
||||
want: "http://example.com/foo",
|
||||
},
|
||||
// https
|
||||
{
|
||||
ru: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "example.com",
|
||||
},
|
||||
id: "foo",
|
||||
want: "https://example.com/foo",
|
||||
},
|
||||
// with path
|
||||
{
|
||||
ru: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
Path: "one/two/three",
|
||||
},
|
||||
id: "foo",
|
||||
want: "http://example.com/one/two/three/foo",
|
||||
},
|
||||
// with fragment
|
||||
{
|
||||
ru: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
Fragment: "frag",
|
||||
},
|
||||
id: "foo",
|
||||
want: "http://example.com/foo",
|
||||
},
|
||||
// with query
|
||||
{
|
||||
ru: &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "example.com",
|
||||
RawQuery: "dog=elroy",
|
||||
},
|
||||
id: "foo",
|
||||
want: "http://example.com/foo",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got := NewResourceLocation(tt.ru, tt.id)
|
||||
if tt.want != got {
|
||||
t.Errorf("case %d: want=%s, got=%s", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyRequest(t *testing.T) {
|
||||
r1, err := http.NewRequest("GET", "http://example.com", strings.NewReader("foo"))
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
r2 := CopyRequest(r1)
|
||||
if !reflect.DeepEqual(r1, r2) {
|
||||
t.Fatalf("Result of CopyRequest incorrect: %#v != %#v", r1, r2)
|
||||
}
|
||||
}
|
||||
29
vendor/github.com/coreos/go-oidc/http/url.go
generated
vendored
29
vendor/github.com/coreos/go-oidc/http/url.go
generated
vendored
@@ -1,29 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// ParseNonEmptyURL checks that a string is a parsable URL which is also not empty
|
||||
// since `url.Parse("")` does not return an error. Must contian a scheme and a host.
|
||||
func ParseNonEmptyURL(u string) (*url.URL, error) {
|
||||
if u == "" {
|
||||
return nil, errors.New("url is empty")
|
||||
}
|
||||
|
||||
ur, err := url.Parse(u)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ur.Scheme == "" {
|
||||
return nil, errors.New("url scheme is empty")
|
||||
}
|
||||
|
||||
if ur.Host == "" {
|
||||
return nil, errors.New("url host is empty")
|
||||
}
|
||||
|
||||
return ur, nil
|
||||
}
|
||||
49
vendor/github.com/coreos/go-oidc/http/url_test.go
generated
vendored
49
vendor/github.com/coreos/go-oidc/http/url_test.go
generated
vendored
@@ -1,49 +0,0 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseNonEmptyURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
u string
|
||||
ok bool
|
||||
}{
|
||||
{"", false},
|
||||
{"http://", false},
|
||||
{"example.com", false},
|
||||
{"example", false},
|
||||
{"http://example", true},
|
||||
{"http://example:1234", true},
|
||||
{"http://example.com", true},
|
||||
{"http://example.com:1234", true},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
u, err := ParseNonEmptyURL(tt.u)
|
||||
if err != nil {
|
||||
t.Logf("err: %v", err)
|
||||
if tt.ok {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if !tt.ok {
|
||||
t.Errorf("case %d: expected error but got none", i)
|
||||
continue
|
||||
}
|
||||
|
||||
uu, err := url.Parse(tt.u)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if uu.String() != u.String() {
|
||||
t.Errorf("case %d: incorrect url value, want: %q, got: %q", i, uu.String(), u.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
126
vendor/github.com/coreos/go-oidc/jose/claims.go
generated
vendored
126
vendor/github.com/coreos/go-oidc/jose/claims.go
generated
vendored
@@ -1,126 +0,0 @@
|
||||
package jose
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Claims map[string]interface{}
|
||||
|
||||
func (c Claims) Add(name string, value interface{}) {
|
||||
c[name] = value
|
||||
}
|
||||
|
||||
func (c Claims) StringClaim(name string) (string, bool, error) {
|
||||
cl, ok := c[name]
|
||||
if !ok {
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
v, ok := cl.(string)
|
||||
if !ok {
|
||||
return "", false, fmt.Errorf("unable to parse claim as string: %v", name)
|
||||
}
|
||||
|
||||
return v, true, nil
|
||||
}
|
||||
|
||||
func (c Claims) StringsClaim(name string) ([]string, bool, error) {
|
||||
cl, ok := c[name]
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
if v, ok := cl.([]string); ok {
|
||||
return v, true, nil
|
||||
}
|
||||
|
||||
// When unmarshaled, []string will become []interface{}.
|
||||
if v, ok := cl.([]interface{}); ok {
|
||||
var ret []string
|
||||
for _, vv := range v {
|
||||
str, ok := vv.(string)
|
||||
if !ok {
|
||||
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
|
||||
}
|
||||
ret = append(ret, str)
|
||||
}
|
||||
return ret, true, nil
|
||||
}
|
||||
|
||||
return nil, false, fmt.Errorf("unable to parse claim as string array: %v", name)
|
||||
}
|
||||
|
||||
func (c Claims) Int64Claim(name string) (int64, bool, error) {
|
||||
cl, ok := c[name]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
v, ok := cl.(int64)
|
||||
if !ok {
|
||||
vf, ok := cl.(float64)
|
||||
if !ok {
|
||||
return 0, false, fmt.Errorf("unable to parse claim as int64: %v", name)
|
||||
}
|
||||
v = int64(vf)
|
||||
}
|
||||
|
||||
return v, true, nil
|
||||
}
|
||||
|
||||
func (c Claims) Float64Claim(name string) (float64, bool, error) {
|
||||
cl, ok := c[name]
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
v, ok := cl.(float64)
|
||||
if !ok {
|
||||
vi, ok := cl.(int64)
|
||||
if !ok {
|
||||
return 0, false, fmt.Errorf("unable to parse claim as float64: %v", name)
|
||||
}
|
||||
v = float64(vi)
|
||||
}
|
||||
|
||||
return v, true, nil
|
||||
}
|
||||
|
||||
func (c Claims) TimeClaim(name string) (time.Time, bool, error) {
|
||||
v, ok, err := c.Float64Claim(name)
|
||||
if !ok || err != nil {
|
||||
return time.Time{}, ok, err
|
||||
}
|
||||
|
||||
s := math.Trunc(v)
|
||||
ns := (v - s) * math.Pow(10, 9)
|
||||
return time.Unix(int64(s), int64(ns)).UTC(), true, nil
|
||||
}
|
||||
|
||||
func decodeClaims(payload []byte) (Claims, error) {
|
||||
var c Claims
|
||||
if err := json.Unmarshal(payload, &c); err != nil {
|
||||
return nil, fmt.Errorf("malformed JWT claims, unable to decode: %v", err)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func marshalClaims(c Claims) ([]byte, error) {
|
||||
b, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func encodeClaims(c Claims) (string, error) {
|
||||
b, err := marshalClaims(c)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return encodeSegment(b), nil
|
||||
}
|
||||
328
vendor/github.com/coreos/go-oidc/jose/claims_test.go
generated
vendored
328
vendor/github.com/coreos/go-oidc/jose/claims_test.go
generated
vendored
@@ -1,328 +0,0 @@
|
||||
package jose
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
tests := []struct {
|
||||
cl Claims
|
||||
key string
|
||||
ok bool
|
||||
err bool
|
||||
val string
|
||||
}{
|
||||
// ok, no err, claim exists
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": "bar",
|
||||
},
|
||||
key: "foo",
|
||||
val: "bar",
|
||||
ok: true,
|
||||
err: false,
|
||||
},
|
||||
// no claims
|
||||
{
|
||||
cl: Claims{},
|
||||
key: "foo",
|
||||
val: "",
|
||||
ok: false,
|
||||
err: false,
|
||||
},
|
||||
// missing claim
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": "bar",
|
||||
},
|
||||
key: "xxx",
|
||||
val: "",
|
||||
ok: false,
|
||||
err: false,
|
||||
},
|
||||
// unparsable: type
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": struct{}{},
|
||||
},
|
||||
key: "foo",
|
||||
val: "",
|
||||
ok: false,
|
||||
err: true,
|
||||
},
|
||||
// unparsable: nil value
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": nil,
|
||||
},
|
||||
key: "foo",
|
||||
val: "",
|
||||
ok: false,
|
||||
err: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
val, ok, err := tt.cl.StringClaim(tt.key)
|
||||
|
||||
if tt.err && err == nil {
|
||||
t.Errorf("case %d: want err=non-nil, got err=nil", i)
|
||||
} else if !tt.err && err != nil {
|
||||
t.Errorf("case %d: want err=nil, got err=%v", i, err)
|
||||
}
|
||||
|
||||
if tt.ok != ok {
|
||||
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
|
||||
}
|
||||
|
||||
if tt.val != val {
|
||||
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
cl Claims
|
||||
key string
|
||||
ok bool
|
||||
err bool
|
||||
val int64
|
||||
}{
|
||||
// ok, no err, claim exists
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": int64(100),
|
||||
},
|
||||
key: "foo",
|
||||
val: int64(100),
|
||||
ok: true,
|
||||
err: false,
|
||||
},
|
||||
// no claims
|
||||
{
|
||||
cl: Claims{},
|
||||
key: "foo",
|
||||
val: 0,
|
||||
ok: false,
|
||||
err: false,
|
||||
},
|
||||
// missing claim
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": "bar",
|
||||
},
|
||||
key: "xxx",
|
||||
val: 0,
|
||||
ok: false,
|
||||
err: false,
|
||||
},
|
||||
// unparsable: type
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": struct{}{},
|
||||
},
|
||||
key: "foo",
|
||||
val: 0,
|
||||
ok: false,
|
||||
err: true,
|
||||
},
|
||||
// unparsable: nil value
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": nil,
|
||||
},
|
||||
key: "foo",
|
||||
val: 0,
|
||||
ok: false,
|
||||
err: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
val, ok, err := tt.cl.Int64Claim(tt.key)
|
||||
|
||||
if tt.err && err == nil {
|
||||
t.Errorf("case %d: want err=non-nil, got err=nil", i)
|
||||
} else if !tt.err && err != nil {
|
||||
t.Errorf("case %d: want err=nil, got err=%v", i, err)
|
||||
}
|
||||
|
||||
if tt.ok != ok {
|
||||
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
|
||||
}
|
||||
|
||||
if tt.val != val {
|
||||
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTime(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
unixNow := now.Unix()
|
||||
|
||||
tests := []struct {
|
||||
cl Claims
|
||||
key string
|
||||
ok bool
|
||||
err bool
|
||||
val time.Time
|
||||
}{
|
||||
// ok, no err, claim exists
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": unixNow,
|
||||
},
|
||||
key: "foo",
|
||||
val: time.Unix(now.Unix(), 0).UTC(),
|
||||
ok: true,
|
||||
err: false,
|
||||
},
|
||||
// no claims
|
||||
{
|
||||
cl: Claims{},
|
||||
key: "foo",
|
||||
val: time.Time{},
|
||||
ok: false,
|
||||
err: false,
|
||||
},
|
||||
// missing claim
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": "bar",
|
||||
},
|
||||
key: "xxx",
|
||||
val: time.Time{},
|
||||
ok: false,
|
||||
err: false,
|
||||
},
|
||||
// unparsable: type
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": struct{}{},
|
||||
},
|
||||
key: "foo",
|
||||
val: time.Time{},
|
||||
ok: false,
|
||||
err: true,
|
||||
},
|
||||
// unparsable: nil value
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": nil,
|
||||
},
|
||||
key: "foo",
|
||||
val: time.Time{},
|
||||
ok: false,
|
||||
err: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
val, ok, err := tt.cl.TimeClaim(tt.key)
|
||||
|
||||
if tt.err && err == nil {
|
||||
t.Errorf("case %d: want err=non-nil, got err=nil", i)
|
||||
} else if !tt.err && err != nil {
|
||||
t.Errorf("case %d: want err=nil, got err=%v", i, err)
|
||||
}
|
||||
|
||||
if tt.ok != ok {
|
||||
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
|
||||
}
|
||||
|
||||
if tt.val != val {
|
||||
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
cl Claims
|
||||
key string
|
||||
ok bool
|
||||
err bool
|
||||
val []string
|
||||
}{
|
||||
// ok, no err, claim exists
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": []string{"bar", "faf"},
|
||||
},
|
||||
key: "foo",
|
||||
val: []string{"bar", "faf"},
|
||||
ok: true,
|
||||
err: false,
|
||||
},
|
||||
// ok, no err, []interface{}
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": []interface{}{"bar", "faf"},
|
||||
},
|
||||
key: "foo",
|
||||
val: []string{"bar", "faf"},
|
||||
ok: true,
|
||||
err: false,
|
||||
},
|
||||
// no claims
|
||||
{
|
||||
cl: Claims{},
|
||||
key: "foo",
|
||||
val: nil,
|
||||
ok: false,
|
||||
err: false,
|
||||
},
|
||||
// missing claim
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": "bar",
|
||||
},
|
||||
key: "xxx",
|
||||
val: nil,
|
||||
ok: false,
|
||||
err: false,
|
||||
},
|
||||
// unparsable: type
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": struct{}{},
|
||||
},
|
||||
key: "foo",
|
||||
val: nil,
|
||||
ok: false,
|
||||
err: true,
|
||||
},
|
||||
// unparsable: nil value
|
||||
{
|
||||
cl: Claims{
|
||||
"foo": nil,
|
||||
},
|
||||
key: "foo",
|
||||
val: nil,
|
||||
ok: false,
|
||||
err: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
val, ok, err := tt.cl.StringsClaim(tt.key)
|
||||
|
||||
if tt.err && err == nil {
|
||||
t.Errorf("case %d: want err=non-nil, got err=nil", i)
|
||||
} else if !tt.err && err != nil {
|
||||
t.Errorf("case %d: want err=nil, got err=%v", i, err)
|
||||
}
|
||||
|
||||
if tt.ok != ok {
|
||||
t.Errorf("case %d: want ok=%v, got ok=%v", i, tt.ok, ok)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.val, val) {
|
||||
t.Errorf("case %d: want val=%v, got val=%v", i, tt.val, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
2
vendor/github.com/coreos/go-oidc/jose/doc.go
generated
vendored
2
vendor/github.com/coreos/go-oidc/jose/doc.go
generated
vendored
@@ -1,2 +0,0 @@
|
||||
// Package jose is DEPRECATED. Use gopkg.in/square/go-jose.v2 instead.
|
||||
package jose
|
||||
112
vendor/github.com/coreos/go-oidc/jose/jose.go
generated
vendored
112
vendor/github.com/coreos/go-oidc/jose/jose.go
generated
vendored
@@ -1,112 +0,0 @@
|
||||
package jose
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
HeaderMediaType = "typ"
|
||||
HeaderKeyAlgorithm = "alg"
|
||||
HeaderKeyID = "kid"
|
||||
)
|
||||
|
||||
const (
|
||||
// Encryption Algorithm Header Parameter Values for JWS
|
||||
// See: https://tools.ietf.org/html/draft-ietf-jose-json-web-algorithms-40#page-6
|
||||
AlgHS256 = "HS256"
|
||||
AlgHS384 = "HS384"
|
||||
AlgHS512 = "HS512"
|
||||
AlgRS256 = "RS256"
|
||||
AlgRS384 = "RS384"
|
||||
AlgRS512 = "RS512"
|
||||
AlgES256 = "ES256"
|
||||
AlgES384 = "ES384"
|
||||
AlgES512 = "ES512"
|
||||
AlgPS256 = "PS256"
|
||||
AlgPS384 = "PS384"
|
||||
AlgPS512 = "PS512"
|
||||
AlgNone = "none"
|
||||
)
|
||||
|
||||
const (
|
||||
// Algorithm Header Parameter Values for JWE
|
||||
// See: https://tools.ietf.org/html/draft-ietf-jose-json-web-algorithms-40#section-4.1
|
||||
AlgRSA15 = "RSA1_5"
|
||||
AlgRSAOAEP = "RSA-OAEP"
|
||||
AlgRSAOAEP256 = "RSA-OAEP-256"
|
||||
AlgA128KW = "A128KW"
|
||||
AlgA192KW = "A192KW"
|
||||
AlgA256KW = "A256KW"
|
||||
AlgDir = "dir"
|
||||
AlgECDHES = "ECDH-ES"
|
||||
AlgECDHESA128KW = "ECDH-ES+A128KW"
|
||||
AlgECDHESA192KW = "ECDH-ES+A192KW"
|
||||
AlgECDHESA256KW = "ECDH-ES+A256KW"
|
||||
AlgA128GCMKW = "A128GCMKW"
|
||||
AlgA192GCMKW = "A192GCMKW"
|
||||
AlgA256GCMKW = "A256GCMKW"
|
||||
AlgPBES2HS256A128KW = "PBES2-HS256+A128KW"
|
||||
AlgPBES2HS384A192KW = "PBES2-HS384+A192KW"
|
||||
AlgPBES2HS512A256KW = "PBES2-HS512+A256KW"
|
||||
)
|
||||
|
||||
const (
|
||||
// Encryption Algorithm Header Parameter Values for JWE
|
||||
// See: https://tools.ietf.org/html/draft-ietf-jose-json-web-algorithms-40#page-22
|
||||
EncA128CBCHS256 = "A128CBC-HS256"
|
||||
EncA128CBCHS384 = "A128CBC-HS384"
|
||||
EncA256CBCHS512 = "A256CBC-HS512"
|
||||
EncA128GCM = "A128GCM"
|
||||
EncA192GCM = "A192GCM"
|
||||
EncA256GCM = "A256GCM"
|
||||
)
|
||||
|
||||
type JOSEHeader map[string]string
|
||||
|
||||
func (j JOSEHeader) Validate() error {
|
||||
if _, exists := j[HeaderKeyAlgorithm]; !exists {
|
||||
return fmt.Errorf("header missing %q parameter", HeaderKeyAlgorithm)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeHeader(seg string) (JOSEHeader, error) {
|
||||
b, err := decodeSegment(seg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var h JOSEHeader
|
||||
err = json.Unmarshal(b, &h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func encodeHeader(h JOSEHeader) (string, error) {
|
||||
b, err := json.Marshal(h)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return encodeSegment(b), nil
|
||||
}
|
||||
|
||||
// Decode JWT specific base64url encoding with padding stripped
|
||||
func decodeSegment(seg string) ([]byte, error) {
|
||||
if l := len(seg) % 4; l != 0 {
|
||||
seg += strings.Repeat("=", 4-l)
|
||||
}
|
||||
return base64.URLEncoding.DecodeString(seg)
|
||||
}
|
||||
|
||||
// Encode JWT specific base64url encoding with padding stripped
|
||||
func encodeSegment(seg []byte) string {
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(seg), "=")
|
||||
}
|
||||
135
vendor/github.com/coreos/go-oidc/jose/jwk.go
generated
vendored
135
vendor/github.com/coreos/go-oidc/jose/jwk.go
generated
vendored
@@ -1,135 +0,0 @@
|
||||
package jose
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// JSON Web Key
|
||||
// https://tools.ietf.org/html/draft-ietf-jose-json-web-key-36#page-5
|
||||
type JWK struct {
|
||||
ID string
|
||||
Type string
|
||||
Alg string
|
||||
Use string
|
||||
Exponent int
|
||||
Modulus *big.Int
|
||||
Secret []byte
|
||||
}
|
||||
|
||||
type jwkJSON struct {
|
||||
ID string `json:"kid"`
|
||||
Type string `json:"kty"`
|
||||
Alg string `json:"alg"`
|
||||
Use string `json:"use"`
|
||||
Exponent string `json:"e"`
|
||||
Modulus string `json:"n"`
|
||||
}
|
||||
|
||||
func (j *JWK) MarshalJSON() ([]byte, error) {
|
||||
t := jwkJSON{
|
||||
ID: j.ID,
|
||||
Type: j.Type,
|
||||
Alg: j.Alg,
|
||||
Use: j.Use,
|
||||
Exponent: encodeExponent(j.Exponent),
|
||||
Modulus: encodeModulus(j.Modulus),
|
||||
}
|
||||
|
||||
return json.Marshal(&t)
|
||||
}
|
||||
|
||||
func (j *JWK) UnmarshalJSON(data []byte) error {
|
||||
var t jwkJSON
|
||||
err := json.Unmarshal(data, &t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e, err := decodeExponent(t.Exponent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n, err := decodeModulus(t.Modulus)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
j.ID = t.ID
|
||||
j.Type = t.Type
|
||||
j.Alg = t.Alg
|
||||
j.Use = t.Use
|
||||
j.Exponent = e
|
||||
j.Modulus = n
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type JWKSet struct {
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
|
||||
func decodeExponent(e string) (int, error) {
|
||||
decE, err := decodeBase64URLPaddingOptional(e)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var eBytes []byte
|
||||
if len(decE) < 8 {
|
||||
eBytes = make([]byte, 8-len(decE), 8)
|
||||
eBytes = append(eBytes, decE...)
|
||||
} else {
|
||||
eBytes = decE
|
||||
}
|
||||
eReader := bytes.NewReader(eBytes)
|
||||
var E uint64
|
||||
err = binary.Read(eReader, binary.BigEndian, &E)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(E), nil
|
||||
}
|
||||
|
||||
func encodeExponent(e int) string {
|
||||
b := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(b, uint64(e))
|
||||
var idx int
|
||||
for ; idx < 8; idx++ {
|
||||
if b[idx] != 0x0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b[idx:])
|
||||
}
|
||||
|
||||
// Turns a URL encoded modulus of a key into a big int.
|
||||
func decodeModulus(n string) (*big.Int, error) {
|
||||
decN, err := decodeBase64URLPaddingOptional(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
N := big.NewInt(0)
|
||||
N.SetBytes(decN)
|
||||
return N, nil
|
||||
}
|
||||
|
||||
func encodeModulus(n *big.Int) string {
|
||||
return base64.RawURLEncoding.EncodeToString(n.Bytes())
|
||||
}
|
||||
|
||||
// decodeBase64URLPaddingOptional decodes Base64 whether there is padding or not.
|
||||
// The stdlib version currently doesn't handle this.
|
||||
// We can get rid of this is if this bug:
|
||||
// https://github.com/golang/go/issues/4237
|
||||
// ever closes.
|
||||
func decodeBase64URLPaddingOptional(e string) ([]byte, error) {
|
||||
if m := len(e) % 4; m != 0 {
|
||||
e += strings.Repeat("=", 4-m)
|
||||
}
|
||||
return base64.URLEncoding.DecodeString(e)
|
||||
}
|
||||
64
vendor/github.com/coreos/go-oidc/jose/jwk_test.go
generated
vendored
64
vendor/github.com/coreos/go-oidc/jose/jwk_test.go
generated
vendored
@@ -1,64 +0,0 @@
|
||||
package jose
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDecodeBase64URLPaddingOptional(t *testing.T) {
|
||||
tests := []struct {
|
||||
encoded string
|
||||
decoded string
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
// With padding
|
||||
encoded: "VGVjdG9uaWM=",
|
||||
decoded: "Tectonic",
|
||||
},
|
||||
{
|
||||
// Without padding
|
||||
encoded: "VGVjdG9uaWM",
|
||||
decoded: "Tectonic",
|
||||
},
|
||||
{
|
||||
// Even More padding
|
||||
encoded: "VGVjdG9uaQ==",
|
||||
decoded: "Tectoni",
|
||||
},
|
||||
{
|
||||
// And take it away!
|
||||
encoded: "VGVjdG9uaQ",
|
||||
decoded: "Tectoni",
|
||||
},
|
||||
{
|
||||
// Too much padding.
|
||||
encoded: "VGVjdG9uaWNh=",
|
||||
decoded: "",
|
||||
err: true,
|
||||
},
|
||||
{
|
||||
// Too much padding.
|
||||
encoded: "VGVjdG9uaWNh=",
|
||||
decoded: "",
|
||||
err: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got, err := decodeBase64URLPaddingOptional(tt.encoded)
|
||||
if tt.err {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: expected non-nil err", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want nil err, got: %v", i, err)
|
||||
}
|
||||
|
||||
if string(got) != tt.decoded {
|
||||
t.Errorf("case %d: want=%q, got=%q", i, tt.decoded, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
51
vendor/github.com/coreos/go-oidc/jose/jws.go
generated
vendored
51
vendor/github.com/coreos/go-oidc/jose/jws.go
generated
vendored
@@ -1,51 +0,0 @@
|
||||
package jose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type JWS struct {
|
||||
RawHeader string
|
||||
Header JOSEHeader
|
||||
RawPayload string
|
||||
Payload []byte
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
// Given a raw encoded JWS token parses it and verifies the structure.
|
||||
func ParseJWS(raw string) (JWS, error) {
|
||||
parts := strings.Split(raw, ".")
|
||||
if len(parts) != 3 {
|
||||
return JWS{}, fmt.Errorf("malformed JWS, only %d segments", len(parts))
|
||||
}
|
||||
|
||||
rawSig := parts[2]
|
||||
jws := JWS{
|
||||
RawHeader: parts[0],
|
||||
RawPayload: parts[1],
|
||||
}
|
||||
|
||||
header, err := decodeHeader(jws.RawHeader)
|
||||
if err != nil {
|
||||
return JWS{}, fmt.Errorf("malformed JWS, unable to decode header, %s", err)
|
||||
}
|
||||
if err = header.Validate(); err != nil {
|
||||
return JWS{}, fmt.Errorf("malformed JWS, %s", err)
|
||||
}
|
||||
jws.Header = header
|
||||
|
||||
payload, err := decodeSegment(jws.RawPayload)
|
||||
if err != nil {
|
||||
return JWS{}, fmt.Errorf("malformed JWS, unable to decode payload: %s", err)
|
||||
}
|
||||
jws.Payload = payload
|
||||
|
||||
sig, err := decodeSegment(rawSig)
|
||||
if err != nil {
|
||||
return JWS{}, fmt.Errorf("malformed JWS, unable to decode signature: %s", err)
|
||||
}
|
||||
jws.Signature = sig
|
||||
|
||||
return jws, nil
|
||||
}
|
||||
74
vendor/github.com/coreos/go-oidc/jose/jws_test.go
generated
vendored
74
vendor/github.com/coreos/go-oidc/jose/jws_test.go
generated
vendored
@@ -1,74 +0,0 @@
|
||||
package jose
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testCase struct{ t string }
|
||||
|
||||
var validInput []testCase
|
||||
|
||||
var invalidInput []testCase
|
||||
|
||||
func init() {
|
||||
validInput = []testCase{
|
||||
{
|
||||
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
|
||||
},
|
||||
}
|
||||
|
||||
invalidInput = []testCase{
|
||||
// empty
|
||||
{
|
||||
"",
|
||||
},
|
||||
// undecodeable
|
||||
{
|
||||
"aaa.bbb.ccc",
|
||||
},
|
||||
// missing parts
|
||||
{
|
||||
"aaa",
|
||||
},
|
||||
// missing parts
|
||||
{
|
||||
"aaa.bbb",
|
||||
},
|
||||
// too many parts
|
||||
{
|
||||
"aaa.bbb.ccc.ddd",
|
||||
},
|
||||
// invalid header
|
||||
// EncodeHeader(map[string]string{"foo": "bar"})
|
||||
{
|
||||
"eyJmb28iOiJiYXIifQ.bbb.ccc",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJWS(t *testing.T) {
|
||||
for i, tt := range validInput {
|
||||
jws, err := ParseJWS(tt.t)
|
||||
if err != nil {
|
||||
t.Errorf("test: %d. expected: valid, actual: invalid", i)
|
||||
}
|
||||
|
||||
expectedHeader := strings.Split(tt.t, ".")[0]
|
||||
if jws.RawHeader != expectedHeader {
|
||||
t.Errorf("test: %d. expected: %s, actual: %s", i, expectedHeader, jws.RawHeader)
|
||||
}
|
||||
|
||||
expectedPayload := strings.Split(tt.t, ".")[1]
|
||||
if jws.RawPayload != expectedPayload {
|
||||
t.Errorf("test: %d. expected: %s, actual: %s", i, expectedPayload, jws.RawPayload)
|
||||
}
|
||||
}
|
||||
|
||||
for i, tt := range invalidInput {
|
||||
_, err := ParseJWS(tt.t)
|
||||
if err == nil {
|
||||
t.Errorf("test: %d. expected: invalid, actual: valid", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
82
vendor/github.com/coreos/go-oidc/jose/jwt.go
generated
vendored
82
vendor/github.com/coreos/go-oidc/jose/jwt.go
generated
vendored
@@ -1,82 +0,0 @@
|
||||
package jose
|
||||
|
||||
import "strings"
|
||||
|
||||
type JWT JWS
|
||||
|
||||
func ParseJWT(token string) (jwt JWT, err error) {
|
||||
jws, err := ParseJWS(token)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return JWT(jws), nil
|
||||
}
|
||||
|
||||
func NewJWT(header JOSEHeader, claims Claims) (jwt JWT, err error) {
|
||||
jwt = JWT{}
|
||||
|
||||
jwt.Header = header
|
||||
jwt.Header[HeaderMediaType] = "JWT"
|
||||
|
||||
claimBytes, err := marshalClaims(claims)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
jwt.Payload = claimBytes
|
||||
|
||||
eh, err := encodeHeader(header)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
jwt.RawHeader = eh
|
||||
|
||||
ec, err := encodeClaims(claims)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
jwt.RawPayload = ec
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (j *JWT) KeyID() (string, bool) {
|
||||
kID, ok := j.Header[HeaderKeyID]
|
||||
return kID, ok
|
||||
}
|
||||
|
||||
func (j *JWT) Claims() (Claims, error) {
|
||||
return decodeClaims(j.Payload)
|
||||
}
|
||||
|
||||
// Encoded data part of the token which may be signed.
|
||||
func (j *JWT) Data() string {
|
||||
return strings.Join([]string{j.RawHeader, j.RawPayload}, ".")
|
||||
}
|
||||
|
||||
// Full encoded JWT token string in format: header.claims.signature
|
||||
func (j *JWT) Encode() string {
|
||||
d := j.Data()
|
||||
s := encodeSegment(j.Signature)
|
||||
return strings.Join([]string{d, s}, ".")
|
||||
}
|
||||
|
||||
func NewSignedJWT(claims Claims, s Signer) (*JWT, error) {
|
||||
header := JOSEHeader{
|
||||
HeaderKeyAlgorithm: s.Alg(),
|
||||
HeaderKeyID: s.ID(),
|
||||
}
|
||||
|
||||
jwt, err := NewJWT(header, claims)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sig, err := s.Sign([]byte(jwt.Data()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
jwt.Signature = sig
|
||||
|
||||
return &jwt, nil
|
||||
}
|
||||
94
vendor/github.com/coreos/go-oidc/jose/jwt_test.go
generated
vendored
94
vendor/github.com/coreos/go-oidc/jose/jwt_test.go
generated
vendored
@@ -1,94 +0,0 @@
|
||||
package jose
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseJWT(t *testing.T) {
|
||||
tests := []struct {
|
||||
r string
|
||||
h JOSEHeader
|
||||
c Claims
|
||||
}{
|
||||
{
|
||||
// Example from JWT spec:
|
||||
// http://self-issued.info/docs/draft-ietf-oauth-json-web-token.html#ExampleJWT
|
||||
"eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk",
|
||||
JOSEHeader{
|
||||
HeaderMediaType: "JWT",
|
||||
HeaderKeyAlgorithm: "HS256",
|
||||
},
|
||||
Claims{
|
||||
"iss": "joe",
|
||||
// NOTE: test numbers must be floats for equality checks to work since values are converted form interface{} to float64 by default.
|
||||
"exp": 1300819380.0,
|
||||
"http://example.com/is_root": true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
jwt, err := ParseJWT(tt.r)
|
||||
if err != nil {
|
||||
t.Errorf("raw token should parse. test: %d. expected: valid, actual: invalid. err=%v", i, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.h, jwt.Header) {
|
||||
t.Errorf("JOSE headers should match. test: %d. expected: %v, actual: %v", i, tt.h, jwt.Header)
|
||||
}
|
||||
|
||||
claims, err := jwt.Claims()
|
||||
if err != nil {
|
||||
t.Errorf("test: %d. expected: valid claim parsing. err=%v", i, err)
|
||||
}
|
||||
if !reflect.DeepEqual(tt.c, claims) {
|
||||
t.Errorf("claims should match. test: %d. expected: %v, actual: %v", i, tt.c, claims)
|
||||
}
|
||||
|
||||
enc := jwt.Encode()
|
||||
if enc != tt.r {
|
||||
t.Errorf("encoded jwt should match raw jwt. test: %d. expected: %v, actual: %v", i, tt.r, enc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewJWTHeaderType(t *testing.T) {
|
||||
jwt, err := NewJWT(JOSEHeader{}, Claims{})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
want := "JWT"
|
||||
got := jwt.Header[HeaderMediaType]
|
||||
if want != got {
|
||||
t.Fatalf("Header %q incorrect: want=%s got=%s", HeaderMediaType, want, got)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestNewJWTHeaderKeyID(t *testing.T) {
|
||||
jwt, err := NewJWT(JOSEHeader{HeaderKeyID: "foo"}, Claims{})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
want := "foo"
|
||||
got, ok := jwt.KeyID()
|
||||
if !ok {
|
||||
t.Fatalf("KeyID not set")
|
||||
} else if want != got {
|
||||
t.Fatalf("KeyID incorrect: want=%s got=%s", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewJWTHeaderKeyIDNotSet(t *testing.T) {
|
||||
jwt, err := NewJWT(JOSEHeader{}, Claims{})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := jwt.KeyID(); ok {
|
||||
t.Fatalf("KeyID set, but should not be")
|
||||
}
|
||||
}
|
||||
24
vendor/github.com/coreos/go-oidc/jose/sig.go
generated
vendored
24
vendor/github.com/coreos/go-oidc/jose/sig.go
generated
vendored
@@ -1,24 +0,0 @@
|
||||
package jose
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Verifier interface {
|
||||
ID() string
|
||||
Alg() string
|
||||
Verify(sig []byte, data []byte) error
|
||||
}
|
||||
|
||||
type Signer interface {
|
||||
Verifier
|
||||
Sign(data []byte) (sig []byte, err error)
|
||||
}
|
||||
|
||||
func NewVerifier(jwk JWK) (Verifier, error) {
|
||||
if jwk.Type != "RSA" {
|
||||
return nil, fmt.Errorf("unsupported key type %q", jwk.Type)
|
||||
}
|
||||
|
||||
return NewVerifierRSA(jwk)
|
||||
}
|
||||
67
vendor/github.com/coreos/go-oidc/jose/sig_rsa.go
generated
vendored
67
vendor/github.com/coreos/go-oidc/jose/sig_rsa.go
generated
vendored
@@ -1,67 +0,0 @@
|
||||
package jose
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type VerifierRSA struct {
|
||||
KeyID string
|
||||
Hash crypto.Hash
|
||||
PublicKey rsa.PublicKey
|
||||
}
|
||||
|
||||
type SignerRSA struct {
|
||||
PrivateKey rsa.PrivateKey
|
||||
VerifierRSA
|
||||
}
|
||||
|
||||
func NewVerifierRSA(jwk JWK) (*VerifierRSA, error) {
|
||||
if jwk.Alg != "" && jwk.Alg != "RS256" {
|
||||
return nil, fmt.Errorf("unsupported key algorithm %q", jwk.Alg)
|
||||
}
|
||||
|
||||
v := VerifierRSA{
|
||||
KeyID: jwk.ID,
|
||||
PublicKey: rsa.PublicKey{
|
||||
N: jwk.Modulus,
|
||||
E: jwk.Exponent,
|
||||
},
|
||||
Hash: crypto.SHA256,
|
||||
}
|
||||
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
func NewSignerRSA(kid string, key rsa.PrivateKey) *SignerRSA {
|
||||
return &SignerRSA{
|
||||
PrivateKey: key,
|
||||
VerifierRSA: VerifierRSA{
|
||||
KeyID: kid,
|
||||
PublicKey: key.PublicKey,
|
||||
Hash: crypto.SHA256,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *VerifierRSA) ID() string {
|
||||
return v.KeyID
|
||||
}
|
||||
|
||||
func (v *VerifierRSA) Alg() string {
|
||||
return "RS256"
|
||||
}
|
||||
|
||||
func (v *VerifierRSA) Verify(sig []byte, data []byte) error {
|
||||
h := v.Hash.New()
|
||||
h.Write(data)
|
||||
return rsa.VerifyPKCS1v15(&v.PublicKey, v.Hash, h.Sum(nil), sig)
|
||||
}
|
||||
|
||||
func (s *SignerRSA) Sign(data []byte) ([]byte, error) {
|
||||
h := s.Hash.New()
|
||||
h.Write(data)
|
||||
return rsa.SignPKCS1v15(rand.Reader, &s.PrivateKey, s.Hash, h.Sum(nil))
|
||||
}
|
||||
2
vendor/github.com/coreos/go-oidc/key/doc.go
generated
vendored
2
vendor/github.com/coreos/go-oidc/key/doc.go
generated
vendored
@@ -1,2 +0,0 @@
|
||||
// Package key is DEPRECATED. Use github.com/coreos/go-oidc instead.
|
||||
package key
|
||||
153
vendor/github.com/coreos/go-oidc/key/key.go
generated
vendored
153
vendor/github.com/coreos/go-oidc/key/key.go
generated
vendored
@@ -1,153 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
func NewPublicKey(jwk jose.JWK) *PublicKey {
|
||||
return &PublicKey{jwk: jwk}
|
||||
}
|
||||
|
||||
type PublicKey struct {
|
||||
jwk jose.JWK
|
||||
}
|
||||
|
||||
func (k *PublicKey) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(&k.jwk)
|
||||
}
|
||||
|
||||
func (k *PublicKey) UnmarshalJSON(data []byte) error {
|
||||
var jwk jose.JWK
|
||||
if err := json.Unmarshal(data, &jwk); err != nil {
|
||||
return err
|
||||
}
|
||||
k.jwk = jwk
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k *PublicKey) ID() string {
|
||||
return k.jwk.ID
|
||||
}
|
||||
|
||||
func (k *PublicKey) Verifier() (jose.Verifier, error) {
|
||||
return jose.NewVerifierRSA(k.jwk)
|
||||
}
|
||||
|
||||
type PrivateKey struct {
|
||||
KeyID string
|
||||
PrivateKey *rsa.PrivateKey
|
||||
}
|
||||
|
||||
func (k *PrivateKey) ID() string {
|
||||
return k.KeyID
|
||||
}
|
||||
|
||||
func (k *PrivateKey) Signer() jose.Signer {
|
||||
return jose.NewSignerRSA(k.ID(), *k.PrivateKey)
|
||||
}
|
||||
|
||||
func (k *PrivateKey) JWK() jose.JWK {
|
||||
return jose.JWK{
|
||||
ID: k.KeyID,
|
||||
Type: "RSA",
|
||||
Alg: "RS256",
|
||||
Use: "sig",
|
||||
Exponent: k.PrivateKey.PublicKey.E,
|
||||
Modulus: k.PrivateKey.PublicKey.N,
|
||||
}
|
||||
}
|
||||
|
||||
type KeySet interface {
|
||||
ExpiresAt() time.Time
|
||||
}
|
||||
|
||||
type PublicKeySet struct {
|
||||
keys []PublicKey
|
||||
index map[string]*PublicKey
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func NewPublicKeySet(jwks []jose.JWK, exp time.Time) *PublicKeySet {
|
||||
keys := make([]PublicKey, len(jwks))
|
||||
index := make(map[string]*PublicKey)
|
||||
for i, jwk := range jwks {
|
||||
keys[i] = *NewPublicKey(jwk)
|
||||
index[keys[i].ID()] = &keys[i]
|
||||
}
|
||||
return &PublicKeySet{
|
||||
keys: keys,
|
||||
index: index,
|
||||
expiresAt: exp,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PublicKeySet) ExpiresAt() time.Time {
|
||||
return s.expiresAt
|
||||
}
|
||||
|
||||
func (s *PublicKeySet) Keys() []PublicKey {
|
||||
return s.keys
|
||||
}
|
||||
|
||||
func (s *PublicKeySet) Key(id string) *PublicKey {
|
||||
return s.index[id]
|
||||
}
|
||||
|
||||
type PrivateKeySet struct {
|
||||
keys []*PrivateKey
|
||||
ActiveKeyID string
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
func NewPrivateKeySet(keys []*PrivateKey, exp time.Time) *PrivateKeySet {
|
||||
return &PrivateKeySet{
|
||||
keys: keys,
|
||||
ActiveKeyID: keys[0].ID(),
|
||||
expiresAt: exp.UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PrivateKeySet) Keys() []*PrivateKey {
|
||||
return s.keys
|
||||
}
|
||||
|
||||
func (s *PrivateKeySet) ExpiresAt() time.Time {
|
||||
return s.expiresAt
|
||||
}
|
||||
|
||||
func (s *PrivateKeySet) Active() *PrivateKey {
|
||||
for i, k := range s.keys {
|
||||
if k.ID() == s.ActiveKeyID {
|
||||
return s.keys[i]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type GeneratePrivateKeyFunc func() (*PrivateKey, error)
|
||||
|
||||
func GeneratePrivateKey() (*PrivateKey, error) {
|
||||
pk, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyID := make([]byte, 20)
|
||||
if _, err := io.ReadFull(rand.Reader, keyID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
k := PrivateKey{
|
||||
KeyID: hex.EncodeToString(keyID),
|
||||
PrivateKey: pk,
|
||||
}
|
||||
|
||||
return &k, nil
|
||||
}
|
||||
103
vendor/github.com/coreos/go-oidc/key/key_test.go
generated
vendored
103
vendor/github.com/coreos/go-oidc/key/key_test.go
generated
vendored
@@ -1,103 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
func TestPrivateRSAKeyJWK(t *testing.T) {
|
||||
n := big.NewInt(int64(17))
|
||||
if n == nil {
|
||||
panic("NewInt returned nil")
|
||||
}
|
||||
|
||||
k := &PrivateKey{
|
||||
KeyID: "foo",
|
||||
PrivateKey: &rsa.PrivateKey{
|
||||
PublicKey: rsa.PublicKey{N: n, E: 65537},
|
||||
},
|
||||
}
|
||||
|
||||
want := jose.JWK{
|
||||
ID: "foo",
|
||||
Type: "RSA",
|
||||
Alg: "RS256",
|
||||
Use: "sig",
|
||||
Modulus: n,
|
||||
Exponent: 65537,
|
||||
}
|
||||
|
||||
got := k.JWK()
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Fatalf("JWK mismatch: want=%#v got=%#v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeySetKey(t *testing.T) {
|
||||
n := big.NewInt(int64(17))
|
||||
if n == nil {
|
||||
panic("NewInt returned nil")
|
||||
}
|
||||
|
||||
k := jose.JWK{
|
||||
ID: "foo",
|
||||
Type: "RSA",
|
||||
Alg: "RS256",
|
||||
Use: "sig",
|
||||
Modulus: n,
|
||||
Exponent: 65537,
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
ks := NewPublicKeySet([]jose.JWK{k}, now)
|
||||
|
||||
want := &PublicKey{jwk: k}
|
||||
got := ks.Key("foo")
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Unexpected response from PublicKeySet.Key: want=%#v got=%#v", want, got)
|
||||
}
|
||||
|
||||
got = ks.Key("bar")
|
||||
if got != nil {
|
||||
t.Errorf("Expected nil response from PublicKeySet.Key, got %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeyMarshalJSON(t *testing.T) {
|
||||
k := jose.JWK{
|
||||
ID: "foo",
|
||||
Type: "RSA",
|
||||
Alg: "RS256",
|
||||
Use: "sig",
|
||||
Modulus: big.NewInt(int64(17)),
|
||||
Exponent: 65537,
|
||||
}
|
||||
want := `{"kid":"foo","kty":"RSA","alg":"RS256","use":"sig","e":"AQAB","n":"EQ"}`
|
||||
pubKey := NewPublicKey(k)
|
||||
gotBytes, err := pubKey.MarshalJSON()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal public key: %v", err)
|
||||
}
|
||||
got := string(gotBytes)
|
||||
if got != want {
|
||||
t.Errorf("got != want:\n%s\n%s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeneratePrivateKeyIDs(t *testing.T) {
|
||||
key1, err := GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("GeneratePrivateKey(): %v", err)
|
||||
}
|
||||
key2, err := GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("GeneratePrivateKey(): %v", err)
|
||||
}
|
||||
if key1.KeyID == key2.KeyID {
|
||||
t.Fatalf("expected different keys to have different key IDs")
|
||||
}
|
||||
}
|
||||
99
vendor/github.com/coreos/go-oidc/key/manager.go
generated
vendored
99
vendor/github.com/coreos/go-oidc/key/manager.go
generated
vendored
@@ -1,99 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/pkg/health"
|
||||
)
|
||||
|
||||
type PrivateKeyManager interface {
|
||||
ExpiresAt() time.Time
|
||||
Signer() (jose.Signer, error)
|
||||
JWKs() ([]jose.JWK, error)
|
||||
PublicKeys() ([]PublicKey, error)
|
||||
|
||||
WritableKeySetRepo
|
||||
health.Checkable
|
||||
}
|
||||
|
||||
func NewPrivateKeyManager() PrivateKeyManager {
|
||||
return &privateKeyManager{
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
type privateKeyManager struct {
|
||||
keySet *PrivateKeySet
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) ExpiresAt() time.Time {
|
||||
if m.keySet == nil {
|
||||
return m.clock.Now().UTC()
|
||||
}
|
||||
|
||||
return m.keySet.ExpiresAt()
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) Signer() (jose.Signer, error) {
|
||||
if err := m.Healthy(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.keySet.Active().Signer(), nil
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) JWKs() ([]jose.JWK, error) {
|
||||
if err := m.Healthy(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := m.keySet.Keys()
|
||||
jwks := make([]jose.JWK, len(keys))
|
||||
for i, k := range keys {
|
||||
jwks[i] = k.JWK()
|
||||
}
|
||||
return jwks, nil
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) PublicKeys() ([]PublicKey, error) {
|
||||
jwks, err := m.JWKs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys := make([]PublicKey, len(jwks))
|
||||
for i, jwk := range jwks {
|
||||
keys[i] = *NewPublicKey(jwk)
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) Healthy() error {
|
||||
if m.keySet == nil {
|
||||
return errors.New("private key manager uninitialized")
|
||||
}
|
||||
|
||||
if len(m.keySet.Keys()) == 0 {
|
||||
return errors.New("private key manager zero keys")
|
||||
}
|
||||
|
||||
if m.keySet.ExpiresAt().Before(m.clock.Now().UTC()) {
|
||||
return errors.New("private key manager keys expired")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *privateKeyManager) Set(keySet KeySet) error {
|
||||
privKeySet, ok := keySet.(*PrivateKeySet)
|
||||
if !ok {
|
||||
return errors.New("unable to cast to PrivateKeySet")
|
||||
}
|
||||
|
||||
m.keySet = privKeySet
|
||||
return nil
|
||||
}
|
||||
225
vendor/github.com/coreos/go-oidc/key/manager_test.go
generated
vendored
225
vendor/github.com/coreos/go-oidc/key/manager_test.go
generated
vendored
@@ -1,225 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
var (
|
||||
jwk1 jose.JWK
|
||||
jwk2 jose.JWK
|
||||
jwk3 jose.JWK
|
||||
)
|
||||
|
||||
func init() {
|
||||
jwk1 = jose.JWK{
|
||||
ID: "1",
|
||||
Type: "RSA",
|
||||
Alg: "RS256",
|
||||
Use: "sig",
|
||||
Modulus: big.NewInt(1),
|
||||
Exponent: 65537,
|
||||
}
|
||||
|
||||
jwk2 = jose.JWK{
|
||||
ID: "2",
|
||||
Type: "RSA",
|
||||
Alg: "RS256",
|
||||
Use: "sig",
|
||||
Modulus: big.NewInt(2),
|
||||
Exponent: 65537,
|
||||
}
|
||||
|
||||
jwk3 = jose.JWK{
|
||||
ID: "3",
|
||||
Type: "RSA",
|
||||
Alg: "RS256",
|
||||
Use: "sig",
|
||||
Modulus: big.NewInt(3),
|
||||
Exponent: 65537,
|
||||
}
|
||||
}
|
||||
|
||||
func generatePrivateKeyStatic(t *testing.T, idAndN int) *PrivateKey {
|
||||
n := big.NewInt(int64(idAndN))
|
||||
if n == nil {
|
||||
t.Fatalf("Call to NewInt(%d) failed", idAndN)
|
||||
}
|
||||
|
||||
pk := &rsa.PrivateKey{
|
||||
PublicKey: rsa.PublicKey{N: n, E: 65537},
|
||||
}
|
||||
|
||||
return &PrivateKey{
|
||||
KeyID: strconv.Itoa(idAndN),
|
||||
PrivateKey: pk,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyManagerJWKsRotate(t *testing.T) {
|
||||
k1 := generatePrivateKeyStatic(t, 1)
|
||||
k2 := generatePrivateKeyStatic(t, 2)
|
||||
k3 := generatePrivateKeyStatic(t, 3)
|
||||
km := NewPrivateKeyManager()
|
||||
err := km.Set(&PrivateKeySet{
|
||||
keys: []*PrivateKey{k1, k2, k3},
|
||||
ActiveKeyID: k1.KeyID,
|
||||
expiresAt: time.Now().Add(time.Minute),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
want := []jose.JWK{jwk1, jwk2, jwk3}
|
||||
got, err := km.JWKs()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Fatalf("JWK mismatch: want=%#v got=%#v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyManagerSigner(t *testing.T) {
|
||||
k := generatePrivateKeyStatic(t, 13)
|
||||
|
||||
km := NewPrivateKeyManager()
|
||||
err := km.Set(&PrivateKeySet{
|
||||
keys: []*PrivateKey{k},
|
||||
ActiveKeyID: k.KeyID,
|
||||
expiresAt: time.Now().Add(time.Minute),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
signer, err := km.Signer()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
wantID := "13"
|
||||
gotID := signer.ID()
|
||||
if wantID != gotID {
|
||||
t.Fatalf("Signer has incorrect ID: want=%s got=%s", wantID, gotID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyManagerHealthyFail(t *testing.T) {
|
||||
keyFixture := generatePrivateKeyStatic(t, 1)
|
||||
tests := []*privateKeyManager{
|
||||
// keySet nil
|
||||
&privateKeyManager{
|
||||
keySet: nil,
|
||||
clock: clockwork.NewRealClock(),
|
||||
},
|
||||
// zero keys
|
||||
&privateKeyManager{
|
||||
keySet: &PrivateKeySet{
|
||||
keys: []*PrivateKey{},
|
||||
expiresAt: time.Now().Add(time.Minute),
|
||||
},
|
||||
clock: clockwork.NewRealClock(),
|
||||
},
|
||||
// key set expired
|
||||
&privateKeyManager{
|
||||
keySet: &PrivateKeySet{
|
||||
keys: []*PrivateKey{keyFixture},
|
||||
expiresAt: time.Now().Add(-1 * time.Minute),
|
||||
},
|
||||
clock: clockwork.NewRealClock(),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
if err := tt.Healthy(); err == nil {
|
||||
t.Errorf("case %d: nil error", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyManagerHealthyFailsOtherMethods(t *testing.T) {
|
||||
km := NewPrivateKeyManager()
|
||||
if _, err := km.JWKs(); err == nil {
|
||||
t.Fatalf("Expected non-nil error")
|
||||
}
|
||||
if _, err := km.Signer(); err == nil {
|
||||
t.Fatalf("Expected non-nil error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyManagerExpiresAt(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
now := fc.Now().UTC()
|
||||
|
||||
k := generatePrivateKeyStatic(t, 17)
|
||||
km := &privateKeyManager{
|
||||
clock: fc,
|
||||
}
|
||||
|
||||
want := fc.Now().UTC()
|
||||
got := km.ExpiresAt()
|
||||
if want != got {
|
||||
t.Fatalf("Incorrect expiration time: want=%v got=%v", want, got)
|
||||
}
|
||||
|
||||
err := km.Set(&PrivateKeySet{
|
||||
keys: []*PrivateKey{k},
|
||||
ActiveKeyID: k.KeyID,
|
||||
expiresAt: now.Add(2 * time.Minute),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
want = fc.Now().UTC().Add(2 * time.Minute)
|
||||
got = km.ExpiresAt()
|
||||
if want != got {
|
||||
t.Fatalf("Incorrect expiration time: want=%v got=%v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeys(t *testing.T) {
|
||||
km := NewPrivateKeyManager()
|
||||
k1 := generatePrivateKeyStatic(t, 1)
|
||||
k2 := generatePrivateKeyStatic(t, 2)
|
||||
k3 := generatePrivateKeyStatic(t, 3)
|
||||
|
||||
tests := [][]*PrivateKey{
|
||||
[]*PrivateKey{k1},
|
||||
[]*PrivateKey{k1, k2},
|
||||
[]*PrivateKey{k1, k2, k3},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
ks := &PrivateKeySet{
|
||||
keys: tt,
|
||||
expiresAt: time.Now().Add(time.Hour),
|
||||
}
|
||||
km.Set(ks)
|
||||
|
||||
jwks, err := km.JWKs()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
pks := NewPublicKeySet(jwks, time.Now().Add(time.Hour))
|
||||
want := pks.Keys()
|
||||
got, err := km.PublicKeys()
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("case %d: Invalid public keys: want=%v got=%v", i, want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
55
vendor/github.com/coreos/go-oidc/key/repo.go
generated
vendored
55
vendor/github.com/coreos/go-oidc/key/repo.go
generated
vendored
@@ -1,55 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrorNoKeys = errors.New("no keys found")
|
||||
|
||||
type WritableKeySetRepo interface {
|
||||
Set(KeySet) error
|
||||
}
|
||||
|
||||
type ReadableKeySetRepo interface {
|
||||
Get() (KeySet, error)
|
||||
}
|
||||
|
||||
type PrivateKeySetRepo interface {
|
||||
WritableKeySetRepo
|
||||
ReadableKeySetRepo
|
||||
}
|
||||
|
||||
func NewPrivateKeySetRepo() PrivateKeySetRepo {
|
||||
return &memPrivateKeySetRepo{}
|
||||
}
|
||||
|
||||
type memPrivateKeySetRepo struct {
|
||||
mu sync.RWMutex
|
||||
pks PrivateKeySet
|
||||
}
|
||||
|
||||
func (r *memPrivateKeySetRepo) Set(ks KeySet) error {
|
||||
pks, ok := ks.(*PrivateKeySet)
|
||||
if !ok {
|
||||
return errors.New("unable to cast to PrivateKeySet")
|
||||
} else if pks == nil {
|
||||
return errors.New("nil KeySet")
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.pks = *pks
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *memPrivateKeySetRepo) Get() (KeySet, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
if r.pks.keys == nil {
|
||||
return nil, ErrorNoKeys
|
||||
}
|
||||
return KeySet(&r.pks), nil
|
||||
}
|
||||
159
vendor/github.com/coreos/go-oidc/key/rotate.go
generated
vendored
159
vendor/github.com/coreos/go-oidc/key/rotate.go
generated
vendored
@@ -1,159 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
ptime "github.com/coreos/pkg/timeutil"
|
||||
"github.com/jonboulle/clockwork"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorPrivateKeysExpired = errors.New("private keys have expired")
|
||||
)
|
||||
|
||||
func NewPrivateKeyRotator(repo PrivateKeySetRepo, ttl time.Duration) *PrivateKeyRotator {
|
||||
return &PrivateKeyRotator{
|
||||
repo: repo,
|
||||
ttl: ttl,
|
||||
|
||||
keep: 2,
|
||||
generateKey: GeneratePrivateKey,
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
type PrivateKeyRotator struct {
|
||||
repo PrivateKeySetRepo
|
||||
generateKey GeneratePrivateKeyFunc
|
||||
clock clockwork.Clock
|
||||
keep int
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func (r *PrivateKeyRotator) expiresAt() time.Time {
|
||||
return r.clock.Now().UTC().Add(r.ttl)
|
||||
}
|
||||
|
||||
func (r *PrivateKeyRotator) Healthy() error {
|
||||
pks, err := r.privateKeySet()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.clock.Now().After(pks.ExpiresAt()) {
|
||||
return ErrorPrivateKeysExpired
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *PrivateKeyRotator) privateKeySet() (*PrivateKeySet, error) {
|
||||
ks, err := r.repo.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pks, ok := ks.(*PrivateKeySet)
|
||||
if !ok {
|
||||
return nil, errors.New("unable to cast to PrivateKeySet")
|
||||
}
|
||||
return pks, nil
|
||||
}
|
||||
|
||||
func (r *PrivateKeyRotator) nextRotation() (time.Duration, error) {
|
||||
pks, err := r.privateKeySet()
|
||||
if err == ErrorNoKeys {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
now := r.clock.Now()
|
||||
|
||||
// Ideally, we want to rotate after half the TTL has elapsed.
|
||||
idealRotationTime := pks.ExpiresAt().Add(-r.ttl / 2)
|
||||
|
||||
// If we are past the ideal rotation time, rotate immediatly.
|
||||
return max(0, idealRotationTime.Sub(now)), nil
|
||||
}
|
||||
|
||||
func max(a, b time.Duration) time.Duration {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *PrivateKeyRotator) Run() chan struct{} {
|
||||
attempt := func() {
|
||||
k, err := r.generateKey()
|
||||
if err != nil {
|
||||
log.Printf("go-oidc: failed generating signing key: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
exp := r.expiresAt()
|
||||
if err := rotatePrivateKeys(r.repo, k, r.keep, exp); err != nil {
|
||||
log.Printf("go-oidc: key rotation failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
stop := make(chan struct{})
|
||||
go func() {
|
||||
for {
|
||||
var nextRotation time.Duration
|
||||
var sleep time.Duration
|
||||
var err error
|
||||
for {
|
||||
if nextRotation, err = r.nextRotation(); err == nil {
|
||||
break
|
||||
}
|
||||
sleep = ptime.ExpBackoff(sleep, time.Minute)
|
||||
log.Printf("go-oidc: error getting nextRotation, retrying in %v: %v", sleep, err)
|
||||
time.Sleep(sleep)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-r.clock.After(nextRotation):
|
||||
attempt()
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return stop
|
||||
}
|
||||
|
||||
func rotatePrivateKeys(repo PrivateKeySetRepo, k *PrivateKey, keep int, exp time.Time) error {
|
||||
ks, err := repo.Get()
|
||||
if err != nil && err != ErrorNoKeys {
|
||||
return err
|
||||
}
|
||||
|
||||
var keys []*PrivateKey
|
||||
if ks != nil {
|
||||
pks, ok := ks.(*PrivateKeySet)
|
||||
if !ok {
|
||||
return errors.New("unable to cast to PrivateKeySet")
|
||||
}
|
||||
keys = pks.Keys()
|
||||
}
|
||||
|
||||
keys = append([]*PrivateKey{k}, keys...)
|
||||
if l := len(keys); l > keep {
|
||||
keys = keys[0:keep]
|
||||
}
|
||||
|
||||
nks := PrivateKeySet{
|
||||
keys: keys,
|
||||
ActiveKeyID: k.ID(),
|
||||
expiresAt: exp,
|
||||
}
|
||||
|
||||
return repo.Set(KeySet(&nks))
|
||||
}
|
||||
311
vendor/github.com/coreos/go-oidc/key/rotate_test.go
generated
vendored
311
vendor/github.com/coreos/go-oidc/key/rotate_test.go
generated
vendored
@@ -1,311 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
)
|
||||
|
||||
func generatePrivateKeySerialFunc(t *testing.T) GeneratePrivateKeyFunc {
|
||||
var n int
|
||||
return func() (*PrivateKey, error) {
|
||||
n++
|
||||
return generatePrivateKeyStatic(t, n), nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotate(t *testing.T) {
|
||||
now := time.Now()
|
||||
k1 := generatePrivateKeyStatic(t, 1)
|
||||
k2 := generatePrivateKeyStatic(t, 2)
|
||||
k3 := generatePrivateKeyStatic(t, 3)
|
||||
|
||||
tests := []struct {
|
||||
start *PrivateKeySet
|
||||
key *PrivateKey
|
||||
keep int
|
||||
exp time.Time
|
||||
want *PrivateKeySet
|
||||
}{
|
||||
// start with nil keys
|
||||
{
|
||||
start: nil,
|
||||
key: k1,
|
||||
keep: 2,
|
||||
exp: now.Add(time.Second),
|
||||
want: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k1},
|
||||
ActiveKeyID: k1.KeyID,
|
||||
expiresAt: now.Add(time.Second),
|
||||
},
|
||||
},
|
||||
// start with zero keys
|
||||
{
|
||||
start: &PrivateKeySet{},
|
||||
key: k1,
|
||||
keep: 2,
|
||||
exp: now.Add(time.Second),
|
||||
want: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k1},
|
||||
ActiveKeyID: k1.KeyID,
|
||||
expiresAt: now.Add(time.Second),
|
||||
},
|
||||
},
|
||||
// add second key
|
||||
{
|
||||
start: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k1},
|
||||
ActiveKeyID: k1.KeyID,
|
||||
expiresAt: now,
|
||||
},
|
||||
key: k2,
|
||||
keep: 2,
|
||||
exp: now.Add(time.Second),
|
||||
want: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k2, k1},
|
||||
ActiveKeyID: k2.KeyID,
|
||||
expiresAt: now.Add(time.Second),
|
||||
},
|
||||
},
|
||||
// rotate in third key
|
||||
{
|
||||
start: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k2, k1},
|
||||
ActiveKeyID: k2.KeyID,
|
||||
expiresAt: now,
|
||||
},
|
||||
key: k3,
|
||||
keep: 2,
|
||||
exp: now.Add(time.Second),
|
||||
want: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k3, k2},
|
||||
ActiveKeyID: k3.KeyID,
|
||||
expiresAt: now.Add(time.Second),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
repo := NewPrivateKeySetRepo()
|
||||
if tt.start != nil {
|
||||
err := repo.Set(tt.start)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
rotatePrivateKeys(repo, tt.key, tt.keep, tt.exp)
|
||||
got, err := repo.Get()
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(tt.want, got) {
|
||||
t.Errorf("case %d: unexpected result: want=%#v got=%#v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyRotatorRun(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
now := fc.Now().UTC()
|
||||
|
||||
k1 := generatePrivateKeyStatic(t, 1)
|
||||
k2 := generatePrivateKeyStatic(t, 2)
|
||||
k3 := generatePrivateKeyStatic(t, 3)
|
||||
k4 := generatePrivateKeyStatic(t, 4)
|
||||
|
||||
kRepo := NewPrivateKeySetRepo()
|
||||
krot := NewPrivateKeyRotator(kRepo, 4*time.Second)
|
||||
krot.clock = fc
|
||||
krot.generateKey = generatePrivateKeySerialFunc(t)
|
||||
|
||||
steps := []*PrivateKeySet{
|
||||
&PrivateKeySet{
|
||||
keys: []*PrivateKey{k1},
|
||||
ActiveKeyID: k1.KeyID,
|
||||
expiresAt: now.Add(4 * time.Second),
|
||||
},
|
||||
&PrivateKeySet{
|
||||
keys: []*PrivateKey{k2, k1},
|
||||
ActiveKeyID: k2.KeyID,
|
||||
expiresAt: now.Add(6 * time.Second),
|
||||
},
|
||||
&PrivateKeySet{
|
||||
keys: []*PrivateKey{k3, k2},
|
||||
ActiveKeyID: k3.KeyID,
|
||||
expiresAt: now.Add(8 * time.Second),
|
||||
},
|
||||
&PrivateKeySet{
|
||||
keys: []*PrivateKey{k4, k3},
|
||||
ActiveKeyID: k4.KeyID,
|
||||
expiresAt: now.Add(10 * time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
stop := krot.Run()
|
||||
defer close(stop)
|
||||
|
||||
for i, st := range steps {
|
||||
// wait for the rotater to get sleepy
|
||||
fc.BlockUntil(1)
|
||||
|
||||
got, err := kRepo.Get()
|
||||
if err != nil {
|
||||
t.Fatalf("step %d: unexpected error: %v", i, err)
|
||||
}
|
||||
if !reflect.DeepEqual(st, got) {
|
||||
t.Fatalf("step %d: unexpected state: want=%#v got=%#v", i, st, got)
|
||||
}
|
||||
fc.Advance(2 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivateKeyRotatorExpiresAt(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
krot := &PrivateKeyRotator{
|
||||
clock: fc,
|
||||
ttl: time.Minute,
|
||||
}
|
||||
got := krot.expiresAt()
|
||||
want := fc.Now().UTC().Add(time.Minute)
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Incorrect expiration time: want=%v got=%v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextRotation(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
now := fc.Now().UTC()
|
||||
|
||||
tests := []struct {
|
||||
expiresAt time.Time
|
||||
ttl time.Duration
|
||||
numKeys int
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
// closest to prod
|
||||
expiresAt: now.Add(time.Hour * 24),
|
||||
ttl: time.Hour * 24,
|
||||
numKeys: 2,
|
||||
expected: time.Hour * 12,
|
||||
},
|
||||
{
|
||||
expiresAt: now.Add(time.Hour * 2),
|
||||
ttl: time.Hour * 4,
|
||||
numKeys: 2,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
// No keys.
|
||||
expiresAt: now.Add(time.Hour * 2),
|
||||
ttl: time.Hour * 4,
|
||||
numKeys: 0,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
// Nil keyset.
|
||||
expiresAt: now.Add(time.Hour * 2),
|
||||
ttl: time.Hour * 4,
|
||||
numKeys: -1,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
// KeySet expired.
|
||||
expiresAt: now.Add(time.Hour * -2),
|
||||
ttl: time.Hour * 4,
|
||||
numKeys: 2,
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
// Expiry past now + TTL
|
||||
expiresAt: now.Add(time.Hour * 5),
|
||||
ttl: time.Hour * 4,
|
||||
numKeys: 2,
|
||||
expected: 3 * time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
kRepo := NewPrivateKeySetRepo()
|
||||
krot := NewPrivateKeyRotator(kRepo, tt.ttl)
|
||||
krot.clock = fc
|
||||
pks := &PrivateKeySet{
|
||||
expiresAt: tt.expiresAt,
|
||||
}
|
||||
if tt.numKeys != -1 {
|
||||
for n := 0; n < tt.numKeys; n++ {
|
||||
pks.keys = append(pks.keys, generatePrivateKeyStatic(t, n))
|
||||
}
|
||||
err := kRepo.Set(pks)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
}
|
||||
actual, err := krot.nextRotation()
|
||||
if err != nil {
|
||||
t.Errorf("case %d: error calling shouldRotate(): %v", i, err)
|
||||
}
|
||||
if actual != tt.expected {
|
||||
t.Errorf("case %d: actual == %v, want %v", i, actual, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthy(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
now := fc.Now().UTC()
|
||||
|
||||
tests := []struct {
|
||||
expiresAt time.Time
|
||||
numKeys int
|
||||
expected error
|
||||
}{
|
||||
{
|
||||
expiresAt: now.Add(time.Hour),
|
||||
numKeys: 2,
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
expiresAt: now.Add(time.Hour),
|
||||
numKeys: -1,
|
||||
expected: ErrorNoKeys,
|
||||
},
|
||||
{
|
||||
expiresAt: now.Add(time.Hour),
|
||||
numKeys: 0,
|
||||
expected: ErrorNoKeys,
|
||||
},
|
||||
{
|
||||
expiresAt: now.Add(-time.Hour),
|
||||
numKeys: 2,
|
||||
expected: ErrorPrivateKeysExpired,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
kRepo := NewPrivateKeySetRepo()
|
||||
krot := NewPrivateKeyRotator(kRepo, time.Hour)
|
||||
krot.clock = fc
|
||||
pks := &PrivateKeySet{
|
||||
expiresAt: tt.expiresAt,
|
||||
}
|
||||
if tt.numKeys != -1 {
|
||||
for n := 0; n < tt.numKeys; n++ {
|
||||
pks.keys = append(pks.keys, generatePrivateKeyStatic(t, n))
|
||||
}
|
||||
err := kRepo.Set(pks)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
}
|
||||
if err := krot.Healthy(); err != tt.expected {
|
||||
t.Errorf("case %d: got==%q, want==%q", i, err, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
91
vendor/github.com/coreos/go-oidc/key/sync.go
generated
vendored
91
vendor/github.com/coreos/go-oidc/key/sync.go
generated
vendored
@@ -1,91 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/pkg/timeutil"
|
||||
)
|
||||
|
||||
func NewKeySetSyncer(r ReadableKeySetRepo, w WritableKeySetRepo) *KeySetSyncer {
|
||||
return &KeySetSyncer{
|
||||
readable: r,
|
||||
writable: w,
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
type KeySetSyncer struct {
|
||||
readable ReadableKeySetRepo
|
||||
writable WritableKeySetRepo
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func (s *KeySetSyncer) Run() chan struct{} {
|
||||
stop := make(chan struct{})
|
||||
go func() {
|
||||
var failing bool
|
||||
var next time.Duration
|
||||
for {
|
||||
exp, err := syncKeySet(s.readable, s.writable, s.clock)
|
||||
if err != nil || exp == 0 {
|
||||
if !failing {
|
||||
failing = true
|
||||
next = time.Second
|
||||
} else {
|
||||
next = timeutil.ExpBackoff(next, time.Minute)
|
||||
}
|
||||
if exp == 0 {
|
||||
log.Printf("Synced to already expired key set, retrying in %v: %v", next, err)
|
||||
|
||||
} else {
|
||||
log.Printf("Failed syncing key set, retrying in %v: %v", next, err)
|
||||
}
|
||||
} else {
|
||||
failing = false
|
||||
next = exp / 2
|
||||
}
|
||||
|
||||
select {
|
||||
case <-s.clock.After(next):
|
||||
continue
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return stop
|
||||
}
|
||||
|
||||
func Sync(r ReadableKeySetRepo, w WritableKeySetRepo) (time.Duration, error) {
|
||||
return syncKeySet(r, w, clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
// syncKeySet copies the keyset from r to the KeySet at w and returns the duration in which the KeySet will expire.
|
||||
// If keyset has already expired, returns a zero duration.
|
||||
func syncKeySet(r ReadableKeySetRepo, w WritableKeySetRepo, clock clockwork.Clock) (exp time.Duration, err error) {
|
||||
var ks KeySet
|
||||
ks, err = r.Get()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if ks == nil {
|
||||
err = errors.New("no source KeySet")
|
||||
return
|
||||
}
|
||||
|
||||
if err = w.Set(ks); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
now := clock.Now()
|
||||
if ks.ExpiresAt().After(now) {
|
||||
exp = ks.ExpiresAt().Sub(now)
|
||||
}
|
||||
return
|
||||
}
|
||||
214
vendor/github.com/coreos/go-oidc/key/sync_test.go
generated
vendored
214
vendor/github.com/coreos/go-oidc/key/sync_test.go
generated
vendored
@@ -1,214 +0,0 @@
|
||||
package key
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
)
|
||||
|
||||
type staticReadableKeySetRepo struct {
|
||||
mu sync.RWMutex
|
||||
ks KeySet
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *staticReadableKeySetRepo) Get() (KeySet, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.ks, r.err
|
||||
}
|
||||
|
||||
func (r *staticReadableKeySetRepo) set(ks KeySet, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.ks, r.err = ks, err
|
||||
}
|
||||
|
||||
func TestKeySyncerSync(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
now := fc.Now().UTC()
|
||||
|
||||
k1 := generatePrivateKeyStatic(t, 1)
|
||||
k2 := generatePrivateKeyStatic(t, 2)
|
||||
k3 := generatePrivateKeyStatic(t, 3)
|
||||
|
||||
steps := []struct {
|
||||
fromKS KeySet
|
||||
fromErr error
|
||||
advance time.Duration
|
||||
want *PrivateKeySet
|
||||
}{
|
||||
// on startup, first sync should trigger within a second
|
||||
{
|
||||
fromKS: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k1},
|
||||
ActiveKeyID: k1.KeyID,
|
||||
expiresAt: now.Add(10 * time.Second),
|
||||
},
|
||||
advance: time.Second,
|
||||
want: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k1},
|
||||
ActiveKeyID: k1.KeyID,
|
||||
expiresAt: now.Add(10 * time.Second),
|
||||
},
|
||||
},
|
||||
// advance halfway into TTL, triggering sync
|
||||
{
|
||||
fromKS: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k2, k1},
|
||||
ActiveKeyID: k2.KeyID,
|
||||
expiresAt: now.Add(15 * time.Second),
|
||||
},
|
||||
advance: 5 * time.Second,
|
||||
want: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k2, k1},
|
||||
ActiveKeyID: k2.KeyID,
|
||||
expiresAt: now.Add(15 * time.Second),
|
||||
},
|
||||
},
|
||||
|
||||
// advance halfway into TTL, triggering sync that fails
|
||||
{
|
||||
fromErr: errors.New("fail!"),
|
||||
advance: 10 * time.Second,
|
||||
want: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k2, k1},
|
||||
ActiveKeyID: k2.KeyID,
|
||||
expiresAt: now.Add(15 * time.Second),
|
||||
},
|
||||
},
|
||||
|
||||
// sync retries quickly, and succeeds with fixed data
|
||||
{
|
||||
fromKS: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k3, k2, k1},
|
||||
ActiveKeyID: k3.KeyID,
|
||||
expiresAt: now.Add(25 * time.Second),
|
||||
},
|
||||
advance: 3 * time.Second,
|
||||
want: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k3, k2, k1},
|
||||
ActiveKeyID: k3.KeyID,
|
||||
expiresAt: now.Add(25 * time.Second),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
from := &staticReadableKeySetRepo{}
|
||||
to := NewPrivateKeySetRepo()
|
||||
|
||||
syncer := NewKeySetSyncer(from, to)
|
||||
syncer.clock = fc
|
||||
stop := syncer.Run()
|
||||
defer close(stop)
|
||||
|
||||
for i, st := range steps {
|
||||
from.set(st.fromKS, st.fromErr)
|
||||
|
||||
fc.Advance(st.advance)
|
||||
fc.BlockUntil(1)
|
||||
|
||||
ks, err := to.Get()
|
||||
if err != nil {
|
||||
t.Fatalf("step %d: unable to get keys: %v", i, err)
|
||||
}
|
||||
if !reflect.DeepEqual(st.want, ks) {
|
||||
t.Fatalf("step %d: incorrect state: want=%#v got=%#v", i, st.want, ks)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSync(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
now := fc.Now().UTC()
|
||||
|
||||
k1 := generatePrivateKeyStatic(t, 1)
|
||||
k2 := generatePrivateKeyStatic(t, 2)
|
||||
k3 := generatePrivateKeyStatic(t, 3)
|
||||
|
||||
tests := []struct {
|
||||
keySet *PrivateKeySet
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
keySet: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k1},
|
||||
ActiveKeyID: k1.KeyID,
|
||||
expiresAt: now.Add(time.Minute),
|
||||
},
|
||||
want: time.Minute,
|
||||
},
|
||||
{
|
||||
keySet: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k2, k1},
|
||||
ActiveKeyID: k2.KeyID,
|
||||
expiresAt: now.Add(time.Minute),
|
||||
},
|
||||
want: time.Minute,
|
||||
},
|
||||
{
|
||||
keySet: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k3, k2, k1},
|
||||
ActiveKeyID: k2.KeyID,
|
||||
expiresAt: now.Add(time.Minute),
|
||||
},
|
||||
want: time.Minute,
|
||||
},
|
||||
{
|
||||
keySet: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k2, k1},
|
||||
ActiveKeyID: k2.KeyID,
|
||||
expiresAt: now.Add(time.Hour),
|
||||
},
|
||||
want: time.Hour,
|
||||
},
|
||||
{
|
||||
keySet: &PrivateKeySet{
|
||||
keys: []*PrivateKey{k1},
|
||||
ActiveKeyID: k1.KeyID,
|
||||
expiresAt: now.Add(-time.Hour),
|
||||
},
|
||||
want: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
from := NewPrivateKeySetRepo()
|
||||
to := NewPrivateKeySetRepo()
|
||||
|
||||
err := from.Set(tt.keySet)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
exp, err := syncKeySet(from, to, fc)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if tt.want != exp {
|
||||
t.Errorf("case %d: want=%v got=%v", i, tt.want, exp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncFail(t *testing.T) {
|
||||
tests := []error{
|
||||
nil,
|
||||
errors.New("fail!"),
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
from := &staticReadableKeySetRepo{ks: nil, err: tt}
|
||||
to := NewPrivateKeySetRepo()
|
||||
|
||||
if _, err := syncKeySet(from, to, clockwork.NewFakeClock()); err == nil {
|
||||
t.Errorf("case %d: expected non-nil error", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
2
vendor/github.com/coreos/go-oidc/oauth2/doc.go
generated
vendored
2
vendor/github.com/coreos/go-oidc/oauth2/doc.go
generated
vendored
@@ -1,2 +0,0 @@
|
||||
// Package oauth2 is DEPRECATED. Use golang.org/x/oauth instead.
|
||||
package oauth2
|
||||
29
vendor/github.com/coreos/go-oidc/oauth2/error.go
generated
vendored
29
vendor/github.com/coreos/go-oidc/oauth2/error.go
generated
vendored
@@ -1,29 +0,0 @@
|
||||
package oauth2
|
||||
|
||||
const (
|
||||
ErrorAccessDenied = "access_denied"
|
||||
ErrorInvalidClient = "invalid_client"
|
||||
ErrorInvalidGrant = "invalid_grant"
|
||||
ErrorInvalidRequest = "invalid_request"
|
||||
ErrorServerError = "server_error"
|
||||
ErrorUnauthorizedClient = "unauthorized_client"
|
||||
ErrorUnsupportedGrantType = "unsupported_grant_type"
|
||||
ErrorUnsupportedResponseType = "unsupported_response_type"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
Type string `json:"error"`
|
||||
Description string `json:"error_description,omitempty"`
|
||||
State string `json:"state,omitempty"`
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e.Description != "" {
|
||||
return e.Type + ": " + e.Description
|
||||
}
|
||||
return e.Type
|
||||
}
|
||||
|
||||
func NewError(typ string) *Error {
|
||||
return &Error{Type: typ}
|
||||
}
|
||||
416
vendor/github.com/coreos/go-oidc/oauth2/oauth2.go
generated
vendored
416
vendor/github.com/coreos/go-oidc/oauth2/oauth2.go
generated
vendored
@@ -1,416 +0,0 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
phttp "github.com/coreos/go-oidc/http"
|
||||
)
|
||||
|
||||
// ResponseTypesEqual compares two response_type values. If either
|
||||
// contains a space, it is treated as an unordered list. For example,
|
||||
// comparing "code id_token" and "id_token code" would evaluate to true.
|
||||
func ResponseTypesEqual(r1, r2 string) bool {
|
||||
if !strings.Contains(r1, " ") || !strings.Contains(r2, " ") {
|
||||
// fast route, no split needed
|
||||
return r1 == r2
|
||||
}
|
||||
|
||||
// split, sort, and compare
|
||||
r1Fields := strings.Fields(r1)
|
||||
r2Fields := strings.Fields(r2)
|
||||
if len(r1Fields) != len(r2Fields) {
|
||||
return false
|
||||
}
|
||||
sort.Strings(r1Fields)
|
||||
sort.Strings(r2Fields)
|
||||
for i, r1Field := range r1Fields {
|
||||
if r1Field != r2Fields[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
const (
|
||||
// OAuth2.0 response types registered by OIDC.
|
||||
//
|
||||
// See: https://openid.net/specs/oauth-v2-multiple-response-types-1_0.html#RegistryContents
|
||||
ResponseTypeCode = "code"
|
||||
ResponseTypeCodeIDToken = "code id_token"
|
||||
ResponseTypeCodeIDTokenToken = "code id_token token"
|
||||
ResponseTypeIDToken = "id_token"
|
||||
ResponseTypeIDTokenToken = "id_token token"
|
||||
ResponseTypeToken = "token"
|
||||
ResponseTypeNone = "none"
|
||||
)
|
||||
|
||||
const (
|
||||
GrantTypeAuthCode = "authorization_code"
|
||||
GrantTypeClientCreds = "client_credentials"
|
||||
GrantTypeUserCreds = "password"
|
||||
GrantTypeImplicit = "implicit"
|
||||
GrantTypeRefreshToken = "refresh_token"
|
||||
|
||||
AuthMethodClientSecretPost = "client_secret_post"
|
||||
AuthMethodClientSecretBasic = "client_secret_basic"
|
||||
AuthMethodClientSecretJWT = "client_secret_jwt"
|
||||
AuthMethodPrivateKeyJWT = "private_key_jwt"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Credentials ClientCredentials
|
||||
Scope []string
|
||||
RedirectURL string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
|
||||
// Must be one of the AuthMethodXXX methods above. Right now, only
|
||||
// AuthMethodClientSecretPost and AuthMethodClientSecretBasic are supported.
|
||||
AuthMethod string
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
hc phttp.Client
|
||||
creds ClientCredentials
|
||||
scope []string
|
||||
authURL *url.URL
|
||||
redirectURL *url.URL
|
||||
tokenURL *url.URL
|
||||
authMethod string
|
||||
}
|
||||
|
||||
type ClientCredentials struct {
|
||||
ID string
|
||||
Secret string
|
||||
}
|
||||
|
||||
func NewClient(hc phttp.Client, cfg Config) (c *Client, err error) {
|
||||
if len(cfg.Credentials.ID) == 0 {
|
||||
err = errors.New("missing client id")
|
||||
return
|
||||
}
|
||||
|
||||
if len(cfg.Credentials.Secret) == 0 {
|
||||
err = errors.New("missing client secret")
|
||||
return
|
||||
}
|
||||
|
||||
if cfg.AuthMethod == "" {
|
||||
cfg.AuthMethod = AuthMethodClientSecretBasic
|
||||
} else if cfg.AuthMethod != AuthMethodClientSecretPost && cfg.AuthMethod != AuthMethodClientSecretBasic {
|
||||
err = fmt.Errorf("auth method %q is not supported", cfg.AuthMethod)
|
||||
return
|
||||
}
|
||||
|
||||
au, err := phttp.ParseNonEmptyURL(cfg.AuthURL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tu, err := phttp.ParseNonEmptyURL(cfg.TokenURL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Allow empty redirect URL in the case where the client
|
||||
// only needs to verify a given token.
|
||||
ru, err := url.Parse(cfg.RedirectURL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c = &Client{
|
||||
creds: cfg.Credentials,
|
||||
scope: cfg.Scope,
|
||||
redirectURL: ru,
|
||||
authURL: au,
|
||||
tokenURL: tu,
|
||||
hc: hc,
|
||||
authMethod: cfg.AuthMethod,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Return the embedded HTTP client
|
||||
func (c *Client) HttpClient() phttp.Client {
|
||||
return c.hc
|
||||
}
|
||||
|
||||
// Generate the url for initial redirect to oauth provider.
|
||||
func (c *Client) AuthCodeURL(state, accessType, prompt string) string {
|
||||
v := c.commonURLValues()
|
||||
v.Set("state", state)
|
||||
if strings.ToLower(accessType) == "offline" {
|
||||
v.Set("access_type", "offline")
|
||||
}
|
||||
|
||||
if prompt != "" {
|
||||
v.Set("prompt", prompt)
|
||||
}
|
||||
v.Set("response_type", "code")
|
||||
|
||||
q := v.Encode()
|
||||
u := *c.authURL
|
||||
if u.RawQuery == "" {
|
||||
u.RawQuery = q
|
||||
} else {
|
||||
u.RawQuery += "&" + q
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (c *Client) commonURLValues() url.Values {
|
||||
return url.Values{
|
||||
"redirect_uri": {c.redirectURL.String()},
|
||||
"scope": {strings.Join(c.scope, " ")},
|
||||
"client_id": {c.creds.ID},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) newAuthenticatedRequest(urlToken string, values url.Values) (*http.Request, error) {
|
||||
var req *http.Request
|
||||
var err error
|
||||
switch c.authMethod {
|
||||
case AuthMethodClientSecretPost:
|
||||
values.Set("client_secret", c.creds.Secret)
|
||||
req, err = http.NewRequest("POST", urlToken, strings.NewReader(values.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case AuthMethodClientSecretBasic:
|
||||
req, err = http.NewRequest("POST", urlToken, strings.NewReader(values.Encode()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encodedID := url.QueryEscape(c.creds.ID)
|
||||
encodedSecret := url.QueryEscape(c.creds.Secret)
|
||||
req.SetBasicAuth(encodedID, encodedSecret)
|
||||
default:
|
||||
panic("misconfigured client: auth method not supported")
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
return req, nil
|
||||
|
||||
}
|
||||
|
||||
// ClientCredsToken posts the client id and secret to obtain a token scoped to the OAuth2 client via the "client_credentials" grant type.
|
||||
// May not be supported by all OAuth2 servers.
|
||||
func (c *Client) ClientCredsToken(scope []string) (result TokenResponse, err error) {
|
||||
v := url.Values{
|
||||
"scope": {strings.Join(scope, " ")},
|
||||
"grant_type": {GrantTypeClientCreds},
|
||||
}
|
||||
|
||||
req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := c.hc.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return parseTokenResponse(resp)
|
||||
}
|
||||
|
||||
// UserCredsToken posts the username and password to obtain a token scoped to the OAuth2 client via the "password" grant_type
|
||||
// May not be supported by all OAuth2 servers.
|
||||
func (c *Client) UserCredsToken(username, password string) (result TokenResponse, err error) {
|
||||
v := url.Values{
|
||||
"scope": {strings.Join(c.scope, " ")},
|
||||
"grant_type": {GrantTypeUserCreds},
|
||||
"username": {username},
|
||||
"password": {password},
|
||||
}
|
||||
|
||||
req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := c.hc.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return parseTokenResponse(resp)
|
||||
}
|
||||
|
||||
// RequestToken requests a token from the Token Endpoint with the specified grantType.
|
||||
// If 'grantType' == GrantTypeAuthCode, then 'value' should be the authorization code.
|
||||
// If 'grantType' == GrantTypeRefreshToken, then 'value' should be the refresh token.
|
||||
func (c *Client) RequestToken(grantType, value string) (result TokenResponse, err error) {
|
||||
v := c.commonURLValues()
|
||||
|
||||
v.Set("grant_type", grantType)
|
||||
v.Set("client_secret", c.creds.Secret)
|
||||
switch grantType {
|
||||
case GrantTypeAuthCode:
|
||||
v.Set("code", value)
|
||||
case GrantTypeRefreshToken:
|
||||
v.Set("refresh_token", value)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported grant_type: %v", grantType)
|
||||
return
|
||||
}
|
||||
|
||||
req, err := c.newAuthenticatedRequest(c.tokenURL.String(), v)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := c.hc.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return parseTokenResponse(resp)
|
||||
}
|
||||
|
||||
func parseTokenResponse(resp *http.Response) (result TokenResponse, err error) {
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
badStatusCode := resp.StatusCode < 200 || resp.StatusCode > 299
|
||||
|
||||
contentType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result = TokenResponse{
|
||||
RawBody: body,
|
||||
}
|
||||
|
||||
newError := func(typ, desc, state string) error {
|
||||
if typ == "" {
|
||||
return fmt.Errorf("unrecognized error %s", body)
|
||||
}
|
||||
return &Error{typ, desc, state}
|
||||
}
|
||||
|
||||
if contentType == "application/x-www-form-urlencoded" || contentType == "text/plain" {
|
||||
var vals url.Values
|
||||
vals, err = url.ParseQuery(string(body))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if error := vals.Get("error"); error != "" || badStatusCode {
|
||||
err = newError(error, vals.Get("error_description"), vals.Get("state"))
|
||||
return
|
||||
}
|
||||
e := vals.Get("expires_in")
|
||||
if e == "" {
|
||||
e = vals.Get("expires")
|
||||
}
|
||||
if e != "" {
|
||||
result.Expires, err = strconv.Atoi(e)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
result.AccessToken = vals.Get("access_token")
|
||||
result.TokenType = vals.Get("token_type")
|
||||
result.IDToken = vals.Get("id_token")
|
||||
result.RefreshToken = vals.Get("refresh_token")
|
||||
result.Scope = vals.Get("scope")
|
||||
} else {
|
||||
var r struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
IDToken string `json:"id_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Scope string `json:"scope"`
|
||||
State string `json:"state"`
|
||||
ExpiresIn json.Number `json:"expires_in"` // Azure AD returns string
|
||||
Expires int `json:"expires"`
|
||||
Error string `json:"error"`
|
||||
Desc string `json:"error_description"`
|
||||
}
|
||||
if err = json.Unmarshal(body, &r); err != nil {
|
||||
return
|
||||
}
|
||||
if r.Error != "" || badStatusCode {
|
||||
err = newError(r.Error, r.Desc, r.State)
|
||||
return
|
||||
}
|
||||
result.AccessToken = r.AccessToken
|
||||
result.TokenType = r.TokenType
|
||||
result.IDToken = r.IDToken
|
||||
result.RefreshToken = r.RefreshToken
|
||||
result.Scope = r.Scope
|
||||
if expiresIn, err := r.ExpiresIn.Int64(); err != nil {
|
||||
result.Expires = r.Expires
|
||||
} else {
|
||||
result.Expires = int(expiresIn)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
AccessToken string
|
||||
TokenType string
|
||||
Expires int
|
||||
IDToken string
|
||||
RefreshToken string // OPTIONAL.
|
||||
Scope string // OPTIONAL, if identical to the scope requested by the client, otherwise, REQUIRED.
|
||||
RawBody []byte // In case callers need some other non-standard info from the token response
|
||||
}
|
||||
|
||||
type AuthCodeRequest struct {
|
||||
ResponseType string
|
||||
ClientID string
|
||||
RedirectURL *url.URL
|
||||
Scope []string
|
||||
State string
|
||||
}
|
||||
|
||||
func ParseAuthCodeRequest(q url.Values) (AuthCodeRequest, error) {
|
||||
acr := AuthCodeRequest{
|
||||
ResponseType: q.Get("response_type"),
|
||||
ClientID: q.Get("client_id"),
|
||||
State: q.Get("state"),
|
||||
Scope: make([]string, 0),
|
||||
}
|
||||
|
||||
qs := strings.TrimSpace(q.Get("scope"))
|
||||
if qs != "" {
|
||||
acr.Scope = strings.Split(qs, " ")
|
||||
}
|
||||
|
||||
err := func() error {
|
||||
if acr.ClientID == "" {
|
||||
return NewError(ErrorInvalidRequest)
|
||||
}
|
||||
|
||||
redirectURL := q.Get("redirect_uri")
|
||||
if redirectURL != "" {
|
||||
ru, err := url.Parse(redirectURL)
|
||||
if err != nil {
|
||||
return NewError(ErrorInvalidRequest)
|
||||
}
|
||||
acr.RedirectURL = ru
|
||||
}
|
||||
|
||||
return nil
|
||||
}()
|
||||
|
||||
return acr, err
|
||||
}
|
||||
518
vendor/github.com/coreos/go-oidc/oauth2/oauth2_test.go
generated
vendored
518
vendor/github.com/coreos/go-oidc/oauth2/oauth2_test.go
generated
vendored
@@ -1,518 +0,0 @@
|
||||
package oauth2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
phttp "github.com/coreos/go-oidc/http"
|
||||
)
|
||||
|
||||
func TestResponseTypesEqual(t *testing.T) {
|
||||
tests := []struct {
|
||||
r1, r2 string
|
||||
want bool
|
||||
}{
|
||||
{"code", "code", true},
|
||||
{"id_token", "code", false},
|
||||
{"code token", "token code", true},
|
||||
{"code token", "code token", true},
|
||||
{"foo", "bar code", false},
|
||||
{"code token id_token", "token id_token code", true},
|
||||
{"code token id_token", "token id_token code zoo", false},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got1 := ResponseTypesEqual(tt.r1, tt.r2)
|
||||
got2 := ResponseTypesEqual(tt.r2, tt.r1)
|
||||
if got1 != got2 {
|
||||
t.Errorf("case %d: got different answers with different orders", i)
|
||||
}
|
||||
if tt.want != got1 {
|
||||
t.Errorf("case %d: want=%t, got=%t", i, tt.want, got1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAuthCodeRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
query url.Values
|
||||
wantACR AuthCodeRequest
|
||||
wantErr error
|
||||
}{
|
||||
// no redirect_uri
|
||||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"scope": []string{"foo bar baz"},
|
||||
"client_id": []string{"XXX"},
|
||||
"state": []string{"pants"},
|
||||
},
|
||||
wantACR: AuthCodeRequest{
|
||||
ResponseType: "code",
|
||||
ClientID: "XXX",
|
||||
Scope: []string{"foo", "bar", "baz"},
|
||||
State: "pants",
|
||||
RedirectURL: nil,
|
||||
},
|
||||
},
|
||||
|
||||
// with redirect_uri
|
||||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"https://127.0.0.1:5555/callback?foo=bar"},
|
||||
"scope": []string{"foo bar baz"},
|
||||
"client_id": []string{"XXX"},
|
||||
"state": []string{"pants"},
|
||||
},
|
||||
wantACR: AuthCodeRequest{
|
||||
ResponseType: "code",
|
||||
ClientID: "XXX",
|
||||
Scope: []string{"foo", "bar", "baz"},
|
||||
State: "pants",
|
||||
RedirectURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "127.0.0.1:5555",
|
||||
Path: "/callback",
|
||||
RawQuery: "foo=bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// unsupported response_type doesn't trigger error
|
||||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"token"},
|
||||
"redirect_uri": []string{"https://127.0.0.1:5555/callback?foo=bar"},
|
||||
"scope": []string{"foo bar baz"},
|
||||
"client_id": []string{"XXX"},
|
||||
"state": []string{"pants"},
|
||||
},
|
||||
wantACR: AuthCodeRequest{
|
||||
ResponseType: "token",
|
||||
ClientID: "XXX",
|
||||
Scope: []string{"foo", "bar", "baz"},
|
||||
State: "pants",
|
||||
RedirectURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "127.0.0.1:5555",
|
||||
Path: "/callback",
|
||||
RawQuery: "foo=bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
// unparseable redirect_uri
|
||||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{":"},
|
||||
"scope": []string{"foo bar baz"},
|
||||
"client_id": []string{"XXX"},
|
||||
"state": []string{"pants"},
|
||||
},
|
||||
wantACR: AuthCodeRequest{
|
||||
ResponseType: "code",
|
||||
ClientID: "XXX",
|
||||
Scope: []string{"foo", "bar", "baz"},
|
||||
State: "pants",
|
||||
},
|
||||
wantErr: NewError(ErrorInvalidRequest),
|
||||
},
|
||||
|
||||
// no client_id, redirect_uri not parsed
|
||||
{
|
||||
query: url.Values{
|
||||
"response_type": []string{"code"},
|
||||
"redirect_uri": []string{"https://127.0.0.1:5555/callback?foo=bar"},
|
||||
"scope": []string{"foo bar baz"},
|
||||
"client_id": []string{},
|
||||
"state": []string{"pants"},
|
||||
},
|
||||
wantACR: AuthCodeRequest{
|
||||
ResponseType: "code",
|
||||
ClientID: "",
|
||||
Scope: []string{"foo", "bar", "baz"},
|
||||
State: "pants",
|
||||
RedirectURL: nil,
|
||||
},
|
||||
wantErr: NewError(ErrorInvalidRequest),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got, err := ParseAuthCodeRequest(tt.query)
|
||||
if !reflect.DeepEqual(tt.wantErr, err) {
|
||||
t.Errorf("case %d: incorrect error value: want=%q got=%q", i, tt.wantErr, err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.wantACR, got) {
|
||||
t.Errorf("case %d: incorrect AuthCodeRequest value: want=%#v got=%#v", i, tt.wantACR, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeBadClient struct {
|
||||
Request *http.Request
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeBadClient) Do(r *http.Request) (*http.Response, error) {
|
||||
f.Request = r
|
||||
return nil, f.err
|
||||
}
|
||||
|
||||
func TestClientCredsToken(t *testing.T) {
|
||||
hc := &fakeBadClient{nil, errors.New("error")}
|
||||
cfg := Config{
|
||||
Credentials: ClientCredentials{ID: "c#id", Secret: "c secret"},
|
||||
Scope: []string{"foo-scope", "bar-scope"},
|
||||
TokenURL: "http://example.com/token",
|
||||
AuthMethod: AuthMethodClientSecretBasic,
|
||||
RedirectURL: "http://example.com/redirect",
|
||||
AuthURL: "http://example.com/auth",
|
||||
}
|
||||
|
||||
c, err := NewClient(hc, cfg)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error %v", err)
|
||||
}
|
||||
|
||||
scope := []string{"openid"}
|
||||
c.ClientCredsToken(scope)
|
||||
if hc.Request == nil {
|
||||
t.Error("request is empty")
|
||||
}
|
||||
|
||||
tu := hc.Request.URL.String()
|
||||
if cfg.TokenURL != tu {
|
||||
t.Errorf("wrong token url, want=%v, got=%v", cfg.TokenURL, tu)
|
||||
}
|
||||
|
||||
ct := hc.Request.Header.Get("Content-Type")
|
||||
if ct != "application/x-www-form-urlencoded" {
|
||||
t.Errorf("wrong content-type, want=application/x-www-form-urlencoded, got=%v", ct)
|
||||
}
|
||||
|
||||
cid, secret, ok := phttp.BasicAuth(hc.Request)
|
||||
if !ok {
|
||||
t.Error("unexpected error parsing basic auth")
|
||||
}
|
||||
|
||||
if url.QueryEscape(cfg.Credentials.ID) != cid {
|
||||
t.Errorf("wrong client ID, want=%v, got=%v", cfg.Credentials.ID, cid)
|
||||
}
|
||||
|
||||
if url.QueryEscape(cfg.Credentials.Secret) != secret {
|
||||
t.Errorf("wrong client secret, want=%v, got=%v", cfg.Credentials.Secret, secret)
|
||||
}
|
||||
|
||||
err = hc.Request.ParseForm()
|
||||
if err != nil {
|
||||
t.Error("unexpected error parsing form")
|
||||
}
|
||||
|
||||
gt := hc.Request.PostForm.Get("grant_type")
|
||||
if gt != GrantTypeClientCreds {
|
||||
t.Errorf("wrong grant_type, want=%v, got=%v", GrantTypeClientCreds, gt)
|
||||
}
|
||||
|
||||
sc := strings.Split(hc.Request.PostForm.Get("scope"), " ")
|
||||
if !reflect.DeepEqual(scope, sc) {
|
||||
t.Errorf("wrong scope, want=%v, got=%v", scope, sc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserCredsToken(t *testing.T) {
|
||||
hc := &fakeBadClient{nil, errors.New("error")}
|
||||
cfg := Config{
|
||||
Credentials: ClientCredentials{ID: "c#id", Secret: "c secret"},
|
||||
Scope: []string{"foo-scope", "bar-scope"},
|
||||
TokenURL: "http://example.com/token",
|
||||
AuthMethod: AuthMethodClientSecretBasic,
|
||||
RedirectURL: "http://example.com/redirect",
|
||||
AuthURL: "http://example.com/auth",
|
||||
}
|
||||
|
||||
c, err := NewClient(hc, cfg)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error %v", err)
|
||||
}
|
||||
|
||||
c.UserCredsToken("username", "password")
|
||||
if hc.Request == nil {
|
||||
t.Error("request is empty")
|
||||
}
|
||||
|
||||
tu := hc.Request.URL.String()
|
||||
if cfg.TokenURL != tu {
|
||||
t.Errorf("wrong token url, want=%v, got=%v", cfg.TokenURL, tu)
|
||||
}
|
||||
|
||||
ct := hc.Request.Header.Get("Content-Type")
|
||||
if ct != "application/x-www-form-urlencoded" {
|
||||
t.Errorf("wrong content-type, want=application/x-www-form-urlencoded, got=%v", ct)
|
||||
}
|
||||
|
||||
cid, secret, ok := phttp.BasicAuth(hc.Request)
|
||||
if !ok {
|
||||
t.Error("unexpected error parsing basic auth")
|
||||
}
|
||||
|
||||
if url.QueryEscape(cfg.Credentials.ID) != cid {
|
||||
t.Errorf("wrong client ID, want=%v, got=%v", cfg.Credentials.ID, cid)
|
||||
}
|
||||
|
||||
if url.QueryEscape(cfg.Credentials.Secret) != secret {
|
||||
t.Errorf("wrong client secret, want=%v, got=%v", cfg.Credentials.Secret, secret)
|
||||
}
|
||||
|
||||
err = hc.Request.ParseForm()
|
||||
if err != nil {
|
||||
t.Error("unexpected error parsing form")
|
||||
}
|
||||
|
||||
gt := hc.Request.PostForm.Get("grant_type")
|
||||
if gt != GrantTypeUserCreds {
|
||||
t.Errorf("wrong grant_type, want=%v, got=%v", GrantTypeUserCreds, gt)
|
||||
}
|
||||
|
||||
sc := strings.Split(hc.Request.PostForm.Get("scope"), " ")
|
||||
if !reflect.DeepEqual(c.scope, sc) {
|
||||
t.Errorf("wrong scope, want=%v, got=%v", c.scope, sc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAuthenticatedRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
authMethod string
|
||||
url string
|
||||
values url.Values
|
||||
}{
|
||||
{
|
||||
authMethod: AuthMethodClientSecretBasic,
|
||||
url: "http://example.com/token",
|
||||
values: url.Values{},
|
||||
},
|
||||
{
|
||||
authMethod: AuthMethodClientSecretPost,
|
||||
url: "http://example.com/token",
|
||||
values: url.Values{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
cfg := Config{
|
||||
Credentials: ClientCredentials{ID: "c#id", Secret: "c secret"},
|
||||
Scope: []string{"foo-scope", "bar-scope"},
|
||||
TokenURL: "http://example.com/token",
|
||||
AuthURL: "http://example.com/auth",
|
||||
RedirectURL: "http://example.com/redirect",
|
||||
AuthMethod: tt.authMethod,
|
||||
}
|
||||
c, err := NewClient(nil, cfg)
|
||||
req, err := c.newAuthenticatedRequest(tt.url, tt.values)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
err = req.ParseForm()
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want nil err, got %v", i, err)
|
||||
}
|
||||
|
||||
if tt.authMethod == AuthMethodClientSecretBasic {
|
||||
cid, secret, ok := phttp.BasicAuth(req)
|
||||
if !ok {
|
||||
t.Errorf("case %d: !ok parsing Basic Auth headers", i)
|
||||
continue
|
||||
}
|
||||
if cid != url.QueryEscape(cfg.Credentials.ID) {
|
||||
t.Errorf("case %d: want CID == %q, got CID == %q", i, cfg.Credentials.ID, cid)
|
||||
}
|
||||
if secret != url.QueryEscape(cfg.Credentials.Secret) {
|
||||
t.Errorf("case %d: want secret == %q, got secret == %q", i, cfg.Credentials.Secret, secret)
|
||||
}
|
||||
} else if tt.authMethod == AuthMethodClientSecretPost {
|
||||
if req.PostFormValue("client_secret") != cfg.Credentials.Secret {
|
||||
t.Errorf("case %d: want client_secret == %q, got client_secret == %q",
|
||||
i, cfg.Credentials.Secret, req.PostFormValue("client_secret"))
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range tt.values {
|
||||
if !reflect.DeepEqual(v, req.PostForm[k]) {
|
||||
t.Errorf("case %d: key:%q want==%q, got==%q", i, k, v, req.PostForm[k])
|
||||
}
|
||||
}
|
||||
|
||||
if req.URL.String() != tt.url {
|
||||
t.Errorf("case %d: want URL==%q, got URL==%q", i, tt.url, req.URL.String())
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponse(t *testing.T) {
|
||||
type response struct {
|
||||
body string
|
||||
contentType string
|
||||
statusCode int // defaults to http.StatusOK
|
||||
}
|
||||
tests := []struct {
|
||||
resp response
|
||||
wantResp TokenResponse
|
||||
wantError *Error
|
||||
}{
|
||||
{
|
||||
resp: response{
|
||||
body: "{ \"error\": \"invalid_client\", \"state\": \"foo\" }",
|
||||
contentType: "application/json",
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
wantError: &Error{Type: "invalid_client", State: "foo"},
|
||||
},
|
||||
{
|
||||
resp: response{
|
||||
body: "{ \"error\": \"invalid_request\", \"state\": \"bar\" }",
|
||||
contentType: "application/json",
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
wantError: &Error{Type: "invalid_request", State: "bar"},
|
||||
},
|
||||
{
|
||||
// Actual response from bitbucket
|
||||
resp: response{
|
||||
body: `{"error_description": "Invalid OAuth client credentials", "error": "unauthorized_client"}`,
|
||||
contentType: "application/json",
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
wantError: &Error{Type: "unauthorized_client", Description: "Invalid OAuth client credentials"},
|
||||
},
|
||||
{
|
||||
// Actual response from github
|
||||
resp: response{
|
||||
body: `error=incorrect_client_credentials&error_description=The+client_id+and%2For+client_secret+passed+are+incorrect.&error_uri=https%3A%2F%2Fdeveloper.github.com%2Fv3%2Foauth%2F%23incorrect-client-credentials`,
|
||||
contentType: "application/x-www-form-urlencoded; charset=utf-8",
|
||||
},
|
||||
wantError: &Error{Type: "incorrect_client_credentials", Description: "The client_id and/or client_secret passed are incorrect."},
|
||||
},
|
||||
{
|
||||
resp: response{
|
||||
body: `{"access_token":"e72e16c7e42f292c6912e7710c838347ae178b4a", "scope":"repo,gist", "token_type":"bearer"}`,
|
||||
contentType: "application/json",
|
||||
},
|
||||
wantResp: TokenResponse{
|
||||
AccessToken: "e72e16c7e42f292c6912e7710c838347ae178b4a",
|
||||
TokenType: "bearer",
|
||||
Scope: "repo,gist",
|
||||
},
|
||||
},
|
||||
{
|
||||
resp: response{
|
||||
body: `access_token=e72e16c7e42f292c6912e7710c838347ae178b4a&scope=user%2Cgist&token_type=bearer`,
|
||||
contentType: "application/x-www-form-urlencoded",
|
||||
},
|
||||
wantResp: TokenResponse{
|
||||
AccessToken: "e72e16c7e42f292c6912e7710c838347ae178b4a",
|
||||
TokenType: "bearer",
|
||||
Scope: "user,gist",
|
||||
},
|
||||
},
|
||||
{
|
||||
resp: response{
|
||||
body: `{"access_token":"foo","id_token":"bar","expires_in":200,"token_type":"bearer","refresh_token":"spam"}`,
|
||||
contentType: "application/json; charset=utf-8",
|
||||
},
|
||||
wantResp: TokenResponse{
|
||||
AccessToken: "foo",
|
||||
IDToken: "bar",
|
||||
Expires: 200,
|
||||
TokenType: "bearer",
|
||||
RefreshToken: "spam",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Azure AD returns "expires_in" value as string
|
||||
resp: response{
|
||||
body: `{"access_token":"foo","id_token":"bar","expires_in":"300","token_type":"bearer","refresh_token":"spam"}`,
|
||||
contentType: "application/json; charset=utf-8",
|
||||
},
|
||||
wantResp: TokenResponse{
|
||||
AccessToken: "foo",
|
||||
IDToken: "bar",
|
||||
Expires: 300,
|
||||
TokenType: "bearer",
|
||||
RefreshToken: "spam",
|
||||
},
|
||||
},
|
||||
{
|
||||
resp: response{
|
||||
body: `{"access_token":"foo","id_token":"bar","expires":200,"token_type":"bearer","refresh_token":"spam"}`,
|
||||
contentType: "application/json; charset=utf-8",
|
||||
},
|
||||
wantResp: TokenResponse{
|
||||
AccessToken: "foo",
|
||||
IDToken: "bar",
|
||||
Expires: 200,
|
||||
TokenType: "bearer",
|
||||
RefreshToken: "spam",
|
||||
},
|
||||
},
|
||||
{
|
||||
resp: response{
|
||||
body: `access_token=foo&id_token=bar&expires_in=200&token_type=bearer&refresh_token=spam`,
|
||||
contentType: "application/x-www-form-urlencoded",
|
||||
},
|
||||
wantResp: TokenResponse{
|
||||
AccessToken: "foo",
|
||||
IDToken: "bar",
|
||||
Expires: 200,
|
||||
TokenType: "bearer",
|
||||
RefreshToken: "spam",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{tt.resp.contentType},
|
||||
"Content-Length": []string{strconv.Itoa(len([]byte(tt.resp.body)))},
|
||||
},
|
||||
Body: ioutil.NopCloser(strings.NewReader(tt.resp.body)),
|
||||
ContentLength: int64(len([]byte(tt.resp.body))),
|
||||
}
|
||||
if tt.resp.statusCode != 0 {
|
||||
r.StatusCode = tt.resp.statusCode
|
||||
}
|
||||
|
||||
result, err := parseTokenResponse(r)
|
||||
if err != nil {
|
||||
if tt.wantError == nil {
|
||||
t.Errorf("case %d: got error==%v", i, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(tt.wantError, err) {
|
||||
t.Errorf("case %d: want=%+v, got=%+v", i, tt.wantError, err)
|
||||
}
|
||||
} else {
|
||||
if tt.wantError != nil {
|
||||
t.Errorf("case %d: want error==%v, got==nil", i, tt.wantError)
|
||||
continue
|
||||
}
|
||||
// don't compare the raw body (it's really big and clogs error messages)
|
||||
result.RawBody = tt.wantResp.RawBody
|
||||
if !reflect.DeepEqual(tt.wantResp, result) {
|
||||
t.Errorf("case %d: want=%+v, got=%+v", i, tt.wantResp, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
846
vendor/github.com/coreos/go-oidc/oidc/client.go
generated
vendored
846
vendor/github.com/coreos/go-oidc/oidc/client.go
generated
vendored
@@ -1,846 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
phttp "github.com/coreos/go-oidc/http"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
// amount of time that must pass after the last key sync
|
||||
// completes before another attempt may begin
|
||||
keySyncWindow = 5 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultScope = []string{"openid", "email", "profile"}
|
||||
|
||||
supportedAuthMethods = map[string]struct{}{
|
||||
oauth2.AuthMethodClientSecretBasic: struct{}{},
|
||||
oauth2.AuthMethodClientSecretPost: struct{}{},
|
||||
}
|
||||
)
|
||||
|
||||
type ClientCredentials oauth2.ClientCredentials
|
||||
|
||||
type ClientIdentity struct {
|
||||
Credentials ClientCredentials
|
||||
Metadata ClientMetadata
|
||||
}
|
||||
|
||||
type JWAOptions struct {
|
||||
// SigningAlg specifies an JWA alg for signing JWTs.
|
||||
//
|
||||
// Specifying this field implies different actions depending on the context. It may
|
||||
// require objects be serialized and signed as a JWT instead of plain JSON, or
|
||||
// require an existing JWT object use the specified alg.
|
||||
//
|
||||
// See: http://openid.net/specs/openid-connect-registration-1_0.html#ClientMetadata
|
||||
SigningAlg string
|
||||
// EncryptionAlg, if provided, specifies that the returned or sent object be stored
|
||||
// (or nested) within a JWT object and encrypted with the provided JWA alg.
|
||||
EncryptionAlg string
|
||||
// EncryptionEnc specifies the JWA enc algorithm to use with EncryptionAlg. If
|
||||
// EncryptionAlg is provided and EncryptionEnc is omitted, this field defaults
|
||||
// to A128CBC-HS256.
|
||||
//
|
||||
// If EncryptionEnc is provided EncryptionAlg must also be specified.
|
||||
EncryptionEnc string
|
||||
}
|
||||
|
||||
func (opt JWAOptions) valid() error {
|
||||
if opt.EncryptionEnc != "" && opt.EncryptionAlg == "" {
|
||||
return errors.New("encryption encoding provided with no encryption algorithm")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (opt JWAOptions) defaults() JWAOptions {
|
||||
if opt.EncryptionAlg != "" && opt.EncryptionEnc == "" {
|
||||
opt.EncryptionEnc = jose.EncA128CBCHS256
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure ClientMetadata satisfies these interfaces.
|
||||
_ json.Marshaler = &ClientMetadata{}
|
||||
_ json.Unmarshaler = &ClientMetadata{}
|
||||
)
|
||||
|
||||
// ClientMetadata holds metadata that the authorization server associates
|
||||
// with a client identifier. The fields range from human-facing display
|
||||
// strings such as client name, to items that impact the security of the
|
||||
// protocol, such as the list of valid redirect URIs.
|
||||
//
|
||||
// See http://openid.net/specs/openid-connect-registration-1_0.html#ClientMetadata
|
||||
//
|
||||
// TODO: support language specific claim representations
|
||||
// http://openid.net/specs/openid-connect-registration-1_0.html#LanguagesAndScripts
|
||||
type ClientMetadata struct {
|
||||
RedirectURIs []url.URL // Required
|
||||
|
||||
// A list of OAuth 2.0 "response_type" values that the client wishes to restrict
|
||||
// itself to. Either "code", "token", or another registered extension.
|
||||
//
|
||||
// If omitted, only "code" will be used.
|
||||
ResponseTypes []string
|
||||
// A list of OAuth 2.0 grant types the client wishes to restrict itself to.
|
||||
// The grant type values used by OIDC are "authorization_code", "implicit",
|
||||
// and "refresh_token".
|
||||
//
|
||||
// If ommitted, only "authorization_code" will be used.
|
||||
GrantTypes []string
|
||||
// "native" or "web". If omitted, "web".
|
||||
ApplicationType string
|
||||
|
||||
// List of email addresses.
|
||||
Contacts []mail.Address
|
||||
// Name of client to be presented to the end-user.
|
||||
ClientName string
|
||||
// URL that references a logo for the Client application.
|
||||
LogoURI *url.URL
|
||||
// URL of the home page of the Client.
|
||||
ClientURI *url.URL
|
||||
// Profile data policies and terms of use to be provided to the end user.
|
||||
PolicyURI *url.URL
|
||||
TermsOfServiceURI *url.URL
|
||||
|
||||
// URL to or the value of the client's JSON Web Key Set document.
|
||||
JWKSURI *url.URL
|
||||
JWKS *jose.JWKSet
|
||||
|
||||
// URL referencing a flie with a single JSON array of redirect URIs.
|
||||
SectorIdentifierURI *url.URL
|
||||
|
||||
SubjectType string
|
||||
|
||||
// Options to restrict the JWS alg and enc values used for server responses and requests.
|
||||
IDTokenResponseOptions JWAOptions
|
||||
UserInfoResponseOptions JWAOptions
|
||||
RequestObjectOptions JWAOptions
|
||||
|
||||
// Client requested authorization method and signing options for the token endpoint.
|
||||
//
|
||||
// Defaults to "client_secret_basic"
|
||||
TokenEndpointAuthMethod string
|
||||
TokenEndpointAuthSigningAlg string
|
||||
|
||||
// DefaultMaxAge specifies the maximum amount of time in seconds before an authorized
|
||||
// user must reauthroize.
|
||||
//
|
||||
// If 0, no limitation is placed on the maximum.
|
||||
DefaultMaxAge int64
|
||||
// RequireAuthTime specifies if the auth_time claim in the ID token is required.
|
||||
RequireAuthTime bool
|
||||
|
||||
// Default Authentication Context Class Reference values for authentication requests.
|
||||
DefaultACRValues []string
|
||||
|
||||
// URI that a third party can use to initiate a login by the relaying party.
|
||||
//
|
||||
// See: http://openid.net/specs/openid-connect-core-1_0.html#ThirdPartyInitiatedLogin
|
||||
InitiateLoginURI *url.URL
|
||||
// Pre-registered request_uri values that may be cached by the server.
|
||||
RequestURIs []url.URL
|
||||
}
|
||||
|
||||
// Defaults returns a shallow copy of ClientMetadata with default
|
||||
// values replacing omitted fields.
|
||||
func (m ClientMetadata) Defaults() ClientMetadata {
|
||||
if len(m.ResponseTypes) == 0 {
|
||||
m.ResponseTypes = []string{oauth2.ResponseTypeCode}
|
||||
}
|
||||
if len(m.GrantTypes) == 0 {
|
||||
m.GrantTypes = []string{oauth2.GrantTypeAuthCode}
|
||||
}
|
||||
if m.ApplicationType == "" {
|
||||
m.ApplicationType = "web"
|
||||
}
|
||||
if m.TokenEndpointAuthMethod == "" {
|
||||
m.TokenEndpointAuthMethod = oauth2.AuthMethodClientSecretBasic
|
||||
}
|
||||
m.IDTokenResponseOptions = m.IDTokenResponseOptions.defaults()
|
||||
m.UserInfoResponseOptions = m.UserInfoResponseOptions.defaults()
|
||||
m.RequestObjectOptions = m.RequestObjectOptions.defaults()
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *ClientMetadata) MarshalJSON() ([]byte, error) {
|
||||
e := m.toEncodableStruct()
|
||||
return json.Marshal(&e)
|
||||
}
|
||||
|
||||
func (m *ClientMetadata) UnmarshalJSON(data []byte) error {
|
||||
var e encodableClientMetadata
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
meta, err := e.toStruct()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := meta.Valid(); err != nil {
|
||||
return err
|
||||
}
|
||||
*m = meta
|
||||
return nil
|
||||
}
|
||||
|
||||
type encodableClientMetadata struct {
|
||||
RedirectURIs []string `json:"redirect_uris"` // Required
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
TermsOfServiceURI string `json:"tos_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS *jose.JWKSet `json:"jwks,omitempty"`
|
||||
SectorIdentifierURI string `json:"sector_identifier_uri,omitempty"`
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
IDTokenSignedResponseAlg string `json:"id_token_signed_response_alg,omitempty"`
|
||||
IDTokenEncryptedResponseAlg string `json:"id_token_encrypted_response_alg,omitempty"`
|
||||
IDTokenEncryptedResponseEnc string `json:"id_token_encrypted_response_enc,omitempty"`
|
||||
UserInfoSignedResponseAlg string `json:"userinfo_signed_response_alg,omitempty"`
|
||||
UserInfoEncryptedResponseAlg string `json:"userinfo_encrypted_response_alg,omitempty"`
|
||||
UserInfoEncryptedResponseEnc string `json:"userinfo_encrypted_response_enc,omitempty"`
|
||||
RequestObjectSigningAlg string `json:"request_object_signing_alg,omitempty"`
|
||||
RequestObjectEncryptionAlg string `json:"request_object_encryption_alg,omitempty"`
|
||||
RequestObjectEncryptionEnc string `json:"request_object_encryption_enc,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg,omitempty"`
|
||||
DefaultMaxAge int64 `json:"default_max_age,omitempty"`
|
||||
RequireAuthTime bool `json:"require_auth_time,omitempty"`
|
||||
DefaultACRValues []string `json:"default_acr_values,omitempty"`
|
||||
InitiateLoginURI string `json:"initiate_login_uri,omitempty"`
|
||||
RequestURIs []string `json:"request_uris,omitempty"`
|
||||
}
|
||||
|
||||
func (c *encodableClientMetadata) toStruct() (ClientMetadata, error) {
|
||||
p := stickyErrParser{}
|
||||
m := ClientMetadata{
|
||||
RedirectURIs: p.parseURIs(c.RedirectURIs, "redirect_uris"),
|
||||
ResponseTypes: c.ResponseTypes,
|
||||
GrantTypes: c.GrantTypes,
|
||||
ApplicationType: c.ApplicationType,
|
||||
Contacts: p.parseEmails(c.Contacts, "contacts"),
|
||||
ClientName: c.ClientName,
|
||||
LogoURI: p.parseURI(c.LogoURI, "logo_uri"),
|
||||
ClientURI: p.parseURI(c.ClientURI, "client_uri"),
|
||||
PolicyURI: p.parseURI(c.PolicyURI, "policy_uri"),
|
||||
TermsOfServiceURI: p.parseURI(c.TermsOfServiceURI, "tos_uri"),
|
||||
JWKSURI: p.parseURI(c.JWKSURI, "jwks_uri"),
|
||||
JWKS: c.JWKS,
|
||||
SectorIdentifierURI: p.parseURI(c.SectorIdentifierURI, "sector_identifier_uri"),
|
||||
SubjectType: c.SubjectType,
|
||||
TokenEndpointAuthMethod: c.TokenEndpointAuthMethod,
|
||||
TokenEndpointAuthSigningAlg: c.TokenEndpointAuthSigningAlg,
|
||||
DefaultMaxAge: c.DefaultMaxAge,
|
||||
RequireAuthTime: c.RequireAuthTime,
|
||||
DefaultACRValues: c.DefaultACRValues,
|
||||
InitiateLoginURI: p.parseURI(c.InitiateLoginURI, "initiate_login_uri"),
|
||||
RequestURIs: p.parseURIs(c.RequestURIs, "request_uris"),
|
||||
IDTokenResponseOptions: JWAOptions{
|
||||
c.IDTokenSignedResponseAlg,
|
||||
c.IDTokenEncryptedResponseAlg,
|
||||
c.IDTokenEncryptedResponseEnc,
|
||||
},
|
||||
UserInfoResponseOptions: JWAOptions{
|
||||
c.UserInfoSignedResponseAlg,
|
||||
c.UserInfoEncryptedResponseAlg,
|
||||
c.UserInfoEncryptedResponseEnc,
|
||||
},
|
||||
RequestObjectOptions: JWAOptions{
|
||||
c.RequestObjectSigningAlg,
|
||||
c.RequestObjectEncryptionAlg,
|
||||
c.RequestObjectEncryptionEnc,
|
||||
},
|
||||
}
|
||||
if p.firstErr != nil {
|
||||
return ClientMetadata{}, p.firstErr
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// stickyErrParser parses URIs and email addresses. Once it encounters
|
||||
// a parse error, subsequent calls become no-op.
|
||||
type stickyErrParser struct {
|
||||
firstErr error
|
||||
}
|
||||
|
||||
func (p *stickyErrParser) parseURI(s, field string) *url.URL {
|
||||
if p.firstErr != nil || s == "" {
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(s)
|
||||
if err == nil {
|
||||
if u.Host == "" {
|
||||
err = errors.New("no host in URI")
|
||||
} else if u.Scheme != "http" && u.Scheme != "https" {
|
||||
err = errors.New("invalid URI scheme")
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
p.firstErr = fmt.Errorf("failed to parse %s: %v", field, err)
|
||||
return nil
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func (p *stickyErrParser) parseURIs(s []string, field string) []url.URL {
|
||||
if p.firstErr != nil || len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
uris := make([]url.URL, len(s))
|
||||
for i, val := range s {
|
||||
if val == "" {
|
||||
p.firstErr = fmt.Errorf("invalid URI in field %s", field)
|
||||
return nil
|
||||
}
|
||||
if u := p.parseURI(val, field); u != nil {
|
||||
uris[i] = *u
|
||||
}
|
||||
}
|
||||
return uris
|
||||
}
|
||||
|
||||
func (p *stickyErrParser) parseEmails(s []string, field string) []mail.Address {
|
||||
if p.firstErr != nil || len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
addrs := make([]mail.Address, len(s))
|
||||
for i, addr := range s {
|
||||
if addr == "" {
|
||||
p.firstErr = fmt.Errorf("invalid email in field %s", field)
|
||||
return nil
|
||||
}
|
||||
a, err := mail.ParseAddress(addr)
|
||||
if err != nil {
|
||||
p.firstErr = fmt.Errorf("invalid email in field %s: %v", field, err)
|
||||
return nil
|
||||
}
|
||||
addrs[i] = *a
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
||||
func (m *ClientMetadata) toEncodableStruct() encodableClientMetadata {
|
||||
return encodableClientMetadata{
|
||||
RedirectURIs: urisToStrings(m.RedirectURIs),
|
||||
ResponseTypes: m.ResponseTypes,
|
||||
GrantTypes: m.GrantTypes,
|
||||
ApplicationType: m.ApplicationType,
|
||||
Contacts: emailsToStrings(m.Contacts),
|
||||
ClientName: m.ClientName,
|
||||
LogoURI: uriToString(m.LogoURI),
|
||||
ClientURI: uriToString(m.ClientURI),
|
||||
PolicyURI: uriToString(m.PolicyURI),
|
||||
TermsOfServiceURI: uriToString(m.TermsOfServiceURI),
|
||||
JWKSURI: uriToString(m.JWKSURI),
|
||||
JWKS: m.JWKS,
|
||||
SectorIdentifierURI: uriToString(m.SectorIdentifierURI),
|
||||
SubjectType: m.SubjectType,
|
||||
IDTokenSignedResponseAlg: m.IDTokenResponseOptions.SigningAlg,
|
||||
IDTokenEncryptedResponseAlg: m.IDTokenResponseOptions.EncryptionAlg,
|
||||
IDTokenEncryptedResponseEnc: m.IDTokenResponseOptions.EncryptionEnc,
|
||||
UserInfoSignedResponseAlg: m.UserInfoResponseOptions.SigningAlg,
|
||||
UserInfoEncryptedResponseAlg: m.UserInfoResponseOptions.EncryptionAlg,
|
||||
UserInfoEncryptedResponseEnc: m.UserInfoResponseOptions.EncryptionEnc,
|
||||
RequestObjectSigningAlg: m.RequestObjectOptions.SigningAlg,
|
||||
RequestObjectEncryptionAlg: m.RequestObjectOptions.EncryptionAlg,
|
||||
RequestObjectEncryptionEnc: m.RequestObjectOptions.EncryptionEnc,
|
||||
TokenEndpointAuthMethod: m.TokenEndpointAuthMethod,
|
||||
TokenEndpointAuthSigningAlg: m.TokenEndpointAuthSigningAlg,
|
||||
DefaultMaxAge: m.DefaultMaxAge,
|
||||
RequireAuthTime: m.RequireAuthTime,
|
||||
DefaultACRValues: m.DefaultACRValues,
|
||||
InitiateLoginURI: uriToString(m.InitiateLoginURI),
|
||||
RequestURIs: urisToStrings(m.RequestURIs),
|
||||
}
|
||||
}
|
||||
|
||||
func uriToString(u *url.URL) string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func urisToStrings(urls []url.URL) []string {
|
||||
if len(urls) == 0 {
|
||||
return nil
|
||||
}
|
||||
sli := make([]string, len(urls))
|
||||
for i, u := range urls {
|
||||
sli[i] = u.String()
|
||||
}
|
||||
return sli
|
||||
}
|
||||
|
||||
func emailsToStrings(addrs []mail.Address) []string {
|
||||
if len(addrs) == 0 {
|
||||
return nil
|
||||
}
|
||||
sli := make([]string, len(addrs))
|
||||
for i, addr := range addrs {
|
||||
sli[i] = addr.String()
|
||||
}
|
||||
return sli
|
||||
}
|
||||
|
||||
// Valid determines if a ClientMetadata conforms with the OIDC specification.
|
||||
//
|
||||
// Valid is called by UnmarshalJSON.
|
||||
//
|
||||
// NOTE(ericchiang): For development purposes Valid does not mandate 'https' for
|
||||
// URLs fields where the OIDC spec requires it. This may change in future releases
|
||||
// of this package. See: https://github.com/coreos/go-oidc/issues/34
|
||||
func (m *ClientMetadata) Valid() error {
|
||||
if len(m.RedirectURIs) == 0 {
|
||||
return errors.New("zero redirect URLs")
|
||||
}
|
||||
|
||||
validURI := func(u *url.URL, fieldName string) error {
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("no host for uri field %s", fieldName)
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return fmt.Errorf("uri field %s scheme is not http or https", fieldName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
uris := []struct {
|
||||
val *url.URL
|
||||
name string
|
||||
}{
|
||||
{m.LogoURI, "logo_uri"},
|
||||
{m.ClientURI, "client_uri"},
|
||||
{m.PolicyURI, "policy_uri"},
|
||||
{m.TermsOfServiceURI, "tos_uri"},
|
||||
{m.JWKSURI, "jwks_uri"},
|
||||
{m.SectorIdentifierURI, "sector_identifier_uri"},
|
||||
{m.InitiateLoginURI, "initiate_login_uri"},
|
||||
}
|
||||
|
||||
for _, uri := range uris {
|
||||
if uri.val == nil {
|
||||
continue
|
||||
}
|
||||
if err := validURI(uri.val, uri.name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
uriLists := []struct {
|
||||
vals []url.URL
|
||||
name string
|
||||
}{
|
||||
{m.RedirectURIs, "redirect_uris"},
|
||||
{m.RequestURIs, "request_uris"},
|
||||
}
|
||||
for _, list := range uriLists {
|
||||
for _, uri := range list.vals {
|
||||
if err := validURI(&uri, list.name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
options := []struct {
|
||||
option JWAOptions
|
||||
name string
|
||||
}{
|
||||
{m.IDTokenResponseOptions, "id_token response"},
|
||||
{m.UserInfoResponseOptions, "userinfo response"},
|
||||
{m.RequestObjectOptions, "request_object"},
|
||||
}
|
||||
for _, option := range options {
|
||||
if err := option.option.valid(); err != nil {
|
||||
return fmt.Errorf("invalid JWA values for %s: %v", option.name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ClientRegistrationResponse struct {
|
||||
ClientID string // Required
|
||||
ClientSecret string
|
||||
RegistrationAccessToken string
|
||||
RegistrationClientURI string
|
||||
// If IsZero is true, unspecified.
|
||||
ClientIDIssuedAt time.Time
|
||||
// Time at which the client_secret will expire.
|
||||
// If IsZero is true, it will not expire.
|
||||
ClientSecretExpiresAt time.Time
|
||||
|
||||
ClientMetadata
|
||||
}
|
||||
|
||||
type encodableClientRegistrationResponse struct {
|
||||
ClientID string `json:"client_id"` // Required
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
// Time at which the client_secret will expire, in seconds since the epoch.
|
||||
// If 0 it will not expire.
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at"` // Required
|
||||
|
||||
encodableClientMetadata
|
||||
}
|
||||
|
||||
func unixToSec(t time.Time) int64 {
|
||||
if t.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func (c *ClientRegistrationResponse) MarshalJSON() ([]byte, error) {
|
||||
e := encodableClientRegistrationResponse{
|
||||
ClientID: c.ClientID,
|
||||
ClientSecret: c.ClientSecret,
|
||||
RegistrationAccessToken: c.RegistrationAccessToken,
|
||||
RegistrationClientURI: c.RegistrationClientURI,
|
||||
ClientIDIssuedAt: unixToSec(c.ClientIDIssuedAt),
|
||||
ClientSecretExpiresAt: unixToSec(c.ClientSecretExpiresAt),
|
||||
encodableClientMetadata: c.ClientMetadata.toEncodableStruct(),
|
||||
}
|
||||
return json.Marshal(&e)
|
||||
}
|
||||
|
||||
func secToUnix(sec int64) time.Time {
|
||||
if sec == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Unix(sec, 0)
|
||||
}
|
||||
|
||||
func (c *ClientRegistrationResponse) UnmarshalJSON(data []byte) error {
|
||||
var e encodableClientRegistrationResponse
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
if e.ClientID == "" {
|
||||
return errors.New("no client_id in client registration response")
|
||||
}
|
||||
metadata, err := e.encodableClientMetadata.toStruct()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*c = ClientRegistrationResponse{
|
||||
ClientID: e.ClientID,
|
||||
ClientSecret: e.ClientSecret,
|
||||
RegistrationAccessToken: e.RegistrationAccessToken,
|
||||
RegistrationClientURI: e.RegistrationClientURI,
|
||||
ClientIDIssuedAt: secToUnix(e.ClientIDIssuedAt),
|
||||
ClientSecretExpiresAt: secToUnix(e.ClientSecretExpiresAt),
|
||||
ClientMetadata: metadata,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
HTTPClient phttp.Client
|
||||
Credentials ClientCredentials
|
||||
Scope []string
|
||||
RedirectURL string
|
||||
ProviderConfig ProviderConfig
|
||||
KeySet key.PublicKeySet
|
||||
}
|
||||
|
||||
func NewClient(cfg ClientConfig) (*Client, error) {
|
||||
// Allow empty redirect URL in the case where the client
|
||||
// only needs to verify a given token.
|
||||
ru, err := url.Parse(cfg.RedirectURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid redirect URL: %v", err)
|
||||
}
|
||||
|
||||
c := Client{
|
||||
credentials: cfg.Credentials,
|
||||
httpClient: cfg.HTTPClient,
|
||||
scope: cfg.Scope,
|
||||
redirectURL: ru.String(),
|
||||
providerConfig: newProviderConfigRepo(cfg.ProviderConfig),
|
||||
keySet: cfg.KeySet,
|
||||
}
|
||||
|
||||
if c.httpClient == nil {
|
||||
c.httpClient = http.DefaultClient
|
||||
}
|
||||
|
||||
if c.scope == nil {
|
||||
c.scope = make([]string, len(DefaultScope))
|
||||
copy(c.scope, DefaultScope)
|
||||
}
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
httpClient phttp.Client
|
||||
providerConfig *providerConfigRepo
|
||||
credentials ClientCredentials
|
||||
redirectURL string
|
||||
scope []string
|
||||
keySet key.PublicKeySet
|
||||
providerSyncer *ProviderConfigSyncer
|
||||
|
||||
keySetSyncMutex sync.RWMutex
|
||||
lastKeySetSync time.Time
|
||||
}
|
||||
|
||||
func (c *Client) Healthy() error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
cfg := c.providerConfig.Get()
|
||||
|
||||
if cfg.Empty() {
|
||||
return errors.New("oidc client provider config empty")
|
||||
}
|
||||
|
||||
if !cfg.ExpiresAt.IsZero() && cfg.ExpiresAt.Before(now) {
|
||||
return errors.New("oidc client provider config expired")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) OAuthClient() (*oauth2.Client, error) {
|
||||
cfg := c.providerConfig.Get()
|
||||
authMethod, err := chooseAuthMethod(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ocfg := oauth2.Config{
|
||||
Credentials: oauth2.ClientCredentials(c.credentials),
|
||||
RedirectURL: c.redirectURL,
|
||||
AuthURL: cfg.AuthEndpoint.String(),
|
||||
TokenURL: cfg.TokenEndpoint.String(),
|
||||
Scope: c.scope,
|
||||
AuthMethod: authMethod,
|
||||
}
|
||||
|
||||
return oauth2.NewClient(c.httpClient, ocfg)
|
||||
}
|
||||
|
||||
func chooseAuthMethod(cfg ProviderConfig) (string, error) {
|
||||
if len(cfg.TokenEndpointAuthMethodsSupported) == 0 {
|
||||
return oauth2.AuthMethodClientSecretBasic, nil
|
||||
}
|
||||
|
||||
for _, authMethod := range cfg.TokenEndpointAuthMethodsSupported {
|
||||
if _, ok := supportedAuthMethods[authMethod]; ok {
|
||||
return authMethod, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("no supported auth methods")
|
||||
}
|
||||
|
||||
// SyncProviderConfig starts the provider config syncer
|
||||
func (c *Client) SyncProviderConfig(discoveryURL string) chan struct{} {
|
||||
r := NewHTTPProviderConfigGetter(c.httpClient, discoveryURL)
|
||||
s := NewProviderConfigSyncer(r, c.providerConfig)
|
||||
stop := s.Run()
|
||||
s.WaitUntilInitialSync()
|
||||
return stop
|
||||
}
|
||||
|
||||
func (c *Client) maybeSyncKeys() error {
|
||||
tooSoon := func() bool {
|
||||
return time.Now().UTC().Before(c.lastKeySetSync.Add(keySyncWindow))
|
||||
}
|
||||
|
||||
// ignore request to sync keys if a sync operation has been
|
||||
// attempted too recently
|
||||
if tooSoon() {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.keySetSyncMutex.Lock()
|
||||
defer c.keySetSyncMutex.Unlock()
|
||||
|
||||
// check again, as another goroutine may have been holding
|
||||
// the lock while updating the keys
|
||||
if tooSoon() {
|
||||
return nil
|
||||
}
|
||||
|
||||
cfg := c.providerConfig.Get()
|
||||
r := NewRemotePublicKeyRepo(c.httpClient, cfg.KeysEndpoint.String())
|
||||
w := &clientKeyRepo{client: c}
|
||||
_, err := key.Sync(r, w)
|
||||
c.lastKeySetSync = time.Now().UTC()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type clientKeyRepo struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
func (r *clientKeyRepo) Set(ks key.KeySet) error {
|
||||
pks, ok := ks.(*key.PublicKeySet)
|
||||
if !ok {
|
||||
return errors.New("unable to cast to PublicKey")
|
||||
}
|
||||
r.client.keySet = *pks
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ClientCredsToken(scope []string) (jose.JWT, error) {
|
||||
cfg := c.providerConfig.Get()
|
||||
|
||||
if !cfg.SupportsGrantType(oauth2.GrantTypeClientCreds) {
|
||||
return jose.JWT{}, fmt.Errorf("%v grant type is not supported", oauth2.GrantTypeClientCreds)
|
||||
}
|
||||
|
||||
oac, err := c.OAuthClient()
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
t, err := oac.ClientCredsToken(scope)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
jwt, err := jose.ParseJWT(t.IDToken)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
return jwt, c.VerifyJWT(jwt)
|
||||
}
|
||||
|
||||
// ExchangeAuthCode exchanges an OAuth2 auth code for an OIDC JWT ID token.
|
||||
func (c *Client) ExchangeAuthCode(code string) (jose.JWT, error) {
|
||||
oac, err := c.OAuthClient()
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
t, err := oac.RequestToken(oauth2.GrantTypeAuthCode, code)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
jwt, err := jose.ParseJWT(t.IDToken)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
return jwt, c.VerifyJWT(jwt)
|
||||
}
|
||||
|
||||
// RefreshToken uses a refresh token to exchange for a new OIDC JWT ID Token.
|
||||
func (c *Client) RefreshToken(refreshToken string) (jose.JWT, error) {
|
||||
oac, err := c.OAuthClient()
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
t, err := oac.RequestToken(oauth2.GrantTypeRefreshToken, refreshToken)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
jwt, err := jose.ParseJWT(t.IDToken)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
return jwt, c.VerifyJWT(jwt)
|
||||
}
|
||||
|
||||
func (c *Client) VerifyJWT(jwt jose.JWT) error {
|
||||
var keysFunc func() []key.PublicKey
|
||||
if kID, ok := jwt.KeyID(); ok {
|
||||
keysFunc = c.keysFuncWithID(kID)
|
||||
} else {
|
||||
keysFunc = c.keysFuncAll()
|
||||
}
|
||||
|
||||
v := NewJWTVerifier(
|
||||
c.providerConfig.Get().Issuer.String(),
|
||||
c.credentials.ID,
|
||||
c.maybeSyncKeys, keysFunc)
|
||||
|
||||
return v.Verify(jwt)
|
||||
}
|
||||
|
||||
// keysFuncWithID returns a function that retrieves at most unexpired
|
||||
// public key from the Client that matches the provided ID
|
||||
func (c *Client) keysFuncWithID(kID string) func() []key.PublicKey {
|
||||
return func() []key.PublicKey {
|
||||
c.keySetSyncMutex.RLock()
|
||||
defer c.keySetSyncMutex.RUnlock()
|
||||
|
||||
if c.keySet.ExpiresAt().Before(time.Now()) {
|
||||
return []key.PublicKey{}
|
||||
}
|
||||
|
||||
k := c.keySet.Key(kID)
|
||||
if k == nil {
|
||||
return []key.PublicKey{}
|
||||
}
|
||||
|
||||
return []key.PublicKey{*k}
|
||||
}
|
||||
}
|
||||
|
||||
// keysFuncAll returns a function that retrieves all unexpired public
|
||||
// keys from the Client
|
||||
func (c *Client) keysFuncAll() func() []key.PublicKey {
|
||||
return func() []key.PublicKey {
|
||||
c.keySetSyncMutex.RLock()
|
||||
defer c.keySetSyncMutex.RUnlock()
|
||||
|
||||
if c.keySet.ExpiresAt().Before(time.Now()) {
|
||||
return []key.PublicKey{}
|
||||
}
|
||||
|
||||
return c.keySet.Keys()
|
||||
}
|
||||
}
|
||||
|
||||
type providerConfigRepo struct {
|
||||
mu sync.RWMutex
|
||||
config ProviderConfig // do not access directly, use Get()
|
||||
}
|
||||
|
||||
func newProviderConfigRepo(pc ProviderConfig) *providerConfigRepo {
|
||||
return &providerConfigRepo{sync.RWMutex{}, pc}
|
||||
}
|
||||
|
||||
// returns an error to implement ProviderConfigSetter
|
||||
func (r *providerConfigRepo) Set(cfg ProviderConfig) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.config = cfg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *providerConfigRepo) Get() ProviderConfig {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.config
|
||||
}
|
||||
81
vendor/github.com/coreos/go-oidc/oidc/client_race_test.go
generated
vendored
81
vendor/github.com/coreos/go-oidc/oidc/client_race_test.go
generated
vendored
@@ -1,81 +0,0 @@
|
||||
// This file contains tests which depend on the race detector being enabled.
|
||||
// +build race
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testProvider struct {
|
||||
baseURL *url.URL
|
||||
}
|
||||
|
||||
func (p *testProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != discoveryConfigPath {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
cfg := ProviderConfig{
|
||||
Issuer: p.baseURL,
|
||||
ExpiresAt: time.Now().Add(time.Second),
|
||||
}
|
||||
cfg = fillRequiredProviderFields(cfg)
|
||||
json.NewEncoder(w).Encode(&cfg)
|
||||
}
|
||||
|
||||
// This test fails by triggering the race detector, not by calling t.Error or t.Fatal.
|
||||
func TestProviderSyncRace(t *testing.T) {
|
||||
|
||||
prov := &testProvider{}
|
||||
|
||||
s := httptest.NewServer(prov)
|
||||
defer s.Close()
|
||||
u, err := url.Parse(s.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
prov.baseURL = u
|
||||
|
||||
prevValue := minimumProviderConfigSyncInterval
|
||||
defer func() { minimumProviderConfigSyncInterval = prevValue }()
|
||||
|
||||
// Reduce the sync interval to increase the write frequencey.
|
||||
minimumProviderConfigSyncInterval = 5 * time.Millisecond
|
||||
|
||||
cliCfg := ClientConfig{
|
||||
HTTPClient: http.DefaultClient,
|
||||
}
|
||||
cli, err := NewClient(cliCfg)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if !cli.providerConfig.Get().Empty() {
|
||||
t.Errorf("want c.ProviderConfig == nil, got c.ProviderConfig=%#v")
|
||||
}
|
||||
|
||||
// SyncProviderConfig beings a goroutine which writes to the client's provider config.
|
||||
c := cli.SyncProviderConfig(s.URL)
|
||||
if cli.providerConfig.Get().Empty() {
|
||||
t.Errorf("want c.ProviderConfig != nil")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// stop the background process
|
||||
c <- struct{}{}
|
||||
}()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
// Creating an OAuth client reads from the provider config.
|
||||
cli.OAuthClient()
|
||||
}
|
||||
}
|
||||
654
vendor/github.com/coreos/go-oidc/oidc/client_test.go
generated
vendored
654
vendor/github.com/coreos/go-oidc/oidc/client_test.go
generated
vendored
@@ -1,654 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
)
|
||||
|
||||
func TestNewClientScopeDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
c ClientConfig
|
||||
e []string
|
||||
}{
|
||||
{
|
||||
// No scope
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect"},
|
||||
e: DefaultScope,
|
||||
},
|
||||
{
|
||||
// Nil scope
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: nil},
|
||||
e: DefaultScope,
|
||||
},
|
||||
{
|
||||
// Empty scope
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{}},
|
||||
e: []string{},
|
||||
},
|
||||
{
|
||||
// Custom scope equal to default
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"openid", "email", "profile"}},
|
||||
e: DefaultScope,
|
||||
},
|
||||
{
|
||||
// Custom scope not including defaults
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"foo", "bar"}},
|
||||
e: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
// Custom scopes overlapping with defaults
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"openid", "foo"}},
|
||||
e: []string{"openid", "foo"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
c, err := NewClient(tt.c)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error from NewClient: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(tt.e, c.scope) {
|
||||
t.Errorf("case %d: want: %v, got: %v", i, tt.e, c.scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthy(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
tests := []struct {
|
||||
p ProviderConfig
|
||||
h bool
|
||||
}{
|
||||
// all ok
|
||||
{
|
||||
p: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "http", Host: "example.com"},
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
},
|
||||
h: true,
|
||||
},
|
||||
// zero-value ProviderConfig.ExpiresAt
|
||||
{
|
||||
p: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "http", Host: "example.com"},
|
||||
},
|
||||
h: true,
|
||||
},
|
||||
// expired ProviderConfig
|
||||
{
|
||||
p: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "http", Host: "example.com"},
|
||||
ExpiresAt: now.Add(time.Hour * -1),
|
||||
},
|
||||
h: false,
|
||||
},
|
||||
// empty ProviderConfig
|
||||
{
|
||||
p: ProviderConfig{},
|
||||
h: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
c := &Client{providerConfig: newProviderConfigRepo(tt.p)}
|
||||
err := c.Healthy()
|
||||
want := tt.h
|
||||
got := (err == nil)
|
||||
|
||||
if want != got {
|
||||
t.Errorf("case %d: want: healthy=%v, got: healhty=%v, err: %v", i, want, got, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientKeysFuncAll(t *testing.T) {
|
||||
priv1, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
|
||||
priv2, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
future := now.Add(time.Hour)
|
||||
past := now.Add(-1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
keySet *key.PublicKeySet
|
||||
want []key.PublicKey
|
||||
}{
|
||||
// two keys, non-expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
|
||||
want: []key.PublicKey{*key.NewPublicKey(priv2.JWK()), *key.NewPublicKey(priv1.JWK())},
|
||||
},
|
||||
|
||||
// no keys, non-expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{}, future),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
|
||||
// two keys, expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, past),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
|
||||
// no keys, expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{}, past),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
var c Client
|
||||
c.keySet = *tt.keySet
|
||||
keysFunc := c.keysFuncAll()
|
||||
got := keysFunc()
|
||||
if !reflect.DeepEqual(tt.want, got) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientKeysFuncWithID(t *testing.T) {
|
||||
priv1, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
|
||||
priv2, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
future := now.Add(time.Hour)
|
||||
past := now.Add(-1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
keySet *key.PublicKeySet
|
||||
argID string
|
||||
want []key.PublicKey
|
||||
}{
|
||||
// two keys, match, non-expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
|
||||
argID: priv2.ID(),
|
||||
want: []key.PublicKey{*key.NewPublicKey(priv2.JWK())},
|
||||
},
|
||||
|
||||
// two keys, no match, non-expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
|
||||
argID: "XXX",
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
|
||||
// no keys, no match, non-expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{}, future),
|
||||
argID: priv2.ID(),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
|
||||
// two keys, match, expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, past),
|
||||
argID: priv2.ID(),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
|
||||
// no keys, no match, expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{}, past),
|
||||
argID: priv2.ID(),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
var c Client
|
||||
c.keySet = *tt.keySet
|
||||
keysFunc := c.keysFuncWithID(tt.argID)
|
||||
got := keysFunc()
|
||||
if !reflect.DeepEqual(tt.want, got) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMetadataValid(t *testing.T) {
|
||||
tests := []ClientMetadata{
|
||||
// one RedirectURL
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}},
|
||||
},
|
||||
|
||||
// one RedirectURL w/ nonempty path
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: "example.com", Path: "/foo"}},
|
||||
},
|
||||
|
||||
// two RedirectURIs
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "foo.example.com"},
|
||||
url.URL{Scheme: "http", Host: "bar.example.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
if err := tt.Valid(); err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMetadataInvalid(t *testing.T) {
|
||||
tests := []ClientMetadata{
|
||||
// nil RedirectURls slice
|
||||
ClientMetadata{
|
||||
RedirectURIs: nil,
|
||||
},
|
||||
|
||||
// empty RedirectURIs slice
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{},
|
||||
},
|
||||
|
||||
// empty url.URL
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{}},
|
||||
},
|
||||
|
||||
// empty url.URL following OK item
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}, url.URL{}},
|
||||
},
|
||||
|
||||
// url.URL with empty Host
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: ""}},
|
||||
},
|
||||
|
||||
// url.URL with empty Scheme
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "", Host: "example.com"}},
|
||||
},
|
||||
|
||||
// url.URL with non-HTTP(S) Scheme
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "tcp", Host: "127.0.0.1"}},
|
||||
},
|
||||
|
||||
// EncryptionEnc without EncryptionAlg
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}},
|
||||
IDTokenResponseOptions: JWAOptions{
|
||||
EncryptionEnc: "A128CBC-HS256",
|
||||
},
|
||||
},
|
||||
|
||||
// List of URIs with one empty element
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}},
|
||||
RequestURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "example.com"},
|
||||
url.URL{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
if err := tt.Valid(); err == nil {
|
||||
t.Errorf("case %d: expected non-nil error", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChooseAuthMethod(t *testing.T) {
|
||||
tests := []struct {
|
||||
supported []string
|
||||
chosen string
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
supported: []string{},
|
||||
chosen: oauth2.AuthMethodClientSecretBasic,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretBasic},
|
||||
chosen: oauth2.AuthMethodClientSecretBasic,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretPost},
|
||||
chosen: oauth2.AuthMethodClientSecretPost,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretPost, oauth2.AuthMethodClientSecretBasic},
|
||||
chosen: oauth2.AuthMethodClientSecretPost,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretBasic, oauth2.AuthMethodClientSecretPost},
|
||||
chosen: oauth2.AuthMethodClientSecretBasic,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretJWT, oauth2.AuthMethodClientSecretPost},
|
||||
chosen: oauth2.AuthMethodClientSecretPost,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretJWT},
|
||||
chosen: "",
|
||||
err: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
cfg := ProviderConfig{
|
||||
TokenEndpointAuthMethodsSupported: tt.supported,
|
||||
}
|
||||
got, err := chooseAuthMethod(cfg)
|
||||
if tt.err {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: expected non-nil err", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if got != tt.chosen {
|
||||
t.Errorf("case %d: want=%q, got=%q", i, tt.chosen, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMetadataUnmarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
data string
|
||||
want ClientMetadata
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
`{"redirect_uris":["https://example.com"]}`,
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com"},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
// redirect_uris required
|
||||
`{}`,
|
||||
ClientMetadata{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
// must have at least one redirect_uris
|
||||
`{"redirect_uris":[]}`,
|
||||
ClientMetadata{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
`{"redirect_uris":["https://example.com"],"contacts":["Ms. Foo <foo@example.com>"]}`,
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com"},
|
||||
},
|
||||
Contacts: []mail.Address{
|
||||
{Name: "Ms. Foo", Address: "foo@example.com"},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
// invalid URI provided for field
|
||||
`{"redirect_uris":["https://example.com"],"logo_uri":"not a valid uri"}`,
|
||||
ClientMetadata{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
// logo_uri can't be a list
|
||||
`{"redirect_uris":["https://example.com"],"logo_uri":["https://example.com/logo"]}`,
|
||||
ClientMetadata{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
`{
|
||||
"redirect_uris":["https://example.com"],
|
||||
"userinfo_encrypted_response_alg":"RSA1_5",
|
||||
"userinfo_encrypted_response_enc":"A128CBC-HS256",
|
||||
"contacts": [
|
||||
"jane doe <jane.doe@example.com>", "john doe <john.doe@example.com>"
|
||||
]
|
||||
}`,
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com"},
|
||||
},
|
||||
UserInfoResponseOptions: JWAOptions{
|
||||
EncryptionAlg: "RSA1_5",
|
||||
EncryptionEnc: "A128CBC-HS256",
|
||||
},
|
||||
Contacts: []mail.Address{
|
||||
{Name: "jane doe", Address: "jane.doe@example.com"},
|
||||
{Name: "john doe", Address: "john.doe@example.com"},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
// If encrypted_response_enc is provided encrypted_response_alg must also be.
|
||||
`{
|
||||
"redirect_uris":["https://example.com"],
|
||||
"userinfo_encrypted_response_enc":"A128CBC-HS256"
|
||||
}`,
|
||||
ClientMetadata{},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
var got ClientMetadata
|
||||
if err := got.UnmarshalJSON([]byte(tt.data)); err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Errorf("case %d: unmarshal failed: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tt.wantErr {
|
||||
t.Errorf("case %d: expected unmarshal to produce error", i)
|
||||
continue
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(tt.want, got); diff != "" {
|
||||
t.Errorf("case %d: results not equal: %s", i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMetadataMarshal(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
metadata ClientMetadata
|
||||
want string
|
||||
}{
|
||||
{
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
`{"redirect_uris":["https://example.com/callback"]}`,
|
||||
},
|
||||
{
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com", Path: "/callback"},
|
||||
},
|
||||
RequestObjectOptions: JWAOptions{
|
||||
EncryptionAlg: "RSA1_5",
|
||||
EncryptionEnc: "A128CBC-HS256",
|
||||
},
|
||||
},
|
||||
`{"redirect_uris":["https://example.com/callback"],"request_object_encryption_alg":"RSA1_5","request_object_encryption_enc":"A128CBC-HS256"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got, err := json.Marshal(&tt.metadata)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to marshal metadata: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if string(got) != tt.want {
|
||||
t.Errorf("case %d: marshaled string did not match expected string", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMetadataMarshalRoundTrip(t *testing.T) {
|
||||
tests := []ClientMetadata{
|
||||
{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com", Path: "/callback"},
|
||||
},
|
||||
LogoURI: &url.URL{Scheme: "https", Host: "example.com", Path: "/logo"},
|
||||
RequestObjectOptions: JWAOptions{
|
||||
EncryptionAlg: "RSA1_5",
|
||||
EncryptionEnc: "A128CBC-HS256",
|
||||
},
|
||||
ApplicationType: "native",
|
||||
TokenEndpointAuthMethod: "client_secret_basic",
|
||||
},
|
||||
}
|
||||
|
||||
for i, want := range tests {
|
||||
data, err := json.Marshal(&want)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to marshal metadata: %v", i, err)
|
||||
continue
|
||||
}
|
||||
var got ClientMetadata
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Errorf("case %d: failed to unmarshal metadata: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if diff := pretty.Compare(want, got); diff != "" {
|
||||
t.Errorf("case %d: struct did not survive a marshaling round trip: %s", i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRegistrationResponseUnmarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
data string
|
||||
want ClientRegistrationResponse
|
||||
wantErr bool
|
||||
secretExpires bool
|
||||
}{
|
||||
{
|
||||
`{
|
||||
"client_id":"foo",
|
||||
"client_secret":"bar",
|
||||
"client_secret_expires_at": 1577858400,
|
||||
"redirect_uris":[
|
||||
"https://client.example.org/callback",
|
||||
"https://client.example.org/callback2"
|
||||
],
|
||||
"client_name":"my_example"
|
||||
}`,
|
||||
ClientRegistrationResponse{
|
||||
ClientID: "foo",
|
||||
ClientSecret: "bar",
|
||||
ClientSecretExpiresAt: time.Unix(1577858400, 0),
|
||||
ClientMetadata: ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "client.example.org", Path: "/callback"},
|
||||
{Scheme: "https", Host: "client.example.org", Path: "/callback2"},
|
||||
},
|
||||
ClientName: "my_example",
|
||||
},
|
||||
},
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
`{
|
||||
"client_id":"foo",
|
||||
"client_secret_expires_at": 0,
|
||||
"redirect_uris":[
|
||||
"https://client.example.org/callback",
|
||||
"https://client.example.org/callback2"
|
||||
],
|
||||
"client_name":"my_example"
|
||||
}`,
|
||||
ClientRegistrationResponse{
|
||||
ClientID: "foo",
|
||||
ClientMetadata: ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "client.example.org", Path: "/callback"},
|
||||
{Scheme: "https", Host: "client.example.org", Path: "/callback2"},
|
||||
},
|
||||
ClientName: "my_example",
|
||||
},
|
||||
},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
// no client id
|
||||
`{
|
||||
"client_secret_expires_at": 0,
|
||||
"redirect_uris":[
|
||||
"https://client.example.org/callback",
|
||||
"https://client.example.org/callback2"
|
||||
],
|
||||
"client_name":"my_example"
|
||||
}`,
|
||||
ClientRegistrationResponse{},
|
||||
true,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
var got ClientRegistrationResponse
|
||||
if err := json.Unmarshal([]byte(tt.data), &got); err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Errorf("case %d: unmarshal failed: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
t.Errorf("case %d: expected unmarshal to produce error", i)
|
||||
continue
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(tt.want, got); diff != "" {
|
||||
t.Errorf("case %d: results not equal: %s", i, diff)
|
||||
}
|
||||
if tt.secretExpires && got.ClientSecretExpiresAt.IsZero() {
|
||||
t.Errorf("case %d: expected client_secret to expire, but it doesn't", i)
|
||||
} else if !tt.secretExpires && !got.ClientSecretExpiresAt.IsZero() {
|
||||
t.Errorf("case %d: expected client_secret to not expire, but it does", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
2
vendor/github.com/coreos/go-oidc/oidc/doc.go
generated
vendored
2
vendor/github.com/coreos/go-oidc/oidc/doc.go
generated
vendored
@@ -1,2 +0,0 @@
|
||||
// Package oidc is DEPRECATED. Use github.com/coreos/go-oidc instead.
|
||||
package oidc
|
||||
44
vendor/github.com/coreos/go-oidc/oidc/identity.go
generated
vendored
44
vendor/github.com/coreos/go-oidc/oidc/identity.go
generated
vendored
@@ -1,44 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
type Identity struct {
|
||||
ID string
|
||||
Name string
|
||||
Email string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func IdentityFromClaims(claims jose.Claims) (*Identity, error) {
|
||||
if claims == nil {
|
||||
return nil, errors.New("nil claim set")
|
||||
}
|
||||
|
||||
var ident Identity
|
||||
var err error
|
||||
var ok bool
|
||||
|
||||
if ident.ID, ok, err = claims.StringClaim("sub"); err != nil {
|
||||
return nil, err
|
||||
} else if !ok {
|
||||
return nil, errors.New("missing required claim: sub")
|
||||
}
|
||||
|
||||
if ident.Email, _, err = claims.StringClaim("email"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exp, ok, err := claims.TimeClaim("exp")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
ident.ExpiresAt = exp
|
||||
}
|
||||
|
||||
return &ident, nil
|
||||
}
|
||||
113
vendor/github.com/coreos/go-oidc/oidc/identity_test.go
generated
vendored
113
vendor/github.com/coreos/go-oidc/oidc/identity_test.go
generated
vendored
@@ -1,113 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
func TestIdentityFromClaims(t *testing.T) {
|
||||
tests := []struct {
|
||||
claims jose.Claims
|
||||
want Identity
|
||||
}{
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"email": "elroy@example.com",
|
||||
"exp": float64(1.416935146e+09),
|
||||
},
|
||||
want: Identity{
|
||||
ID: "123850281",
|
||||
Name: "",
|
||||
Email: "elroy@example.com",
|
||||
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"exp": float64(1.416935146e+09),
|
||||
},
|
||||
want: Identity{
|
||||
ID: "123850281",
|
||||
Name: "",
|
||||
Email: "",
|
||||
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"email": "elroy@example.com",
|
||||
"exp": int64(1416935146),
|
||||
},
|
||||
want: Identity{
|
||||
ID: "123850281",
|
||||
Name: "",
|
||||
Email: "elroy@example.com",
|
||||
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"email": "elroy@example.com",
|
||||
},
|
||||
want: Identity{
|
||||
ID: "123850281",
|
||||
Name: "",
|
||||
Email: "elroy@example.com",
|
||||
ExpiresAt: time.Time{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got, err := IdentityFromClaims(tt.claims)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(tt.want, *got) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, *got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentityFromClaimsFail(t *testing.T) {
|
||||
tests := []jose.Claims{
|
||||
// sub incorrect type
|
||||
jose.Claims{
|
||||
"sub": 123,
|
||||
"name": "foo",
|
||||
"email": "elroy@example.com",
|
||||
},
|
||||
// email incorrect type
|
||||
jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"email": false,
|
||||
},
|
||||
// exp incorrect type
|
||||
jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"email": "elroy@example.com",
|
||||
"exp": "2014-11-25 18:05:46 +0000 UTC",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
_, err := IdentityFromClaims(tt)
|
||||
if err == nil {
|
||||
t.Errorf("case %d: expected non-nil error", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
3
vendor/github.com/coreos/go-oidc/oidc/interface.go
generated
vendored
3
vendor/github.com/coreos/go-oidc/oidc/interface.go
generated
vendored
@@ -1,3 +0,0 @@
|
||||
package oidc
|
||||
|
||||
type LoginFunc func(ident Identity, sessionKey string) (redirectURL string, err error)
|
||||
67
vendor/github.com/coreos/go-oidc/oidc/key.go
generated
vendored
67
vendor/github.com/coreos/go-oidc/oidc/key.go
generated
vendored
@@ -1,67 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
phttp "github.com/coreos/go-oidc/http"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
)
|
||||
|
||||
// DefaultPublicKeySetTTL is the default TTL set on the PublicKeySet if no
|
||||
// Cache-Control header is provided by the JWK Set document endpoint.
|
||||
const DefaultPublicKeySetTTL = 24 * time.Hour
|
||||
|
||||
// NewRemotePublicKeyRepo is responsible for fetching the JWK Set document.
|
||||
func NewRemotePublicKeyRepo(hc phttp.Client, ep string) *remotePublicKeyRepo {
|
||||
return &remotePublicKeyRepo{hc: hc, ep: ep}
|
||||
}
|
||||
|
||||
type remotePublicKeyRepo struct {
|
||||
hc phttp.Client
|
||||
ep string
|
||||
}
|
||||
|
||||
// Get returns a PublicKeySet fetched from the JWK Set document endpoint. A TTL
|
||||
// is set on the Key Set to avoid it having to be re-retrieved for every
|
||||
// encryption event. This TTL is typically controlled by the endpoint returning
|
||||
// a Cache-Control header, but defaults to 24 hours if no Cache-Control header
|
||||
// is found.
|
||||
func (r *remotePublicKeyRepo) Get() (key.KeySet, error) {
|
||||
req, err := http.NewRequest("GET", r.ep, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := r.hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var d struct {
|
||||
Keys []jose.JWK `json:"keys"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(d.Keys) == 0 {
|
||||
return nil, errors.New("zero keys in response")
|
||||
}
|
||||
|
||||
ttl, ok, err := phttp.Cacheable(resp.Header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
ttl = DefaultPublicKeySetTTL
|
||||
}
|
||||
|
||||
exp := time.Now().UTC().Add(ttl)
|
||||
ks := key.NewPublicKeySet(d.Keys, exp)
|
||||
return ks, nil
|
||||
}
|
||||
690
vendor/github.com/coreos/go-oidc/oidc/provider.go
generated
vendored
690
vendor/github.com/coreos/go-oidc/oidc/provider.go
generated
vendored
@@ -1,690 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/pkg/timeutil"
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
phttp "github.com/coreos/go-oidc/http"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
// Subject Identifier types defined by the OIDC spec. Specifies if the provider
|
||||
// should provide the same sub claim value to all clients (public) or a unique
|
||||
// value for each client (pairwise).
|
||||
//
|
||||
// See: http://openid.net/specs/openid-connect-core-1_0.html#SubjectIDTypes
|
||||
SubjectTypePublic = "public"
|
||||
SubjectTypePairwise = "pairwise"
|
||||
)
|
||||
|
||||
var (
|
||||
// Default values for omitted provider config fields.
|
||||
//
|
||||
// Use ProviderConfig's Defaults method to fill a provider config with these values.
|
||||
DefaultGrantTypesSupported = []string{oauth2.GrantTypeAuthCode, oauth2.GrantTypeImplicit}
|
||||
DefaultResponseModesSupported = []string{"query", "fragment"}
|
||||
DefaultTokenEndpointAuthMethodsSupported = []string{oauth2.AuthMethodClientSecretBasic}
|
||||
DefaultClaimTypesSupported = []string{"normal"}
|
||||
)
|
||||
|
||||
const (
|
||||
MaximumProviderConfigSyncInterval = 24 * time.Hour
|
||||
MinimumProviderConfigSyncInterval = time.Minute
|
||||
|
||||
discoveryConfigPath = "/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
// internally configurable for tests
|
||||
var minimumProviderConfigSyncInterval = MinimumProviderConfigSyncInterval
|
||||
|
||||
var (
|
||||
// Ensure ProviderConfig satisfies these interfaces.
|
||||
_ json.Marshaler = &ProviderConfig{}
|
||||
_ json.Unmarshaler = &ProviderConfig{}
|
||||
)
|
||||
|
||||
// ProviderConfig represents the OpenID Provider Metadata specifying what
|
||||
// configurations a provider supports.
|
||||
//
|
||||
// See: http://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||
type ProviderConfig struct {
|
||||
Issuer *url.URL // Required
|
||||
AuthEndpoint *url.URL // Required
|
||||
TokenEndpoint *url.URL // Required if grant types other than "implicit" are supported
|
||||
UserInfoEndpoint *url.URL
|
||||
KeysEndpoint *url.URL // Required
|
||||
RegistrationEndpoint *url.URL
|
||||
EndSessionEndpoint *url.URL
|
||||
CheckSessionIFrame *url.URL
|
||||
|
||||
// Servers MAY choose not to advertise some supported scope values even when this
|
||||
// parameter is used, although those defined in OpenID Core SHOULD be listed, if supported.
|
||||
ScopesSupported []string
|
||||
// OAuth2.0 response types supported.
|
||||
ResponseTypesSupported []string // Required
|
||||
// OAuth2.0 response modes supported.
|
||||
//
|
||||
// If omitted, defaults to DefaultResponseModesSupported.
|
||||
ResponseModesSupported []string
|
||||
// OAuth2.0 grant types supported.
|
||||
//
|
||||
// If omitted, defaults to DefaultGrantTypesSupported.
|
||||
GrantTypesSupported []string
|
||||
ACRValuesSupported []string
|
||||
// SubjectTypesSupported specifies strategies for providing values for the sub claim.
|
||||
SubjectTypesSupported []string // Required
|
||||
|
||||
// JWA signing and encryption algorith values supported for ID tokens.
|
||||
IDTokenSigningAlgValues []string // Required
|
||||
IDTokenEncryptionAlgValues []string
|
||||
IDTokenEncryptionEncValues []string
|
||||
|
||||
// JWA signing and encryption algorith values supported for user info responses.
|
||||
UserInfoSigningAlgValues []string
|
||||
UserInfoEncryptionAlgValues []string
|
||||
UserInfoEncryptionEncValues []string
|
||||
|
||||
// JWA signing and encryption algorith values supported for request objects.
|
||||
ReqObjSigningAlgValues []string
|
||||
ReqObjEncryptionAlgValues []string
|
||||
ReqObjEncryptionEncValues []string
|
||||
|
||||
TokenEndpointAuthMethodsSupported []string
|
||||
TokenEndpointAuthSigningAlgValuesSupported []string
|
||||
DisplayValuesSupported []string
|
||||
ClaimTypesSupported []string
|
||||
ClaimsSupported []string
|
||||
ServiceDocs *url.URL
|
||||
ClaimsLocalsSupported []string
|
||||
UILocalsSupported []string
|
||||
ClaimsParameterSupported bool
|
||||
RequestParameterSupported bool
|
||||
RequestURIParamaterSupported bool
|
||||
RequireRequestURIRegistration bool
|
||||
|
||||
Policy *url.URL
|
||||
TermsOfService *url.URL
|
||||
|
||||
// Not part of the OpenID Provider Metadata
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// Defaults returns a shallow copy of ProviderConfig with default
|
||||
// values replacing omitted fields.
|
||||
//
|
||||
// var cfg oidc.ProviderConfig
|
||||
// // Fill provider config with default values for omitted fields.
|
||||
// cfg = cfg.Defaults()
|
||||
//
|
||||
func (p ProviderConfig) Defaults() ProviderConfig {
|
||||
setDefault := func(val *[]string, defaultVal []string) {
|
||||
if len(*val) == 0 {
|
||||
*val = defaultVal
|
||||
}
|
||||
}
|
||||
setDefault(&p.GrantTypesSupported, DefaultGrantTypesSupported)
|
||||
setDefault(&p.ResponseModesSupported, DefaultResponseModesSupported)
|
||||
setDefault(&p.TokenEndpointAuthMethodsSupported, DefaultTokenEndpointAuthMethodsSupported)
|
||||
setDefault(&p.ClaimTypesSupported, DefaultClaimTypesSupported)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) MarshalJSON() ([]byte, error) {
|
||||
e := p.toEncodableStruct()
|
||||
return json.Marshal(&e)
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) UnmarshalJSON(data []byte) error {
|
||||
var e encodableProviderConfig
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
conf, err := e.toStruct()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conf.Valid(); err != nil {
|
||||
return err
|
||||
}
|
||||
*p = conf
|
||||
return nil
|
||||
}
|
||||
|
||||
type encodableProviderConfig struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"`
|
||||
KeysEndpoint string `json:"jwks_uri"`
|
||||
RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
|
||||
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
|
||||
CheckSessionIFrame string `json:"check_session_iframe,omitempty"`
|
||||
|
||||
// Use 'omitempty' for all slices as per OIDC spec:
|
||||
// "Claims that return multiple values are represented as JSON arrays.
|
||||
// Claims with zero elements MUST be omitted from the response."
|
||||
// http://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse
|
||||
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
|
||||
ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported,omitempty"`
|
||||
ACRValuesSupported []string `json:"acr_values_supported,omitempty"`
|
||||
SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
|
||||
|
||||
IDTokenSigningAlgValues []string `json:"id_token_signing_alg_values_supported,omitempty"`
|
||||
IDTokenEncryptionAlgValues []string `json:"id_token_encryption_alg_values_supported,omitempty"`
|
||||
IDTokenEncryptionEncValues []string `json:"id_token_encryption_enc_values_supported,omitempty"`
|
||||
UserInfoSigningAlgValues []string `json:"userinfo_signing_alg_values_supported,omitempty"`
|
||||
UserInfoEncryptionAlgValues []string `json:"userinfo_encryption_alg_values_supported,omitempty"`
|
||||
UserInfoEncryptionEncValues []string `json:"userinfo_encryption_enc_values_supported,omitempty"`
|
||||
ReqObjSigningAlgValues []string `json:"request_object_signing_alg_values_supported,omitempty"`
|
||||
ReqObjEncryptionAlgValues []string `json:"request_object_encryption_alg_values_supported,omitempty"`
|
||||
ReqObjEncryptionEncValues []string `json:"request_object_encryption_enc_values_supported,omitempty"`
|
||||
|
||||
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
|
||||
TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"`
|
||||
|
||||
DisplayValuesSupported []string `json:"display_values_supported,omitempty"`
|
||||
ClaimTypesSupported []string `json:"claim_types_supported,omitempty"`
|
||||
ClaimsSupported []string `json:"claims_supported,omitempty"`
|
||||
ServiceDocs string `json:"service_documentation,omitempty"`
|
||||
ClaimsLocalsSupported []string `json:"claims_locales_supported,omitempty"`
|
||||
UILocalsSupported []string `json:"ui_locales_supported,omitempty"`
|
||||
ClaimsParameterSupported bool `json:"claims_parameter_supported,omitempty"`
|
||||
RequestParameterSupported bool `json:"request_parameter_supported,omitempty"`
|
||||
RequestURIParamaterSupported bool `json:"request_uri_parameter_supported,omitempty"`
|
||||
RequireRequestURIRegistration bool `json:"require_request_uri_registration,omitempty"`
|
||||
|
||||
Policy string `json:"op_policy_uri,omitempty"`
|
||||
TermsOfService string `json:"op_tos_uri,omitempty"`
|
||||
}
|
||||
|
||||
func (cfg ProviderConfig) toEncodableStruct() encodableProviderConfig {
|
||||
return encodableProviderConfig{
|
||||
Issuer: uriToString(cfg.Issuer),
|
||||
AuthEndpoint: uriToString(cfg.AuthEndpoint),
|
||||
TokenEndpoint: uriToString(cfg.TokenEndpoint),
|
||||
UserInfoEndpoint: uriToString(cfg.UserInfoEndpoint),
|
||||
KeysEndpoint: uriToString(cfg.KeysEndpoint),
|
||||
RegistrationEndpoint: uriToString(cfg.RegistrationEndpoint),
|
||||
EndSessionEndpoint: uriToString(cfg.EndSessionEndpoint),
|
||||
CheckSessionIFrame: uriToString(cfg.CheckSessionIFrame),
|
||||
ScopesSupported: cfg.ScopesSupported,
|
||||
ResponseTypesSupported: cfg.ResponseTypesSupported,
|
||||
ResponseModesSupported: cfg.ResponseModesSupported,
|
||||
GrantTypesSupported: cfg.GrantTypesSupported,
|
||||
ACRValuesSupported: cfg.ACRValuesSupported,
|
||||
SubjectTypesSupported: cfg.SubjectTypesSupported,
|
||||
IDTokenSigningAlgValues: cfg.IDTokenSigningAlgValues,
|
||||
IDTokenEncryptionAlgValues: cfg.IDTokenEncryptionAlgValues,
|
||||
IDTokenEncryptionEncValues: cfg.IDTokenEncryptionEncValues,
|
||||
UserInfoSigningAlgValues: cfg.UserInfoSigningAlgValues,
|
||||
UserInfoEncryptionAlgValues: cfg.UserInfoEncryptionAlgValues,
|
||||
UserInfoEncryptionEncValues: cfg.UserInfoEncryptionEncValues,
|
||||
ReqObjSigningAlgValues: cfg.ReqObjSigningAlgValues,
|
||||
ReqObjEncryptionAlgValues: cfg.ReqObjEncryptionAlgValues,
|
||||
ReqObjEncryptionEncValues: cfg.ReqObjEncryptionEncValues,
|
||||
TokenEndpointAuthMethodsSupported: cfg.TokenEndpointAuthMethodsSupported,
|
||||
TokenEndpointAuthSigningAlgValuesSupported: cfg.TokenEndpointAuthSigningAlgValuesSupported,
|
||||
DisplayValuesSupported: cfg.DisplayValuesSupported,
|
||||
ClaimTypesSupported: cfg.ClaimTypesSupported,
|
||||
ClaimsSupported: cfg.ClaimsSupported,
|
||||
ServiceDocs: uriToString(cfg.ServiceDocs),
|
||||
ClaimsLocalsSupported: cfg.ClaimsLocalsSupported,
|
||||
UILocalsSupported: cfg.UILocalsSupported,
|
||||
ClaimsParameterSupported: cfg.ClaimsParameterSupported,
|
||||
RequestParameterSupported: cfg.RequestParameterSupported,
|
||||
RequestURIParamaterSupported: cfg.RequestURIParamaterSupported,
|
||||
RequireRequestURIRegistration: cfg.RequireRequestURIRegistration,
|
||||
Policy: uriToString(cfg.Policy),
|
||||
TermsOfService: uriToString(cfg.TermsOfService),
|
||||
}
|
||||
}
|
||||
|
||||
func (e encodableProviderConfig) toStruct() (ProviderConfig, error) {
|
||||
p := stickyErrParser{}
|
||||
conf := ProviderConfig{
|
||||
Issuer: p.parseURI(e.Issuer, "issuer"),
|
||||
AuthEndpoint: p.parseURI(e.AuthEndpoint, "authorization_endpoint"),
|
||||
TokenEndpoint: p.parseURI(e.TokenEndpoint, "token_endpoint"),
|
||||
UserInfoEndpoint: p.parseURI(e.UserInfoEndpoint, "userinfo_endpoint"),
|
||||
KeysEndpoint: p.parseURI(e.KeysEndpoint, "jwks_uri"),
|
||||
RegistrationEndpoint: p.parseURI(e.RegistrationEndpoint, "registration_endpoint"),
|
||||
EndSessionEndpoint: p.parseURI(e.EndSessionEndpoint, "end_session_endpoint"),
|
||||
CheckSessionIFrame: p.parseURI(e.CheckSessionIFrame, "check_session_iframe"),
|
||||
ScopesSupported: e.ScopesSupported,
|
||||
ResponseTypesSupported: e.ResponseTypesSupported,
|
||||
ResponseModesSupported: e.ResponseModesSupported,
|
||||
GrantTypesSupported: e.GrantTypesSupported,
|
||||
ACRValuesSupported: e.ACRValuesSupported,
|
||||
SubjectTypesSupported: e.SubjectTypesSupported,
|
||||
IDTokenSigningAlgValues: e.IDTokenSigningAlgValues,
|
||||
IDTokenEncryptionAlgValues: e.IDTokenEncryptionAlgValues,
|
||||
IDTokenEncryptionEncValues: e.IDTokenEncryptionEncValues,
|
||||
UserInfoSigningAlgValues: e.UserInfoSigningAlgValues,
|
||||
UserInfoEncryptionAlgValues: e.UserInfoEncryptionAlgValues,
|
||||
UserInfoEncryptionEncValues: e.UserInfoEncryptionEncValues,
|
||||
ReqObjSigningAlgValues: e.ReqObjSigningAlgValues,
|
||||
ReqObjEncryptionAlgValues: e.ReqObjEncryptionAlgValues,
|
||||
ReqObjEncryptionEncValues: e.ReqObjEncryptionEncValues,
|
||||
TokenEndpointAuthMethodsSupported: e.TokenEndpointAuthMethodsSupported,
|
||||
TokenEndpointAuthSigningAlgValuesSupported: e.TokenEndpointAuthSigningAlgValuesSupported,
|
||||
DisplayValuesSupported: e.DisplayValuesSupported,
|
||||
ClaimTypesSupported: e.ClaimTypesSupported,
|
||||
ClaimsSupported: e.ClaimsSupported,
|
||||
ServiceDocs: p.parseURI(e.ServiceDocs, "service_documentation"),
|
||||
ClaimsLocalsSupported: e.ClaimsLocalsSupported,
|
||||
UILocalsSupported: e.UILocalsSupported,
|
||||
ClaimsParameterSupported: e.ClaimsParameterSupported,
|
||||
RequestParameterSupported: e.RequestParameterSupported,
|
||||
RequestURIParamaterSupported: e.RequestURIParamaterSupported,
|
||||
RequireRequestURIRegistration: e.RequireRequestURIRegistration,
|
||||
Policy: p.parseURI(e.Policy, "op_policy-uri"),
|
||||
TermsOfService: p.parseURI(e.TermsOfService, "op_tos_uri"),
|
||||
}
|
||||
if p.firstErr != nil {
|
||||
return ProviderConfig{}, p.firstErr
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// Empty returns if a ProviderConfig holds no information.
|
||||
//
|
||||
// This case generally indicates a ProviderConfigGetter has experienced an error
|
||||
// and has nothing to report.
|
||||
func (p ProviderConfig) Empty() bool {
|
||||
return p.Issuer == nil
|
||||
}
|
||||
|
||||
func contains(sli []string, ele string) bool {
|
||||
for _, s := range sli {
|
||||
if s == ele {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Valid determines if a ProviderConfig conforms with the OIDC specification.
|
||||
// If Valid returns successfully it guarantees required field are non-nil and
|
||||
// URLs are well formed.
|
||||
//
|
||||
// Valid is called by UnmarshalJSON.
|
||||
//
|
||||
// NOTE(ericchiang): For development purposes Valid does not mandate 'https' for
|
||||
// URLs fields where the OIDC spec requires it. This may change in future releases
|
||||
// of this package. See: https://github.com/coreos/go-oidc/issues/34
|
||||
func (p ProviderConfig) Valid() error {
|
||||
grantTypes := p.GrantTypesSupported
|
||||
if len(grantTypes) == 0 {
|
||||
grantTypes = DefaultGrantTypesSupported
|
||||
}
|
||||
implicitOnly := true
|
||||
for _, grantType := range grantTypes {
|
||||
if grantType != oauth2.GrantTypeImplicit {
|
||||
implicitOnly = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(p.SubjectTypesSupported) == 0 {
|
||||
return errors.New("missing required field subject_types_supported")
|
||||
}
|
||||
if len(p.IDTokenSigningAlgValues) == 0 {
|
||||
return errors.New("missing required field id_token_signing_alg_values_supported")
|
||||
}
|
||||
|
||||
if len(p.ScopesSupported) != 0 && !contains(p.ScopesSupported, "openid") {
|
||||
return errors.New("scoped_supported must be unspecified or include 'openid'")
|
||||
}
|
||||
|
||||
if !contains(p.IDTokenSigningAlgValues, "RS256") {
|
||||
return errors.New("id_token_signing_alg_values_supported must include 'RS256'")
|
||||
}
|
||||
if contains(p.TokenEndpointAuthMethodsSupported, "none") {
|
||||
return errors.New("token_endpoint_auth_signing_alg_values_supported cannot include 'none'")
|
||||
}
|
||||
|
||||
uris := []struct {
|
||||
val *url.URL
|
||||
name string
|
||||
required bool
|
||||
}{
|
||||
{p.Issuer, "issuer", true},
|
||||
{p.AuthEndpoint, "authorization_endpoint", true},
|
||||
{p.TokenEndpoint, "token_endpoint", !implicitOnly},
|
||||
{p.UserInfoEndpoint, "userinfo_endpoint", false},
|
||||
{p.KeysEndpoint, "jwks_uri", true},
|
||||
{p.RegistrationEndpoint, "registration_endpoint", false},
|
||||
{p.EndSessionEndpoint, "end_session_endpoint", false},
|
||||
{p.CheckSessionIFrame, "check_session_iframe", false},
|
||||
{p.ServiceDocs, "service_documentation", false},
|
||||
{p.Policy, "op_policy_uri", false},
|
||||
{p.TermsOfService, "op_tos_uri", false},
|
||||
}
|
||||
|
||||
for _, uri := range uris {
|
||||
if uri.val == nil {
|
||||
if !uri.required {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("empty value for required uri field %s", uri.name)
|
||||
}
|
||||
if uri.val.Host == "" {
|
||||
return fmt.Errorf("no host for uri field %s", uri.name)
|
||||
}
|
||||
if uri.val.Scheme != "http" && uri.val.Scheme != "https" {
|
||||
return fmt.Errorf("uri field %s schemeis not http or https", uri.name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Supports determines if provider supports a client given their respective metadata.
|
||||
func (p ProviderConfig) Supports(c ClientMetadata) error {
|
||||
if err := p.Valid(); err != nil {
|
||||
return fmt.Errorf("invalid provider config: %v", err)
|
||||
}
|
||||
if err := c.Valid(); err != nil {
|
||||
return fmt.Errorf("invalid client config: %v", err)
|
||||
}
|
||||
|
||||
// Fill default values for omitted fields
|
||||
c = c.Defaults()
|
||||
p = p.Defaults()
|
||||
|
||||
// Do the supported values list the requested one?
|
||||
supports := []struct {
|
||||
supported []string
|
||||
requested string
|
||||
name string
|
||||
}{
|
||||
{p.IDTokenSigningAlgValues, c.IDTokenResponseOptions.SigningAlg, "id_token_signed_response_alg"},
|
||||
{p.IDTokenEncryptionAlgValues, c.IDTokenResponseOptions.EncryptionAlg, "id_token_encryption_response_alg"},
|
||||
{p.IDTokenEncryptionEncValues, c.IDTokenResponseOptions.EncryptionEnc, "id_token_encryption_response_enc"},
|
||||
{p.UserInfoSigningAlgValues, c.UserInfoResponseOptions.SigningAlg, "userinfo_signed_response_alg"},
|
||||
{p.UserInfoEncryptionAlgValues, c.UserInfoResponseOptions.EncryptionAlg, "userinfo_encryption_response_alg"},
|
||||
{p.UserInfoEncryptionEncValues, c.UserInfoResponseOptions.EncryptionEnc, "userinfo_encryption_response_enc"},
|
||||
{p.ReqObjSigningAlgValues, c.RequestObjectOptions.SigningAlg, "request_object_signing_alg"},
|
||||
{p.ReqObjEncryptionAlgValues, c.RequestObjectOptions.EncryptionAlg, "request_object_encryption_alg"},
|
||||
{p.ReqObjEncryptionEncValues, c.RequestObjectOptions.EncryptionEnc, "request_object_encryption_enc"},
|
||||
}
|
||||
for _, field := range supports {
|
||||
if field.requested == "" {
|
||||
continue
|
||||
}
|
||||
if !contains(field.supported, field.requested) {
|
||||
return fmt.Errorf("provider does not support requested value for field %s", field.name)
|
||||
}
|
||||
}
|
||||
|
||||
stringsEqual := func(s1, s2 string) bool { return s1 == s2 }
|
||||
|
||||
// For lists, are the list of requested values a subset of the supported ones?
|
||||
supportsAll := []struct {
|
||||
supported []string
|
||||
requested []string
|
||||
name string
|
||||
// OAuth2.0 response_type can be space separated lists where order doesn't matter.
|
||||
// For example "id_token token" is the same as "token id_token"
|
||||
// Support a custom compare method.
|
||||
comp func(s1, s2 string) bool
|
||||
}{
|
||||
{p.GrantTypesSupported, c.GrantTypes, "grant_types", stringsEqual},
|
||||
{p.ResponseTypesSupported, c.ResponseTypes, "response_type", oauth2.ResponseTypesEqual},
|
||||
}
|
||||
for _, field := range supportsAll {
|
||||
requestLoop:
|
||||
for _, req := range field.requested {
|
||||
for _, sup := range field.supported {
|
||||
if field.comp(req, sup) {
|
||||
continue requestLoop
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("provider does not support requested value for field %s", field.name)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(ericchiang): Are there more checks we feel comfortable with begin strict about?
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p ProviderConfig) SupportsGrantType(grantType string) bool {
|
||||
var supported []string
|
||||
if len(p.GrantTypesSupported) == 0 {
|
||||
supported = DefaultGrantTypesSupported
|
||||
} else {
|
||||
supported = p.GrantTypesSupported
|
||||
}
|
||||
|
||||
for _, t := range supported {
|
||||
if t == grantType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type ProviderConfigGetter interface {
|
||||
Get() (ProviderConfig, error)
|
||||
}
|
||||
|
||||
type ProviderConfigSetter interface {
|
||||
Set(ProviderConfig) error
|
||||
}
|
||||
|
||||
type ProviderConfigSyncer struct {
|
||||
from ProviderConfigGetter
|
||||
to ProviderConfigSetter
|
||||
clock clockwork.Clock
|
||||
|
||||
initialSyncDone bool
|
||||
initialSyncWait sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewProviderConfigSyncer(from ProviderConfigGetter, to ProviderConfigSetter) *ProviderConfigSyncer {
|
||||
return &ProviderConfigSyncer{
|
||||
from: from,
|
||||
to: to,
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProviderConfigSyncer) Run() chan struct{} {
|
||||
stop := make(chan struct{})
|
||||
|
||||
var next pcsStepper
|
||||
next = &pcsStepNext{aft: time.Duration(0)}
|
||||
|
||||
s.initialSyncWait.Add(1)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.clock.After(next.after()):
|
||||
next = next.step(s.sync)
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return stop
|
||||
}
|
||||
|
||||
func (s *ProviderConfigSyncer) WaitUntilInitialSync() {
|
||||
s.initialSyncWait.Wait()
|
||||
}
|
||||
|
||||
func (s *ProviderConfigSyncer) sync() (time.Duration, error) {
|
||||
cfg, err := s.from.Get()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err = s.to.Set(cfg); err != nil {
|
||||
return 0, fmt.Errorf("error setting provider config: %v", err)
|
||||
}
|
||||
|
||||
if !s.initialSyncDone {
|
||||
s.initialSyncWait.Done()
|
||||
s.initialSyncDone = true
|
||||
}
|
||||
|
||||
return nextSyncAfter(cfg.ExpiresAt, s.clock), nil
|
||||
}
|
||||
|
||||
type pcsStepFunc func() (time.Duration, error)
|
||||
|
||||
type pcsStepper interface {
|
||||
after() time.Duration
|
||||
step(pcsStepFunc) pcsStepper
|
||||
}
|
||||
|
||||
type pcsStepNext struct {
|
||||
aft time.Duration
|
||||
}
|
||||
|
||||
func (n *pcsStepNext) after() time.Duration {
|
||||
return n.aft
|
||||
}
|
||||
|
||||
func (n *pcsStepNext) step(fn pcsStepFunc) (next pcsStepper) {
|
||||
ttl, err := fn()
|
||||
if err == nil {
|
||||
next = &pcsStepNext{aft: ttl}
|
||||
} else {
|
||||
next = &pcsStepRetry{aft: time.Second}
|
||||
log.Printf("go-oidc: provider config sync failed, retrying in %v: %v", next.after(), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type pcsStepRetry struct {
|
||||
aft time.Duration
|
||||
}
|
||||
|
||||
func (r *pcsStepRetry) after() time.Duration {
|
||||
return r.aft
|
||||
}
|
||||
|
||||
func (r *pcsStepRetry) step(fn pcsStepFunc) (next pcsStepper) {
|
||||
ttl, err := fn()
|
||||
if err == nil {
|
||||
next = &pcsStepNext{aft: ttl}
|
||||
} else {
|
||||
next = &pcsStepRetry{aft: timeutil.ExpBackoff(r.aft, time.Minute)}
|
||||
log.Printf("go-oidc: provider config sync failed, retrying in %v: %v", next.after(), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func nextSyncAfter(exp time.Time, clock clockwork.Clock) time.Duration {
|
||||
if exp.IsZero() {
|
||||
return MaximumProviderConfigSyncInterval
|
||||
}
|
||||
|
||||
t := exp.Sub(clock.Now()) / 2
|
||||
if t > MaximumProviderConfigSyncInterval {
|
||||
t = MaximumProviderConfigSyncInterval
|
||||
} else if t < minimumProviderConfigSyncInterval {
|
||||
t = minimumProviderConfigSyncInterval
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type httpProviderConfigGetter struct {
|
||||
hc phttp.Client
|
||||
issuerURL string
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func NewHTTPProviderConfigGetter(hc phttp.Client, issuerURL string) *httpProviderConfigGetter {
|
||||
return &httpProviderConfigGetter{
|
||||
hc: hc,
|
||||
issuerURL: issuerURL,
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *httpProviderConfigGetter) Get() (cfg ProviderConfig, err error) {
|
||||
// If the Issuer value contains a path component, any terminating / MUST be removed before
|
||||
// appending /.well-known/openid-configuration.
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationRequest
|
||||
discoveryURL := strings.TrimSuffix(r.issuerURL, "/") + discoveryConfigPath
|
||||
req, err := http.NewRequest("GET", discoveryURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.hc.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if err = json.NewDecoder(resp.Body).Decode(&cfg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var ttl time.Duration
|
||||
var ok bool
|
||||
ttl, ok, err = phttp.Cacheable(resp.Header)
|
||||
if err != nil {
|
||||
return
|
||||
} else if ok {
|
||||
cfg.ExpiresAt = r.clock.Now().UTC().Add(ttl)
|
||||
}
|
||||
|
||||
// The issuer value returned MUST be identical to the Issuer URL that was directly used to retrieve the configuration information.
|
||||
// http://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationValidation
|
||||
if !urlEqual(cfg.Issuer.String(), r.issuerURL) {
|
||||
err = fmt.Errorf(`"issuer" in config (%v) does not match provided issuer URL (%v)`, cfg.Issuer, r.issuerURL)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func FetchProviderConfig(hc phttp.Client, issuerURL string) (ProviderConfig, error) {
|
||||
if hc == nil {
|
||||
hc = http.DefaultClient
|
||||
}
|
||||
|
||||
g := NewHTTPProviderConfigGetter(hc, issuerURL)
|
||||
return g.Get()
|
||||
}
|
||||
|
||||
func WaitForProviderConfig(hc phttp.Client, issuerURL string) (pcfg ProviderConfig) {
|
||||
return waitForProviderConfig(hc, issuerURL, clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
func waitForProviderConfig(hc phttp.Client, issuerURL string, clock clockwork.Clock) (pcfg ProviderConfig) {
|
||||
var sleep time.Duration
|
||||
var err error
|
||||
for {
|
||||
pcfg, err = FetchProviderConfig(hc, issuerURL)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
sleep = timeutil.ExpBackoff(sleep, time.Minute)
|
||||
fmt.Printf("Failed fetching provider config, trying again in %v: %v\n", sleep, err)
|
||||
time.Sleep(sleep)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
940
vendor/github.com/coreos/go-oidc/oidc/provider_test.go
generated
vendored
940
vendor/github.com/coreos/go-oidc/oidc/provider_test.go
generated
vendored
@@ -1,940 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/kylelemons/godebug/diff"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
)
|
||||
|
||||
func TestProviderConfigDefaults(t *testing.T) {
|
||||
var cfg ProviderConfig
|
||||
cfg = cfg.Defaults()
|
||||
tests := []struct {
|
||||
got, want []string
|
||||
name string
|
||||
}{
|
||||
{cfg.GrantTypesSupported, DefaultGrantTypesSupported, "grant types"},
|
||||
{cfg.ResponseModesSupported, DefaultResponseModesSupported, "response modes"},
|
||||
{cfg.ClaimTypesSupported, DefaultClaimTypesSupported, "claim types"},
|
||||
{
|
||||
cfg.TokenEndpointAuthMethodsSupported,
|
||||
DefaultTokenEndpointAuthMethodsSupported,
|
||||
"token endpoint auth methods",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if diff := pretty.Compare(tt.want, tt.got); diff != "" {
|
||||
t.Errorf("%s: did not match %s", tt.name, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigUnmarshal(t *testing.T) {
|
||||
|
||||
// helper for quickly creating uris
|
||||
uri := func(path string) *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "server.example.com",
|
||||
Path: path,
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
data string
|
||||
want ProviderConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
data: `{
|
||||
"issuer": "https://server.example.com",
|
||||
"authorization_endpoint": "https://server.example.com/connect/authorize",
|
||||
"token_endpoint": "https://server.example.com/connect/token",
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"],
|
||||
"token_endpoint_auth_signing_alg_values_supported": ["RS256", "ES256"],
|
||||
"userinfo_endpoint": "https://server.example.com/connect/userinfo",
|
||||
"jwks_uri": "https://server.example.com/jwks.json",
|
||||
"registration_endpoint": "https://server.example.com/connect/register",
|
||||
"scopes_supported": [
|
||||
"openid", "profile", "email", "address", "phone", "offline_access"
|
||||
],
|
||||
"response_types_supported": [
|
||||
"code", "code id_token", "id_token", "id_token token"
|
||||
],
|
||||
"acr_values_supported": [
|
||||
"urn:mace:incommon:iap:silver", "urn:mace:incommon:iap:bronze"
|
||||
],
|
||||
"subject_types_supported": ["public", "pairwise"],
|
||||
"userinfo_signing_alg_values_supported": ["RS256", "ES256", "HS256"],
|
||||
"userinfo_encryption_alg_values_supported": ["RSA1_5", "A128KW"],
|
||||
"userinfo_encryption_enc_values_supported": ["A128CBC-HS256", "A128GCM"],
|
||||
"id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"],
|
||||
"id_token_encryption_alg_values_supported": ["RSA1_5", "A128KW"],
|
||||
"id_token_encryption_enc_values_supported": ["A128CBC-HS256", "A128GCM"],
|
||||
"request_object_signing_alg_values_supported": ["none", "RS256", "ES256"],
|
||||
"display_values_supported": ["page", "popup"],
|
||||
"claim_types_supported": ["normal", "distributed"],
|
||||
"claims_supported": [
|
||||
"sub", "iss", "auth_time", "acr", "name", "given_name",
|
||||
"family_name", "nickname", "profile", "picture", "website",
|
||||
"email", "email_verified", "locale", "zoneinfo",
|
||||
"http://example.info/claims/groups"
|
||||
],
|
||||
"claims_parameter_supported": true,
|
||||
"service_documentation": "https://server.example.com/connect/service_documentation.html",
|
||||
"ui_locales_supported": ["en-US", "en-GB", "en-CA", "fr-FR", "fr-CA"]
|
||||
}
|
||||
`,
|
||||
want: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "server.example.com"},
|
||||
AuthEndpoint: uri("/connect/authorize"),
|
||||
TokenEndpoint: uri("/connect/token"),
|
||||
TokenEndpointAuthMethodsSupported: []string{
|
||||
oauth2.AuthMethodClientSecretBasic, oauth2.AuthMethodPrivateKeyJWT,
|
||||
},
|
||||
TokenEndpointAuthSigningAlgValuesSupported: []string{
|
||||
jose.AlgRS256, jose.AlgES256,
|
||||
},
|
||||
UserInfoEndpoint: uri("/connect/userinfo"),
|
||||
KeysEndpoint: uri("/jwks.json"),
|
||||
RegistrationEndpoint: uri("/connect/register"),
|
||||
ScopesSupported: []string{
|
||||
"openid", "profile", "email", "address", "phone", "offline_access",
|
||||
},
|
||||
ResponseTypesSupported: []string{
|
||||
oauth2.ResponseTypeCode, oauth2.ResponseTypeCodeIDToken,
|
||||
oauth2.ResponseTypeIDToken, oauth2.ResponseTypeIDTokenToken,
|
||||
},
|
||||
ACRValuesSupported: []string{
|
||||
"urn:mace:incommon:iap:silver", "urn:mace:incommon:iap:bronze",
|
||||
},
|
||||
SubjectTypesSupported: []string{
|
||||
SubjectTypePublic, SubjectTypePairwise,
|
||||
},
|
||||
UserInfoSigningAlgValues: []string{jose.AlgRS256, jose.AlgES256, jose.AlgHS256},
|
||||
UserInfoEncryptionAlgValues: []string{"RSA1_5", "A128KW"},
|
||||
UserInfoEncryptionEncValues: []string{"A128CBC-HS256", "A128GCM"},
|
||||
IDTokenSigningAlgValues: []string{jose.AlgRS256, jose.AlgES256, jose.AlgHS256},
|
||||
IDTokenEncryptionAlgValues: []string{"RSA1_5", "A128KW"},
|
||||
IDTokenEncryptionEncValues: []string{"A128CBC-HS256", "A128GCM"},
|
||||
ReqObjSigningAlgValues: []string{jose.AlgNone, jose.AlgRS256, jose.AlgES256},
|
||||
DisplayValuesSupported: []string{"page", "popup"},
|
||||
ClaimTypesSupported: []string{"normal", "distributed"},
|
||||
ClaimsSupported: []string{
|
||||
"sub", "iss", "auth_time", "acr", "name", "given_name",
|
||||
"family_name", "nickname", "profile", "picture", "website",
|
||||
"email", "email_verified", "locale", "zoneinfo",
|
||||
"http://example.info/claims/groups",
|
||||
},
|
||||
ClaimsParameterSupported: true,
|
||||
ServiceDocs: uri("/connect/service_documentation.html"),
|
||||
UILocalsSupported: []string{"en-US", "en-GB", "en-CA", "fr-FR", "fr-CA"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
// missing a lot of required field
|
||||
data: `{}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
data: `{
|
||||
"issuer": "https://server.example.com",
|
||||
"authorization_endpoint": "https://server.example.com/connect/authorize",
|
||||
"token_endpoint": "https://server.example.com/connect/token",
|
||||
"jwks_uri": "https://server.example.com/jwks.json",
|
||||
"response_types_supported": [
|
||||
"code", "code id_token", "id_token", "id_token token"
|
||||
],
|
||||
"subject_types_supported": ["public", "pairwise"],
|
||||
"id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"]
|
||||
}
|
||||
`,
|
||||
want: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "server.example.com"},
|
||||
AuthEndpoint: uri("/connect/authorize"),
|
||||
TokenEndpoint: uri("/connect/token"),
|
||||
KeysEndpoint: uri("/jwks.json"),
|
||||
ResponseTypesSupported: []string{
|
||||
oauth2.ResponseTypeCode, oauth2.ResponseTypeCodeIDToken,
|
||||
oauth2.ResponseTypeIDToken, oauth2.ResponseTypeIDTokenToken,
|
||||
},
|
||||
SubjectTypesSupported: []string{
|
||||
SubjectTypePublic, SubjectTypePairwise,
|
||||
},
|
||||
IDTokenSigningAlgValues: []string{jose.AlgRS256, jose.AlgES256, jose.AlgHS256},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
// invalid scheme 'ftp://'
|
||||
data: `{
|
||||
"issuer": "https://server.example.com",
|
||||
"authorization_endpoint": "https://server.example.com/connect/authorize",
|
||||
"token_endpoint": "https://server.example.com/connect/token",
|
||||
"jwks_uri": "ftp://server.example.com/jwks.json",
|
||||
"response_types_supported": [
|
||||
"code", "code id_token", "id_token", "id_token token"
|
||||
],
|
||||
"subject_types_supported": ["public", "pairwise"],
|
||||
"id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"]
|
||||
}
|
||||
`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
var got ProviderConfig
|
||||
if err := json.Unmarshal([]byte(tt.data), &got); err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Errorf("case %d: failed to unmarshal provider config: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tt.wantErr {
|
||||
t.Errorf("case %d: expected error", i)
|
||||
continue
|
||||
}
|
||||
if diff := pretty.Compare(tt.want, got); diff != "" {
|
||||
t.Errorf("case %d: unmarshaled struct did not match expected %s", i, diff)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestProviderConfigMarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
cfg ProviderConfig
|
||||
want string
|
||||
}{
|
||||
{
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "auth.example.com"},
|
||||
AuthEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/auth",
|
||||
},
|
||||
TokenEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/token",
|
||||
},
|
||||
UserInfoEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/userinfo",
|
||||
},
|
||||
KeysEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/jwk",
|
||||
},
|
||||
ResponseTypesSupported: []string{oauth2.ResponseTypeCode},
|
||||
SubjectTypesSupported: []string{SubjectTypePublic},
|
||||
IDTokenSigningAlgValues: []string{jose.AlgRS256},
|
||||
},
|
||||
// spacing must match json.MarshalIndent(cfg, "", "\t")
|
||||
want: `{
|
||||
"issuer": "https://auth.example.com",
|
||||
"authorization_endpoint": "https://auth.example.com/auth",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"userinfo_endpoint": "https://auth.example.com/userinfo",
|
||||
"jwks_uri": "https://auth.example.com/jwk",
|
||||
"response_types_supported": [
|
||||
"code"
|
||||
],
|
||||
"subject_types_supported": [
|
||||
"public"
|
||||
],
|
||||
"id_token_signing_alg_values_supported": [
|
||||
"RS256"
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "auth.example.com"},
|
||||
AuthEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/auth",
|
||||
},
|
||||
TokenEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/token",
|
||||
},
|
||||
UserInfoEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/userinfo",
|
||||
},
|
||||
KeysEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/jwk",
|
||||
},
|
||||
RegistrationEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/register",
|
||||
},
|
||||
ScopesSupported: DefaultScope,
|
||||
ResponseTypesSupported: []string{oauth2.ResponseTypeCode},
|
||||
ResponseModesSupported: DefaultResponseModesSupported,
|
||||
GrantTypesSupported: []string{oauth2.GrantTypeAuthCode},
|
||||
SubjectTypesSupported: []string{SubjectTypePublic},
|
||||
IDTokenSigningAlgValues: []string{jose.AlgRS256},
|
||||
ServiceDocs: &url.URL{Scheme: "https", Host: "example.com", Path: "/docs"},
|
||||
},
|
||||
// spacing must match json.MarshalIndent(cfg, "", "\t")
|
||||
want: `{
|
||||
"issuer": "https://auth.example.com",
|
||||
"authorization_endpoint": "https://auth.example.com/auth",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"userinfo_endpoint": "https://auth.example.com/userinfo",
|
||||
"jwks_uri": "https://auth.example.com/jwk",
|
||||
"registration_endpoint": "https://auth.example.com/register",
|
||||
"scopes_supported": [
|
||||
"openid",
|
||||
"email",
|
||||
"profile"
|
||||
],
|
||||
"response_types_supported": [
|
||||
"code"
|
||||
],
|
||||
"response_modes_supported": [
|
||||
"query",
|
||||
"fragment"
|
||||
],
|
||||
"grant_types_supported": [
|
||||
"authorization_code"
|
||||
],
|
||||
"subject_types_supported": [
|
||||
"public"
|
||||
],
|
||||
"id_token_signing_alg_values_supported": [
|
||||
"RS256"
|
||||
],
|
||||
"service_documentation": "https://example.com/docs"
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got, err := json.MarshalIndent(&tt.cfg, "", "\t")
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to marshal config: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if d := diff.Diff(string(got), string(tt.want)); d != "" {
|
||||
t.Errorf("case %d: expected did not match result: %s", i, d)
|
||||
}
|
||||
|
||||
var cfg ProviderConfig
|
||||
if err := json.Unmarshal(got, &cfg); err != nil {
|
||||
t.Errorf("case %d: could not unmarshal marshal response: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if d := pretty.Compare(tt.cfg, cfg); d != "" {
|
||||
t.Errorf("case %d: config did not survive JSON marshaling round trip: %s", i, d)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestProviderConfigSupports(t *testing.T) {
|
||||
tests := []struct {
|
||||
provider ProviderConfig
|
||||
client ClientMetadata
|
||||
fillRequiredProviderFields bool
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
provider: ProviderConfig{},
|
||||
client: ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
fillRequiredProviderFields: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
// invalid provider config
|
||||
provider: ProviderConfig{},
|
||||
client: ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
fillRequiredProviderFields: false,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
// invalid client config
|
||||
provider: ProviderConfig{},
|
||||
client: ClientMetadata{},
|
||||
fillRequiredProviderFields: true,
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
if tt.fillRequiredProviderFields {
|
||||
tt.provider = fillRequiredProviderFields(tt.provider)
|
||||
}
|
||||
|
||||
err := tt.provider.Supports(tt.client)
|
||||
if err == nil && !tt.ok {
|
||||
t.Errorf("case %d: expected non-nil error", i)
|
||||
}
|
||||
if err != nil && tt.ok {
|
||||
t.Errorf("case %d: supports failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newValidProviderConfig() ProviderConfig {
|
||||
var cfg ProviderConfig
|
||||
return fillRequiredProviderFields(cfg)
|
||||
}
|
||||
|
||||
// fill a provider config with enough information to be valid
|
||||
func fillRequiredProviderFields(cfg ProviderConfig) ProviderConfig {
|
||||
if cfg.Issuer == nil {
|
||||
cfg.Issuer = &url.URL{Scheme: "https", Host: "auth.example.com"}
|
||||
}
|
||||
urlPath := func(path string) *url.URL {
|
||||
var u url.URL
|
||||
u = *cfg.Issuer
|
||||
u.Path = path
|
||||
return &u
|
||||
}
|
||||
cfg.AuthEndpoint = urlPath("/auth")
|
||||
cfg.TokenEndpoint = urlPath("/token")
|
||||
cfg.UserInfoEndpoint = urlPath("/userinfo")
|
||||
cfg.KeysEndpoint = urlPath("/jwk")
|
||||
cfg.ResponseTypesSupported = []string{oauth2.ResponseTypeCode}
|
||||
cfg.SubjectTypesSupported = []string{SubjectTypePublic}
|
||||
cfg.IDTokenSigningAlgValues = []string{jose.AlgRS256}
|
||||
return cfg
|
||||
}
|
||||
|
||||
type fakeProviderConfigGetterSetter struct {
|
||||
cfg *ProviderConfig
|
||||
getCount int
|
||||
setCount int
|
||||
}
|
||||
|
||||
func (g *fakeProviderConfigGetterSetter) Get() (ProviderConfig, error) {
|
||||
g.getCount++
|
||||
return *g.cfg, nil
|
||||
}
|
||||
|
||||
func (g *fakeProviderConfigGetterSetter) Set(cfg ProviderConfig) error {
|
||||
g.cfg = &cfg
|
||||
g.setCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeProviderConfigHandler struct {
|
||||
cfg ProviderConfig
|
||||
maxAge time.Duration
|
||||
}
|
||||
|
||||
func (s *fakeProviderConfigHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
b, _ := json.Marshal(&s.cfg)
|
||||
if s.maxAge.Seconds() >= 0 {
|
||||
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(s.maxAge.Seconds())))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
func TestProviderConfigRequiredFields(t *testing.T) {
|
||||
// Ensure provider metadata responses have all the required fields.
|
||||
// taken from https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||
requiredFields := []string{
|
||||
"issuer",
|
||||
"authorization_endpoint",
|
||||
"token_endpoint", // "This is REQUIRED unless only the Implicit Flow is used."
|
||||
"jwks_uri",
|
||||
"response_types_supported",
|
||||
"subject_types_supported",
|
||||
"id_token_signing_alg_values_supported",
|
||||
}
|
||||
|
||||
svr := &fakeProviderConfigHandler{
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "http", Host: "example.com"},
|
||||
ExpiresAt: time.Now().Add(time.Minute),
|
||||
},
|
||||
maxAge: time.Minute,
|
||||
}
|
||||
svr.cfg = fillRequiredProviderFields(svr.cfg)
|
||||
s := httptest.NewServer(svr)
|
||||
defer s.Close()
|
||||
|
||||
resp, err := http.Get(s.URL + "/")
|
||||
if err != nil {
|
||||
t.Errorf("get: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var data map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
t.Errorf("decode: %v", err)
|
||||
return
|
||||
}
|
||||
for _, field := range requiredFields {
|
||||
if _, ok := data[field]; !ok {
|
||||
t.Errorf("provider metadata does not have required field '%s'", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type handlerClient struct {
|
||||
Handler http.Handler
|
||||
}
|
||||
|
||||
func (hc *handlerClient) Do(r *http.Request) (*http.Response, error) {
|
||||
w := httptest.NewRecorder()
|
||||
hc.Handler.ServeHTTP(w, r)
|
||||
|
||||
resp := http.Response{
|
||||
StatusCode: w.Code,
|
||||
Header: w.Header(),
|
||||
Body: ioutil.NopCloser(w.Body),
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func TestHTTPProviderConfigGetter(t *testing.T) {
|
||||
svr := &fakeProviderConfigHandler{}
|
||||
hc := &handlerClient{Handler: svr}
|
||||
fc := clockwork.NewFakeClock()
|
||||
now := fc.Now().UTC()
|
||||
|
||||
tests := []struct {
|
||||
dsc string
|
||||
age time.Duration
|
||||
cfg ProviderConfig
|
||||
ok bool
|
||||
}{
|
||||
// everything is good
|
||||
{
|
||||
dsc: "https://example.com",
|
||||
age: time.Minute,
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
ExpiresAt: now.Add(time.Minute),
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
// iss and disco url differ by scheme only (how google works)
|
||||
{
|
||||
dsc: "https://example.com",
|
||||
age: time.Minute,
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
ExpiresAt: now.Add(time.Minute),
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
// issuer and discovery URL mismatch
|
||||
{
|
||||
dsc: "https://foo.com",
|
||||
age: time.Minute,
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
ExpiresAt: now.Add(time.Minute),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// missing cache header results in zero ExpiresAt
|
||||
{
|
||||
dsc: "https://example.com",
|
||||
age: -1,
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
tt.cfg = fillRequiredProviderFields(tt.cfg)
|
||||
svr.cfg = tt.cfg
|
||||
svr.maxAge = tt.age
|
||||
getter := NewHTTPProviderConfigGetter(hc, tt.dsc)
|
||||
getter.clock = fc
|
||||
|
||||
got, err := getter.Get()
|
||||
if err != nil {
|
||||
if tt.ok {
|
||||
t.Errorf("test %d: unexpected error: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !tt.ok {
|
||||
t.Errorf("test %d: expected error", i)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.cfg, got) {
|
||||
t.Errorf("test %d: want: %#v, got: %#v", i, tt.cfg, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigSyncerRun(t *testing.T) {
|
||||
c1 := &ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
c2 := &ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
first *ProviderConfig
|
||||
advance time.Duration
|
||||
second *ProviderConfig
|
||||
firstExp time.Duration
|
||||
secondExp time.Duration
|
||||
count int
|
||||
}{
|
||||
// exp is 10m, should have same config after 1s
|
||||
{
|
||||
first: c1,
|
||||
firstExp: time.Duration(10 * time.Minute),
|
||||
advance: time.Minute,
|
||||
second: c1,
|
||||
secondExp: time.Duration(10 * time.Minute),
|
||||
count: 1,
|
||||
},
|
||||
// exp is 10m, should have new config after 10/2 = 5m
|
||||
{
|
||||
first: c1,
|
||||
firstExp: time.Duration(10 * time.Minute),
|
||||
advance: time.Duration(5 * time.Minute),
|
||||
second: c2,
|
||||
secondExp: time.Duration(10 * time.Minute),
|
||||
count: 2,
|
||||
},
|
||||
// exp is 20m, should have new config after 20/2 = 10m
|
||||
{
|
||||
first: c1,
|
||||
firstExp: time.Duration(20 * time.Minute),
|
||||
advance: time.Duration(10 * time.Minute),
|
||||
second: c2,
|
||||
secondExp: time.Duration(30 * time.Minute),
|
||||
count: 2,
|
||||
},
|
||||
}
|
||||
|
||||
assertCfg := func(i int, to *fakeProviderConfigGetterSetter, want ProviderConfig) {
|
||||
got, err := to.Get()
|
||||
if err != nil {
|
||||
t.Fatalf("test %d: unable to get config: %v", i, err)
|
||||
}
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Fatalf("test %d: incorrect state:\nwant=%#v\ngot=%#v", i, want, got)
|
||||
}
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
from := &fakeProviderConfigGetterSetter{}
|
||||
to := &fakeProviderConfigGetterSetter{}
|
||||
|
||||
fc := clockwork.NewFakeClock()
|
||||
now := fc.Now().UTC()
|
||||
syncer := NewProviderConfigSyncer(from, to)
|
||||
syncer.clock = fc
|
||||
|
||||
tt.first.ExpiresAt = now.Add(tt.firstExp)
|
||||
tt.second.ExpiresAt = now.Add(tt.secondExp)
|
||||
if err := from.Set(*tt.first); err != nil {
|
||||
t.Fatalf("test %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
stop := syncer.Run()
|
||||
defer close(stop)
|
||||
fc.BlockUntil(1)
|
||||
|
||||
// first sync
|
||||
assertCfg(i, to, *tt.first)
|
||||
|
||||
if err := from.Set(*tt.second); err != nil {
|
||||
t.Fatalf("test %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
fc.Advance(tt.advance)
|
||||
fc.BlockUntil(1)
|
||||
|
||||
// second sync
|
||||
assertCfg(i, to, *tt.second)
|
||||
|
||||
if tt.count != from.getCount {
|
||||
t.Fatalf("test %d: want: %v, got: %v", i, tt.count, from.getCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type staticProviderConfigGetter struct {
|
||||
cfg ProviderConfig
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *staticProviderConfigGetter) Get() (ProviderConfig, error) {
|
||||
return g.cfg, g.err
|
||||
}
|
||||
|
||||
type staticProviderConfigSetter struct {
|
||||
cfg *ProviderConfig
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *staticProviderConfigSetter) Set(cfg ProviderConfig) error {
|
||||
s.cfg = &cfg
|
||||
return s.err
|
||||
}
|
||||
|
||||
func TestProviderConfigSyncerSyncFailure(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
|
||||
tests := []struct {
|
||||
from *staticProviderConfigGetter
|
||||
to *staticProviderConfigSetter
|
||||
|
||||
// want indicates what ProviderConfig should be passed to Set.
|
||||
// If nil, the Set should not be called.
|
||||
want *ProviderConfig
|
||||
}{
|
||||
// generic Get failure
|
||||
{
|
||||
from: &staticProviderConfigGetter{err: errors.New("fail")},
|
||||
to: &staticProviderConfigSetter{},
|
||||
want: nil,
|
||||
},
|
||||
// generic Set failure
|
||||
{
|
||||
from: &staticProviderConfigGetter{cfg: ProviderConfig{ExpiresAt: fc.Now().Add(time.Minute)}},
|
||||
to: &staticProviderConfigSetter{err: errors.New("fail")},
|
||||
want: &ProviderConfig{ExpiresAt: fc.Now().Add(time.Minute)},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
pcs := &ProviderConfigSyncer{
|
||||
from: tt.from,
|
||||
to: tt.to,
|
||||
clock: fc,
|
||||
}
|
||||
_, err := pcs.sync()
|
||||
if err == nil {
|
||||
t.Errorf("case %d: expected non-nil error", i)
|
||||
}
|
||||
if !reflect.DeepEqual(tt.want, tt.to.cfg) {
|
||||
t.Errorf("case %d: Set mismatch: want=%#v got=%#v", i, tt.want, tt.to.cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextSyncAfter(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
|
||||
tests := []struct {
|
||||
exp time.Time
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
exp: fc.Now().Add(time.Hour),
|
||||
want: 30 * time.Minute,
|
||||
},
|
||||
// override large values with the maximum
|
||||
{
|
||||
exp: fc.Now().Add(168 * time.Hour), // one week
|
||||
want: 24 * time.Hour,
|
||||
},
|
||||
// override "now" values with the minimum
|
||||
{
|
||||
exp: fc.Now(),
|
||||
want: time.Minute,
|
||||
},
|
||||
// override negative values with the minimum
|
||||
{
|
||||
exp: fc.Now().Add(-1 * time.Minute),
|
||||
want: time.Minute,
|
||||
},
|
||||
// zero-value Time results in maximum sync interval
|
||||
{
|
||||
exp: time.Time{},
|
||||
want: 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got := nextSyncAfter(tt.exp, fc)
|
||||
if tt.want != got {
|
||||
t.Errorf("case %d: want=%v got=%v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigEmpty(t *testing.T) {
|
||||
cfg := ProviderConfig{}
|
||||
if !cfg.Empty() {
|
||||
t.Fatalf("Empty provider config reports non-empty")
|
||||
}
|
||||
cfg = ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
if cfg.Empty() {
|
||||
t.Fatalf("Non-empty provider config reports empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPCSStepAfter(t *testing.T) {
|
||||
pass := func() (time.Duration, error) { return 7 * time.Second, nil }
|
||||
fail := func() (time.Duration, error) { return 0, errors.New("fail") }
|
||||
|
||||
tests := []struct {
|
||||
stepper pcsStepper
|
||||
stepFunc pcsStepFunc
|
||||
want pcsStepper
|
||||
}{
|
||||
// good step results in retry at TTL
|
||||
{
|
||||
stepper: &pcsStepNext{},
|
||||
stepFunc: pass,
|
||||
want: &pcsStepNext{aft: 7 * time.Second},
|
||||
},
|
||||
|
||||
// good step after failed step results results in retry at TTL
|
||||
{
|
||||
stepper: &pcsStepRetry{aft: 2 * time.Second},
|
||||
stepFunc: pass,
|
||||
want: &pcsStepNext{aft: 7 * time.Second},
|
||||
},
|
||||
|
||||
// failed step results in a retry in 1s
|
||||
{
|
||||
stepper: &pcsStepNext{},
|
||||
stepFunc: fail,
|
||||
want: &pcsStepRetry{aft: time.Second},
|
||||
},
|
||||
|
||||
// failed retry backs off by a factor of 2
|
||||
{
|
||||
stepper: &pcsStepRetry{aft: time.Second},
|
||||
stepFunc: fail,
|
||||
want: &pcsStepRetry{aft: 2 * time.Second},
|
||||
},
|
||||
|
||||
// failed retry backs off by a factor of 2, up to 1m
|
||||
{
|
||||
stepper: &pcsStepRetry{aft: 32 * time.Second},
|
||||
stepFunc: fail,
|
||||
want: &pcsStepRetry{aft: 60 * time.Second},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got := tt.stepper.step(tt.stepFunc)
|
||||
if !reflect.DeepEqual(tt.want, got) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigSupportsGrantType(t *testing.T) {
|
||||
tests := []struct {
|
||||
types []string
|
||||
typ string
|
||||
want bool
|
||||
}{
|
||||
// explicitly supported
|
||||
{
|
||||
types: []string{"foo_type"},
|
||||
typ: "foo_type",
|
||||
want: true,
|
||||
},
|
||||
|
||||
// explicitly unsupported
|
||||
{
|
||||
types: []string{"bar_type"},
|
||||
typ: "foo_type",
|
||||
want: false,
|
||||
},
|
||||
|
||||
// default type explicitly unsupported
|
||||
{
|
||||
types: []string{oauth2.GrantTypeImplicit},
|
||||
typ: oauth2.GrantTypeAuthCode,
|
||||
want: false,
|
||||
},
|
||||
|
||||
// type not found in default set
|
||||
{
|
||||
types: []string{},
|
||||
typ: "foo_type",
|
||||
want: false,
|
||||
},
|
||||
|
||||
// type found in default set
|
||||
{
|
||||
types: []string{},
|
||||
typ: oauth2.GrantTypeAuthCode,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
cfg := ProviderConfig{
|
||||
GrantTypesSupported: tt.types,
|
||||
}
|
||||
got := cfg.SupportsGrantType(tt.typ)
|
||||
if tt.want != got {
|
||||
t.Errorf("case %d: assert %v supports %v: want=%t got=%t", i, tt.types, tt.typ, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClient struct {
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (f *fakeClient) Do(req *http.Request) (*http.Response, error) {
|
||||
return f.resp, nil
|
||||
}
|
||||
|
||||
func TestWaitForProviderConfigImmediateSuccess(t *testing.T) {
|
||||
cfg := newValidProviderConfig()
|
||||
b, err := json.Marshal(&cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed marshaling provider config")
|
||||
}
|
||||
|
||||
resp := http.Response{Body: ioutil.NopCloser(bytes.NewBuffer(b))}
|
||||
hc := &fakeClient{&resp}
|
||||
fc := clockwork.NewFakeClock()
|
||||
|
||||
reschan := make(chan ProviderConfig)
|
||||
go func() {
|
||||
reschan <- waitForProviderConfig(hc, cfg.Issuer.String(), fc)
|
||||
}()
|
||||
|
||||
var got ProviderConfig
|
||||
select {
|
||||
case got = <-reschan:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("Did not receive result within 1s")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(cfg, got) {
|
||||
t.Fatalf("Received incorrect provider config: want=%#v got=%#v", cfg, got)
|
||||
}
|
||||
}
|
||||
88
vendor/github.com/coreos/go-oidc/oidc/transport.go
generated
vendored
88
vendor/github.com/coreos/go-oidc/oidc/transport.go
generated
vendored
@@ -1,88 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
phttp "github.com/coreos/go-oidc/http"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
type TokenRefresher interface {
|
||||
// Verify checks if the provided token is currently valid or not.
|
||||
Verify(jose.JWT) error
|
||||
|
||||
// Refresh attempts to authenticate and retrieve a new token.
|
||||
Refresh() (jose.JWT, error)
|
||||
}
|
||||
|
||||
type ClientCredsTokenRefresher struct {
|
||||
Issuer string
|
||||
OIDCClient *Client
|
||||
}
|
||||
|
||||
func (c *ClientCredsTokenRefresher) Verify(jwt jose.JWT) (err error) {
|
||||
_, err = VerifyClientClaims(jwt, c.Issuer)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientCredsTokenRefresher) Refresh() (jwt jose.JWT, err error) {
|
||||
if err = c.OIDCClient.Healthy(); err != nil {
|
||||
err = fmt.Errorf("unable to authenticate, unhealthy OIDC client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
jwt, err = c.OIDCClient.ClientCredsToken([]string{"openid"})
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to verify auth code with issuer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type AuthenticatedTransport struct {
|
||||
TokenRefresher
|
||||
http.RoundTripper
|
||||
|
||||
mu sync.Mutex
|
||||
jwt jose.JWT
|
||||
}
|
||||
|
||||
func (t *AuthenticatedTransport) verifiedJWT() (jose.JWT, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.TokenRefresher.Verify(t.jwt) == nil {
|
||||
return t.jwt, nil
|
||||
}
|
||||
|
||||
jwt, err := t.TokenRefresher.Refresh()
|
||||
if err != nil {
|
||||
return jose.JWT{}, fmt.Errorf("unable to acquire valid JWT: %v", err)
|
||||
}
|
||||
|
||||
t.jwt = jwt
|
||||
return t.jwt, nil
|
||||
}
|
||||
|
||||
// SetJWT sets the JWT held by the Transport.
|
||||
// This is useful for cases in which you want to set an initial JWT.
|
||||
func (t *AuthenticatedTransport) SetJWT(jwt jose.JWT) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.jwt = jwt
|
||||
}
|
||||
|
||||
func (t *AuthenticatedTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
jwt, err := t.verifiedJWT()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := phttp.CopyRequest(r)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwt.Encode()))
|
||||
return t.RoundTripper.RoundTrip(req)
|
||||
}
|
||||
176
vendor/github.com/coreos/go-oidc/oidc/transport_test.go
generated
vendored
176
vendor/github.com/coreos/go-oidc/oidc/transport_test.go
generated
vendored
@@ -1,176 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
type staticTokenRefresher struct {
|
||||
verify func(jose.JWT) error
|
||||
refresh func() (jose.JWT, error)
|
||||
}
|
||||
|
||||
func (s *staticTokenRefresher) Verify(jwt jose.JWT) error {
|
||||
return s.verify(jwt)
|
||||
}
|
||||
|
||||
func (s *staticTokenRefresher) Refresh() (jose.JWT, error) {
|
||||
return s.refresh()
|
||||
}
|
||||
|
||||
func TestAuthenticatedTransportVerifiedJWT(t *testing.T) {
|
||||
tests := []struct {
|
||||
refresher TokenRefresher
|
||||
startJWT jose.JWT
|
||||
wantJWT jose.JWT
|
||||
wantError error
|
||||
}{
|
||||
// verification succeeds, so refresh is not called
|
||||
{
|
||||
refresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return nil },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
|
||||
},
|
||||
startJWT: jose.JWT{RawPayload: "1"},
|
||||
wantJWT: jose.JWT{RawPayload: "1"},
|
||||
},
|
||||
|
||||
// verification fails, refresh succeeds so cached JWT changes
|
||||
{
|
||||
refresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return errors.New("fail!") },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
|
||||
},
|
||||
startJWT: jose.JWT{RawPayload: "1"},
|
||||
wantJWT: jose.JWT{RawPayload: "2"},
|
||||
},
|
||||
|
||||
// verification succeeds, so failing refresh isn't attempted
|
||||
{
|
||||
refresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return nil },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
|
||||
},
|
||||
startJWT: jose.JWT{RawPayload: "1"},
|
||||
wantJWT: jose.JWT{RawPayload: "1"},
|
||||
},
|
||||
|
||||
// verification fails, but refresh fails, too
|
||||
{
|
||||
refresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return errors.New("fail!") },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
|
||||
},
|
||||
startJWT: jose.JWT{RawPayload: "1"},
|
||||
wantJWT: jose.JWT{},
|
||||
wantError: errors.New("unable to acquire valid JWT: fail!"),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
at := &AuthenticatedTransport{
|
||||
TokenRefresher: tt.refresher,
|
||||
}
|
||||
at.SetJWT(tt.startJWT)
|
||||
|
||||
gotJWT, err := at.verifiedJWT()
|
||||
if !reflect.DeepEqual(tt.wantError, err) {
|
||||
t.Errorf("#%d: unexpected error: want=%#v got=%#v", i, tt.wantError, err)
|
||||
}
|
||||
if !reflect.DeepEqual(tt.wantJWT, gotJWT) {
|
||||
t.Errorf("#%d: incorrect JWT returned from verifiedJWT: want=%#v got=%#v", i, tt.wantJWT, gotJWT)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticatedTransportJWTCaching(t *testing.T) {
|
||||
at := &AuthenticatedTransport{
|
||||
TokenRefresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return errors.New("fail!") },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
|
||||
},
|
||||
jwt: jose.JWT{RawPayload: "1"},
|
||||
}
|
||||
|
||||
wantJWT := jose.JWT{RawPayload: "2"}
|
||||
gotJWT, err := at.verifiedJWT()
|
||||
if err != nil {
|
||||
t.Fatalf("got non-nil error: %#v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(wantJWT, gotJWT) {
|
||||
t.Fatalf("incorrect JWT returned from verifiedJWT: want=%#v got=%#v", wantJWT, gotJWT)
|
||||
}
|
||||
|
||||
at.TokenRefresher = &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return nil },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "3"}, nil },
|
||||
}
|
||||
|
||||
// the previous JWT should still be cached on the AuthenticatedTransport since
|
||||
// it is still valid, even though there's a new token ready to refresh
|
||||
gotJWT, err = at.verifiedJWT()
|
||||
if err != nil {
|
||||
t.Fatalf("got non-nil error: %#v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(wantJWT, gotJWT) {
|
||||
t.Fatalf("incorrect JWT returned from verifiedJWT: want=%#v got=%#v", wantJWT, gotJWT)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeRoundTripper struct {
|
||||
Request *http.Request
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (r *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
r.Request = req
|
||||
return r.resp, nil
|
||||
}
|
||||
|
||||
func TestAuthenticatedTransportRoundTrip(t *testing.T) {
|
||||
rr := &fakeRoundTripper{nil, &http.Response{StatusCode: http.StatusOK}}
|
||||
at := &AuthenticatedTransport{
|
||||
TokenRefresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return nil },
|
||||
},
|
||||
RoundTripper: rr,
|
||||
jwt: jose.JWT{RawPayload: "1"},
|
||||
}
|
||||
|
||||
req := http.Request{}
|
||||
_, err := at.RoundTrip(&req)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(req, http.Request{}) {
|
||||
t.Errorf("http.Request object was modified")
|
||||
}
|
||||
|
||||
want := []string{"Bearer .1."}
|
||||
got := rr.Request.Header["Authorization"]
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("incorrect Authorization header: want=%#v got=%#v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticatedTransportRoundTripRefreshFail(t *testing.T) {
|
||||
rr := &fakeRoundTripper{nil, &http.Response{StatusCode: http.StatusOK}}
|
||||
at := &AuthenticatedTransport{
|
||||
TokenRefresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return errors.New("fail!") },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
|
||||
},
|
||||
RoundTripper: rr,
|
||||
jwt: jose.JWT{RawPayload: "1"},
|
||||
}
|
||||
|
||||
_, err := at.RoundTrip(&http.Request{})
|
||||
if err == nil {
|
||||
t.Errorf("expected non-nil error")
|
||||
}
|
||||
}
|
||||
109
vendor/github.com/coreos/go-oidc/oidc/util.go
generated
vendored
109
vendor/github.com/coreos/go-oidc/oidc/util.go
generated
vendored
@@ -1,109 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
// RequestTokenExtractor funcs extract a raw encoded token from a request.
|
||||
type RequestTokenExtractor func(r *http.Request) (string, error)
|
||||
|
||||
// ExtractBearerToken is a RequestTokenExtractor which extracts a bearer token from a request's
|
||||
// Authorization header.
|
||||
func ExtractBearerToken(r *http.Request) (string, error) {
|
||||
ah := r.Header.Get("Authorization")
|
||||
if ah == "" {
|
||||
return "", errors.New("missing Authorization header")
|
||||
}
|
||||
|
||||
if len(ah) <= 6 || strings.ToUpper(ah[0:6]) != "BEARER" {
|
||||
return "", errors.New("should be a bearer token")
|
||||
}
|
||||
|
||||
val := ah[7:]
|
||||
if len(val) == 0 {
|
||||
return "", errors.New("bearer token is empty")
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// CookieTokenExtractor returns a RequestTokenExtractor which extracts a token from the named cookie in a request.
|
||||
func CookieTokenExtractor(cookieName string) RequestTokenExtractor {
|
||||
return func(r *http.Request) (string, error) {
|
||||
ck, err := r.Cookie(cookieName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("token cookie not found in request: %v", err)
|
||||
}
|
||||
|
||||
if ck.Value == "" {
|
||||
return "", errors.New("token cookie found but is empty")
|
||||
}
|
||||
|
||||
return ck.Value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaims(iss, sub string, aud interface{}, iat, exp time.Time) jose.Claims {
|
||||
return jose.Claims{
|
||||
// required
|
||||
"iss": iss,
|
||||
"sub": sub,
|
||||
"aud": aud,
|
||||
"iat": iat.Unix(),
|
||||
"exp": exp.Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
func GenClientID(hostport string) (string, error) {
|
||||
b, err := randBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var host string
|
||||
if strings.Contains(hostport, ":") {
|
||||
host, _, err = net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
host = hostport
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s@%s", base64.URLEncoding.EncodeToString(b), host), nil
|
||||
}
|
||||
|
||||
func randBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
got, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if n != got {
|
||||
return nil, errors.New("unable to generate enough random data")
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// urlEqual checks two urls for equality using only the host and path portions.
|
||||
func urlEqual(url1, url2 string) bool {
|
||||
u1, err := url.Parse(url1)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
u2, err := url.Parse(url2)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.ToLower(u1.Host+u1.Path) == strings.ToLower(u2.Host+u2.Path)
|
||||
}
|
||||
110
vendor/github.com/coreos/go-oidc/oidc/util_test.go
generated
vendored
110
vendor/github.com/coreos/go-oidc/oidc/util_test.go
generated
vendored
@@ -1,110 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
func TestCookieTokenExtractorInvalid(t *testing.T) {
|
||||
ckName := "tokenCookie"
|
||||
tests := []*http.Cookie{
|
||||
&http.Cookie{},
|
||||
&http.Cookie{Name: ckName},
|
||||
&http.Cookie{Name: ckName, Value: ""},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, _ := http.NewRequest("", "", nil)
|
||||
r.AddCookie(tt)
|
||||
_, err := CookieTokenExtractor(ckName)(r)
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want: error for invalid cookie token, got: no error.", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieTokenExtractorValid(t *testing.T) {
|
||||
validToken := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
ckName := "tokenCookie"
|
||||
tests := []*http.Cookie{
|
||||
&http.Cookie{Name: ckName, Value: "some non-empty value"},
|
||||
&http.Cookie{Name: ckName, Value: validToken},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, _ := http.NewRequest("", "", nil)
|
||||
r.AddCookie(tt)
|
||||
_, err := CookieTokenExtractor(ckName)(r)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want: valid cookie with no error, got: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBearerTokenInvalid(t *testing.T) {
|
||||
tests := []string{"", "x", "Bearer", "xxxxxxx", "Bearer "}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, _ := http.NewRequest("", "", nil)
|
||||
r.Header.Add("Authorization", tt)
|
||||
_, err := ExtractBearerToken(r)
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want: invalid Authorization header, got: valid Authorization header.", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBearerTokenValid(t *testing.T) {
|
||||
validToken := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
tests := []string{
|
||||
fmt.Sprintf("Bearer %s", validToken),
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, _ := http.NewRequest("", "", nil)
|
||||
r.Header.Add("Authorization", tt)
|
||||
_, err := ExtractBearerToken(r)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want: valid Authorization header, got: invalid Authorization header: %v.", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClaims(t *testing.T) {
|
||||
issAt := time.Date(2, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
expAt := time.Date(2, time.January, 1, 1, 0, 0, 0, time.UTC)
|
||||
|
||||
want := jose.Claims{
|
||||
"iss": "https://example.com",
|
||||
"sub": "user-123",
|
||||
"aud": "client-abc",
|
||||
"iat": issAt.Unix(),
|
||||
"exp": expAt.Unix(),
|
||||
}
|
||||
|
||||
got := NewClaims("https://example.com", "user-123", "client-abc", issAt, expAt)
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Fatalf("want=%#v got=%#v", want, got)
|
||||
}
|
||||
|
||||
want2 := jose.Claims{
|
||||
"iss": "https://example.com",
|
||||
"sub": "user-123",
|
||||
"aud": []string{"client-abc", "client-def"},
|
||||
"iat": issAt.Unix(),
|
||||
"exp": expAt.Unix(),
|
||||
}
|
||||
|
||||
got2 := NewClaims("https://example.com", "user-123", []string{"client-abc", "client-def"}, issAt, expAt)
|
||||
|
||||
if !reflect.DeepEqual(want2, got2) {
|
||||
t.Fatalf("want=%#v got=%#v", want2, got2)
|
||||
}
|
||||
|
||||
}
|
||||
190
vendor/github.com/coreos/go-oidc/oidc/verification.go
generated
vendored
190
vendor/github.com/coreos/go-oidc/oidc/verification.go
generated
vendored
@@ -1,190 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
)
|
||||
|
||||
func VerifySignature(jwt jose.JWT, keys []key.PublicKey) (bool, error) {
|
||||
jwtBytes := []byte(jwt.Data())
|
||||
for _, k := range keys {
|
||||
v, err := k.Verifier()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if v.Verify(jwt.Signature, jwtBytes) == nil {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// containsString returns true if the given string(needle) is found
|
||||
// in the string array(haystack).
|
||||
func containsString(needle string, haystack []string) bool {
|
||||
for _, v := range haystack {
|
||||
if v == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify claims in accordance with OIDC spec
|
||||
// http://openid.net/specs/openid-connect-basic-1_0.html#IDTokenValidation
|
||||
func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
claims, err := jwt.Claims()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ident, err := IdentityFromClaims(claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ident.ExpiresAt.Before(now) {
|
||||
return errors.New("token is expired")
|
||||
}
|
||||
|
||||
// iss REQUIRED. Issuer Identifier for the Issuer of the response.
|
||||
// The iss value is a case sensitive URL using the https scheme that contains scheme,
|
||||
// host, and optionally, port number and path components and no query or fragment components.
|
||||
if iss, exists := claims["iss"].(string); exists {
|
||||
if !urlEqual(iss, issuer) {
|
||||
return fmt.Errorf("invalid claim value: 'iss'. expected=%s, found=%s.", issuer, iss)
|
||||
}
|
||||
} else {
|
||||
return errors.New("missing claim: 'iss'")
|
||||
}
|
||||
|
||||
// iat REQUIRED. Time at which the JWT was issued.
|
||||
// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z
|
||||
// as measured in UTC until the date/time.
|
||||
if _, exists := claims["iat"].(float64); !exists {
|
||||
return errors.New("missing claim: 'iat'")
|
||||
}
|
||||
|
||||
// aud REQUIRED. Audience(s) that this ID Token is intended for.
|
||||
// It MUST contain the OAuth 2.0 client_id of the Relying Party as an audience value.
|
||||
// It MAY also contain identifiers for other audiences. In the general case, the aud
|
||||
// value is an array of case sensitive strings. In the common special case when there
|
||||
// is one audience, the aud value MAY be a single case sensitive string.
|
||||
if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
|
||||
if aud != clientID {
|
||||
return fmt.Errorf("invalid claims, 'aud' claim and 'client_id' do not match, aud=%s, client_id=%s", aud, clientID)
|
||||
}
|
||||
} else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
|
||||
if !containsString(clientID, aud) {
|
||||
return fmt.Errorf("invalid claims, cannot find 'client_id' in 'aud' claim, aud=%v, client_id=%s", aud, clientID)
|
||||
}
|
||||
} else {
|
||||
return errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyClientClaims verifies all the required claims are valid for a "client credentials" JWT.
|
||||
// Returns the client ID if valid, or an error if invalid.
|
||||
func VerifyClientClaims(jwt jose.JWT, issuer string) (string, error) {
|
||||
claims, err := jwt.Claims()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse JWT claims: %v", err)
|
||||
}
|
||||
|
||||
iss, ok, err := claims.StringClaim("iss")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse 'iss' claim: %v", err)
|
||||
} else if !ok {
|
||||
return "", errors.New("missing required 'iss' claim")
|
||||
} else if !urlEqual(iss, issuer) {
|
||||
return "", fmt.Errorf("'iss' claim does not match expected issuer, iss=%s", iss)
|
||||
}
|
||||
|
||||
sub, ok, err := claims.StringClaim("sub")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse 'sub' claim: %v", err)
|
||||
} else if !ok {
|
||||
return "", errors.New("missing required 'sub' claim")
|
||||
}
|
||||
|
||||
if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
|
||||
if aud != sub {
|
||||
return "", fmt.Errorf("invalid claims, 'aud' claim and 'sub' claim do not match, aud=%s, sub=%s", aud, sub)
|
||||
}
|
||||
} else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
|
||||
if !containsString(sub, aud) {
|
||||
return "", fmt.Errorf("invalid claims, cannot find 'sud' in 'aud' claim, aud=%v, sub=%s", aud, sub)
|
||||
}
|
||||
} else {
|
||||
return "", errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
exp, ok, err := claims.TimeClaim("exp")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse 'exp' claim: %v", err)
|
||||
} else if !ok {
|
||||
return "", errors.New("missing required 'exp' claim")
|
||||
} else if exp.Before(now) {
|
||||
return "", fmt.Errorf("token already expired at: %v", exp)
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
type JWTVerifier struct {
|
||||
issuer string
|
||||
clientID string
|
||||
syncFunc func() error
|
||||
keysFunc func() []key.PublicKey
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func NewJWTVerifier(issuer, clientID string, syncFunc func() error, keysFunc func() []key.PublicKey) JWTVerifier {
|
||||
return JWTVerifier{
|
||||
issuer: issuer,
|
||||
clientID: clientID,
|
||||
syncFunc: syncFunc,
|
||||
keysFunc: keysFunc,
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
func (v *JWTVerifier) Verify(jwt jose.JWT) error {
|
||||
// Verify claims before verifying the signature. This is an optimization to throw out
|
||||
// tokens we know are invalid without undergoing an expensive signature check and
|
||||
// possibly a re-sync event.
|
||||
if err := VerifyClaims(jwt, v.issuer, v.clientID); err != nil {
|
||||
return fmt.Errorf("oidc: JWT claims invalid: %v", err)
|
||||
}
|
||||
|
||||
ok, err := VerifySignature(jwt, v.keysFunc())
|
||||
if err != nil {
|
||||
return fmt.Errorf("oidc: JWT signature verification failed: %v", err)
|
||||
} else if ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = v.syncFunc(); err != nil {
|
||||
return fmt.Errorf("oidc: failed syncing KeySet: %v", err)
|
||||
}
|
||||
|
||||
ok, err = VerifySignature(jwt, v.keysFunc())
|
||||
if err != nil {
|
||||
return fmt.Errorf("oidc: JWT signature verification failed: %v", err)
|
||||
} else if !ok {
|
||||
return errors.New("oidc: unable to verify JWT signature: no matching keys")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
380
vendor/github.com/coreos/go-oidc/oidc/verification_test.go
generated
vendored
380
vendor/github.com/coreos/go-oidc/oidc/verification_test.go
generated
vendored
@@ -1,380 +0,0 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
)
|
||||
|
||||
func TestVerifyClientClaims(t *testing.T) {
|
||||
validIss := "https://example.com"
|
||||
validClientID := "valid-client"
|
||||
now := time.Now()
|
||||
tomorrow := now.Add(24 * time.Hour)
|
||||
header := jose.JOSEHeader{
|
||||
jose.HeaderKeyAlgorithm: "test-alg",
|
||||
jose.HeaderKeyID: "1",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
claims jose.Claims
|
||||
ok bool
|
||||
}{
|
||||
// valid token
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
// valid token, ('aud' claim is []string)
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": []string{"foo", validClientID},
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
// valid token, ('aud' claim is []interface{})
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": []interface{}{"foo", validClientID},
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
// missing 'iss' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"sub": validClientID,
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// invalid 'iss' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": "INVALID",
|
||||
"sub": validClientID,
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// missing 'sub' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// invalid 'sub' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": "INVALID",
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// missing 'aud' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// invalid 'aud' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": "INVALID",
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// invalid 'aud' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": []string{"INVALID1", "INVALID2"},
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// invalid 'aud' type
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": struct{}{},
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// expired
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(now.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
jwt, err := jose.NewJWT(header, tt.claims)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: Failed to generate JWT, error=%v", i, err)
|
||||
}
|
||||
|
||||
got, err := VerifyClientClaims(jwt, validIss)
|
||||
if tt.ok {
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error, err=%v", i, err)
|
||||
}
|
||||
if got != validClientID {
|
||||
t.Errorf("case %d: incorrect client ID, want=%s, got=%s", i, validClientID, got)
|
||||
}
|
||||
} else if err == nil {
|
||||
t.Errorf("case %d: expected error but err is nil", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTVerifier(t *testing.T) {
|
||||
iss := "http://example.com"
|
||||
now := time.Now()
|
||||
future12 := now.Add(12 * time.Hour)
|
||||
past36 := now.Add(-36 * time.Hour)
|
||||
past12 := now.Add(-12 * time.Hour)
|
||||
|
||||
priv1, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
pk1 := *key.NewPublicKey(priv1.JWK())
|
||||
|
||||
priv2, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
pk2 := *key.NewPublicKey(priv2.JWK())
|
||||
|
||||
newJWT := func(issuer, subject string, aud interface{}, issuedAt, exp time.Time, signer jose.Signer) jose.JWT {
|
||||
jwt, err := jose.NewSignedJWT(NewClaims(issuer, subject, aud, issuedAt, exp), signer)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return *jwt
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier JWTVerifier
|
||||
jwt jose.JWT
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "JWT signed with available key",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv1.Signer()),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "JWT signed with available key, with bad claims",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "YYY", past12, future12, priv1.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "JWT signed with available key",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", []string{"YYY", "ZZZ"}, past12, future12, priv1.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "expired JWT signed with available key",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past36, past12, priv1.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "JWT signed with unrecognized key, verifiable after sync",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() func() []key.PublicKey {
|
||||
var i int
|
||||
return func() []key.PublicKey {
|
||||
defer func() { i++ }()
|
||||
return [][]key.PublicKey{
|
||||
[]key.PublicKey{pk1},
|
||||
[]key.PublicKey{pk2},
|
||||
}[i]
|
||||
}
|
||||
}(),
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past36, future12, priv2.Signer()),
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "JWT signed with unrecognized key, not verifiable after sync",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv2.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "verifier gets no keys from keysFunc, still not verifiable after sync",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv1.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "verifier gets no keys from keysFunc, verifiable after sync",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() func() []key.PublicKey {
|
||||
var i int
|
||||
return func() []key.PublicKey {
|
||||
defer func() { i++ }()
|
||||
return [][]key.PublicKey{
|
||||
[]key.PublicKey{},
|
||||
[]key.PublicKey{pk2},
|
||||
}[i]
|
||||
}
|
||||
}(),
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv2.Signer()),
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "JWT signed with available key, 'aud' is a string array",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", []string{"ZZZ", "XXX"}, past12, future12, priv1.Signer()),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer claim shouldn't trigger sync",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error {
|
||||
t.Errorf("invalid issuer claim shouldn't trigger a sync")
|
||||
return nil
|
||||
},
|
||||
keysFunc: func() func() []key.PublicKey {
|
||||
var i int
|
||||
return func() []key.PublicKey {
|
||||
defer func() { i++ }()
|
||||
return [][]key.PublicKey{
|
||||
[]key.PublicKey{},
|
||||
[]key.PublicKey{pk2},
|
||||
}[i]
|
||||
}
|
||||
}(),
|
||||
},
|
||||
jwt: newJWT("invalid-issuer", "XXX", []string{"ZZZ", "XXX"}, past12, future12, priv2.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
err := tt.verifier.Verify(tt.jwt)
|
||||
if tt.wantErr && (err == nil) {
|
||||
t.Errorf("case %q: wanted non-nil error", tt.name)
|
||||
} else if !tt.wantErr && (err != nil) {
|
||||
t.Errorf("case %q: wanted nil error, got %v", tt.name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user