connector/saml: clean up SAML verification logic and comments
This commit is contained in:
@@ -2,10 +2,8 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/pem"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
@@ -270,12 +268,22 @@ func (p *provider) POSTData(s connector.Scopes, id string) (action, value string
|
||||
return p.ssoURL, base64.StdEncoding.EncodeToString(data), nil
|
||||
}
|
||||
|
||||
// HandlePOST interprets a request from a SAML provider attempting to verify a
|
||||
// user's identity.
|
||||
//
|
||||
// The steps taken are:
|
||||
//
|
||||
// * Verify signature on XML document (or verify sig on assertion elements).
|
||||
// * Verify various parts of the Assertion element. Conditions, audience, etc.
|
||||
// * Map the Assertion's attribute elements to user info.
|
||||
//
|
||||
func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo string) (ident connector.Identity, err error) {
|
||||
rawResp, err := base64.StdEncoding.DecodeString(samlResponse)
|
||||
if err != nil {
|
||||
return ident, fmt.Errorf("decode response: %v", err)
|
||||
}
|
||||
|
||||
// Root element is allowed to not be signed if the Assertion element is.
|
||||
rootElementSigned := true
|
||||
if p.validator != nil {
|
||||
rawResp, rootElementSigned, err = verifyResponseSig(p.validator, rawResp)
|
||||
@@ -289,6 +297,8 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
|
||||
return ident, fmt.Errorf("unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
// If the root element isn't signed, there's no reason to inspect these
|
||||
// elements. They're not verified.
|
||||
if rootElementSigned {
|
||||
if p.ssoIssuer != "" && resp.Issuer != nil && resp.Issuer.Issuer != p.ssoIssuer {
|
||||
return ident, fmt.Errorf("expected Issuer value %s, got %s", p.ssoIssuer, resp.Issuer.Issuer)
|
||||
@@ -303,10 +313,14 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
|
||||
// Destination is optional.
|
||||
if resp.Destination != "" && resp.Destination != p.redirectURI {
|
||||
return ident, fmt.Errorf("expected destination %q got %q", p.redirectURI, resp.Destination)
|
||||
|
||||
}
|
||||
|
||||
if err = p.validateStatus(&resp); err != nil {
|
||||
// Status is a required element.
|
||||
if resp.Status == nil {
|
||||
return ident, fmt.Errorf("Response did not contain a Status element")
|
||||
}
|
||||
|
||||
if err = p.validateStatus(resp.Status); err != nil {
|
||||
return ident, err
|
||||
}
|
||||
}
|
||||
@@ -315,16 +329,25 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
|
||||
if assertion == nil {
|
||||
return ident, fmt.Errorf("response did not contain an assertion")
|
||||
}
|
||||
|
||||
// Subject is usually optional, but we need it for the user ID, so complain
|
||||
// if it's not present.
|
||||
subject := assertion.Subject
|
||||
if subject == nil {
|
||||
return ident, fmt.Errorf("response did not contain a subject")
|
||||
}
|
||||
|
||||
if err = p.validateConditions(assertion); err != nil {
|
||||
// Validate that the response is to the request we originally sent.
|
||||
if err = p.validateSubject(subject, inResponseTo); err != nil {
|
||||
return ident, err
|
||||
}
|
||||
if err = p.validateSubjectConfirmation(subject); err != nil {
|
||||
return ident, err
|
||||
|
||||
// Conditions element is optional, but must be validated if present.
|
||||
if assertion.Conditions != nil {
|
||||
// Validate that dex is the intended audience of this response.
|
||||
if err = p.validateConditions(assertion.Conditions); err != nil {
|
||||
return ident, err
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
@@ -336,53 +359,57 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
|
||||
return ident, fmt.Errorf("subject does not contain an NameID element")
|
||||
}
|
||||
|
||||
// After verifying the assertion, map data in the attribute statements to
|
||||
// various user info.
|
||||
attributes := assertion.AttributeStatement
|
||||
if attributes == nil {
|
||||
return ident, fmt.Errorf("response did not contain a AttributeStatement")
|
||||
}
|
||||
|
||||
// Grab the email.
|
||||
if ident.Email, _ = attributes.get(p.emailAttr); ident.Email == "" {
|
||||
return ident, fmt.Errorf("no attribute with name %q: %s", p.emailAttr, attributes.names())
|
||||
}
|
||||
// TODO(ericchiang): Does SAML have an email_verified equivalent?
|
||||
ident.EmailVerified = true
|
||||
|
||||
// Grab the username.
|
||||
if ident.Username, _ = attributes.get(p.usernameAttr); ident.Username == "" {
|
||||
return ident, fmt.Errorf("no attribute with name %q: %s", p.usernameAttr, attributes.names())
|
||||
}
|
||||
|
||||
if s.Groups && p.groupsAttr != "" {
|
||||
if p.groupsDelim != "" {
|
||||
groupsStr, ok := attributes.get(p.groupsAttr)
|
||||
if !ok {
|
||||
return ident, fmt.Errorf("no attribute with name %q: %s", p.groupsAttr, attributes.names())
|
||||
}
|
||||
// TODO(ericchiang): Do we need to further trim whitespace?
|
||||
ident.Groups = strings.Split(groupsStr, p.groupsDelim)
|
||||
} else {
|
||||
groups, ok := attributes.all(p.groupsAttr)
|
||||
if !ok {
|
||||
return ident, fmt.Errorf("no attribute with name %q: %s", p.groupsAttr, attributes.names())
|
||||
}
|
||||
ident.Groups = groups
|
||||
}
|
||||
if !s.Groups || p.groupsAttr == "" {
|
||||
// Groups not requested or not configured. We're done.
|
||||
return ident, nil
|
||||
}
|
||||
|
||||
// Grab the groups.
|
||||
if p.groupsDelim != "" {
|
||||
groupsStr, ok := attributes.get(p.groupsAttr)
|
||||
if !ok {
|
||||
return ident, fmt.Errorf("no attribute with name %q: %s", p.groupsAttr, attributes.names())
|
||||
}
|
||||
// TODO(ericchiang): Do we need to further trim whitespace?
|
||||
ident.Groups = strings.Split(groupsStr, p.groupsDelim)
|
||||
} else {
|
||||
groups, ok := attributes.all(p.groupsAttr)
|
||||
if !ok {
|
||||
return ident, fmt.Errorf("no attribute with name %q: %s", p.groupsAttr, attributes.names())
|
||||
}
|
||||
ident.Groups = groups
|
||||
}
|
||||
return ident, nil
|
||||
}
|
||||
|
||||
// Validate that the StatusCode of the Response is success.
|
||||
// Otherwise return a human readable message to the end user
|
||||
func (p *provider) validateStatus(resp *response) error {
|
||||
// Status is mandatory in the Response type
|
||||
status := resp.Status
|
||||
if status == nil {
|
||||
return fmt.Errorf("response did not contain a Status")
|
||||
}
|
||||
// validateStatus verifies that the response has a good status code or
|
||||
// formats a human readble error based on the bad status.
|
||||
func (p *provider) validateStatus(status *status) error {
|
||||
// StatusCode is mandatory in the Status type
|
||||
statusCode := status.StatusCode
|
||||
if statusCode == nil {
|
||||
return fmt.Errorf("response did not contain a StatusCode")
|
||||
}
|
||||
|
||||
if statusCode.Value != statusCodeSuccess {
|
||||
parts := strings.Split(statusCode.Value, ":")
|
||||
lastPart := parts[len(parts)-1]
|
||||
@@ -396,96 +423,107 @@ func (p *provider) validateStatus(resp *response) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Multiple subject SubjectConfirmation can be in the assertion
|
||||
// and at least one SubjectConfirmation must be valid.
|
||||
// validateSubject ensures the response is to the request we expect.
|
||||
//
|
||||
// This is described in the spec "Profiles for the OASIS Security
|
||||
// Assertion Markup Language" in section 3.3 Bearer.
|
||||
// see https://www.oasis-open.org/committees/download.php/35389/sstc-saml-profiles-errata-2.0-wd-06-diff.pdf
|
||||
func (p *provider) validateSubjectConfirmation(subject *subject) error {
|
||||
validSubjectConfirmation := false
|
||||
subjectConfirmations := subject.SubjectConfirmations
|
||||
if subjectConfirmations != nil && len(subjectConfirmations) > 0 {
|
||||
for _, subjectConfirmation := range subjectConfirmations {
|
||||
// skip if method is wrong
|
||||
method := subjectConfirmation.Method
|
||||
if method != "" && method != subjectConfirmationMethodBearer {
|
||||
continue
|
||||
//
|
||||
// Some of these fields are optional, but we're going to be strict here since
|
||||
// we have no other way of guarenteeing that this is actually the response to
|
||||
// the request we expect.
|
||||
func (p *provider) validateSubject(subject *subject, inResponseTo string) error {
|
||||
// Optional according to the spec, but again, we're going to be strict here.
|
||||
if len(subject.SubjectConfirmations) == 0 {
|
||||
return fmt.Errorf("Subject contained no SubjectConfrimations")
|
||||
}
|
||||
|
||||
var errs []error
|
||||
// One of these must match our assumptions, not all.
|
||||
for _, c := range subject.SubjectConfirmations {
|
||||
err := func() error {
|
||||
if c.Method != subjectConfirmationMethodBearer {
|
||||
return fmt.Errorf("unexpected subject confirmation method: %v", c.Method)
|
||||
}
|
||||
subjectConfirmationData := subjectConfirmation.SubjectConfirmationData
|
||||
if subjectConfirmationData == nil {
|
||||
continue
|
||||
|
||||
data := c.SubjectConfirmationData
|
||||
if data == nil {
|
||||
return fmt.Errorf("SubjectConfirmation contained no SubjectConfirmationData")
|
||||
}
|
||||
inResponseTo := subjectConfirmationData.InResponseTo
|
||||
if inResponseTo != "" {
|
||||
// TODO also validate InResponseTo if present
|
||||
if data.InResponseTo != inResponseTo {
|
||||
return fmt.Errorf("expected SubjectConfirmationData InResponseTo value %q, got %q", inResponseTo, data.InResponseTo)
|
||||
}
|
||||
// only validate that subjectConfirmationData is not expired
|
||||
|
||||
notBefore := time.Time(data.NotBefore)
|
||||
notOnOrAfter := time.Time(data.NotOnOrAfter)
|
||||
now := p.now()
|
||||
notOnOrAfter := time.Time(subjectConfirmationData.NotOnOrAfter)
|
||||
if !notOnOrAfter.IsZero() {
|
||||
if now.After(notOnOrAfter) {
|
||||
continue
|
||||
}
|
||||
if !notBefore.IsZero() && before(now, notBefore) {
|
||||
return fmt.Errorf("at %s got response that cannot be processed before %s", now, notBefore)
|
||||
}
|
||||
// validate recipient if present
|
||||
recipient := subjectConfirmationData.Recipient
|
||||
if recipient != "" && recipient != p.redirectURI {
|
||||
continue
|
||||
if !notOnOrAfter.IsZero() && after(now, notOnOrAfter) {
|
||||
return fmt.Errorf("at %s got response that cannot be processed because it expired at %s", now, notOnOrAfter)
|
||||
}
|
||||
validSubjectConfirmation = true
|
||||
if r := data.Recipient; r != "" && r != p.redirectURI {
|
||||
return fmt.Errorf("expected Recipient %q got %q", p.redirectURI, r)
|
||||
}
|
||||
return nil
|
||||
}()
|
||||
if err == nil {
|
||||
// Subject is valid.
|
||||
return nil
|
||||
}
|
||||
errs = append(errs, err)
|
||||
}
|
||||
if !validSubjectConfirmation {
|
||||
return fmt.Errorf("no valid SubjectConfirmation was found on this Response")
|
||||
|
||||
if len(errs) == 1 {
|
||||
return fmt.Errorf("failed to validate subject confirmation: %v", errs[0])
|
||||
}
|
||||
return nil
|
||||
return fmt.Errorf("failed to validate subject confirmation: %v", errs)
|
||||
}
|
||||
|
||||
// Validates the Conditions element and all of it's content
|
||||
// validationConditions ensures that dex is the intended audience
|
||||
// for the request, and not another service provider.
|
||||
//
|
||||
// See: https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf
|
||||
// "2.3.3 Element <Assertion>"
|
||||
func (p *provider) validateConditions(assertion *assertion) error {
|
||||
// Checks if a Conditions element exists
|
||||
conditions := assertion.Conditions
|
||||
if conditions == nil {
|
||||
return nil
|
||||
}
|
||||
// Validates Assertion timestamps
|
||||
func (p *provider) validateConditions(conditions *conditions) error {
|
||||
// Ensure the conditions haven't expired.
|
||||
now := p.now()
|
||||
notBefore := time.Time(conditions.NotBefore)
|
||||
if !notBefore.IsZero() {
|
||||
if now.Add(allowedClockDrift).Before(notBefore) {
|
||||
return fmt.Errorf("at %s got response that cannot be processed before %s", now, notBefore)
|
||||
}
|
||||
if !notBefore.IsZero() && before(now, notBefore) {
|
||||
return fmt.Errorf("at %s got response that cannot be processed before %s", now, notBefore)
|
||||
}
|
||||
|
||||
notOnOrAfter := time.Time(conditions.NotOnOrAfter)
|
||||
if !notOnOrAfter.IsZero() {
|
||||
if now.After(notOnOrAfter.Add(allowedClockDrift)) {
|
||||
return fmt.Errorf("at %s got response that cannot be processed because it expired at %s", now, notOnOrAfter)
|
||||
if !notOnOrAfter.IsZero() && after(now, notOnOrAfter) {
|
||||
return fmt.Errorf("at %s got response that cannot be processed because it expired at %s", now, notOnOrAfter)
|
||||
}
|
||||
|
||||
// Sometimes, dex's issuer string can be different than the redirect URI,
|
||||
// but if dex's issuer isn't explicitly provided assume the redirect URI.
|
||||
expAud := p.entityIssuer
|
||||
if expAud == "" {
|
||||
expAud = p.redirectURI
|
||||
}
|
||||
|
||||
// AudienceRestriction elements indicate the intended audience(s) of an
|
||||
// assertion. If dex isn't in these audiences, reject the assertion.
|
||||
//
|
||||
// Note that if there are multiple AudienceRestriction elements, each must
|
||||
// individually contain dex in their audience list.
|
||||
for _, r := range conditions.AudienceRestriction {
|
||||
values := make([]string, len(r.Audiences))
|
||||
issuerInAudiences := false
|
||||
for i, aud := range r.Audiences {
|
||||
if aud.Value == expAud {
|
||||
issuerInAudiences = true
|
||||
break
|
||||
}
|
||||
values[i] = aud.Value
|
||||
}
|
||||
}
|
||||
// Validates audience
|
||||
audienceValue := p.entityIssuer
|
||||
if audienceValue == "" {
|
||||
audienceValue = p.redirectURI
|
||||
}
|
||||
audienceRestriction := conditions.AudienceRestriction
|
||||
if audienceRestriction != nil {
|
||||
audiences := audienceRestriction.Audiences
|
||||
if audiences != nil && len(audiences) > 0 {
|
||||
values := make([]string, len(audiences))
|
||||
issuerInAudiences := false
|
||||
for i, audience := range audiences {
|
||||
if audience.Value == audienceValue {
|
||||
issuerInAudiences = true
|
||||
break
|
||||
}
|
||||
values[i] = audience.Value
|
||||
}
|
||||
if !issuerInAudiences {
|
||||
return fmt.Errorf("required audience %s was not in Response audiences %s", audienceValue, values)
|
||||
}
|
||||
|
||||
if !issuerInAudiences {
|
||||
return fmt.Errorf("required audience %s was not in Response audiences %s", expAud, values)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -544,24 +582,14 @@ func verifyResponseSig(validator *dsig.ValidationContext, data []byte) (signed [
|
||||
return signed, false, err
|
||||
}
|
||||
|
||||
func uuidv4() string {
|
||||
u := make([]byte, 16)
|
||||
if _, err := rand.Read(u); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
u[6] = (u[6] | 0x40) & 0x4F
|
||||
u[8] = (u[8] | 0x80) & 0xBF
|
||||
|
||||
r := make([]byte, 36)
|
||||
r[8] = '-'
|
||||
r[13] = '-'
|
||||
r[18] = '-'
|
||||
r[23] = '-'
|
||||
hex.Encode(r, u[0:4])
|
||||
hex.Encode(r[9:], u[4:6])
|
||||
hex.Encode(r[14:], u[6:8])
|
||||
hex.Encode(r[19:], u[8:10])
|
||||
hex.Encode(r[24:], u[10:])
|
||||
|
||||
return string(r)
|
||||
// before determines if a given time is before the current time, with an
|
||||
// allowed clock drift.
|
||||
func before(now, notBefore time.Time) bool {
|
||||
return now.Add(allowedClockDrift).Before(notBefore)
|
||||
}
|
||||
|
||||
// after determines if a given time is after the current time, with an
|
||||
// allowed clock drift.
|
||||
func after(now, notOnOrAfter time.Time) bool {
|
||||
return now.After(notOnOrAfter.Add(allowedClockDrift))
|
||||
}
|
||||
|
Reference in New Issue
Block a user