vendor: revendor
This commit is contained in:
846
vendor/github.com/coreos/go-oidc/oidc/client.go
generated
vendored
Normal file
846
vendor/github.com/coreos/go-oidc/oidc/client.go
generated
vendored
Normal file
@@ -0,0 +1,846 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
phttp "github.com/coreos/go-oidc/http"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
// amount of time that must pass after the last key sync
|
||||
// completes before another attempt may begin
|
||||
keySyncWindow = 5 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultScope = []string{"openid", "email", "profile"}
|
||||
|
||||
supportedAuthMethods = map[string]struct{}{
|
||||
oauth2.AuthMethodClientSecretBasic: struct{}{},
|
||||
oauth2.AuthMethodClientSecretPost: struct{}{},
|
||||
}
|
||||
)
|
||||
|
||||
type ClientCredentials oauth2.ClientCredentials
|
||||
|
||||
type ClientIdentity struct {
|
||||
Credentials ClientCredentials
|
||||
Metadata ClientMetadata
|
||||
}
|
||||
|
||||
type JWAOptions struct {
|
||||
// SigningAlg specifies an JWA alg for signing JWTs.
|
||||
//
|
||||
// Specifying this field implies different actions depending on the context. It may
|
||||
// require objects be serialized and signed as a JWT instead of plain JSON, or
|
||||
// require an existing JWT object use the specified alg.
|
||||
//
|
||||
// See: http://openid.net/specs/openid-connect-registration-1_0.html#ClientMetadata
|
||||
SigningAlg string
|
||||
// EncryptionAlg, if provided, specifies that the returned or sent object be stored
|
||||
// (or nested) within a JWT object and encrypted with the provided JWA alg.
|
||||
EncryptionAlg string
|
||||
// EncryptionEnc specifies the JWA enc algorithm to use with EncryptionAlg. If
|
||||
// EncryptionAlg is provided and EncryptionEnc is omitted, this field defaults
|
||||
// to A128CBC-HS256.
|
||||
//
|
||||
// If EncryptionEnc is provided EncryptionAlg must also be specified.
|
||||
EncryptionEnc string
|
||||
}
|
||||
|
||||
func (opt JWAOptions) valid() error {
|
||||
if opt.EncryptionEnc != "" && opt.EncryptionAlg == "" {
|
||||
return errors.New("encryption encoding provided with no encryption algorithm")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (opt JWAOptions) defaults() JWAOptions {
|
||||
if opt.EncryptionAlg != "" && opt.EncryptionEnc == "" {
|
||||
opt.EncryptionEnc = jose.EncA128CBCHS256
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure ClientMetadata satisfies these interfaces.
|
||||
_ json.Marshaler = &ClientMetadata{}
|
||||
_ json.Unmarshaler = &ClientMetadata{}
|
||||
)
|
||||
|
||||
// ClientMetadata holds metadata that the authorization server associates
|
||||
// with a client identifier. The fields range from human-facing display
|
||||
// strings such as client name, to items that impact the security of the
|
||||
// protocol, such as the list of valid redirect URIs.
|
||||
//
|
||||
// See http://openid.net/specs/openid-connect-registration-1_0.html#ClientMetadata
|
||||
//
|
||||
// TODO: support language specific claim representations
|
||||
// http://openid.net/specs/openid-connect-registration-1_0.html#LanguagesAndScripts
|
||||
type ClientMetadata struct {
|
||||
RedirectURIs []url.URL // Required
|
||||
|
||||
// A list of OAuth 2.0 "response_type" values that the client wishes to restrict
|
||||
// itself to. Either "code", "token", or another registered extension.
|
||||
//
|
||||
// If omitted, only "code" will be used.
|
||||
ResponseTypes []string
|
||||
// A list of OAuth 2.0 grant types the client wishes to restrict itself to.
|
||||
// The grant type values used by OIDC are "authorization_code", "implicit",
|
||||
// and "refresh_token".
|
||||
//
|
||||
// If ommitted, only "authorization_code" will be used.
|
||||
GrantTypes []string
|
||||
// "native" or "web". If omitted, "web".
|
||||
ApplicationType string
|
||||
|
||||
// List of email addresses.
|
||||
Contacts []mail.Address
|
||||
// Name of client to be presented to the end-user.
|
||||
ClientName string
|
||||
// URL that references a logo for the Client application.
|
||||
LogoURI *url.URL
|
||||
// URL of the home page of the Client.
|
||||
ClientURI *url.URL
|
||||
// Profile data policies and terms of use to be provided to the end user.
|
||||
PolicyURI *url.URL
|
||||
TermsOfServiceURI *url.URL
|
||||
|
||||
// URL to or the value of the client's JSON Web Key Set document.
|
||||
JWKSURI *url.URL
|
||||
JWKS *jose.JWKSet
|
||||
|
||||
// URL referencing a flie with a single JSON array of redirect URIs.
|
||||
SectorIdentifierURI *url.URL
|
||||
|
||||
SubjectType string
|
||||
|
||||
// Options to restrict the JWS alg and enc values used for server responses and requests.
|
||||
IDTokenResponseOptions JWAOptions
|
||||
UserInfoResponseOptions JWAOptions
|
||||
RequestObjectOptions JWAOptions
|
||||
|
||||
// Client requested authorization method and signing options for the token endpoint.
|
||||
//
|
||||
// Defaults to "client_secret_basic"
|
||||
TokenEndpointAuthMethod string
|
||||
TokenEndpointAuthSigningAlg string
|
||||
|
||||
// DefaultMaxAge specifies the maximum amount of time in seconds before an authorized
|
||||
// user must reauthroize.
|
||||
//
|
||||
// If 0, no limitation is placed on the maximum.
|
||||
DefaultMaxAge int64
|
||||
// RequireAuthTime specifies if the auth_time claim in the ID token is required.
|
||||
RequireAuthTime bool
|
||||
|
||||
// Default Authentication Context Class Reference values for authentication requests.
|
||||
DefaultACRValues []string
|
||||
|
||||
// URI that a third party can use to initiate a login by the relaying party.
|
||||
//
|
||||
// See: http://openid.net/specs/openid-connect-core-1_0.html#ThirdPartyInitiatedLogin
|
||||
InitiateLoginURI *url.URL
|
||||
// Pre-registered request_uri values that may be cached by the server.
|
||||
RequestURIs []url.URL
|
||||
}
|
||||
|
||||
// Defaults returns a shallow copy of ClientMetadata with default
|
||||
// values replacing omitted fields.
|
||||
func (m ClientMetadata) Defaults() ClientMetadata {
|
||||
if len(m.ResponseTypes) == 0 {
|
||||
m.ResponseTypes = []string{oauth2.ResponseTypeCode}
|
||||
}
|
||||
if len(m.GrantTypes) == 0 {
|
||||
m.GrantTypes = []string{oauth2.GrantTypeAuthCode}
|
||||
}
|
||||
if m.ApplicationType == "" {
|
||||
m.ApplicationType = "web"
|
||||
}
|
||||
if m.TokenEndpointAuthMethod == "" {
|
||||
m.TokenEndpointAuthMethod = oauth2.AuthMethodClientSecretBasic
|
||||
}
|
||||
m.IDTokenResponseOptions = m.IDTokenResponseOptions.defaults()
|
||||
m.UserInfoResponseOptions = m.UserInfoResponseOptions.defaults()
|
||||
m.RequestObjectOptions = m.RequestObjectOptions.defaults()
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *ClientMetadata) MarshalJSON() ([]byte, error) {
|
||||
e := m.toEncodableStruct()
|
||||
return json.Marshal(&e)
|
||||
}
|
||||
|
||||
func (m *ClientMetadata) UnmarshalJSON(data []byte) error {
|
||||
var e encodableClientMetadata
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
meta, err := e.toStruct()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := meta.Valid(); err != nil {
|
||||
return err
|
||||
}
|
||||
*m = meta
|
||||
return nil
|
||||
}
|
||||
|
||||
type encodableClientMetadata struct {
|
||||
RedirectURIs []string `json:"redirect_uris"` // Required
|
||||
ResponseTypes []string `json:"response_types,omitempty"`
|
||||
GrantTypes []string `json:"grant_types,omitempty"`
|
||||
ApplicationType string `json:"application_type,omitempty"`
|
||||
Contacts []string `json:"contacts,omitempty"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
LogoURI string `json:"logo_uri,omitempty"`
|
||||
ClientURI string `json:"client_uri,omitempty"`
|
||||
PolicyURI string `json:"policy_uri,omitempty"`
|
||||
TermsOfServiceURI string `json:"tos_uri,omitempty"`
|
||||
JWKSURI string `json:"jwks_uri,omitempty"`
|
||||
JWKS *jose.JWKSet `json:"jwks,omitempty"`
|
||||
SectorIdentifierURI string `json:"sector_identifier_uri,omitempty"`
|
||||
SubjectType string `json:"subject_type,omitempty"`
|
||||
IDTokenSignedResponseAlg string `json:"id_token_signed_response_alg,omitempty"`
|
||||
IDTokenEncryptedResponseAlg string `json:"id_token_encrypted_response_alg,omitempty"`
|
||||
IDTokenEncryptedResponseEnc string `json:"id_token_encrypted_response_enc,omitempty"`
|
||||
UserInfoSignedResponseAlg string `json:"userinfo_signed_response_alg,omitempty"`
|
||||
UserInfoEncryptedResponseAlg string `json:"userinfo_encrypted_response_alg,omitempty"`
|
||||
UserInfoEncryptedResponseEnc string `json:"userinfo_encrypted_response_enc,omitempty"`
|
||||
RequestObjectSigningAlg string `json:"request_object_signing_alg,omitempty"`
|
||||
RequestObjectEncryptionAlg string `json:"request_object_encryption_alg,omitempty"`
|
||||
RequestObjectEncryptionEnc string `json:"request_object_encryption_enc,omitempty"`
|
||||
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"`
|
||||
TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg,omitempty"`
|
||||
DefaultMaxAge int64 `json:"default_max_age,omitempty"`
|
||||
RequireAuthTime bool `json:"require_auth_time,omitempty"`
|
||||
DefaultACRValues []string `json:"default_acr_values,omitempty"`
|
||||
InitiateLoginURI string `json:"initiate_login_uri,omitempty"`
|
||||
RequestURIs []string `json:"request_uris,omitempty"`
|
||||
}
|
||||
|
||||
func (c *encodableClientMetadata) toStruct() (ClientMetadata, error) {
|
||||
p := stickyErrParser{}
|
||||
m := ClientMetadata{
|
||||
RedirectURIs: p.parseURIs(c.RedirectURIs, "redirect_uris"),
|
||||
ResponseTypes: c.ResponseTypes,
|
||||
GrantTypes: c.GrantTypes,
|
||||
ApplicationType: c.ApplicationType,
|
||||
Contacts: p.parseEmails(c.Contacts, "contacts"),
|
||||
ClientName: c.ClientName,
|
||||
LogoURI: p.parseURI(c.LogoURI, "logo_uri"),
|
||||
ClientURI: p.parseURI(c.ClientURI, "client_uri"),
|
||||
PolicyURI: p.parseURI(c.PolicyURI, "policy_uri"),
|
||||
TermsOfServiceURI: p.parseURI(c.TermsOfServiceURI, "tos_uri"),
|
||||
JWKSURI: p.parseURI(c.JWKSURI, "jwks_uri"),
|
||||
JWKS: c.JWKS,
|
||||
SectorIdentifierURI: p.parseURI(c.SectorIdentifierURI, "sector_identifier_uri"),
|
||||
SubjectType: c.SubjectType,
|
||||
TokenEndpointAuthMethod: c.TokenEndpointAuthMethod,
|
||||
TokenEndpointAuthSigningAlg: c.TokenEndpointAuthSigningAlg,
|
||||
DefaultMaxAge: c.DefaultMaxAge,
|
||||
RequireAuthTime: c.RequireAuthTime,
|
||||
DefaultACRValues: c.DefaultACRValues,
|
||||
InitiateLoginURI: p.parseURI(c.InitiateLoginURI, "initiate_login_uri"),
|
||||
RequestURIs: p.parseURIs(c.RequestURIs, "request_uris"),
|
||||
IDTokenResponseOptions: JWAOptions{
|
||||
c.IDTokenSignedResponseAlg,
|
||||
c.IDTokenEncryptedResponseAlg,
|
||||
c.IDTokenEncryptedResponseEnc,
|
||||
},
|
||||
UserInfoResponseOptions: JWAOptions{
|
||||
c.UserInfoSignedResponseAlg,
|
||||
c.UserInfoEncryptedResponseAlg,
|
||||
c.UserInfoEncryptedResponseEnc,
|
||||
},
|
||||
RequestObjectOptions: JWAOptions{
|
||||
c.RequestObjectSigningAlg,
|
||||
c.RequestObjectEncryptionAlg,
|
||||
c.RequestObjectEncryptionEnc,
|
||||
},
|
||||
}
|
||||
if p.firstErr != nil {
|
||||
return ClientMetadata{}, p.firstErr
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// stickyErrParser parses URIs and email addresses. Once it encounters
|
||||
// a parse error, subsequent calls become no-op.
|
||||
type stickyErrParser struct {
|
||||
firstErr error
|
||||
}
|
||||
|
||||
func (p *stickyErrParser) parseURI(s, field string) *url.URL {
|
||||
if p.firstErr != nil || s == "" {
|
||||
return nil
|
||||
}
|
||||
u, err := url.Parse(s)
|
||||
if err == nil {
|
||||
if u.Host == "" {
|
||||
err = errors.New("no host in URI")
|
||||
} else if u.Scheme != "http" && u.Scheme != "https" {
|
||||
err = errors.New("invalid URI scheme")
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
p.firstErr = fmt.Errorf("failed to parse %s: %v", field, err)
|
||||
return nil
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
func (p *stickyErrParser) parseURIs(s []string, field string) []url.URL {
|
||||
if p.firstErr != nil || len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
uris := make([]url.URL, len(s))
|
||||
for i, val := range s {
|
||||
if val == "" {
|
||||
p.firstErr = fmt.Errorf("invalid URI in field %s", field)
|
||||
return nil
|
||||
}
|
||||
if u := p.parseURI(val, field); u != nil {
|
||||
uris[i] = *u
|
||||
}
|
||||
}
|
||||
return uris
|
||||
}
|
||||
|
||||
func (p *stickyErrParser) parseEmails(s []string, field string) []mail.Address {
|
||||
if p.firstErr != nil || len(s) == 0 {
|
||||
return nil
|
||||
}
|
||||
addrs := make([]mail.Address, len(s))
|
||||
for i, addr := range s {
|
||||
if addr == "" {
|
||||
p.firstErr = fmt.Errorf("invalid email in field %s", field)
|
||||
return nil
|
||||
}
|
||||
a, err := mail.ParseAddress(addr)
|
||||
if err != nil {
|
||||
p.firstErr = fmt.Errorf("invalid email in field %s: %v", field, err)
|
||||
return nil
|
||||
}
|
||||
addrs[i] = *a
|
||||
}
|
||||
return addrs
|
||||
}
|
||||
|
||||
func (m *ClientMetadata) toEncodableStruct() encodableClientMetadata {
|
||||
return encodableClientMetadata{
|
||||
RedirectURIs: urisToStrings(m.RedirectURIs),
|
||||
ResponseTypes: m.ResponseTypes,
|
||||
GrantTypes: m.GrantTypes,
|
||||
ApplicationType: m.ApplicationType,
|
||||
Contacts: emailsToStrings(m.Contacts),
|
||||
ClientName: m.ClientName,
|
||||
LogoURI: uriToString(m.LogoURI),
|
||||
ClientURI: uriToString(m.ClientURI),
|
||||
PolicyURI: uriToString(m.PolicyURI),
|
||||
TermsOfServiceURI: uriToString(m.TermsOfServiceURI),
|
||||
JWKSURI: uriToString(m.JWKSURI),
|
||||
JWKS: m.JWKS,
|
||||
SectorIdentifierURI: uriToString(m.SectorIdentifierURI),
|
||||
SubjectType: m.SubjectType,
|
||||
IDTokenSignedResponseAlg: m.IDTokenResponseOptions.SigningAlg,
|
||||
IDTokenEncryptedResponseAlg: m.IDTokenResponseOptions.EncryptionAlg,
|
||||
IDTokenEncryptedResponseEnc: m.IDTokenResponseOptions.EncryptionEnc,
|
||||
UserInfoSignedResponseAlg: m.UserInfoResponseOptions.SigningAlg,
|
||||
UserInfoEncryptedResponseAlg: m.UserInfoResponseOptions.EncryptionAlg,
|
||||
UserInfoEncryptedResponseEnc: m.UserInfoResponseOptions.EncryptionEnc,
|
||||
RequestObjectSigningAlg: m.RequestObjectOptions.SigningAlg,
|
||||
RequestObjectEncryptionAlg: m.RequestObjectOptions.EncryptionAlg,
|
||||
RequestObjectEncryptionEnc: m.RequestObjectOptions.EncryptionEnc,
|
||||
TokenEndpointAuthMethod: m.TokenEndpointAuthMethod,
|
||||
TokenEndpointAuthSigningAlg: m.TokenEndpointAuthSigningAlg,
|
||||
DefaultMaxAge: m.DefaultMaxAge,
|
||||
RequireAuthTime: m.RequireAuthTime,
|
||||
DefaultACRValues: m.DefaultACRValues,
|
||||
InitiateLoginURI: uriToString(m.InitiateLoginURI),
|
||||
RequestURIs: urisToStrings(m.RequestURIs),
|
||||
}
|
||||
}
|
||||
|
||||
func uriToString(u *url.URL) string {
|
||||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func urisToStrings(urls []url.URL) []string {
|
||||
if len(urls) == 0 {
|
||||
return nil
|
||||
}
|
||||
sli := make([]string, len(urls))
|
||||
for i, u := range urls {
|
||||
sli[i] = u.String()
|
||||
}
|
||||
return sli
|
||||
}
|
||||
|
||||
func emailsToStrings(addrs []mail.Address) []string {
|
||||
if len(addrs) == 0 {
|
||||
return nil
|
||||
}
|
||||
sli := make([]string, len(addrs))
|
||||
for i, addr := range addrs {
|
||||
sli[i] = addr.String()
|
||||
}
|
||||
return sli
|
||||
}
|
||||
|
||||
// Valid determines if a ClientMetadata conforms with the OIDC specification.
|
||||
//
|
||||
// Valid is called by UnmarshalJSON.
|
||||
//
|
||||
// NOTE(ericchiang): For development purposes Valid does not mandate 'https' for
|
||||
// URLs fields where the OIDC spec requires it. This may change in future releases
|
||||
// of this package. See: https://github.com/coreos/go-oidc/issues/34
|
||||
func (m *ClientMetadata) Valid() error {
|
||||
if len(m.RedirectURIs) == 0 {
|
||||
return errors.New("zero redirect URLs")
|
||||
}
|
||||
|
||||
validURI := func(u *url.URL, fieldName string) error {
|
||||
if u.Host == "" {
|
||||
return fmt.Errorf("no host for uri field %s", fieldName)
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return fmt.Errorf("uri field %s scheme is not http or https", fieldName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
uris := []struct {
|
||||
val *url.URL
|
||||
name string
|
||||
}{
|
||||
{m.LogoURI, "logo_uri"},
|
||||
{m.ClientURI, "client_uri"},
|
||||
{m.PolicyURI, "policy_uri"},
|
||||
{m.TermsOfServiceURI, "tos_uri"},
|
||||
{m.JWKSURI, "jwks_uri"},
|
||||
{m.SectorIdentifierURI, "sector_identifier_uri"},
|
||||
{m.InitiateLoginURI, "initiate_login_uri"},
|
||||
}
|
||||
|
||||
for _, uri := range uris {
|
||||
if uri.val == nil {
|
||||
continue
|
||||
}
|
||||
if err := validURI(uri.val, uri.name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
uriLists := []struct {
|
||||
vals []url.URL
|
||||
name string
|
||||
}{
|
||||
{m.RedirectURIs, "redirect_uris"},
|
||||
{m.RequestURIs, "request_uris"},
|
||||
}
|
||||
for _, list := range uriLists {
|
||||
for _, uri := range list.vals {
|
||||
if err := validURI(&uri, list.name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
options := []struct {
|
||||
option JWAOptions
|
||||
name string
|
||||
}{
|
||||
{m.IDTokenResponseOptions, "id_token response"},
|
||||
{m.UserInfoResponseOptions, "userinfo response"},
|
||||
{m.RequestObjectOptions, "request_object"},
|
||||
}
|
||||
for _, option := range options {
|
||||
if err := option.option.valid(); err != nil {
|
||||
return fmt.Errorf("invalid JWA values for %s: %v", option.name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ClientRegistrationResponse struct {
|
||||
ClientID string // Required
|
||||
ClientSecret string
|
||||
RegistrationAccessToken string
|
||||
RegistrationClientURI string
|
||||
// If IsZero is true, unspecified.
|
||||
ClientIDIssuedAt time.Time
|
||||
// Time at which the client_secret will expire.
|
||||
// If IsZero is true, it will not expire.
|
||||
ClientSecretExpiresAt time.Time
|
||||
|
||||
ClientMetadata
|
||||
}
|
||||
|
||||
type encodableClientRegistrationResponse struct {
|
||||
ClientID string `json:"client_id"` // Required
|
||||
ClientSecret string `json:"client_secret,omitempty"`
|
||||
RegistrationAccessToken string `json:"registration_access_token,omitempty"`
|
||||
RegistrationClientURI string `json:"registration_client_uri,omitempty"`
|
||||
ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"`
|
||||
// Time at which the client_secret will expire, in seconds since the epoch.
|
||||
// If 0 it will not expire.
|
||||
ClientSecretExpiresAt int64 `json:"client_secret_expires_at"` // Required
|
||||
|
||||
encodableClientMetadata
|
||||
}
|
||||
|
||||
func unixToSec(t time.Time) int64 {
|
||||
if t.IsZero() {
|
||||
return 0
|
||||
}
|
||||
return t.Unix()
|
||||
}
|
||||
|
||||
func (c *ClientRegistrationResponse) MarshalJSON() ([]byte, error) {
|
||||
e := encodableClientRegistrationResponse{
|
||||
ClientID: c.ClientID,
|
||||
ClientSecret: c.ClientSecret,
|
||||
RegistrationAccessToken: c.RegistrationAccessToken,
|
||||
RegistrationClientURI: c.RegistrationClientURI,
|
||||
ClientIDIssuedAt: unixToSec(c.ClientIDIssuedAt),
|
||||
ClientSecretExpiresAt: unixToSec(c.ClientSecretExpiresAt),
|
||||
encodableClientMetadata: c.ClientMetadata.toEncodableStruct(),
|
||||
}
|
||||
return json.Marshal(&e)
|
||||
}
|
||||
|
||||
func secToUnix(sec int64) time.Time {
|
||||
if sec == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Unix(sec, 0)
|
||||
}
|
||||
|
||||
func (c *ClientRegistrationResponse) UnmarshalJSON(data []byte) error {
|
||||
var e encodableClientRegistrationResponse
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
if e.ClientID == "" {
|
||||
return errors.New("no client_id in client registration response")
|
||||
}
|
||||
metadata, err := e.encodableClientMetadata.toStruct()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*c = ClientRegistrationResponse{
|
||||
ClientID: e.ClientID,
|
||||
ClientSecret: e.ClientSecret,
|
||||
RegistrationAccessToken: e.RegistrationAccessToken,
|
||||
RegistrationClientURI: e.RegistrationClientURI,
|
||||
ClientIDIssuedAt: secToUnix(e.ClientIDIssuedAt),
|
||||
ClientSecretExpiresAt: secToUnix(e.ClientSecretExpiresAt),
|
||||
ClientMetadata: metadata,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ClientConfig struct {
|
||||
HTTPClient phttp.Client
|
||||
Credentials ClientCredentials
|
||||
Scope []string
|
||||
RedirectURL string
|
||||
ProviderConfig ProviderConfig
|
||||
KeySet key.PublicKeySet
|
||||
}
|
||||
|
||||
func NewClient(cfg ClientConfig) (*Client, error) {
|
||||
// Allow empty redirect URL in the case where the client
|
||||
// only needs to verify a given token.
|
||||
ru, err := url.Parse(cfg.RedirectURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid redirect URL: %v", err)
|
||||
}
|
||||
|
||||
c := Client{
|
||||
credentials: cfg.Credentials,
|
||||
httpClient: cfg.HTTPClient,
|
||||
scope: cfg.Scope,
|
||||
redirectURL: ru.String(),
|
||||
providerConfig: newProviderConfigRepo(cfg.ProviderConfig),
|
||||
keySet: cfg.KeySet,
|
||||
}
|
||||
|
||||
if c.httpClient == nil {
|
||||
c.httpClient = http.DefaultClient
|
||||
}
|
||||
|
||||
if c.scope == nil {
|
||||
c.scope = make([]string, len(DefaultScope))
|
||||
copy(c.scope, DefaultScope)
|
||||
}
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
httpClient phttp.Client
|
||||
providerConfig *providerConfigRepo
|
||||
credentials ClientCredentials
|
||||
redirectURL string
|
||||
scope []string
|
||||
keySet key.PublicKeySet
|
||||
providerSyncer *ProviderConfigSyncer
|
||||
|
||||
keySetSyncMutex sync.RWMutex
|
||||
lastKeySetSync time.Time
|
||||
}
|
||||
|
||||
func (c *Client) Healthy() error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
cfg := c.providerConfig.Get()
|
||||
|
||||
if cfg.Empty() {
|
||||
return errors.New("oidc client provider config empty")
|
||||
}
|
||||
|
||||
if !cfg.ExpiresAt.IsZero() && cfg.ExpiresAt.Before(now) {
|
||||
return errors.New("oidc client provider config expired")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) OAuthClient() (*oauth2.Client, error) {
|
||||
cfg := c.providerConfig.Get()
|
||||
authMethod, err := chooseAuthMethod(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ocfg := oauth2.Config{
|
||||
Credentials: oauth2.ClientCredentials(c.credentials),
|
||||
RedirectURL: c.redirectURL,
|
||||
AuthURL: cfg.AuthEndpoint.String(),
|
||||
TokenURL: cfg.TokenEndpoint.String(),
|
||||
Scope: c.scope,
|
||||
AuthMethod: authMethod,
|
||||
}
|
||||
|
||||
return oauth2.NewClient(c.httpClient, ocfg)
|
||||
}
|
||||
|
||||
func chooseAuthMethod(cfg ProviderConfig) (string, error) {
|
||||
if len(cfg.TokenEndpointAuthMethodsSupported) == 0 {
|
||||
return oauth2.AuthMethodClientSecretBasic, nil
|
||||
}
|
||||
|
||||
for _, authMethod := range cfg.TokenEndpointAuthMethodsSupported {
|
||||
if _, ok := supportedAuthMethods[authMethod]; ok {
|
||||
return authMethod, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("no supported auth methods")
|
||||
}
|
||||
|
||||
// SyncProviderConfig starts the provider config syncer
|
||||
func (c *Client) SyncProviderConfig(discoveryURL string) chan struct{} {
|
||||
r := NewHTTPProviderConfigGetter(c.httpClient, discoveryURL)
|
||||
s := NewProviderConfigSyncer(r, c.providerConfig)
|
||||
stop := s.Run()
|
||||
s.WaitUntilInitialSync()
|
||||
return stop
|
||||
}
|
||||
|
||||
func (c *Client) maybeSyncKeys() error {
|
||||
tooSoon := func() bool {
|
||||
return time.Now().UTC().Before(c.lastKeySetSync.Add(keySyncWindow))
|
||||
}
|
||||
|
||||
// ignore request to sync keys if a sync operation has been
|
||||
// attempted too recently
|
||||
if tooSoon() {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.keySetSyncMutex.Lock()
|
||||
defer c.keySetSyncMutex.Unlock()
|
||||
|
||||
// check again, as another goroutine may have been holding
|
||||
// the lock while updating the keys
|
||||
if tooSoon() {
|
||||
return nil
|
||||
}
|
||||
|
||||
cfg := c.providerConfig.Get()
|
||||
r := NewRemotePublicKeyRepo(c.httpClient, cfg.KeysEndpoint.String())
|
||||
w := &clientKeyRepo{client: c}
|
||||
_, err := key.Sync(r, w)
|
||||
c.lastKeySetSync = time.Now().UTC()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
type clientKeyRepo struct {
|
||||
client *Client
|
||||
}
|
||||
|
||||
func (r *clientKeyRepo) Set(ks key.KeySet) error {
|
||||
pks, ok := ks.(*key.PublicKeySet)
|
||||
if !ok {
|
||||
return errors.New("unable to cast to PublicKey")
|
||||
}
|
||||
r.client.keySet = *pks
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) ClientCredsToken(scope []string) (jose.JWT, error) {
|
||||
cfg := c.providerConfig.Get()
|
||||
|
||||
if !cfg.SupportsGrantType(oauth2.GrantTypeClientCreds) {
|
||||
return jose.JWT{}, fmt.Errorf("%v grant type is not supported", oauth2.GrantTypeClientCreds)
|
||||
}
|
||||
|
||||
oac, err := c.OAuthClient()
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
t, err := oac.ClientCredsToken(scope)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
jwt, err := jose.ParseJWT(t.IDToken)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
return jwt, c.VerifyJWT(jwt)
|
||||
}
|
||||
|
||||
// ExchangeAuthCode exchanges an OAuth2 auth code for an OIDC JWT ID token.
|
||||
func (c *Client) ExchangeAuthCode(code string) (jose.JWT, error) {
|
||||
oac, err := c.OAuthClient()
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
t, err := oac.RequestToken(oauth2.GrantTypeAuthCode, code)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
jwt, err := jose.ParseJWT(t.IDToken)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
return jwt, c.VerifyJWT(jwt)
|
||||
}
|
||||
|
||||
// RefreshToken uses a refresh token to exchange for a new OIDC JWT ID Token.
|
||||
func (c *Client) RefreshToken(refreshToken string) (jose.JWT, error) {
|
||||
oac, err := c.OAuthClient()
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
t, err := oac.RequestToken(oauth2.GrantTypeRefreshToken, refreshToken)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
jwt, err := jose.ParseJWT(t.IDToken)
|
||||
if err != nil {
|
||||
return jose.JWT{}, err
|
||||
}
|
||||
|
||||
return jwt, c.VerifyJWT(jwt)
|
||||
}
|
||||
|
||||
func (c *Client) VerifyJWT(jwt jose.JWT) error {
|
||||
var keysFunc func() []key.PublicKey
|
||||
if kID, ok := jwt.KeyID(); ok {
|
||||
keysFunc = c.keysFuncWithID(kID)
|
||||
} else {
|
||||
keysFunc = c.keysFuncAll()
|
||||
}
|
||||
|
||||
v := NewJWTVerifier(
|
||||
c.providerConfig.Get().Issuer.String(),
|
||||
c.credentials.ID,
|
||||
c.maybeSyncKeys, keysFunc)
|
||||
|
||||
return v.Verify(jwt)
|
||||
}
|
||||
|
||||
// keysFuncWithID returns a function that retrieves at most unexpired
|
||||
// public key from the Client that matches the provided ID
|
||||
func (c *Client) keysFuncWithID(kID string) func() []key.PublicKey {
|
||||
return func() []key.PublicKey {
|
||||
c.keySetSyncMutex.RLock()
|
||||
defer c.keySetSyncMutex.RUnlock()
|
||||
|
||||
if c.keySet.ExpiresAt().Before(time.Now()) {
|
||||
return []key.PublicKey{}
|
||||
}
|
||||
|
||||
k := c.keySet.Key(kID)
|
||||
if k == nil {
|
||||
return []key.PublicKey{}
|
||||
}
|
||||
|
||||
return []key.PublicKey{*k}
|
||||
}
|
||||
}
|
||||
|
||||
// keysFuncAll returns a function that retrieves all unexpired public
|
||||
// keys from the Client
|
||||
func (c *Client) keysFuncAll() func() []key.PublicKey {
|
||||
return func() []key.PublicKey {
|
||||
c.keySetSyncMutex.RLock()
|
||||
defer c.keySetSyncMutex.RUnlock()
|
||||
|
||||
if c.keySet.ExpiresAt().Before(time.Now()) {
|
||||
return []key.PublicKey{}
|
||||
}
|
||||
|
||||
return c.keySet.Keys()
|
||||
}
|
||||
}
|
||||
|
||||
type providerConfigRepo struct {
|
||||
mu sync.RWMutex
|
||||
config ProviderConfig // do not access directly, use Get()
|
||||
}
|
||||
|
||||
func newProviderConfigRepo(pc ProviderConfig) *providerConfigRepo {
|
||||
return &providerConfigRepo{sync.RWMutex{}, pc}
|
||||
}
|
||||
|
||||
// returns an error to implement ProviderConfigSetter
|
||||
func (r *providerConfigRepo) Set(cfg ProviderConfig) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.config = cfg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *providerConfigRepo) Get() ProviderConfig {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.config
|
||||
}
|
81
vendor/github.com/coreos/go-oidc/oidc/client_race_test.go
generated
vendored
Normal file
81
vendor/github.com/coreos/go-oidc/oidc/client_race_test.go
generated
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
// This file contains tests which depend on the race detector being enabled.
|
||||
// +build race
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testProvider struct {
|
||||
baseURL *url.URL
|
||||
}
|
||||
|
||||
func (p *testProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != discoveryConfigPath {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
cfg := ProviderConfig{
|
||||
Issuer: p.baseURL,
|
||||
ExpiresAt: time.Now().Add(time.Second),
|
||||
}
|
||||
cfg = fillRequiredProviderFields(cfg)
|
||||
json.NewEncoder(w).Encode(&cfg)
|
||||
}
|
||||
|
||||
// This test fails by triggering the race detector, not by calling t.Error or t.Fatal.
|
||||
func TestProviderSyncRace(t *testing.T) {
|
||||
|
||||
prov := &testProvider{}
|
||||
|
||||
s := httptest.NewServer(prov)
|
||||
defer s.Close()
|
||||
u, err := url.Parse(s.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
prov.baseURL = u
|
||||
|
||||
prevValue := minimumProviderConfigSyncInterval
|
||||
defer func() { minimumProviderConfigSyncInterval = prevValue }()
|
||||
|
||||
// Reduce the sync interval to increase the write frequencey.
|
||||
minimumProviderConfigSyncInterval = 5 * time.Millisecond
|
||||
|
||||
cliCfg := ClientConfig{
|
||||
HTTPClient: http.DefaultClient,
|
||||
}
|
||||
cli, err := NewClient(cliCfg)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if !cli.providerConfig.Get().Empty() {
|
||||
t.Errorf("want c.ProviderConfig == nil, got c.ProviderConfig=%#v")
|
||||
}
|
||||
|
||||
// SyncProviderConfig beings a goroutine which writes to the client's provider config.
|
||||
c := cli.SyncProviderConfig(s.URL)
|
||||
if cli.providerConfig.Get().Empty() {
|
||||
t.Errorf("want c.ProviderConfig != nil")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// stop the background process
|
||||
c <- struct{}{}
|
||||
}()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
// Creating an OAuth client reads from the provider config.
|
||||
cli.OAuthClient()
|
||||
}
|
||||
}
|
654
vendor/github.com/coreos/go-oidc/oidc/client_test.go
generated
vendored
Normal file
654
vendor/github.com/coreos/go-oidc/oidc/client_test.go
generated
vendored
Normal file
@@ -0,0 +1,654 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
)
|
||||
|
||||
func TestNewClientScopeDefault(t *testing.T) {
|
||||
tests := []struct {
|
||||
c ClientConfig
|
||||
e []string
|
||||
}{
|
||||
{
|
||||
// No scope
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect"},
|
||||
e: DefaultScope,
|
||||
},
|
||||
{
|
||||
// Nil scope
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: nil},
|
||||
e: DefaultScope,
|
||||
},
|
||||
{
|
||||
// Empty scope
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{}},
|
||||
e: []string{},
|
||||
},
|
||||
{
|
||||
// Custom scope equal to default
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"openid", "email", "profile"}},
|
||||
e: DefaultScope,
|
||||
},
|
||||
{
|
||||
// Custom scope not including defaults
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"foo", "bar"}},
|
||||
e: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
// Custom scopes overlapping with defaults
|
||||
c: ClientConfig{RedirectURL: "http://example.com/redirect", Scope: []string{"openid", "foo"}},
|
||||
e: []string{"openid", "foo"},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
c, err := NewClient(tt.c)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error from NewClient: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(tt.e, c.scope) {
|
||||
t.Errorf("case %d: want: %v, got: %v", i, tt.e, c.scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthy(t *testing.T) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
tests := []struct {
|
||||
p ProviderConfig
|
||||
h bool
|
||||
}{
|
||||
// all ok
|
||||
{
|
||||
p: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "http", Host: "example.com"},
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
},
|
||||
h: true,
|
||||
},
|
||||
// zero-value ProviderConfig.ExpiresAt
|
||||
{
|
||||
p: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "http", Host: "example.com"},
|
||||
},
|
||||
h: true,
|
||||
},
|
||||
// expired ProviderConfig
|
||||
{
|
||||
p: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "http", Host: "example.com"},
|
||||
ExpiresAt: now.Add(time.Hour * -1),
|
||||
},
|
||||
h: false,
|
||||
},
|
||||
// empty ProviderConfig
|
||||
{
|
||||
p: ProviderConfig{},
|
||||
h: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
c := &Client{providerConfig: newProviderConfigRepo(tt.p)}
|
||||
err := c.Healthy()
|
||||
want := tt.h
|
||||
got := (err == nil)
|
||||
|
||||
if want != got {
|
||||
t.Errorf("case %d: want: healthy=%v, got: healhty=%v, err: %v", i, want, got, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientKeysFuncAll(t *testing.T) {
|
||||
priv1, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
|
||||
priv2, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
future := now.Add(time.Hour)
|
||||
past := now.Add(-1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
keySet *key.PublicKeySet
|
||||
want []key.PublicKey
|
||||
}{
|
||||
// two keys, non-expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
|
||||
want: []key.PublicKey{*key.NewPublicKey(priv2.JWK()), *key.NewPublicKey(priv1.JWK())},
|
||||
},
|
||||
|
||||
// no keys, non-expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{}, future),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
|
||||
// two keys, expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, past),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
|
||||
// no keys, expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{}, past),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
var c Client
|
||||
c.keySet = *tt.keySet
|
||||
keysFunc := c.keysFuncAll()
|
||||
got := keysFunc()
|
||||
if !reflect.DeepEqual(tt.want, got) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientKeysFuncWithID(t *testing.T) {
|
||||
priv1, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
|
||||
priv2, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
future := now.Add(time.Hour)
|
||||
past := now.Add(-1 * time.Hour)
|
||||
|
||||
tests := []struct {
|
||||
keySet *key.PublicKeySet
|
||||
argID string
|
||||
want []key.PublicKey
|
||||
}{
|
||||
// two keys, match, non-expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
|
||||
argID: priv2.ID(),
|
||||
want: []key.PublicKey{*key.NewPublicKey(priv2.JWK())},
|
||||
},
|
||||
|
||||
// two keys, no match, non-expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, future),
|
||||
argID: "XXX",
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
|
||||
// no keys, no match, non-expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{}, future),
|
||||
argID: priv2.ID(),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
|
||||
// two keys, match, expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{priv2.JWK(), priv1.JWK()}, past),
|
||||
argID: priv2.ID(),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
|
||||
// no keys, no match, expired set
|
||||
{
|
||||
keySet: key.NewPublicKeySet([]jose.JWK{}, past),
|
||||
argID: priv2.ID(),
|
||||
want: []key.PublicKey{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
var c Client
|
||||
c.keySet = *tt.keySet
|
||||
keysFunc := c.keysFuncWithID(tt.argID)
|
||||
got := keysFunc()
|
||||
if !reflect.DeepEqual(tt.want, got) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMetadataValid(t *testing.T) {
|
||||
tests := []ClientMetadata{
|
||||
// one RedirectURL
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}},
|
||||
},
|
||||
|
||||
// one RedirectURL w/ nonempty path
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: "example.com", Path: "/foo"}},
|
||||
},
|
||||
|
||||
// two RedirectURIs
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "foo.example.com"},
|
||||
url.URL{Scheme: "http", Host: "bar.example.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
if err := tt.Valid(); err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMetadataInvalid(t *testing.T) {
|
||||
tests := []ClientMetadata{
|
||||
// nil RedirectURls slice
|
||||
ClientMetadata{
|
||||
RedirectURIs: nil,
|
||||
},
|
||||
|
||||
// empty RedirectURIs slice
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{},
|
||||
},
|
||||
|
||||
// empty url.URL
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{}},
|
||||
},
|
||||
|
||||
// empty url.URL following OK item
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}, url.URL{}},
|
||||
},
|
||||
|
||||
// url.URL with empty Host
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: ""}},
|
||||
},
|
||||
|
||||
// url.URL with empty Scheme
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "", Host: "example.com"}},
|
||||
},
|
||||
|
||||
// url.URL with non-HTTP(S) Scheme
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "tcp", Host: "127.0.0.1"}},
|
||||
},
|
||||
|
||||
// EncryptionEnc without EncryptionAlg
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}},
|
||||
IDTokenResponseOptions: JWAOptions{
|
||||
EncryptionEnc: "A128CBC-HS256",
|
||||
},
|
||||
},
|
||||
|
||||
// List of URIs with one empty element
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{url.URL{Scheme: "http", Host: "example.com"}},
|
||||
RequestURIs: []url.URL{
|
||||
url.URL{Scheme: "http", Host: "example.com"},
|
||||
url.URL{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
if err := tt.Valid(); err == nil {
|
||||
t.Errorf("case %d: expected non-nil error", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChooseAuthMethod(t *testing.T) {
|
||||
tests := []struct {
|
||||
supported []string
|
||||
chosen string
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
supported: []string{},
|
||||
chosen: oauth2.AuthMethodClientSecretBasic,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretBasic},
|
||||
chosen: oauth2.AuthMethodClientSecretBasic,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretPost},
|
||||
chosen: oauth2.AuthMethodClientSecretPost,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretPost, oauth2.AuthMethodClientSecretBasic},
|
||||
chosen: oauth2.AuthMethodClientSecretPost,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretBasic, oauth2.AuthMethodClientSecretPost},
|
||||
chosen: oauth2.AuthMethodClientSecretBasic,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretJWT, oauth2.AuthMethodClientSecretPost},
|
||||
chosen: oauth2.AuthMethodClientSecretPost,
|
||||
},
|
||||
{
|
||||
supported: []string{oauth2.AuthMethodClientSecretJWT},
|
||||
chosen: "",
|
||||
err: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
cfg := ProviderConfig{
|
||||
TokenEndpointAuthMethodsSupported: tt.supported,
|
||||
}
|
||||
got, err := chooseAuthMethod(cfg)
|
||||
if tt.err {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: expected non-nil err", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if got != tt.chosen {
|
||||
t.Errorf("case %d: want=%q, got=%q", i, tt.chosen, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMetadataUnmarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
data string
|
||||
want ClientMetadata
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
`{"redirect_uris":["https://example.com"]}`,
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com"},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
// redirect_uris required
|
||||
`{}`,
|
||||
ClientMetadata{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
// must have at least one redirect_uris
|
||||
`{"redirect_uris":[]}`,
|
||||
ClientMetadata{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
`{"redirect_uris":["https://example.com"],"contacts":["Ms. Foo <foo@example.com>"]}`,
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com"},
|
||||
},
|
||||
Contacts: []mail.Address{
|
||||
{Name: "Ms. Foo", Address: "foo@example.com"},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
// invalid URI provided for field
|
||||
`{"redirect_uris":["https://example.com"],"logo_uri":"not a valid uri"}`,
|
||||
ClientMetadata{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
// logo_uri can't be a list
|
||||
`{"redirect_uris":["https://example.com"],"logo_uri":["https://example.com/logo"]}`,
|
||||
ClientMetadata{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
`{
|
||||
"redirect_uris":["https://example.com"],
|
||||
"userinfo_encrypted_response_alg":"RSA1_5",
|
||||
"userinfo_encrypted_response_enc":"A128CBC-HS256",
|
||||
"contacts": [
|
||||
"jane doe <jane.doe@example.com>", "john doe <john.doe@example.com>"
|
||||
]
|
||||
}`,
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com"},
|
||||
},
|
||||
UserInfoResponseOptions: JWAOptions{
|
||||
EncryptionAlg: "RSA1_5",
|
||||
EncryptionEnc: "A128CBC-HS256",
|
||||
},
|
||||
Contacts: []mail.Address{
|
||||
{Name: "jane doe", Address: "jane.doe@example.com"},
|
||||
{Name: "john doe", Address: "john.doe@example.com"},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
// If encrypted_response_enc is provided encrypted_response_alg must also be.
|
||||
`{
|
||||
"redirect_uris":["https://example.com"],
|
||||
"userinfo_encrypted_response_enc":"A128CBC-HS256"
|
||||
}`,
|
||||
ClientMetadata{},
|
||||
true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
var got ClientMetadata
|
||||
if err := got.UnmarshalJSON([]byte(tt.data)); err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Errorf("case %d: unmarshal failed: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tt.wantErr {
|
||||
t.Errorf("case %d: expected unmarshal to produce error", i)
|
||||
continue
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(tt.want, got); diff != "" {
|
||||
t.Errorf("case %d: results not equal: %s", i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMetadataMarshal(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
metadata ClientMetadata
|
||||
want string
|
||||
}{
|
||||
{
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
`{"redirect_uris":["https://example.com/callback"]}`,
|
||||
},
|
||||
{
|
||||
ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com", Path: "/callback"},
|
||||
},
|
||||
RequestObjectOptions: JWAOptions{
|
||||
EncryptionAlg: "RSA1_5",
|
||||
EncryptionEnc: "A128CBC-HS256",
|
||||
},
|
||||
},
|
||||
`{"redirect_uris":["https://example.com/callback"],"request_object_encryption_alg":"RSA1_5","request_object_encryption_enc":"A128CBC-HS256"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got, err := json.Marshal(&tt.metadata)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to marshal metadata: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if string(got) != tt.want {
|
||||
t.Errorf("case %d: marshaled string did not match expected string", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMetadataMarshalRoundTrip(t *testing.T) {
|
||||
tests := []ClientMetadata{
|
||||
{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com", Path: "/callback"},
|
||||
},
|
||||
LogoURI: &url.URL{Scheme: "https", Host: "example.com", Path: "/logo"},
|
||||
RequestObjectOptions: JWAOptions{
|
||||
EncryptionAlg: "RSA1_5",
|
||||
EncryptionEnc: "A128CBC-HS256",
|
||||
},
|
||||
ApplicationType: "native",
|
||||
TokenEndpointAuthMethod: "client_secret_basic",
|
||||
},
|
||||
}
|
||||
|
||||
for i, want := range tests {
|
||||
data, err := json.Marshal(&want)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to marshal metadata: %v", i, err)
|
||||
continue
|
||||
}
|
||||
var got ClientMetadata
|
||||
if err := json.Unmarshal(data, &got); err != nil {
|
||||
t.Errorf("case %d: failed to unmarshal metadata: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if diff := pretty.Compare(want, got); diff != "" {
|
||||
t.Errorf("case %d: struct did not survive a marshaling round trip: %s", i, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientRegistrationResponseUnmarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
data string
|
||||
want ClientRegistrationResponse
|
||||
wantErr bool
|
||||
secretExpires bool
|
||||
}{
|
||||
{
|
||||
`{
|
||||
"client_id":"foo",
|
||||
"client_secret":"bar",
|
||||
"client_secret_expires_at": 1577858400,
|
||||
"redirect_uris":[
|
||||
"https://client.example.org/callback",
|
||||
"https://client.example.org/callback2"
|
||||
],
|
||||
"client_name":"my_example"
|
||||
}`,
|
||||
ClientRegistrationResponse{
|
||||
ClientID: "foo",
|
||||
ClientSecret: "bar",
|
||||
ClientSecretExpiresAt: time.Unix(1577858400, 0),
|
||||
ClientMetadata: ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "client.example.org", Path: "/callback"},
|
||||
{Scheme: "https", Host: "client.example.org", Path: "/callback2"},
|
||||
},
|
||||
ClientName: "my_example",
|
||||
},
|
||||
},
|
||||
false,
|
||||
true,
|
||||
},
|
||||
{
|
||||
`{
|
||||
"client_id":"foo",
|
||||
"client_secret_expires_at": 0,
|
||||
"redirect_uris":[
|
||||
"https://client.example.org/callback",
|
||||
"https://client.example.org/callback2"
|
||||
],
|
||||
"client_name":"my_example"
|
||||
}`,
|
||||
ClientRegistrationResponse{
|
||||
ClientID: "foo",
|
||||
ClientMetadata: ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "client.example.org", Path: "/callback"},
|
||||
{Scheme: "https", Host: "client.example.org", Path: "/callback2"},
|
||||
},
|
||||
ClientName: "my_example",
|
||||
},
|
||||
},
|
||||
false,
|
||||
false,
|
||||
},
|
||||
{
|
||||
// no client id
|
||||
`{
|
||||
"client_secret_expires_at": 0,
|
||||
"redirect_uris":[
|
||||
"https://client.example.org/callback",
|
||||
"https://client.example.org/callback2"
|
||||
],
|
||||
"client_name":"my_example"
|
||||
}`,
|
||||
ClientRegistrationResponse{},
|
||||
true,
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
var got ClientRegistrationResponse
|
||||
if err := json.Unmarshal([]byte(tt.data), &got); err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Errorf("case %d: unmarshal failed: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
t.Errorf("case %d: expected unmarshal to produce error", i)
|
||||
continue
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(tt.want, got); diff != "" {
|
||||
t.Errorf("case %d: results not equal: %s", i, diff)
|
||||
}
|
||||
if tt.secretExpires && got.ClientSecretExpiresAt.IsZero() {
|
||||
t.Errorf("case %d: expected client_secret to expire, but it doesn't", i)
|
||||
} else if !tt.secretExpires && !got.ClientSecretExpiresAt.IsZero() {
|
||||
t.Errorf("case %d: expected client_secret to not expire, but it does", i)
|
||||
}
|
||||
}
|
||||
}
|
2
vendor/github.com/coreos/go-oidc/oidc/doc.go
generated
vendored
Normal file
2
vendor/github.com/coreos/go-oidc/oidc/doc.go
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
// Package oidc is DEPRECATED. Use github.com/coreos/go-oidc instead.
|
||||
package oidc
|
44
vendor/github.com/coreos/go-oidc/oidc/identity.go
generated
vendored
Normal file
44
vendor/github.com/coreos/go-oidc/oidc/identity.go
generated
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
type Identity struct {
|
||||
ID string
|
||||
Name string
|
||||
Email string
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
func IdentityFromClaims(claims jose.Claims) (*Identity, error) {
|
||||
if claims == nil {
|
||||
return nil, errors.New("nil claim set")
|
||||
}
|
||||
|
||||
var ident Identity
|
||||
var err error
|
||||
var ok bool
|
||||
|
||||
if ident.ID, ok, err = claims.StringClaim("sub"); err != nil {
|
||||
return nil, err
|
||||
} else if !ok {
|
||||
return nil, errors.New("missing required claim: sub")
|
||||
}
|
||||
|
||||
if ident.Email, _, err = claims.StringClaim("email"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exp, ok, err := claims.TimeClaim("exp")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
ident.ExpiresAt = exp
|
||||
}
|
||||
|
||||
return &ident, nil
|
||||
}
|
113
vendor/github.com/coreos/go-oidc/oidc/identity_test.go
generated
vendored
Normal file
113
vendor/github.com/coreos/go-oidc/oidc/identity_test.go
generated
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
func TestIdentityFromClaims(t *testing.T) {
|
||||
tests := []struct {
|
||||
claims jose.Claims
|
||||
want Identity
|
||||
}{
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"email": "elroy@example.com",
|
||||
"exp": float64(1.416935146e+09),
|
||||
},
|
||||
want: Identity{
|
||||
ID: "123850281",
|
||||
Name: "",
|
||||
Email: "elroy@example.com",
|
||||
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"exp": float64(1.416935146e+09),
|
||||
},
|
||||
want: Identity{
|
||||
ID: "123850281",
|
||||
Name: "",
|
||||
Email: "",
|
||||
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"email": "elroy@example.com",
|
||||
"exp": int64(1416935146),
|
||||
},
|
||||
want: Identity{
|
||||
ID: "123850281",
|
||||
Name: "",
|
||||
Email: "elroy@example.com",
|
||||
ExpiresAt: time.Date(2014, time.November, 25, 17, 05, 46, 0, time.UTC),
|
||||
},
|
||||
},
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"email": "elroy@example.com",
|
||||
},
|
||||
want: Identity{
|
||||
ID: "123850281",
|
||||
Name: "",
|
||||
Email: "elroy@example.com",
|
||||
ExpiresAt: time.Time{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got, err := IdentityFromClaims(tt.claims)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(tt.want, *got) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, *got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentityFromClaimsFail(t *testing.T) {
|
||||
tests := []jose.Claims{
|
||||
// sub incorrect type
|
||||
jose.Claims{
|
||||
"sub": 123,
|
||||
"name": "foo",
|
||||
"email": "elroy@example.com",
|
||||
},
|
||||
// email incorrect type
|
||||
jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"email": false,
|
||||
},
|
||||
// exp incorrect type
|
||||
jose.Claims{
|
||||
"sub": "123850281",
|
||||
"name": "Elroy",
|
||||
"email": "elroy@example.com",
|
||||
"exp": "2014-11-25 18:05:46 +0000 UTC",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
_, err := IdentityFromClaims(tt)
|
||||
if err == nil {
|
||||
t.Errorf("case %d: expected non-nil error", i)
|
||||
}
|
||||
}
|
||||
}
|
3
vendor/github.com/coreos/go-oidc/oidc/interface.go
generated
vendored
Normal file
3
vendor/github.com/coreos/go-oidc/oidc/interface.go
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
package oidc
|
||||
|
||||
type LoginFunc func(ident Identity, sessionKey string) (redirectURL string, err error)
|
67
vendor/github.com/coreos/go-oidc/oidc/key.go
generated
vendored
Executable file
67
vendor/github.com/coreos/go-oidc/oidc/key.go
generated
vendored
Executable file
@@ -0,0 +1,67 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
phttp "github.com/coreos/go-oidc/http"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
)
|
||||
|
||||
// DefaultPublicKeySetTTL is the default TTL set on the PublicKeySet if no
|
||||
// Cache-Control header is provided by the JWK Set document endpoint.
|
||||
const DefaultPublicKeySetTTL = 24 * time.Hour
|
||||
|
||||
// NewRemotePublicKeyRepo is responsible for fetching the JWK Set document.
|
||||
func NewRemotePublicKeyRepo(hc phttp.Client, ep string) *remotePublicKeyRepo {
|
||||
return &remotePublicKeyRepo{hc: hc, ep: ep}
|
||||
}
|
||||
|
||||
type remotePublicKeyRepo struct {
|
||||
hc phttp.Client
|
||||
ep string
|
||||
}
|
||||
|
||||
// Get returns a PublicKeySet fetched from the JWK Set document endpoint. A TTL
|
||||
// is set on the Key Set to avoid it having to be re-retrieved for every
|
||||
// encryption event. This TTL is typically controlled by the endpoint returning
|
||||
// a Cache-Control header, but defaults to 24 hours if no Cache-Control header
|
||||
// is found.
|
||||
func (r *remotePublicKeyRepo) Get() (key.KeySet, error) {
|
||||
req, err := http.NewRequest("GET", r.ep, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := r.hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var d struct {
|
||||
Keys []jose.JWK `json:"keys"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&d); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(d.Keys) == 0 {
|
||||
return nil, errors.New("zero keys in response")
|
||||
}
|
||||
|
||||
ttl, ok, err := phttp.Cacheable(resp.Header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
ttl = DefaultPublicKeySetTTL
|
||||
}
|
||||
|
||||
exp := time.Now().UTC().Add(ttl)
|
||||
ks := key.NewPublicKeySet(d.Keys, exp)
|
||||
return ks, nil
|
||||
}
|
690
vendor/github.com/coreos/go-oidc/oidc/provider.go
generated
vendored
Normal file
690
vendor/github.com/coreos/go-oidc/oidc/provider.go
generated
vendored
Normal file
@@ -0,0 +1,690 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/pkg/timeutil"
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
phttp "github.com/coreos/go-oidc/http"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
// Subject Identifier types defined by the OIDC spec. Specifies if the provider
|
||||
// should provide the same sub claim value to all clients (public) or a unique
|
||||
// value for each client (pairwise).
|
||||
//
|
||||
// See: http://openid.net/specs/openid-connect-core-1_0.html#SubjectIDTypes
|
||||
SubjectTypePublic = "public"
|
||||
SubjectTypePairwise = "pairwise"
|
||||
)
|
||||
|
||||
var (
|
||||
// Default values for omitted provider config fields.
|
||||
//
|
||||
// Use ProviderConfig's Defaults method to fill a provider config with these values.
|
||||
DefaultGrantTypesSupported = []string{oauth2.GrantTypeAuthCode, oauth2.GrantTypeImplicit}
|
||||
DefaultResponseModesSupported = []string{"query", "fragment"}
|
||||
DefaultTokenEndpointAuthMethodsSupported = []string{oauth2.AuthMethodClientSecretBasic}
|
||||
DefaultClaimTypesSupported = []string{"normal"}
|
||||
)
|
||||
|
||||
const (
|
||||
MaximumProviderConfigSyncInterval = 24 * time.Hour
|
||||
MinimumProviderConfigSyncInterval = time.Minute
|
||||
|
||||
discoveryConfigPath = "/.well-known/openid-configuration"
|
||||
)
|
||||
|
||||
// internally configurable for tests
|
||||
var minimumProviderConfigSyncInterval = MinimumProviderConfigSyncInterval
|
||||
|
||||
var (
|
||||
// Ensure ProviderConfig satisfies these interfaces.
|
||||
_ json.Marshaler = &ProviderConfig{}
|
||||
_ json.Unmarshaler = &ProviderConfig{}
|
||||
)
|
||||
|
||||
// ProviderConfig represents the OpenID Provider Metadata specifying what
|
||||
// configurations a provider supports.
|
||||
//
|
||||
// See: http://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||
type ProviderConfig struct {
|
||||
Issuer *url.URL // Required
|
||||
AuthEndpoint *url.URL // Required
|
||||
TokenEndpoint *url.URL // Required if grant types other than "implicit" are supported
|
||||
UserInfoEndpoint *url.URL
|
||||
KeysEndpoint *url.URL // Required
|
||||
RegistrationEndpoint *url.URL
|
||||
EndSessionEndpoint *url.URL
|
||||
CheckSessionIFrame *url.URL
|
||||
|
||||
// Servers MAY choose not to advertise some supported scope values even when this
|
||||
// parameter is used, although those defined in OpenID Core SHOULD be listed, if supported.
|
||||
ScopesSupported []string
|
||||
// OAuth2.0 response types supported.
|
||||
ResponseTypesSupported []string // Required
|
||||
// OAuth2.0 response modes supported.
|
||||
//
|
||||
// If omitted, defaults to DefaultResponseModesSupported.
|
||||
ResponseModesSupported []string
|
||||
// OAuth2.0 grant types supported.
|
||||
//
|
||||
// If omitted, defaults to DefaultGrantTypesSupported.
|
||||
GrantTypesSupported []string
|
||||
ACRValuesSupported []string
|
||||
// SubjectTypesSupported specifies strategies for providing values for the sub claim.
|
||||
SubjectTypesSupported []string // Required
|
||||
|
||||
// JWA signing and encryption algorith values supported for ID tokens.
|
||||
IDTokenSigningAlgValues []string // Required
|
||||
IDTokenEncryptionAlgValues []string
|
||||
IDTokenEncryptionEncValues []string
|
||||
|
||||
// JWA signing and encryption algorith values supported for user info responses.
|
||||
UserInfoSigningAlgValues []string
|
||||
UserInfoEncryptionAlgValues []string
|
||||
UserInfoEncryptionEncValues []string
|
||||
|
||||
// JWA signing and encryption algorith values supported for request objects.
|
||||
ReqObjSigningAlgValues []string
|
||||
ReqObjEncryptionAlgValues []string
|
||||
ReqObjEncryptionEncValues []string
|
||||
|
||||
TokenEndpointAuthMethodsSupported []string
|
||||
TokenEndpointAuthSigningAlgValuesSupported []string
|
||||
DisplayValuesSupported []string
|
||||
ClaimTypesSupported []string
|
||||
ClaimsSupported []string
|
||||
ServiceDocs *url.URL
|
||||
ClaimsLocalsSupported []string
|
||||
UILocalsSupported []string
|
||||
ClaimsParameterSupported bool
|
||||
RequestParameterSupported bool
|
||||
RequestURIParamaterSupported bool
|
||||
RequireRequestURIRegistration bool
|
||||
|
||||
Policy *url.URL
|
||||
TermsOfService *url.URL
|
||||
|
||||
// Not part of the OpenID Provider Metadata
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
// Defaults returns a shallow copy of ProviderConfig with default
|
||||
// values replacing omitted fields.
|
||||
//
|
||||
// var cfg oidc.ProviderConfig
|
||||
// // Fill provider config with default values for omitted fields.
|
||||
// cfg = cfg.Defaults()
|
||||
//
|
||||
func (p ProviderConfig) Defaults() ProviderConfig {
|
||||
setDefault := func(val *[]string, defaultVal []string) {
|
||||
if len(*val) == 0 {
|
||||
*val = defaultVal
|
||||
}
|
||||
}
|
||||
setDefault(&p.GrantTypesSupported, DefaultGrantTypesSupported)
|
||||
setDefault(&p.ResponseModesSupported, DefaultResponseModesSupported)
|
||||
setDefault(&p.TokenEndpointAuthMethodsSupported, DefaultTokenEndpointAuthMethodsSupported)
|
||||
setDefault(&p.ClaimTypesSupported, DefaultClaimTypesSupported)
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) MarshalJSON() ([]byte, error) {
|
||||
e := p.toEncodableStruct()
|
||||
return json.Marshal(&e)
|
||||
}
|
||||
|
||||
func (p *ProviderConfig) UnmarshalJSON(data []byte) error {
|
||||
var e encodableProviderConfig
|
||||
if err := json.Unmarshal(data, &e); err != nil {
|
||||
return err
|
||||
}
|
||||
conf, err := e.toStruct()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := conf.Valid(); err != nil {
|
||||
return err
|
||||
}
|
||||
*p = conf
|
||||
return nil
|
||||
}
|
||||
|
||||
type encodableProviderConfig struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthEndpoint string `json:"authorization_endpoint"`
|
||||
TokenEndpoint string `json:"token_endpoint"`
|
||||
UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"`
|
||||
KeysEndpoint string `json:"jwks_uri"`
|
||||
RegistrationEndpoint string `json:"registration_endpoint,omitempty"`
|
||||
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
|
||||
CheckSessionIFrame string `json:"check_session_iframe,omitempty"`
|
||||
|
||||
// Use 'omitempty' for all slices as per OIDC spec:
|
||||
// "Claims that return multiple values are represented as JSON arrays.
|
||||
// Claims with zero elements MUST be omitted from the response."
|
||||
// http://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationResponse
|
||||
|
||||
ScopesSupported []string `json:"scopes_supported,omitempty"`
|
||||
ResponseTypesSupported []string `json:"response_types_supported,omitempty"`
|
||||
ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
|
||||
GrantTypesSupported []string `json:"grant_types_supported,omitempty"`
|
||||
ACRValuesSupported []string `json:"acr_values_supported,omitempty"`
|
||||
SubjectTypesSupported []string `json:"subject_types_supported,omitempty"`
|
||||
|
||||
IDTokenSigningAlgValues []string `json:"id_token_signing_alg_values_supported,omitempty"`
|
||||
IDTokenEncryptionAlgValues []string `json:"id_token_encryption_alg_values_supported,omitempty"`
|
||||
IDTokenEncryptionEncValues []string `json:"id_token_encryption_enc_values_supported,omitempty"`
|
||||
UserInfoSigningAlgValues []string `json:"userinfo_signing_alg_values_supported,omitempty"`
|
||||
UserInfoEncryptionAlgValues []string `json:"userinfo_encryption_alg_values_supported,omitempty"`
|
||||
UserInfoEncryptionEncValues []string `json:"userinfo_encryption_enc_values_supported,omitempty"`
|
||||
ReqObjSigningAlgValues []string `json:"request_object_signing_alg_values_supported,omitempty"`
|
||||
ReqObjEncryptionAlgValues []string `json:"request_object_encryption_alg_values_supported,omitempty"`
|
||||
ReqObjEncryptionEncValues []string `json:"request_object_encryption_enc_values_supported,omitempty"`
|
||||
|
||||
TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"`
|
||||
TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"`
|
||||
|
||||
DisplayValuesSupported []string `json:"display_values_supported,omitempty"`
|
||||
ClaimTypesSupported []string `json:"claim_types_supported,omitempty"`
|
||||
ClaimsSupported []string `json:"claims_supported,omitempty"`
|
||||
ServiceDocs string `json:"service_documentation,omitempty"`
|
||||
ClaimsLocalsSupported []string `json:"claims_locales_supported,omitempty"`
|
||||
UILocalsSupported []string `json:"ui_locales_supported,omitempty"`
|
||||
ClaimsParameterSupported bool `json:"claims_parameter_supported,omitempty"`
|
||||
RequestParameterSupported bool `json:"request_parameter_supported,omitempty"`
|
||||
RequestURIParamaterSupported bool `json:"request_uri_parameter_supported,omitempty"`
|
||||
RequireRequestURIRegistration bool `json:"require_request_uri_registration,omitempty"`
|
||||
|
||||
Policy string `json:"op_policy_uri,omitempty"`
|
||||
TermsOfService string `json:"op_tos_uri,omitempty"`
|
||||
}
|
||||
|
||||
func (cfg ProviderConfig) toEncodableStruct() encodableProviderConfig {
|
||||
return encodableProviderConfig{
|
||||
Issuer: uriToString(cfg.Issuer),
|
||||
AuthEndpoint: uriToString(cfg.AuthEndpoint),
|
||||
TokenEndpoint: uriToString(cfg.TokenEndpoint),
|
||||
UserInfoEndpoint: uriToString(cfg.UserInfoEndpoint),
|
||||
KeysEndpoint: uriToString(cfg.KeysEndpoint),
|
||||
RegistrationEndpoint: uriToString(cfg.RegistrationEndpoint),
|
||||
EndSessionEndpoint: uriToString(cfg.EndSessionEndpoint),
|
||||
CheckSessionIFrame: uriToString(cfg.CheckSessionIFrame),
|
||||
ScopesSupported: cfg.ScopesSupported,
|
||||
ResponseTypesSupported: cfg.ResponseTypesSupported,
|
||||
ResponseModesSupported: cfg.ResponseModesSupported,
|
||||
GrantTypesSupported: cfg.GrantTypesSupported,
|
||||
ACRValuesSupported: cfg.ACRValuesSupported,
|
||||
SubjectTypesSupported: cfg.SubjectTypesSupported,
|
||||
IDTokenSigningAlgValues: cfg.IDTokenSigningAlgValues,
|
||||
IDTokenEncryptionAlgValues: cfg.IDTokenEncryptionAlgValues,
|
||||
IDTokenEncryptionEncValues: cfg.IDTokenEncryptionEncValues,
|
||||
UserInfoSigningAlgValues: cfg.UserInfoSigningAlgValues,
|
||||
UserInfoEncryptionAlgValues: cfg.UserInfoEncryptionAlgValues,
|
||||
UserInfoEncryptionEncValues: cfg.UserInfoEncryptionEncValues,
|
||||
ReqObjSigningAlgValues: cfg.ReqObjSigningAlgValues,
|
||||
ReqObjEncryptionAlgValues: cfg.ReqObjEncryptionAlgValues,
|
||||
ReqObjEncryptionEncValues: cfg.ReqObjEncryptionEncValues,
|
||||
TokenEndpointAuthMethodsSupported: cfg.TokenEndpointAuthMethodsSupported,
|
||||
TokenEndpointAuthSigningAlgValuesSupported: cfg.TokenEndpointAuthSigningAlgValuesSupported,
|
||||
DisplayValuesSupported: cfg.DisplayValuesSupported,
|
||||
ClaimTypesSupported: cfg.ClaimTypesSupported,
|
||||
ClaimsSupported: cfg.ClaimsSupported,
|
||||
ServiceDocs: uriToString(cfg.ServiceDocs),
|
||||
ClaimsLocalsSupported: cfg.ClaimsLocalsSupported,
|
||||
UILocalsSupported: cfg.UILocalsSupported,
|
||||
ClaimsParameterSupported: cfg.ClaimsParameterSupported,
|
||||
RequestParameterSupported: cfg.RequestParameterSupported,
|
||||
RequestURIParamaterSupported: cfg.RequestURIParamaterSupported,
|
||||
RequireRequestURIRegistration: cfg.RequireRequestURIRegistration,
|
||||
Policy: uriToString(cfg.Policy),
|
||||
TermsOfService: uriToString(cfg.TermsOfService),
|
||||
}
|
||||
}
|
||||
|
||||
func (e encodableProviderConfig) toStruct() (ProviderConfig, error) {
|
||||
p := stickyErrParser{}
|
||||
conf := ProviderConfig{
|
||||
Issuer: p.parseURI(e.Issuer, "issuer"),
|
||||
AuthEndpoint: p.parseURI(e.AuthEndpoint, "authorization_endpoint"),
|
||||
TokenEndpoint: p.parseURI(e.TokenEndpoint, "token_endpoint"),
|
||||
UserInfoEndpoint: p.parseURI(e.UserInfoEndpoint, "userinfo_endpoint"),
|
||||
KeysEndpoint: p.parseURI(e.KeysEndpoint, "jwks_uri"),
|
||||
RegistrationEndpoint: p.parseURI(e.RegistrationEndpoint, "registration_endpoint"),
|
||||
EndSessionEndpoint: p.parseURI(e.EndSessionEndpoint, "end_session_endpoint"),
|
||||
CheckSessionIFrame: p.parseURI(e.CheckSessionIFrame, "check_session_iframe"),
|
||||
ScopesSupported: e.ScopesSupported,
|
||||
ResponseTypesSupported: e.ResponseTypesSupported,
|
||||
ResponseModesSupported: e.ResponseModesSupported,
|
||||
GrantTypesSupported: e.GrantTypesSupported,
|
||||
ACRValuesSupported: e.ACRValuesSupported,
|
||||
SubjectTypesSupported: e.SubjectTypesSupported,
|
||||
IDTokenSigningAlgValues: e.IDTokenSigningAlgValues,
|
||||
IDTokenEncryptionAlgValues: e.IDTokenEncryptionAlgValues,
|
||||
IDTokenEncryptionEncValues: e.IDTokenEncryptionEncValues,
|
||||
UserInfoSigningAlgValues: e.UserInfoSigningAlgValues,
|
||||
UserInfoEncryptionAlgValues: e.UserInfoEncryptionAlgValues,
|
||||
UserInfoEncryptionEncValues: e.UserInfoEncryptionEncValues,
|
||||
ReqObjSigningAlgValues: e.ReqObjSigningAlgValues,
|
||||
ReqObjEncryptionAlgValues: e.ReqObjEncryptionAlgValues,
|
||||
ReqObjEncryptionEncValues: e.ReqObjEncryptionEncValues,
|
||||
TokenEndpointAuthMethodsSupported: e.TokenEndpointAuthMethodsSupported,
|
||||
TokenEndpointAuthSigningAlgValuesSupported: e.TokenEndpointAuthSigningAlgValuesSupported,
|
||||
DisplayValuesSupported: e.DisplayValuesSupported,
|
||||
ClaimTypesSupported: e.ClaimTypesSupported,
|
||||
ClaimsSupported: e.ClaimsSupported,
|
||||
ServiceDocs: p.parseURI(e.ServiceDocs, "service_documentation"),
|
||||
ClaimsLocalsSupported: e.ClaimsLocalsSupported,
|
||||
UILocalsSupported: e.UILocalsSupported,
|
||||
ClaimsParameterSupported: e.ClaimsParameterSupported,
|
||||
RequestParameterSupported: e.RequestParameterSupported,
|
||||
RequestURIParamaterSupported: e.RequestURIParamaterSupported,
|
||||
RequireRequestURIRegistration: e.RequireRequestURIRegistration,
|
||||
Policy: p.parseURI(e.Policy, "op_policy-uri"),
|
||||
TermsOfService: p.parseURI(e.TermsOfService, "op_tos_uri"),
|
||||
}
|
||||
if p.firstErr != nil {
|
||||
return ProviderConfig{}, p.firstErr
|
||||
}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// Empty returns if a ProviderConfig holds no information.
|
||||
//
|
||||
// This case generally indicates a ProviderConfigGetter has experienced an error
|
||||
// and has nothing to report.
|
||||
func (p ProviderConfig) Empty() bool {
|
||||
return p.Issuer == nil
|
||||
}
|
||||
|
||||
func contains(sli []string, ele string) bool {
|
||||
for _, s := range sli {
|
||||
if s == ele {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Valid determines if a ProviderConfig conforms with the OIDC specification.
|
||||
// If Valid returns successfully it guarantees required field are non-nil and
|
||||
// URLs are well formed.
|
||||
//
|
||||
// Valid is called by UnmarshalJSON.
|
||||
//
|
||||
// NOTE(ericchiang): For development purposes Valid does not mandate 'https' for
|
||||
// URLs fields where the OIDC spec requires it. This may change in future releases
|
||||
// of this package. See: https://github.com/coreos/go-oidc/issues/34
|
||||
func (p ProviderConfig) Valid() error {
|
||||
grantTypes := p.GrantTypesSupported
|
||||
if len(grantTypes) == 0 {
|
||||
grantTypes = DefaultGrantTypesSupported
|
||||
}
|
||||
implicitOnly := true
|
||||
for _, grantType := range grantTypes {
|
||||
if grantType != oauth2.GrantTypeImplicit {
|
||||
implicitOnly = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(p.SubjectTypesSupported) == 0 {
|
||||
return errors.New("missing required field subject_types_supported")
|
||||
}
|
||||
if len(p.IDTokenSigningAlgValues) == 0 {
|
||||
return errors.New("missing required field id_token_signing_alg_values_supported")
|
||||
}
|
||||
|
||||
if len(p.ScopesSupported) != 0 && !contains(p.ScopesSupported, "openid") {
|
||||
return errors.New("scoped_supported must be unspecified or include 'openid'")
|
||||
}
|
||||
|
||||
if !contains(p.IDTokenSigningAlgValues, "RS256") {
|
||||
return errors.New("id_token_signing_alg_values_supported must include 'RS256'")
|
||||
}
|
||||
if contains(p.TokenEndpointAuthMethodsSupported, "none") {
|
||||
return errors.New("token_endpoint_auth_signing_alg_values_supported cannot include 'none'")
|
||||
}
|
||||
|
||||
uris := []struct {
|
||||
val *url.URL
|
||||
name string
|
||||
required bool
|
||||
}{
|
||||
{p.Issuer, "issuer", true},
|
||||
{p.AuthEndpoint, "authorization_endpoint", true},
|
||||
{p.TokenEndpoint, "token_endpoint", !implicitOnly},
|
||||
{p.UserInfoEndpoint, "userinfo_endpoint", false},
|
||||
{p.KeysEndpoint, "jwks_uri", true},
|
||||
{p.RegistrationEndpoint, "registration_endpoint", false},
|
||||
{p.EndSessionEndpoint, "end_session_endpoint", false},
|
||||
{p.CheckSessionIFrame, "check_session_iframe", false},
|
||||
{p.ServiceDocs, "service_documentation", false},
|
||||
{p.Policy, "op_policy_uri", false},
|
||||
{p.TermsOfService, "op_tos_uri", false},
|
||||
}
|
||||
|
||||
for _, uri := range uris {
|
||||
if uri.val == nil {
|
||||
if !uri.required {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("empty value for required uri field %s", uri.name)
|
||||
}
|
||||
if uri.val.Host == "" {
|
||||
return fmt.Errorf("no host for uri field %s", uri.name)
|
||||
}
|
||||
if uri.val.Scheme != "http" && uri.val.Scheme != "https" {
|
||||
return fmt.Errorf("uri field %s schemeis not http or https", uri.name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Supports determines if provider supports a client given their respective metadata.
|
||||
func (p ProviderConfig) Supports(c ClientMetadata) error {
|
||||
if err := p.Valid(); err != nil {
|
||||
return fmt.Errorf("invalid provider config: %v", err)
|
||||
}
|
||||
if err := c.Valid(); err != nil {
|
||||
return fmt.Errorf("invalid client config: %v", err)
|
||||
}
|
||||
|
||||
// Fill default values for omitted fields
|
||||
c = c.Defaults()
|
||||
p = p.Defaults()
|
||||
|
||||
// Do the supported values list the requested one?
|
||||
supports := []struct {
|
||||
supported []string
|
||||
requested string
|
||||
name string
|
||||
}{
|
||||
{p.IDTokenSigningAlgValues, c.IDTokenResponseOptions.SigningAlg, "id_token_signed_response_alg"},
|
||||
{p.IDTokenEncryptionAlgValues, c.IDTokenResponseOptions.EncryptionAlg, "id_token_encryption_response_alg"},
|
||||
{p.IDTokenEncryptionEncValues, c.IDTokenResponseOptions.EncryptionEnc, "id_token_encryption_response_enc"},
|
||||
{p.UserInfoSigningAlgValues, c.UserInfoResponseOptions.SigningAlg, "userinfo_signed_response_alg"},
|
||||
{p.UserInfoEncryptionAlgValues, c.UserInfoResponseOptions.EncryptionAlg, "userinfo_encryption_response_alg"},
|
||||
{p.UserInfoEncryptionEncValues, c.UserInfoResponseOptions.EncryptionEnc, "userinfo_encryption_response_enc"},
|
||||
{p.ReqObjSigningAlgValues, c.RequestObjectOptions.SigningAlg, "request_object_signing_alg"},
|
||||
{p.ReqObjEncryptionAlgValues, c.RequestObjectOptions.EncryptionAlg, "request_object_encryption_alg"},
|
||||
{p.ReqObjEncryptionEncValues, c.RequestObjectOptions.EncryptionEnc, "request_object_encryption_enc"},
|
||||
}
|
||||
for _, field := range supports {
|
||||
if field.requested == "" {
|
||||
continue
|
||||
}
|
||||
if !contains(field.supported, field.requested) {
|
||||
return fmt.Errorf("provider does not support requested value for field %s", field.name)
|
||||
}
|
||||
}
|
||||
|
||||
stringsEqual := func(s1, s2 string) bool { return s1 == s2 }
|
||||
|
||||
// For lists, are the list of requested values a subset of the supported ones?
|
||||
supportsAll := []struct {
|
||||
supported []string
|
||||
requested []string
|
||||
name string
|
||||
// OAuth2.0 response_type can be space separated lists where order doesn't matter.
|
||||
// For example "id_token token" is the same as "token id_token"
|
||||
// Support a custom compare method.
|
||||
comp func(s1, s2 string) bool
|
||||
}{
|
||||
{p.GrantTypesSupported, c.GrantTypes, "grant_types", stringsEqual},
|
||||
{p.ResponseTypesSupported, c.ResponseTypes, "response_type", oauth2.ResponseTypesEqual},
|
||||
}
|
||||
for _, field := range supportsAll {
|
||||
requestLoop:
|
||||
for _, req := range field.requested {
|
||||
for _, sup := range field.supported {
|
||||
if field.comp(req, sup) {
|
||||
continue requestLoop
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("provider does not support requested value for field %s", field.name)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(ericchiang): Are there more checks we feel comfortable with begin strict about?
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p ProviderConfig) SupportsGrantType(grantType string) bool {
|
||||
var supported []string
|
||||
if len(p.GrantTypesSupported) == 0 {
|
||||
supported = DefaultGrantTypesSupported
|
||||
} else {
|
||||
supported = p.GrantTypesSupported
|
||||
}
|
||||
|
||||
for _, t := range supported {
|
||||
if t == grantType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type ProviderConfigGetter interface {
|
||||
Get() (ProviderConfig, error)
|
||||
}
|
||||
|
||||
type ProviderConfigSetter interface {
|
||||
Set(ProviderConfig) error
|
||||
}
|
||||
|
||||
type ProviderConfigSyncer struct {
|
||||
from ProviderConfigGetter
|
||||
to ProviderConfigSetter
|
||||
clock clockwork.Clock
|
||||
|
||||
initialSyncDone bool
|
||||
initialSyncWait sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewProviderConfigSyncer(from ProviderConfigGetter, to ProviderConfigSetter) *ProviderConfigSyncer {
|
||||
return &ProviderConfigSyncer{
|
||||
from: from,
|
||||
to: to,
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProviderConfigSyncer) Run() chan struct{} {
|
||||
stop := make(chan struct{})
|
||||
|
||||
var next pcsStepper
|
||||
next = &pcsStepNext{aft: time.Duration(0)}
|
||||
|
||||
s.initialSyncWait.Add(1)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.clock.After(next.after()):
|
||||
next = next.step(s.sync)
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return stop
|
||||
}
|
||||
|
||||
func (s *ProviderConfigSyncer) WaitUntilInitialSync() {
|
||||
s.initialSyncWait.Wait()
|
||||
}
|
||||
|
||||
func (s *ProviderConfigSyncer) sync() (time.Duration, error) {
|
||||
cfg, err := s.from.Get()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err = s.to.Set(cfg); err != nil {
|
||||
return 0, fmt.Errorf("error setting provider config: %v", err)
|
||||
}
|
||||
|
||||
if !s.initialSyncDone {
|
||||
s.initialSyncWait.Done()
|
||||
s.initialSyncDone = true
|
||||
}
|
||||
|
||||
return nextSyncAfter(cfg.ExpiresAt, s.clock), nil
|
||||
}
|
||||
|
||||
type pcsStepFunc func() (time.Duration, error)
|
||||
|
||||
type pcsStepper interface {
|
||||
after() time.Duration
|
||||
step(pcsStepFunc) pcsStepper
|
||||
}
|
||||
|
||||
type pcsStepNext struct {
|
||||
aft time.Duration
|
||||
}
|
||||
|
||||
func (n *pcsStepNext) after() time.Duration {
|
||||
return n.aft
|
||||
}
|
||||
|
||||
func (n *pcsStepNext) step(fn pcsStepFunc) (next pcsStepper) {
|
||||
ttl, err := fn()
|
||||
if err == nil {
|
||||
next = &pcsStepNext{aft: ttl}
|
||||
} else {
|
||||
next = &pcsStepRetry{aft: time.Second}
|
||||
log.Printf("go-oidc: provider config sync falied, retyring in %v: %v", next.after(), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type pcsStepRetry struct {
|
||||
aft time.Duration
|
||||
}
|
||||
|
||||
func (r *pcsStepRetry) after() time.Duration {
|
||||
return r.aft
|
||||
}
|
||||
|
||||
func (r *pcsStepRetry) step(fn pcsStepFunc) (next pcsStepper) {
|
||||
ttl, err := fn()
|
||||
if err == nil {
|
||||
next = &pcsStepNext{aft: ttl}
|
||||
} else {
|
||||
next = &pcsStepRetry{aft: timeutil.ExpBackoff(r.aft, time.Minute)}
|
||||
log.Printf("go-oidc: provider config sync falied, retyring in %v: %v", next.after(), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func nextSyncAfter(exp time.Time, clock clockwork.Clock) time.Duration {
|
||||
if exp.IsZero() {
|
||||
return MaximumProviderConfigSyncInterval
|
||||
}
|
||||
|
||||
t := exp.Sub(clock.Now()) / 2
|
||||
if t > MaximumProviderConfigSyncInterval {
|
||||
t = MaximumProviderConfigSyncInterval
|
||||
} else if t < minimumProviderConfigSyncInterval {
|
||||
t = minimumProviderConfigSyncInterval
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type httpProviderConfigGetter struct {
|
||||
hc phttp.Client
|
||||
issuerURL string
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func NewHTTPProviderConfigGetter(hc phttp.Client, issuerURL string) *httpProviderConfigGetter {
|
||||
return &httpProviderConfigGetter{
|
||||
hc: hc,
|
||||
issuerURL: issuerURL,
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *httpProviderConfigGetter) Get() (cfg ProviderConfig, err error) {
|
||||
// If the Issuer value contains a path component, any terminating / MUST be removed before
|
||||
// appending /.well-known/openid-configuration.
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationRequest
|
||||
discoveryURL := strings.TrimSuffix(r.issuerURL, "/") + discoveryConfigPath
|
||||
req, err := http.NewRequest("GET", discoveryURL, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := r.hc.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if err = json.NewDecoder(resp.Body).Decode(&cfg); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var ttl time.Duration
|
||||
var ok bool
|
||||
ttl, ok, err = phttp.Cacheable(resp.Header)
|
||||
if err != nil {
|
||||
return
|
||||
} else if ok {
|
||||
cfg.ExpiresAt = r.clock.Now().UTC().Add(ttl)
|
||||
}
|
||||
|
||||
// The issuer value returned MUST be identical to the Issuer URL that was directly used to retrieve the configuration information.
|
||||
// http://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfigurationValidation
|
||||
if !urlEqual(cfg.Issuer.String(), r.issuerURL) {
|
||||
err = fmt.Errorf(`"issuer" in config (%v) does not match provided issuer URL (%v)`, cfg.Issuer, r.issuerURL)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func FetchProviderConfig(hc phttp.Client, issuerURL string) (ProviderConfig, error) {
|
||||
if hc == nil {
|
||||
hc = http.DefaultClient
|
||||
}
|
||||
|
||||
g := NewHTTPProviderConfigGetter(hc, issuerURL)
|
||||
return g.Get()
|
||||
}
|
||||
|
||||
func WaitForProviderConfig(hc phttp.Client, issuerURL string) (pcfg ProviderConfig) {
|
||||
return waitForProviderConfig(hc, issuerURL, clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
func waitForProviderConfig(hc phttp.Client, issuerURL string, clock clockwork.Clock) (pcfg ProviderConfig) {
|
||||
var sleep time.Duration
|
||||
var err error
|
||||
for {
|
||||
pcfg, err = FetchProviderConfig(hc, issuerURL)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
sleep = timeutil.ExpBackoff(sleep, time.Minute)
|
||||
fmt.Printf("Failed fetching provider config, trying again in %v: %v\n", sleep, err)
|
||||
time.Sleep(sleep)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
940
vendor/github.com/coreos/go-oidc/oidc/provider_test.go
generated
vendored
Normal file
940
vendor/github.com/coreos/go-oidc/oidc/provider_test.go
generated
vendored
Normal file
@@ -0,0 +1,940 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/kylelemons/godebug/diff"
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/oauth2"
|
||||
)
|
||||
|
||||
func TestProviderConfigDefaults(t *testing.T) {
|
||||
var cfg ProviderConfig
|
||||
cfg = cfg.Defaults()
|
||||
tests := []struct {
|
||||
got, want []string
|
||||
name string
|
||||
}{
|
||||
{cfg.GrantTypesSupported, DefaultGrantTypesSupported, "grant types"},
|
||||
{cfg.ResponseModesSupported, DefaultResponseModesSupported, "response modes"},
|
||||
{cfg.ClaimTypesSupported, DefaultClaimTypesSupported, "claim types"},
|
||||
{
|
||||
cfg.TokenEndpointAuthMethodsSupported,
|
||||
DefaultTokenEndpointAuthMethodsSupported,
|
||||
"token endpoint auth methods",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if diff := pretty.Compare(tt.want, tt.got); diff != "" {
|
||||
t.Errorf("%s: did not match %s", tt.name, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigUnmarshal(t *testing.T) {
|
||||
|
||||
// helper for quickly creating uris
|
||||
uri := func(path string) *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "server.example.com",
|
||||
Path: path,
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
data string
|
||||
want ProviderConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
data: `{
|
||||
"issuer": "https://server.example.com",
|
||||
"authorization_endpoint": "https://server.example.com/connect/authorize",
|
||||
"token_endpoint": "https://server.example.com/connect/token",
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_basic", "private_key_jwt"],
|
||||
"token_endpoint_auth_signing_alg_values_supported": ["RS256", "ES256"],
|
||||
"userinfo_endpoint": "https://server.example.com/connect/userinfo",
|
||||
"jwks_uri": "https://server.example.com/jwks.json",
|
||||
"registration_endpoint": "https://server.example.com/connect/register",
|
||||
"scopes_supported": [
|
||||
"openid", "profile", "email", "address", "phone", "offline_access"
|
||||
],
|
||||
"response_types_supported": [
|
||||
"code", "code id_token", "id_token", "id_token token"
|
||||
],
|
||||
"acr_values_supported": [
|
||||
"urn:mace:incommon:iap:silver", "urn:mace:incommon:iap:bronze"
|
||||
],
|
||||
"subject_types_supported": ["public", "pairwise"],
|
||||
"userinfo_signing_alg_values_supported": ["RS256", "ES256", "HS256"],
|
||||
"userinfo_encryption_alg_values_supported": ["RSA1_5", "A128KW"],
|
||||
"userinfo_encryption_enc_values_supported": ["A128CBC-HS256", "A128GCM"],
|
||||
"id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"],
|
||||
"id_token_encryption_alg_values_supported": ["RSA1_5", "A128KW"],
|
||||
"id_token_encryption_enc_values_supported": ["A128CBC-HS256", "A128GCM"],
|
||||
"request_object_signing_alg_values_supported": ["none", "RS256", "ES256"],
|
||||
"display_values_supported": ["page", "popup"],
|
||||
"claim_types_supported": ["normal", "distributed"],
|
||||
"claims_supported": [
|
||||
"sub", "iss", "auth_time", "acr", "name", "given_name",
|
||||
"family_name", "nickname", "profile", "picture", "website",
|
||||
"email", "email_verified", "locale", "zoneinfo",
|
||||
"http://example.info/claims/groups"
|
||||
],
|
||||
"claims_parameter_supported": true,
|
||||
"service_documentation": "https://server.example.com/connect/service_documentation.html",
|
||||
"ui_locales_supported": ["en-US", "en-GB", "en-CA", "fr-FR", "fr-CA"]
|
||||
}
|
||||
`,
|
||||
want: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "server.example.com"},
|
||||
AuthEndpoint: uri("/connect/authorize"),
|
||||
TokenEndpoint: uri("/connect/token"),
|
||||
TokenEndpointAuthMethodsSupported: []string{
|
||||
oauth2.AuthMethodClientSecretBasic, oauth2.AuthMethodPrivateKeyJWT,
|
||||
},
|
||||
TokenEndpointAuthSigningAlgValuesSupported: []string{
|
||||
jose.AlgRS256, jose.AlgES256,
|
||||
},
|
||||
UserInfoEndpoint: uri("/connect/userinfo"),
|
||||
KeysEndpoint: uri("/jwks.json"),
|
||||
RegistrationEndpoint: uri("/connect/register"),
|
||||
ScopesSupported: []string{
|
||||
"openid", "profile", "email", "address", "phone", "offline_access",
|
||||
},
|
||||
ResponseTypesSupported: []string{
|
||||
oauth2.ResponseTypeCode, oauth2.ResponseTypeCodeIDToken,
|
||||
oauth2.ResponseTypeIDToken, oauth2.ResponseTypeIDTokenToken,
|
||||
},
|
||||
ACRValuesSupported: []string{
|
||||
"urn:mace:incommon:iap:silver", "urn:mace:incommon:iap:bronze",
|
||||
},
|
||||
SubjectTypesSupported: []string{
|
||||
SubjectTypePublic, SubjectTypePairwise,
|
||||
},
|
||||
UserInfoSigningAlgValues: []string{jose.AlgRS256, jose.AlgES256, jose.AlgHS256},
|
||||
UserInfoEncryptionAlgValues: []string{"RSA1_5", "A128KW"},
|
||||
UserInfoEncryptionEncValues: []string{"A128CBC-HS256", "A128GCM"},
|
||||
IDTokenSigningAlgValues: []string{jose.AlgRS256, jose.AlgES256, jose.AlgHS256},
|
||||
IDTokenEncryptionAlgValues: []string{"RSA1_5", "A128KW"},
|
||||
IDTokenEncryptionEncValues: []string{"A128CBC-HS256", "A128GCM"},
|
||||
ReqObjSigningAlgValues: []string{jose.AlgNone, jose.AlgRS256, jose.AlgES256},
|
||||
DisplayValuesSupported: []string{"page", "popup"},
|
||||
ClaimTypesSupported: []string{"normal", "distributed"},
|
||||
ClaimsSupported: []string{
|
||||
"sub", "iss", "auth_time", "acr", "name", "given_name",
|
||||
"family_name", "nickname", "profile", "picture", "website",
|
||||
"email", "email_verified", "locale", "zoneinfo",
|
||||
"http://example.info/claims/groups",
|
||||
},
|
||||
ClaimsParameterSupported: true,
|
||||
ServiceDocs: uri("/connect/service_documentation.html"),
|
||||
UILocalsSupported: []string{"en-US", "en-GB", "en-CA", "fr-FR", "fr-CA"},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
// missing a lot of required field
|
||||
data: `{}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
data: `{
|
||||
"issuer": "https://server.example.com",
|
||||
"authorization_endpoint": "https://server.example.com/connect/authorize",
|
||||
"token_endpoint": "https://server.example.com/connect/token",
|
||||
"jwks_uri": "https://server.example.com/jwks.json",
|
||||
"response_types_supported": [
|
||||
"code", "code id_token", "id_token", "id_token token"
|
||||
],
|
||||
"subject_types_supported": ["public", "pairwise"],
|
||||
"id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"]
|
||||
}
|
||||
`,
|
||||
want: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "server.example.com"},
|
||||
AuthEndpoint: uri("/connect/authorize"),
|
||||
TokenEndpoint: uri("/connect/token"),
|
||||
KeysEndpoint: uri("/jwks.json"),
|
||||
ResponseTypesSupported: []string{
|
||||
oauth2.ResponseTypeCode, oauth2.ResponseTypeCodeIDToken,
|
||||
oauth2.ResponseTypeIDToken, oauth2.ResponseTypeIDTokenToken,
|
||||
},
|
||||
SubjectTypesSupported: []string{
|
||||
SubjectTypePublic, SubjectTypePairwise,
|
||||
},
|
||||
IDTokenSigningAlgValues: []string{jose.AlgRS256, jose.AlgES256, jose.AlgHS256},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
// invalid scheme 'ftp://'
|
||||
data: `{
|
||||
"issuer": "https://server.example.com",
|
||||
"authorization_endpoint": "https://server.example.com/connect/authorize",
|
||||
"token_endpoint": "https://server.example.com/connect/token",
|
||||
"jwks_uri": "ftp://server.example.com/jwks.json",
|
||||
"response_types_supported": [
|
||||
"code", "code id_token", "id_token", "id_token token"
|
||||
],
|
||||
"subject_types_supported": ["public", "pairwise"],
|
||||
"id_token_signing_alg_values_supported": ["RS256", "ES256", "HS256"]
|
||||
}
|
||||
`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
var got ProviderConfig
|
||||
if err := json.Unmarshal([]byte(tt.data), &got); err != nil {
|
||||
if !tt.wantErr {
|
||||
t.Errorf("case %d: failed to unmarshal provider config: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tt.wantErr {
|
||||
t.Errorf("case %d: expected error", i)
|
||||
continue
|
||||
}
|
||||
if diff := pretty.Compare(tt.want, got); diff != "" {
|
||||
t.Errorf("case %d: unmarshaled struct did not match expected %s", i, diff)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestProviderConfigMarshal(t *testing.T) {
|
||||
tests := []struct {
|
||||
cfg ProviderConfig
|
||||
want string
|
||||
}{
|
||||
{
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "auth.example.com"},
|
||||
AuthEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/auth",
|
||||
},
|
||||
TokenEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/token",
|
||||
},
|
||||
UserInfoEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/userinfo",
|
||||
},
|
||||
KeysEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/jwk",
|
||||
},
|
||||
ResponseTypesSupported: []string{oauth2.ResponseTypeCode},
|
||||
SubjectTypesSupported: []string{SubjectTypePublic},
|
||||
IDTokenSigningAlgValues: []string{jose.AlgRS256},
|
||||
},
|
||||
// spacing must match json.MarshalIndent(cfg, "", "\t")
|
||||
want: `{
|
||||
"issuer": "https://auth.example.com",
|
||||
"authorization_endpoint": "https://auth.example.com/auth",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"userinfo_endpoint": "https://auth.example.com/userinfo",
|
||||
"jwks_uri": "https://auth.example.com/jwk",
|
||||
"response_types_supported": [
|
||||
"code"
|
||||
],
|
||||
"subject_types_supported": [
|
||||
"public"
|
||||
],
|
||||
"id_token_signing_alg_values_supported": [
|
||||
"RS256"
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "auth.example.com"},
|
||||
AuthEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/auth",
|
||||
},
|
||||
TokenEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/token",
|
||||
},
|
||||
UserInfoEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/userinfo",
|
||||
},
|
||||
KeysEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/jwk",
|
||||
},
|
||||
RegistrationEndpoint: &url.URL{
|
||||
Scheme: "https", Host: "auth.example.com", Path: "/register",
|
||||
},
|
||||
ScopesSupported: DefaultScope,
|
||||
ResponseTypesSupported: []string{oauth2.ResponseTypeCode},
|
||||
ResponseModesSupported: DefaultResponseModesSupported,
|
||||
GrantTypesSupported: []string{oauth2.GrantTypeAuthCode},
|
||||
SubjectTypesSupported: []string{SubjectTypePublic},
|
||||
IDTokenSigningAlgValues: []string{jose.AlgRS256},
|
||||
ServiceDocs: &url.URL{Scheme: "https", Host: "example.com", Path: "/docs"},
|
||||
},
|
||||
// spacing must match json.MarshalIndent(cfg, "", "\t")
|
||||
want: `{
|
||||
"issuer": "https://auth.example.com",
|
||||
"authorization_endpoint": "https://auth.example.com/auth",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"userinfo_endpoint": "https://auth.example.com/userinfo",
|
||||
"jwks_uri": "https://auth.example.com/jwk",
|
||||
"registration_endpoint": "https://auth.example.com/register",
|
||||
"scopes_supported": [
|
||||
"openid",
|
||||
"email",
|
||||
"profile"
|
||||
],
|
||||
"response_types_supported": [
|
||||
"code"
|
||||
],
|
||||
"response_modes_supported": [
|
||||
"query",
|
||||
"fragment"
|
||||
],
|
||||
"grant_types_supported": [
|
||||
"authorization_code"
|
||||
],
|
||||
"subject_types_supported": [
|
||||
"public"
|
||||
],
|
||||
"id_token_signing_alg_values_supported": [
|
||||
"RS256"
|
||||
],
|
||||
"service_documentation": "https://example.com/docs"
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got, err := json.MarshalIndent(&tt.cfg, "", "\t")
|
||||
if err != nil {
|
||||
t.Errorf("case %d: failed to marshal config: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if d := diff.Diff(string(got), string(tt.want)); d != "" {
|
||||
t.Errorf("case %d: expected did not match result: %s", i, d)
|
||||
}
|
||||
|
||||
var cfg ProviderConfig
|
||||
if err := json.Unmarshal(got, &cfg); err != nil {
|
||||
t.Errorf("case %d: could not unmarshal marshal response: %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if d := pretty.Compare(tt.cfg, cfg); d != "" {
|
||||
t.Errorf("case %d: config did not survive JSON marshaling round trip: %s", i, d)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestProviderConfigSupports(t *testing.T) {
|
||||
tests := []struct {
|
||||
provider ProviderConfig
|
||||
client ClientMetadata
|
||||
fillRequiredProviderFields bool
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
provider: ProviderConfig{},
|
||||
client: ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
fillRequiredProviderFields: true,
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
// invalid provider config
|
||||
provider: ProviderConfig{},
|
||||
client: ClientMetadata{
|
||||
RedirectURIs: []url.URL{
|
||||
{Scheme: "https", Host: "example.com", Path: "/callback"},
|
||||
},
|
||||
},
|
||||
fillRequiredProviderFields: false,
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
// invalid client config
|
||||
provider: ProviderConfig{},
|
||||
client: ClientMetadata{},
|
||||
fillRequiredProviderFields: true,
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
if tt.fillRequiredProviderFields {
|
||||
tt.provider = fillRequiredProviderFields(tt.provider)
|
||||
}
|
||||
|
||||
err := tt.provider.Supports(tt.client)
|
||||
if err == nil && !tt.ok {
|
||||
t.Errorf("case %d: expected non-nil error", i)
|
||||
}
|
||||
if err != nil && tt.ok {
|
||||
t.Errorf("case %d: supports failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newValidProviderConfig() ProviderConfig {
|
||||
var cfg ProviderConfig
|
||||
return fillRequiredProviderFields(cfg)
|
||||
}
|
||||
|
||||
// fill a provider config with enough information to be valid
|
||||
func fillRequiredProviderFields(cfg ProviderConfig) ProviderConfig {
|
||||
if cfg.Issuer == nil {
|
||||
cfg.Issuer = &url.URL{Scheme: "https", Host: "auth.example.com"}
|
||||
}
|
||||
urlPath := func(path string) *url.URL {
|
||||
var u url.URL
|
||||
u = *cfg.Issuer
|
||||
u.Path = path
|
||||
return &u
|
||||
}
|
||||
cfg.AuthEndpoint = urlPath("/auth")
|
||||
cfg.TokenEndpoint = urlPath("/token")
|
||||
cfg.UserInfoEndpoint = urlPath("/userinfo")
|
||||
cfg.KeysEndpoint = urlPath("/jwk")
|
||||
cfg.ResponseTypesSupported = []string{oauth2.ResponseTypeCode}
|
||||
cfg.SubjectTypesSupported = []string{SubjectTypePublic}
|
||||
cfg.IDTokenSigningAlgValues = []string{jose.AlgRS256}
|
||||
return cfg
|
||||
}
|
||||
|
||||
type fakeProviderConfigGetterSetter struct {
|
||||
cfg *ProviderConfig
|
||||
getCount int
|
||||
setCount int
|
||||
}
|
||||
|
||||
func (g *fakeProviderConfigGetterSetter) Get() (ProviderConfig, error) {
|
||||
g.getCount++
|
||||
return *g.cfg, nil
|
||||
}
|
||||
|
||||
func (g *fakeProviderConfigGetterSetter) Set(cfg ProviderConfig) error {
|
||||
g.cfg = &cfg
|
||||
g.setCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeProviderConfigHandler struct {
|
||||
cfg ProviderConfig
|
||||
maxAge time.Duration
|
||||
}
|
||||
|
||||
func (s *fakeProviderConfigHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
b, _ := json.Marshal(&s.cfg)
|
||||
if s.maxAge.Seconds() >= 0 {
|
||||
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", int(s.maxAge.Seconds())))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write(b)
|
||||
}
|
||||
|
||||
func TestProviderConfigRequiredFields(t *testing.T) {
|
||||
// Ensure provider metadata responses have all the required fields.
|
||||
// taken from https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||
requiredFields := []string{
|
||||
"issuer",
|
||||
"authorization_endpoint",
|
||||
"token_endpoint", // "This is REQUIRED unless only the Implicit Flow is used."
|
||||
"jwks_uri",
|
||||
"response_types_supported",
|
||||
"subject_types_supported",
|
||||
"id_token_signing_alg_values_supported",
|
||||
}
|
||||
|
||||
svr := &fakeProviderConfigHandler{
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "http", Host: "example.com"},
|
||||
ExpiresAt: time.Now().Add(time.Minute),
|
||||
},
|
||||
maxAge: time.Minute,
|
||||
}
|
||||
svr.cfg = fillRequiredProviderFields(svr.cfg)
|
||||
s := httptest.NewServer(svr)
|
||||
defer s.Close()
|
||||
|
||||
resp, err := http.Get(s.URL + "/")
|
||||
if err != nil {
|
||||
t.Errorf("get: %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var data map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
|
||||
t.Errorf("decode: %v", err)
|
||||
return
|
||||
}
|
||||
for _, field := range requiredFields {
|
||||
if _, ok := data[field]; !ok {
|
||||
t.Errorf("provider metadata does not have required field '%s'", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type handlerClient struct {
|
||||
Handler http.Handler
|
||||
}
|
||||
|
||||
func (hc *handlerClient) Do(r *http.Request) (*http.Response, error) {
|
||||
w := httptest.NewRecorder()
|
||||
hc.Handler.ServeHTTP(w, r)
|
||||
|
||||
resp := http.Response{
|
||||
StatusCode: w.Code,
|
||||
Header: w.Header(),
|
||||
Body: ioutil.NopCloser(w.Body),
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
func TestHTTPProviderConfigGetter(t *testing.T) {
|
||||
svr := &fakeProviderConfigHandler{}
|
||||
hc := &handlerClient{Handler: svr}
|
||||
fc := clockwork.NewFakeClock()
|
||||
now := fc.Now().UTC()
|
||||
|
||||
tests := []struct {
|
||||
dsc string
|
||||
age time.Duration
|
||||
cfg ProviderConfig
|
||||
ok bool
|
||||
}{
|
||||
// everything is good
|
||||
{
|
||||
dsc: "https://example.com",
|
||||
age: time.Minute,
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
ExpiresAt: now.Add(time.Minute),
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
// iss and disco url differ by scheme only (how google works)
|
||||
{
|
||||
dsc: "https://example.com",
|
||||
age: time.Minute,
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
ExpiresAt: now.Add(time.Minute),
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
// issuer and discovery URL mismatch
|
||||
{
|
||||
dsc: "https://foo.com",
|
||||
age: time.Minute,
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
ExpiresAt: now.Add(time.Minute),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// missing cache header results in zero ExpiresAt
|
||||
{
|
||||
dsc: "https://example.com",
|
||||
age: -1,
|
||||
cfg: ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
tt.cfg = fillRequiredProviderFields(tt.cfg)
|
||||
svr.cfg = tt.cfg
|
||||
svr.maxAge = tt.age
|
||||
getter := NewHTTPProviderConfigGetter(hc, tt.dsc)
|
||||
getter.clock = fc
|
||||
|
||||
got, err := getter.Get()
|
||||
if err != nil {
|
||||
if tt.ok {
|
||||
t.Errorf("test %d: unexpected error: %v", i, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !tt.ok {
|
||||
t.Errorf("test %d: expected error", i)
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tt.cfg, got) {
|
||||
t.Errorf("test %d: want: %#v, got: %#v", i, tt.cfg, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigSyncerRun(t *testing.T) {
|
||||
c1 := &ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
c2 := &ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
first *ProviderConfig
|
||||
advance time.Duration
|
||||
second *ProviderConfig
|
||||
firstExp time.Duration
|
||||
secondExp time.Duration
|
||||
count int
|
||||
}{
|
||||
// exp is 10m, should have same config after 1s
|
||||
{
|
||||
first: c1,
|
||||
firstExp: time.Duration(10 * time.Minute),
|
||||
advance: time.Minute,
|
||||
second: c1,
|
||||
secondExp: time.Duration(10 * time.Minute),
|
||||
count: 1,
|
||||
},
|
||||
// exp is 10m, should have new config after 10/2 = 5m
|
||||
{
|
||||
first: c1,
|
||||
firstExp: time.Duration(10 * time.Minute),
|
||||
advance: time.Duration(5 * time.Minute),
|
||||
second: c2,
|
||||
secondExp: time.Duration(10 * time.Minute),
|
||||
count: 2,
|
||||
},
|
||||
// exp is 20m, should have new config after 20/2 = 10m
|
||||
{
|
||||
first: c1,
|
||||
firstExp: time.Duration(20 * time.Minute),
|
||||
advance: time.Duration(10 * time.Minute),
|
||||
second: c2,
|
||||
secondExp: time.Duration(30 * time.Minute),
|
||||
count: 2,
|
||||
},
|
||||
}
|
||||
|
||||
assertCfg := func(i int, to *fakeProviderConfigGetterSetter, want ProviderConfig) {
|
||||
got, err := to.Get()
|
||||
if err != nil {
|
||||
t.Fatalf("test %d: unable to get config: %v", i, err)
|
||||
}
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Fatalf("test %d: incorrect state:\nwant=%#v\ngot=%#v", i, want, got)
|
||||
}
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
from := &fakeProviderConfigGetterSetter{}
|
||||
to := &fakeProviderConfigGetterSetter{}
|
||||
|
||||
fc := clockwork.NewFakeClock()
|
||||
now := fc.Now().UTC()
|
||||
syncer := NewProviderConfigSyncer(from, to)
|
||||
syncer.clock = fc
|
||||
|
||||
tt.first.ExpiresAt = now.Add(tt.firstExp)
|
||||
tt.second.ExpiresAt = now.Add(tt.secondExp)
|
||||
if err := from.Set(*tt.first); err != nil {
|
||||
t.Fatalf("test %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
stop := syncer.Run()
|
||||
defer close(stop)
|
||||
fc.BlockUntil(1)
|
||||
|
||||
// first sync
|
||||
assertCfg(i, to, *tt.first)
|
||||
|
||||
if err := from.Set(*tt.second); err != nil {
|
||||
t.Fatalf("test %d: unexpected error: %v", i, err)
|
||||
}
|
||||
|
||||
fc.Advance(tt.advance)
|
||||
fc.BlockUntil(1)
|
||||
|
||||
// second sync
|
||||
assertCfg(i, to, *tt.second)
|
||||
|
||||
if tt.count != from.getCount {
|
||||
t.Fatalf("test %d: want: %v, got: %v", i, tt.count, from.getCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type staticProviderConfigGetter struct {
|
||||
cfg ProviderConfig
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *staticProviderConfigGetter) Get() (ProviderConfig, error) {
|
||||
return g.cfg, g.err
|
||||
}
|
||||
|
||||
type staticProviderConfigSetter struct {
|
||||
cfg *ProviderConfig
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *staticProviderConfigSetter) Set(cfg ProviderConfig) error {
|
||||
s.cfg = &cfg
|
||||
return s.err
|
||||
}
|
||||
|
||||
func TestProviderConfigSyncerSyncFailure(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
|
||||
tests := []struct {
|
||||
from *staticProviderConfigGetter
|
||||
to *staticProviderConfigSetter
|
||||
|
||||
// want indicates what ProviderConfig should be passed to Set.
|
||||
// If nil, the Set should not be called.
|
||||
want *ProviderConfig
|
||||
}{
|
||||
// generic Get failure
|
||||
{
|
||||
from: &staticProviderConfigGetter{err: errors.New("fail")},
|
||||
to: &staticProviderConfigSetter{},
|
||||
want: nil,
|
||||
},
|
||||
// generic Set failure
|
||||
{
|
||||
from: &staticProviderConfigGetter{cfg: ProviderConfig{ExpiresAt: fc.Now().Add(time.Minute)}},
|
||||
to: &staticProviderConfigSetter{err: errors.New("fail")},
|
||||
want: &ProviderConfig{ExpiresAt: fc.Now().Add(time.Minute)},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
pcs := &ProviderConfigSyncer{
|
||||
from: tt.from,
|
||||
to: tt.to,
|
||||
clock: fc,
|
||||
}
|
||||
_, err := pcs.sync()
|
||||
if err == nil {
|
||||
t.Errorf("case %d: expected non-nil error", i)
|
||||
}
|
||||
if !reflect.DeepEqual(tt.want, tt.to.cfg) {
|
||||
t.Errorf("case %d: Set mismatch: want=%#v got=%#v", i, tt.want, tt.to.cfg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextSyncAfter(t *testing.T) {
|
||||
fc := clockwork.NewFakeClock()
|
||||
|
||||
tests := []struct {
|
||||
exp time.Time
|
||||
want time.Duration
|
||||
}{
|
||||
{
|
||||
exp: fc.Now().Add(time.Hour),
|
||||
want: 30 * time.Minute,
|
||||
},
|
||||
// override large values with the maximum
|
||||
{
|
||||
exp: fc.Now().Add(168 * time.Hour), // one week
|
||||
want: 24 * time.Hour,
|
||||
},
|
||||
// override "now" values with the minimum
|
||||
{
|
||||
exp: fc.Now(),
|
||||
want: time.Minute,
|
||||
},
|
||||
// override negative values with the minimum
|
||||
{
|
||||
exp: fc.Now().Add(-1 * time.Minute),
|
||||
want: time.Minute,
|
||||
},
|
||||
// zero-value Time results in maximum sync interval
|
||||
{
|
||||
exp: time.Time{},
|
||||
want: 24 * time.Hour,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got := nextSyncAfter(tt.exp, fc)
|
||||
if tt.want != got {
|
||||
t.Errorf("case %d: want=%v got=%v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigEmpty(t *testing.T) {
|
||||
cfg := ProviderConfig{}
|
||||
if !cfg.Empty() {
|
||||
t.Fatalf("Empty provider config reports non-empty")
|
||||
}
|
||||
cfg = ProviderConfig{
|
||||
Issuer: &url.URL{Scheme: "https", Host: "example.com"},
|
||||
}
|
||||
if cfg.Empty() {
|
||||
t.Fatalf("Non-empty provider config reports empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPCSStepAfter(t *testing.T) {
|
||||
pass := func() (time.Duration, error) { return 7 * time.Second, nil }
|
||||
fail := func() (time.Duration, error) { return 0, errors.New("fail") }
|
||||
|
||||
tests := []struct {
|
||||
stepper pcsStepper
|
||||
stepFunc pcsStepFunc
|
||||
want pcsStepper
|
||||
}{
|
||||
// good step results in retry at TTL
|
||||
{
|
||||
stepper: &pcsStepNext{},
|
||||
stepFunc: pass,
|
||||
want: &pcsStepNext{aft: 7 * time.Second},
|
||||
},
|
||||
|
||||
// good step after failed step results results in retry at TTL
|
||||
{
|
||||
stepper: &pcsStepRetry{aft: 2 * time.Second},
|
||||
stepFunc: pass,
|
||||
want: &pcsStepNext{aft: 7 * time.Second},
|
||||
},
|
||||
|
||||
// failed step results in a retry in 1s
|
||||
{
|
||||
stepper: &pcsStepNext{},
|
||||
stepFunc: fail,
|
||||
want: &pcsStepRetry{aft: time.Second},
|
||||
},
|
||||
|
||||
// failed retry backs off by a factor of 2
|
||||
{
|
||||
stepper: &pcsStepRetry{aft: time.Second},
|
||||
stepFunc: fail,
|
||||
want: &pcsStepRetry{aft: 2 * time.Second},
|
||||
},
|
||||
|
||||
// failed retry backs off by a factor of 2, up to 1m
|
||||
{
|
||||
stepper: &pcsStepRetry{aft: 32 * time.Second},
|
||||
stepFunc: fail,
|
||||
want: &pcsStepRetry{aft: 60 * time.Second},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got := tt.stepper.step(tt.stepFunc)
|
||||
if !reflect.DeepEqual(tt.want, got) {
|
||||
t.Errorf("case %d: want=%#v got=%#v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigSupportsGrantType(t *testing.T) {
|
||||
tests := []struct {
|
||||
types []string
|
||||
typ string
|
||||
want bool
|
||||
}{
|
||||
// explicitly supported
|
||||
{
|
||||
types: []string{"foo_type"},
|
||||
typ: "foo_type",
|
||||
want: true,
|
||||
},
|
||||
|
||||
// explicitly unsupported
|
||||
{
|
||||
types: []string{"bar_type"},
|
||||
typ: "foo_type",
|
||||
want: false,
|
||||
},
|
||||
|
||||
// default type explicitly unsupported
|
||||
{
|
||||
types: []string{oauth2.GrantTypeImplicit},
|
||||
typ: oauth2.GrantTypeAuthCode,
|
||||
want: false,
|
||||
},
|
||||
|
||||
// type not found in default set
|
||||
{
|
||||
types: []string{},
|
||||
typ: "foo_type",
|
||||
want: false,
|
||||
},
|
||||
|
||||
// type found in default set
|
||||
{
|
||||
types: []string{},
|
||||
typ: oauth2.GrantTypeAuthCode,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
cfg := ProviderConfig{
|
||||
GrantTypesSupported: tt.types,
|
||||
}
|
||||
got := cfg.SupportsGrantType(tt.typ)
|
||||
if tt.want != got {
|
||||
t.Errorf("case %d: assert %v supports %v: want=%t got=%t", i, tt.types, tt.typ, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClient struct {
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (f *fakeClient) Do(req *http.Request) (*http.Response, error) {
|
||||
return f.resp, nil
|
||||
}
|
||||
|
||||
func TestWaitForProviderConfigImmediateSuccess(t *testing.T) {
|
||||
cfg := newValidProviderConfig()
|
||||
b, err := json.Marshal(&cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed marshaling provider config")
|
||||
}
|
||||
|
||||
resp := http.Response{Body: ioutil.NopCloser(bytes.NewBuffer(b))}
|
||||
hc := &fakeClient{&resp}
|
||||
fc := clockwork.NewFakeClock()
|
||||
|
||||
reschan := make(chan ProviderConfig)
|
||||
go func() {
|
||||
reschan <- waitForProviderConfig(hc, cfg.Issuer.String(), fc)
|
||||
}()
|
||||
|
||||
var got ProviderConfig
|
||||
select {
|
||||
case got = <-reschan:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("Did not receive result within 1s")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(cfg, got) {
|
||||
t.Fatalf("Received incorrect provider config: want=%#v got=%#v", cfg, got)
|
||||
}
|
||||
}
|
88
vendor/github.com/coreos/go-oidc/oidc/transport.go
generated
vendored
Normal file
88
vendor/github.com/coreos/go-oidc/oidc/transport.go
generated
vendored
Normal file
@@ -0,0 +1,88 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
phttp "github.com/coreos/go-oidc/http"
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
type TokenRefresher interface {
|
||||
// Verify checks if the provided token is currently valid or not.
|
||||
Verify(jose.JWT) error
|
||||
|
||||
// Refresh attempts to authenticate and retrieve a new token.
|
||||
Refresh() (jose.JWT, error)
|
||||
}
|
||||
|
||||
type ClientCredsTokenRefresher struct {
|
||||
Issuer string
|
||||
OIDCClient *Client
|
||||
}
|
||||
|
||||
func (c *ClientCredsTokenRefresher) Verify(jwt jose.JWT) (err error) {
|
||||
_, err = VerifyClientClaims(jwt, c.Issuer)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ClientCredsTokenRefresher) Refresh() (jwt jose.JWT, err error) {
|
||||
if err = c.OIDCClient.Healthy(); err != nil {
|
||||
err = fmt.Errorf("unable to authenticate, unhealthy OIDC client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
jwt, err = c.OIDCClient.ClientCredsToken([]string{"openid"})
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to verify auth code with issuer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type AuthenticatedTransport struct {
|
||||
TokenRefresher
|
||||
http.RoundTripper
|
||||
|
||||
mu sync.Mutex
|
||||
jwt jose.JWT
|
||||
}
|
||||
|
||||
func (t *AuthenticatedTransport) verifiedJWT() (jose.JWT, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.TokenRefresher.Verify(t.jwt) == nil {
|
||||
return t.jwt, nil
|
||||
}
|
||||
|
||||
jwt, err := t.TokenRefresher.Refresh()
|
||||
if err != nil {
|
||||
return jose.JWT{}, fmt.Errorf("unable to acquire valid JWT: %v", err)
|
||||
}
|
||||
|
||||
t.jwt = jwt
|
||||
return t.jwt, nil
|
||||
}
|
||||
|
||||
// SetJWT sets the JWT held by the Transport.
|
||||
// This is useful for cases in which you want to set an initial JWT.
|
||||
func (t *AuthenticatedTransport) SetJWT(jwt jose.JWT) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.jwt = jwt
|
||||
}
|
||||
|
||||
func (t *AuthenticatedTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
jwt, err := t.verifiedJWT()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req := phttp.CopyRequest(r)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwt.Encode()))
|
||||
return t.RoundTripper.RoundTrip(req)
|
||||
}
|
176
vendor/github.com/coreos/go-oidc/oidc/transport_test.go
generated
vendored
Normal file
176
vendor/github.com/coreos/go-oidc/oidc/transport_test.go
generated
vendored
Normal file
@@ -0,0 +1,176 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
type staticTokenRefresher struct {
|
||||
verify func(jose.JWT) error
|
||||
refresh func() (jose.JWT, error)
|
||||
}
|
||||
|
||||
func (s *staticTokenRefresher) Verify(jwt jose.JWT) error {
|
||||
return s.verify(jwt)
|
||||
}
|
||||
|
||||
func (s *staticTokenRefresher) Refresh() (jose.JWT, error) {
|
||||
return s.refresh()
|
||||
}
|
||||
|
||||
func TestAuthenticatedTransportVerifiedJWT(t *testing.T) {
|
||||
tests := []struct {
|
||||
refresher TokenRefresher
|
||||
startJWT jose.JWT
|
||||
wantJWT jose.JWT
|
||||
wantError error
|
||||
}{
|
||||
// verification succeeds, so refresh is not called
|
||||
{
|
||||
refresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return nil },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
|
||||
},
|
||||
startJWT: jose.JWT{RawPayload: "1"},
|
||||
wantJWT: jose.JWT{RawPayload: "1"},
|
||||
},
|
||||
|
||||
// verification fails, refresh succeeds so cached JWT changes
|
||||
{
|
||||
refresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return errors.New("fail!") },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
|
||||
},
|
||||
startJWT: jose.JWT{RawPayload: "1"},
|
||||
wantJWT: jose.JWT{RawPayload: "2"},
|
||||
},
|
||||
|
||||
// verification succeeds, so failing refresh isn't attempted
|
||||
{
|
||||
refresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return nil },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
|
||||
},
|
||||
startJWT: jose.JWT{RawPayload: "1"},
|
||||
wantJWT: jose.JWT{RawPayload: "1"},
|
||||
},
|
||||
|
||||
// verification fails, but refresh fails, too
|
||||
{
|
||||
refresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return errors.New("fail!") },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
|
||||
},
|
||||
startJWT: jose.JWT{RawPayload: "1"},
|
||||
wantJWT: jose.JWT{},
|
||||
wantError: errors.New("unable to acquire valid JWT: fail!"),
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
at := &AuthenticatedTransport{
|
||||
TokenRefresher: tt.refresher,
|
||||
}
|
||||
at.SetJWT(tt.startJWT)
|
||||
|
||||
gotJWT, err := at.verifiedJWT()
|
||||
if !reflect.DeepEqual(tt.wantError, err) {
|
||||
t.Errorf("#%d: unexpected error: want=%#v got=%#v", i, tt.wantError, err)
|
||||
}
|
||||
if !reflect.DeepEqual(tt.wantJWT, gotJWT) {
|
||||
t.Errorf("#%d: incorrect JWT returned from verifiedJWT: want=%#v got=%#v", i, tt.wantJWT, gotJWT)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticatedTransportJWTCaching(t *testing.T) {
|
||||
at := &AuthenticatedTransport{
|
||||
TokenRefresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return errors.New("fail!") },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "2"}, nil },
|
||||
},
|
||||
jwt: jose.JWT{RawPayload: "1"},
|
||||
}
|
||||
|
||||
wantJWT := jose.JWT{RawPayload: "2"}
|
||||
gotJWT, err := at.verifiedJWT()
|
||||
if err != nil {
|
||||
t.Fatalf("got non-nil error: %#v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(wantJWT, gotJWT) {
|
||||
t.Fatalf("incorrect JWT returned from verifiedJWT: want=%#v got=%#v", wantJWT, gotJWT)
|
||||
}
|
||||
|
||||
at.TokenRefresher = &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return nil },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{RawPayload: "3"}, nil },
|
||||
}
|
||||
|
||||
// the previous JWT should still be cached on the AuthenticatedTransport since
|
||||
// it is still valid, even though there's a new token ready to refresh
|
||||
gotJWT, err = at.verifiedJWT()
|
||||
if err != nil {
|
||||
t.Fatalf("got non-nil error: %#v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(wantJWT, gotJWT) {
|
||||
t.Fatalf("incorrect JWT returned from verifiedJWT: want=%#v got=%#v", wantJWT, gotJWT)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeRoundTripper struct {
|
||||
Request *http.Request
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
func (r *fakeRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
r.Request = req
|
||||
return r.resp, nil
|
||||
}
|
||||
|
||||
func TestAuthenticatedTransportRoundTrip(t *testing.T) {
|
||||
rr := &fakeRoundTripper{nil, &http.Response{StatusCode: http.StatusOK}}
|
||||
at := &AuthenticatedTransport{
|
||||
TokenRefresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return nil },
|
||||
},
|
||||
RoundTripper: rr,
|
||||
jwt: jose.JWT{RawPayload: "1"},
|
||||
}
|
||||
|
||||
req := http.Request{}
|
||||
_, err := at.RoundTrip(&req)
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(req, http.Request{}) {
|
||||
t.Errorf("http.Request object was modified")
|
||||
}
|
||||
|
||||
want := []string{"Bearer .1."}
|
||||
got := rr.Request.Header["Authorization"]
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("incorrect Authorization header: want=%#v got=%#v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthenticatedTransportRoundTripRefreshFail(t *testing.T) {
|
||||
rr := &fakeRoundTripper{nil, &http.Response{StatusCode: http.StatusOK}}
|
||||
at := &AuthenticatedTransport{
|
||||
TokenRefresher: &staticTokenRefresher{
|
||||
verify: func(jose.JWT) error { return errors.New("fail!") },
|
||||
refresh: func() (jose.JWT, error) { return jose.JWT{}, errors.New("fail!") },
|
||||
},
|
||||
RoundTripper: rr,
|
||||
jwt: jose.JWT{RawPayload: "1"},
|
||||
}
|
||||
|
||||
_, err := at.RoundTrip(&http.Request{})
|
||||
if err == nil {
|
||||
t.Errorf("expected non-nil error")
|
||||
}
|
||||
}
|
109
vendor/github.com/coreos/go-oidc/oidc/util.go
generated
vendored
Normal file
109
vendor/github.com/coreos/go-oidc/oidc/util.go
generated
vendored
Normal file
@@ -0,0 +1,109 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
// RequestTokenExtractor funcs extract a raw encoded token from a request.
|
||||
type RequestTokenExtractor func(r *http.Request) (string, error)
|
||||
|
||||
// ExtractBearerToken is a RequestTokenExtractor which extracts a bearer token from a request's
|
||||
// Authorization header.
|
||||
func ExtractBearerToken(r *http.Request) (string, error) {
|
||||
ah := r.Header.Get("Authorization")
|
||||
if ah == "" {
|
||||
return "", errors.New("missing Authorization header")
|
||||
}
|
||||
|
||||
if len(ah) <= 6 || strings.ToUpper(ah[0:6]) != "BEARER" {
|
||||
return "", errors.New("should be a bearer token")
|
||||
}
|
||||
|
||||
val := ah[7:]
|
||||
if len(val) == 0 {
|
||||
return "", errors.New("bearer token is empty")
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// CookieTokenExtractor returns a RequestTokenExtractor which extracts a token from the named cookie in a request.
|
||||
func CookieTokenExtractor(cookieName string) RequestTokenExtractor {
|
||||
return func(r *http.Request) (string, error) {
|
||||
ck, err := r.Cookie(cookieName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("token cookie not found in request: %v", err)
|
||||
}
|
||||
|
||||
if ck.Value == "" {
|
||||
return "", errors.New("token cookie found but is empty")
|
||||
}
|
||||
|
||||
return ck.Value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func NewClaims(iss, sub string, aud interface{}, iat, exp time.Time) jose.Claims {
|
||||
return jose.Claims{
|
||||
// required
|
||||
"iss": iss,
|
||||
"sub": sub,
|
||||
"aud": aud,
|
||||
"iat": iat.Unix(),
|
||||
"exp": exp.Unix(),
|
||||
}
|
||||
}
|
||||
|
||||
func GenClientID(hostport string) (string, error) {
|
||||
b, err := randBytes(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var host string
|
||||
if strings.Contains(hostport, ":") {
|
||||
host, _, err = net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
host = hostport
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s@%s", base64.URLEncoding.EncodeToString(b), host), nil
|
||||
}
|
||||
|
||||
func randBytes(n int) ([]byte, error) {
|
||||
b := make([]byte, n)
|
||||
got, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if n != got {
|
||||
return nil, errors.New("unable to generate enough random data")
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
||||
// urlEqual checks two urls for equality using only the host and path portions.
|
||||
func urlEqual(url1, url2 string) bool {
|
||||
u1, err := url.Parse(url1)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
u2, err := url.Parse(url2)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.ToLower(u1.Host+u1.Path) == strings.ToLower(u2.Host+u2.Path)
|
||||
}
|
110
vendor/github.com/coreos/go-oidc/oidc/util_test.go
generated
vendored
Normal file
110
vendor/github.com/coreos/go-oidc/oidc/util_test.go
generated
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
)
|
||||
|
||||
func TestCookieTokenExtractorInvalid(t *testing.T) {
|
||||
ckName := "tokenCookie"
|
||||
tests := []*http.Cookie{
|
||||
&http.Cookie{},
|
||||
&http.Cookie{Name: ckName},
|
||||
&http.Cookie{Name: ckName, Value: ""},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, _ := http.NewRequest("", "", nil)
|
||||
r.AddCookie(tt)
|
||||
_, err := CookieTokenExtractor(ckName)(r)
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want: error for invalid cookie token, got: no error.", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCookieTokenExtractorValid(t *testing.T) {
|
||||
validToken := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
ckName := "tokenCookie"
|
||||
tests := []*http.Cookie{
|
||||
&http.Cookie{Name: ckName, Value: "some non-empty value"},
|
||||
&http.Cookie{Name: ckName, Value: validToken},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, _ := http.NewRequest("", "", nil)
|
||||
r.AddCookie(tt)
|
||||
_, err := CookieTokenExtractor(ckName)(r)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want: valid cookie with no error, got: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBearerTokenInvalid(t *testing.T) {
|
||||
tests := []string{"", "x", "Bearer", "xxxxxxx", "Bearer "}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, _ := http.NewRequest("", "", nil)
|
||||
r.Header.Add("Authorization", tt)
|
||||
_, err := ExtractBearerToken(r)
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want: invalid Authorization header, got: valid Authorization header.", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBearerTokenValid(t *testing.T) {
|
||||
validToken := "eyJ0eXAiOiJKV1QiLA0KICJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJqb2UiLA0KICJleHAiOjEzMDA4MTkzODAsDQogImh0dHA6Ly9leGFtcGxlLmNvbS9pc19yb290Ijp0cnVlfQ.dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
tests := []string{
|
||||
fmt.Sprintf("Bearer %s", validToken),
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
r, _ := http.NewRequest("", "", nil)
|
||||
r.Header.Add("Authorization", tt)
|
||||
_, err := ExtractBearerToken(r)
|
||||
if err != nil {
|
||||
t.Errorf("case %d: want: valid Authorization header, got: invalid Authorization header: %v.", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewClaims(t *testing.T) {
|
||||
issAt := time.Date(2, time.January, 1, 0, 0, 0, 0, time.UTC)
|
||||
expAt := time.Date(2, time.January, 1, 1, 0, 0, 0, time.UTC)
|
||||
|
||||
want := jose.Claims{
|
||||
"iss": "https://example.com",
|
||||
"sub": "user-123",
|
||||
"aud": "client-abc",
|
||||
"iat": issAt.Unix(),
|
||||
"exp": expAt.Unix(),
|
||||
}
|
||||
|
||||
got := NewClaims("https://example.com", "user-123", "client-abc", issAt, expAt)
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Fatalf("want=%#v got=%#v", want, got)
|
||||
}
|
||||
|
||||
want2 := jose.Claims{
|
||||
"iss": "https://example.com",
|
||||
"sub": "user-123",
|
||||
"aud": []string{"client-abc", "client-def"},
|
||||
"iat": issAt.Unix(),
|
||||
"exp": expAt.Unix(),
|
||||
}
|
||||
|
||||
got2 := NewClaims("https://example.com", "user-123", []string{"client-abc", "client-def"}, issAt, expAt)
|
||||
|
||||
if !reflect.DeepEqual(want2, got2) {
|
||||
t.Fatalf("want=%#v got=%#v", want2, got2)
|
||||
}
|
||||
|
||||
}
|
190
vendor/github.com/coreos/go-oidc/oidc/verification.go
generated
vendored
Normal file
190
vendor/github.com/coreos/go-oidc/oidc/verification.go
generated
vendored
Normal file
@@ -0,0 +1,190 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
)
|
||||
|
||||
func VerifySignature(jwt jose.JWT, keys []key.PublicKey) (bool, error) {
|
||||
jwtBytes := []byte(jwt.Data())
|
||||
for _, k := range keys {
|
||||
v, err := k.Verifier()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if v.Verify(jwt.Signature, jwtBytes) == nil {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// containsString returns true if the given string(needle) is found
|
||||
// in the string array(haystack).
|
||||
func containsString(needle string, haystack []string) bool {
|
||||
for _, v := range haystack {
|
||||
if v == needle {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify claims in accordance with OIDC spec
|
||||
// http://openid.net/specs/openid-connect-basic-1_0.html#IDTokenValidation
|
||||
func VerifyClaims(jwt jose.JWT, issuer, clientID string) error {
|
||||
now := time.Now().UTC()
|
||||
|
||||
claims, err := jwt.Claims()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ident, err := IdentityFromClaims(claims)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if ident.ExpiresAt.Before(now) {
|
||||
return errors.New("token is expired")
|
||||
}
|
||||
|
||||
// iss REQUIRED. Issuer Identifier for the Issuer of the response.
|
||||
// The iss value is a case sensitive URL using the https scheme that contains scheme,
|
||||
// host, and optionally, port number and path components and no query or fragment components.
|
||||
if iss, exists := claims["iss"].(string); exists {
|
||||
if !urlEqual(iss, issuer) {
|
||||
return fmt.Errorf("invalid claim value: 'iss'. expected=%s, found=%s.", issuer, iss)
|
||||
}
|
||||
} else {
|
||||
return errors.New("missing claim: 'iss'")
|
||||
}
|
||||
|
||||
// iat REQUIRED. Time at which the JWT was issued.
|
||||
// Its value is a JSON number representing the number of seconds from 1970-01-01T0:0:0Z
|
||||
// as measured in UTC until the date/time.
|
||||
if _, exists := claims["iat"].(float64); !exists {
|
||||
return errors.New("missing claim: 'iat'")
|
||||
}
|
||||
|
||||
// aud REQUIRED. Audience(s) that this ID Token is intended for.
|
||||
// It MUST contain the OAuth 2.0 client_id of the Relying Party as an audience value.
|
||||
// It MAY also contain identifiers for other audiences. In the general case, the aud
|
||||
// value is an array of case sensitive strings. In the common special case when there
|
||||
// is one audience, the aud value MAY be a single case sensitive string.
|
||||
if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
|
||||
if aud != clientID {
|
||||
return fmt.Errorf("invalid claims, 'aud' claim and 'client_id' do not match, aud=%s, client_id=%s", aud, clientID)
|
||||
}
|
||||
} else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
|
||||
if !containsString(clientID, aud) {
|
||||
return fmt.Errorf("invalid claims, cannot find 'client_id' in 'aud' claim, aud=%v, client_id=%s", aud, clientID)
|
||||
}
|
||||
} else {
|
||||
return errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyClientClaims verifies all the required claims are valid for a "client credentials" JWT.
|
||||
// Returns the client ID if valid, or an error if invalid.
|
||||
func VerifyClientClaims(jwt jose.JWT, issuer string) (string, error) {
|
||||
claims, err := jwt.Claims()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse JWT claims: %v", err)
|
||||
}
|
||||
|
||||
iss, ok, err := claims.StringClaim("iss")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse 'iss' claim: %v", err)
|
||||
} else if !ok {
|
||||
return "", errors.New("missing required 'iss' claim")
|
||||
} else if !urlEqual(iss, issuer) {
|
||||
return "", fmt.Errorf("'iss' claim does not match expected issuer, iss=%s", iss)
|
||||
}
|
||||
|
||||
sub, ok, err := claims.StringClaim("sub")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse 'sub' claim: %v", err)
|
||||
} else if !ok {
|
||||
return "", errors.New("missing required 'sub' claim")
|
||||
}
|
||||
|
||||
if aud, ok, err := claims.StringClaim("aud"); err == nil && ok {
|
||||
if aud != sub {
|
||||
return "", fmt.Errorf("invalid claims, 'aud' claim and 'sub' claim do not match, aud=%s, sub=%s", aud, sub)
|
||||
}
|
||||
} else if aud, ok, err := claims.StringsClaim("aud"); err == nil && ok {
|
||||
if !containsString(sub, aud) {
|
||||
return "", fmt.Errorf("invalid claims, cannot find 'sud' in 'aud' claim, aud=%v, sub=%s", aud, sub)
|
||||
}
|
||||
} else {
|
||||
return "", errors.New("invalid claim value: 'aud' is required, and should be either string or string array")
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
exp, ok, err := claims.TimeClaim("exp")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse 'exp' claim: %v", err)
|
||||
} else if !ok {
|
||||
return "", errors.New("missing required 'exp' claim")
|
||||
} else if exp.Before(now) {
|
||||
return "", fmt.Errorf("token already expired at: %v", exp)
|
||||
}
|
||||
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
type JWTVerifier struct {
|
||||
issuer string
|
||||
clientID string
|
||||
syncFunc func() error
|
||||
keysFunc func() []key.PublicKey
|
||||
clock clockwork.Clock
|
||||
}
|
||||
|
||||
func NewJWTVerifier(issuer, clientID string, syncFunc func() error, keysFunc func() []key.PublicKey) JWTVerifier {
|
||||
return JWTVerifier{
|
||||
issuer: issuer,
|
||||
clientID: clientID,
|
||||
syncFunc: syncFunc,
|
||||
keysFunc: keysFunc,
|
||||
clock: clockwork.NewRealClock(),
|
||||
}
|
||||
}
|
||||
|
||||
func (v *JWTVerifier) Verify(jwt jose.JWT) error {
|
||||
// Verify claims before verifying the signature. This is an optimization to throw out
|
||||
// tokens we know are invalid without undergoing an expensive signature check and
|
||||
// possibly a re-sync event.
|
||||
if err := VerifyClaims(jwt, v.issuer, v.clientID); err != nil {
|
||||
return fmt.Errorf("oidc: JWT claims invalid: %v", err)
|
||||
}
|
||||
|
||||
ok, err := VerifySignature(jwt, v.keysFunc())
|
||||
if err != nil {
|
||||
return fmt.Errorf("oidc: JWT signature verification failed: %v", err)
|
||||
} else if ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = v.syncFunc(); err != nil {
|
||||
return fmt.Errorf("oidc: failed syncing KeySet: %v", err)
|
||||
}
|
||||
|
||||
ok, err = VerifySignature(jwt, v.keysFunc())
|
||||
if err != nil {
|
||||
return fmt.Errorf("oidc: JWT signature verification failed: %v", err)
|
||||
} else if !ok {
|
||||
return errors.New("oidc: unable to verify JWT signature: no matching keys")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
380
vendor/github.com/coreos/go-oidc/oidc/verification_test.go
generated
vendored
Normal file
380
vendor/github.com/coreos/go-oidc/oidc/verification_test.go
generated
vendored
Normal file
@@ -0,0 +1,380 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/jose"
|
||||
"github.com/coreos/go-oidc/key"
|
||||
)
|
||||
|
||||
func TestVerifyClientClaims(t *testing.T) {
|
||||
validIss := "https://example.com"
|
||||
validClientID := "valid-client"
|
||||
now := time.Now()
|
||||
tomorrow := now.Add(24 * time.Hour)
|
||||
header := jose.JOSEHeader{
|
||||
jose.HeaderKeyAlgorithm: "test-alg",
|
||||
jose.HeaderKeyID: "1",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
claims jose.Claims
|
||||
ok bool
|
||||
}{
|
||||
// valid token
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
// valid token, ('aud' claim is []string)
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": []string{"foo", validClientID},
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
// valid token, ('aud' claim is []interface{})
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": []interface{}{"foo", validClientID},
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
// missing 'iss' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"sub": validClientID,
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// invalid 'iss' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": "INVALID",
|
||||
"sub": validClientID,
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// missing 'sub' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// invalid 'sub' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": "INVALID",
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// missing 'aud' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// invalid 'aud' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": "INVALID",
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// invalid 'aud' claim
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": []string{"INVALID1", "INVALID2"},
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// invalid 'aud' type
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": struct{}{},
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(tomorrow.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
// expired
|
||||
{
|
||||
claims: jose.Claims{
|
||||
"iss": validIss,
|
||||
"sub": validClientID,
|
||||
"aud": validClientID,
|
||||
"iat": float64(now.Unix()),
|
||||
"exp": float64(now.Unix()),
|
||||
},
|
||||
ok: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
jwt, err := jose.NewJWT(header, tt.claims)
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: Failed to generate JWT, error=%v", i, err)
|
||||
}
|
||||
|
||||
got, err := VerifyClientClaims(jwt, validIss)
|
||||
if tt.ok {
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error, err=%v", i, err)
|
||||
}
|
||||
if got != validClientID {
|
||||
t.Errorf("case %d: incorrect client ID, want=%s, got=%s", i, validClientID, got)
|
||||
}
|
||||
} else if err == nil {
|
||||
t.Errorf("case %d: expected error but err is nil", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJWTVerifier(t *testing.T) {
|
||||
iss := "http://example.com"
|
||||
now := time.Now()
|
||||
future12 := now.Add(12 * time.Hour)
|
||||
past36 := now.Add(-36 * time.Hour)
|
||||
past12 := now.Add(-12 * time.Hour)
|
||||
|
||||
priv1, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
pk1 := *key.NewPublicKey(priv1.JWK())
|
||||
|
||||
priv2, err := key.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate private key, error=%v", err)
|
||||
}
|
||||
pk2 := *key.NewPublicKey(priv2.JWK())
|
||||
|
||||
newJWT := func(issuer, subject string, aud interface{}, issuedAt, exp time.Time, signer jose.Signer) jose.JWT {
|
||||
jwt, err := jose.NewSignedJWT(NewClaims(issuer, subject, aud, issuedAt, exp), signer)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return *jwt
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
verifier JWTVerifier
|
||||
jwt jose.JWT
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "JWT signed with available key",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv1.Signer()),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "JWT signed with available key, with bad claims",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "YYY", past12, future12, priv1.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "JWT signed with available key",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", []string{"YYY", "ZZZ"}, past12, future12, priv1.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "expired JWT signed with available key",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past36, past12, priv1.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "JWT signed with unrecognized key, verifiable after sync",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() func() []key.PublicKey {
|
||||
var i int
|
||||
return func() []key.PublicKey {
|
||||
defer func() { i++ }()
|
||||
return [][]key.PublicKey{
|
||||
[]key.PublicKey{pk1},
|
||||
[]key.PublicKey{pk2},
|
||||
}[i]
|
||||
}
|
||||
}(),
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past36, future12, priv2.Signer()),
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "JWT signed with unrecognized key, not verifiable after sync",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv2.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "verifier gets no keys from keysFunc, still not verifiable after sync",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv1.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
{
|
||||
name: "verifier gets no keys from keysFunc, verifiable after sync",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() func() []key.PublicKey {
|
||||
var i int
|
||||
return func() []key.PublicKey {
|
||||
defer func() { i++ }()
|
||||
return [][]key.PublicKey{
|
||||
[]key.PublicKey{},
|
||||
[]key.PublicKey{pk2},
|
||||
}[i]
|
||||
}
|
||||
}(),
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", "XXX", past12, future12, priv2.Signer()),
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
{
|
||||
name: "JWT signed with available key, 'aud' is a string array",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error { return nil },
|
||||
keysFunc: func() []key.PublicKey {
|
||||
return []key.PublicKey{pk1}
|
||||
},
|
||||
},
|
||||
jwt: newJWT(iss, "XXX", []string{"ZZZ", "XXX"}, past12, future12, priv1.Signer()),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer claim shouldn't trigger sync",
|
||||
verifier: JWTVerifier{
|
||||
issuer: "example.com",
|
||||
clientID: "XXX",
|
||||
syncFunc: func() error {
|
||||
t.Errorf("invalid issuer claim shouldn't trigger a sync")
|
||||
return nil
|
||||
},
|
||||
keysFunc: func() func() []key.PublicKey {
|
||||
var i int
|
||||
return func() []key.PublicKey {
|
||||
defer func() { i++ }()
|
||||
return [][]key.PublicKey{
|
||||
[]key.PublicKey{},
|
||||
[]key.PublicKey{pk2},
|
||||
}[i]
|
||||
}
|
||||
}(),
|
||||
},
|
||||
jwt: newJWT("invalid-issuer", "XXX", []string{"ZZZ", "XXX"}, past12, future12, priv2.Signer()),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
err := tt.verifier.Verify(tt.jwt)
|
||||
if tt.wantErr && (err == nil) {
|
||||
t.Errorf("case %q: wanted non-nil error", tt.name)
|
||||
} else if !tt.wantErr && (err != nil) {
|
||||
t.Errorf("case %q: wanted nil error, got %v", tt.name, err)
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user