*: validate InResponseTo SAML response field and make issuer optional
This commit is contained in:
		| @@ -66,20 +66,25 @@ type CallbackConnector interface { | ||||
| } | ||||
|  | ||||
| // SAMLConnector represents SAML connectors which implement the HTTP POST binding. | ||||
| //  RelayState is handled by the server. | ||||
| // | ||||
| // RelayState is handled by the server. | ||||
| // See: https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf | ||||
| // "3.5 HTTP POST Binding" | ||||
| type SAMLConnector interface { | ||||
| 	// POSTData returns an encoded SAML request and SSO URL for the server to | ||||
| 	// render a POST form with. | ||||
| 	POSTData(s Scopes) (sooURL, samlRequest string, err error) | ||||
| 	// | ||||
| 	// POSTData should encode the provided request ID in the returned serialized | ||||
| 	// SAML request. | ||||
| 	POSTData(s Scopes, requestID string) (sooURL, samlRequest string, err error) | ||||
|  | ||||
| 	// TODO(ericchiang): Provide expected "InResponseTo" ID value. | ||||
| 	// HandlePOST decodes, verifies, and maps attributes from the SAML response. | ||||
| 	// It passes the expected value of the "InResponseTo" response field, which | ||||
| 	// the connector must ensure matches the response value. | ||||
| 	// | ||||
| 	// See: https://www.oasis-open.org/committees/download.php/35711/sstc-saml-core-errata-2.0-wd-06-diff.pdf | ||||
| 	// "3.2.2 Complex Type StatusResponseType" | ||||
|  | ||||
| 	// HandlePOST decodes, verifies, and maps attributes from the SAML response. | ||||
| 	HandlePOST(s Scopes, samlResponse string) (identity Identity, err error) | ||||
| 	HandlePOST(s Scopes, samlResponse, inResponseTo string) (identity Identity, err error) | ||||
| } | ||||
|  | ||||
| // RefreshConnector is a connector that can update the client claims. | ||||
|   | ||||
| @@ -135,7 +135,6 @@ func (c *Config) openConnector(logger logrus.FieldLogger) (interface { | ||||
| 	requiredFields := []struct { | ||||
| 		name, val string | ||||
| 	}{ | ||||
| 		{"issuer", c.Issuer}, | ||||
| 		{"ssoURL", c.SSOURL}, | ||||
| 		{"usernameAttr", c.UsernameAttr}, | ||||
| 		{"emailAttr", c.EmailAttr}, | ||||
| @@ -240,7 +239,7 @@ type provider struct { | ||||
| 	logger logrus.FieldLogger | ||||
| } | ||||
|  | ||||
| func (p *provider) POSTData(s connector.Scopes) (action, value string, err error) { | ||||
| func (p *provider) POSTData(s connector.Scopes, id string) (action, value string, err error) { | ||||
|  | ||||
| 	// NOTE(ericchiang): If we can't follow up with the identity provider, can we | ||||
| 	// support refresh tokens? | ||||
| @@ -250,28 +249,32 @@ func (p *provider) POSTData(s connector.Scopes) (action, value string, err error | ||||
|  | ||||
| 	r := &authnRequest{ | ||||
| 		ProtocolBinding: bindingPOST, | ||||
| 		ID:              "_" + uuidv4(), | ||||
| 		ID:              id, | ||||
| 		IssueInstant:    xmlTime(p.now()), | ||||
| 		Destination:     p.ssoURL, | ||||
| 		Issuer: &issuer{ | ||||
| 			Issuer: p.issuer, | ||||
| 		}, | ||||
| 		NameIDPolicy: &nameIDPolicy{ | ||||
| 			AllowCreate: true, | ||||
| 			Format:      p.nameIDPolicyFormat, | ||||
| 		}, | ||||
| 		AssertionConsumerServiceURL: p.redirectURI, | ||||
| 	} | ||||
| 	if p.issuer != "" { | ||||
| 		// Issuer for the request is optional. For example, okta always ignores | ||||
| 		// this value. | ||||
| 		r.Issuer = &issuer{Issuer: p.issuer} | ||||
| 	} | ||||
|  | ||||
| 	data, err := xml.MarshalIndent(r, "", "  ") | ||||
| 	if err != nil { | ||||
| 		return "", "", fmt.Errorf("marshal authn request: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// See: https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf | ||||
| 	// "3.5.4 Message Encoding" | ||||
| 	return p.ssoURL, base64.StdEncoding.EncodeToString(data), nil | ||||
| } | ||||
|  | ||||
| func (p *provider) HandlePOST(s connector.Scopes, samlResponse string) (ident connector.Identity, err error) { | ||||
| func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo string) (ident connector.Identity, err error) { | ||||
| 	rawResp, err := base64.StdEncoding.DecodeString(samlResponse) | ||||
| 	if err != nil { | ||||
| 		return ident, fmt.Errorf("decode response: %v", err) | ||||
| @@ -287,6 +290,17 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse string) (ident co | ||||
| 		return ident, fmt.Errorf("unmarshal response: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if p.issuer != "" && resp.Issuer != nil && resp.Issuer.Issuer != p.issuer { | ||||
| 		return ident, fmt.Errorf("expected Issuer value %s, got %s", p.issuer, resp.Issuer.Issuer) | ||||
| 	} | ||||
|  | ||||
| 	// Verify InResponseTo value matches the expected ID associated with | ||||
| 	// the RelayState. | ||||
| 	if resp.InResponseTo != inResponseTo { | ||||
| 		return ident, fmt.Errorf("expected InResponseTo value %s, got %s", inResponseTo, resp.InResponseTo) | ||||
| 	} | ||||
|  | ||||
| 	// Destination is optional. | ||||
| 	if resp.Destination != "" && resp.Destination != p.redirectURI { | ||||
| 		return ident, fmt.Errorf("expected destination %q got %q", p.redirectURI, resp.Destination) | ||||
|  | ||||
| @@ -327,26 +341,26 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse string) (ident co | ||||
| 	} | ||||
|  | ||||
| 	if ident.Email, _ = attributes.get(p.emailAttr); ident.Email == "" { | ||||
| 		return ident, fmt.Errorf("no attribute with name %q", p.emailAttr) | ||||
| 		return ident, fmt.Errorf("no attribute with name %q: %s", p.emailAttr, attributes.names()) | ||||
| 	} | ||||
| 	ident.EmailVerified = true | ||||
|  | ||||
| 	if ident.Username, _ = attributes.get(p.usernameAttr); ident.Username == "" { | ||||
| 		return ident, fmt.Errorf("no attribute with name %q", p.usernameAttr) | ||||
| 		return ident, fmt.Errorf("no attribute with name %q: %s", p.usernameAttr, attributes.names()) | ||||
| 	} | ||||
|  | ||||
| 	if s.Groups && p.groupsAttr != "" { | ||||
| 		if p.groupsDelim != "" { | ||||
| 			groupsStr, ok := attributes.get(p.groupsAttr) | ||||
| 			if !ok { | ||||
| 				return ident, fmt.Errorf("no attribute with name %q", p.groupsAttr) | ||||
| 				return ident, fmt.Errorf("no attribute with name %q: %s", p.groupsAttr, attributes.names()) | ||||
| 			} | ||||
| 			// TODO(ericchiang): Do we need to further trim whitespace? | ||||
| 			ident.Groups = strings.Split(groupsStr, p.groupsDelim) | ||||
| 		} else { | ||||
| 			groups, ok := attributes.all(p.groupsAttr) | ||||
| 			if !ok { | ||||
| 				return ident, fmt.Errorf("no attribute with name %q", p.groupsAttr) | ||||
| 				return ident, fmt.Errorf("no attribute with name %q: %s", p.groupsAttr, attributes.names()) | ||||
| 			} | ||||
| 			ident.Groups = groups | ||||
| 		} | ||||
| @@ -427,6 +441,9 @@ func (p *provider) validateSubjectConfirmation(subject *subject) error { | ||||
| } | ||||
|  | ||||
| // Validates the Conditions element and all of it's content | ||||
| // | ||||
| // See: https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf | ||||
| // "2.3.3 Element <Assertion>" | ||||
| func (p *provider) validateConditions(assertion *assertion) error { | ||||
| 	// Checks if a Conditions element exists | ||||
| 	conditions := assertion.Conditions | ||||
| @@ -452,15 +469,17 @@ func (p *provider) validateConditions(assertion *assertion) error { | ||||
| 	if audienceRestriction != nil { | ||||
| 		audiences := audienceRestriction.Audiences | ||||
| 		if audiences != nil && len(audiences) > 0 { | ||||
| 			values := make([]string, len(audiences)) | ||||
| 			issuerInAudiences := false | ||||
| 			for _, audience := range audiences { | ||||
| 				if audience.Value == p.issuer { | ||||
| 			for i, audience := range audiences { | ||||
| 				if audience.Value == p.redirectURI { | ||||
| 					issuerInAudiences = true | ||||
| 					break | ||||
| 				} | ||||
| 				values[i] = audience.Value | ||||
| 			} | ||||
| 			if !issuerInAudiences { | ||||
| 				return fmt.Errorf("required audience %s was not in Response audiences %s", p.issuer, audiences) | ||||
| 				return fmt.Errorf("required audience %s was not in Response audiences %s", p.redirectURI, values) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|   | ||||
| @@ -18,8 +18,11 @@ import ( | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	defaultIssuer      = "http://localhost:5556/dex/callback" | ||||
| 	defaultIssuer      = "http://www.okta.com/exk91cb99lKkKSYoy0h7" | ||||
| 	defaultRedirectURI = "http://localhost:5556/dex/callback" | ||||
|  | ||||
| 	// Response ID embedded in our testdata. | ||||
| 	testDataResponseID = "_fd1b3ef9-ec09-44a7-a66b-0d39c250f6a0" | ||||
| ) | ||||
|  | ||||
| func loadCert(ca string) (*x509.Certificate, error) { | ||||
| @@ -109,7 +112,7 @@ func TestHandlePOST(t *testing.T) { | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	ident, err := p.HandlePOST(scopes, base64.StdEncoding.EncodeToString(data)) | ||||
| 	ident, err := p.HandlePOST(scopes, base64.StdEncoding.EncodeToString(data), testDataResponseID) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| @@ -254,12 +257,12 @@ func TestValidateConditions(t *testing.T) { | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("validation of %q should succeed", "Conditions where notBefore is 15 seconds after now") | ||||
| 	} | ||||
| 	// Audiences contains the issuer | ||||
| 	validAudience := audience{Value: p.issuer} | ||||
| 	// Audiences contains the redirectURI | ||||
| 	validAudience := audience{Value: p.redirectURI} | ||||
| 	cond.AudienceRestriction.Audiences = []audience{validAudience} | ||||
| 	err = p.validateConditions(assert) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("validation of %q should succeed", "Audiences contains the issuer") | ||||
| 		t.Fatalf("validation of %q should succeed: %v", "Audiences contains the redirectURI", err) | ||||
| 	} | ||||
| 	// Audiences is not empty and not contains the issuer | ||||
| 	invalidAudience := audience{Value: "invalid"} | ||||
|   | ||||
| @@ -162,8 +162,9 @@ type authnContextClassRef struct { | ||||
| type response struct { | ||||
| 	XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:protocol Response"` | ||||
|  | ||||
| 	ID      string      `xml:"ID,attr"` | ||||
| 	Version samlVersion `xml:"Version,attr"` | ||||
| 	ID           string      `xml:"ID,attr"` | ||||
| 	InResponseTo string      `xml:"InResponseTo,attr"` | ||||
| 	Version      samlVersion `xml:"Version,attr"` | ||||
|  | ||||
| 	Destination string `xml:"Destination,attr,omitempty"` | ||||
|  | ||||
| @@ -221,6 +222,16 @@ func (a *attributeStatement) all(name string) (s []string, ok bool) { | ||||
| 	return | ||||
| } | ||||
|  | ||||
| // names list the names of all attributes in the attribute statement. | ||||
| func (a *attributeStatement) names() []string { | ||||
| 	s := make([]string, len(a.Attributes)) | ||||
|  | ||||
| 	for i, attr := range a.Attributes { | ||||
| 		s[i] = attr.Name | ||||
| 	} | ||||
| 	return s | ||||
| } | ||||
|  | ||||
| type attribute struct { | ||||
| 	XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:assertion Attribute"` | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user