diff --git a/connector/saml/saml.go b/connector/saml/saml.go index 8210a641..2b092b3f 100644 --- a/connector/saml/saml.go +++ b/connector/saml/saml.go @@ -2,6 +2,7 @@ package saml import ( + "bytes" "crypto/x509" "encoding/base64" "encoding/pem" @@ -200,6 +201,10 @@ func (c *Config) openConnector(logger logrus.FieldLogger) (*provider, error) { for { block, caData = pem.Decode(caData) if block == nil { + caData = bytes.TrimSpace(caData) + if len(caData) > 0 { // if there's some left, we've been given bad caData + return nil, fmt.Errorf("parse cert: trailing data: %q", string(caData)) + } break } cert, err := x509.ParseCertificate(block.Bytes) diff --git a/connector/saml/saml_test.go b/connector/saml/saml_test.go index 2f319473..4497d059 100644 --- a/connector/saml/saml_test.go +++ b/connector/saml/saml_test.go @@ -337,6 +337,93 @@ func (r responseTest) run(t *testing.T) { } } +func TestConfigCAData(t *testing.T) { + logger := logrus.New() + validPEM, err := ioutil.ReadFile("testdata/ca.crt") + if err != nil { + t.Fatal(err) + } + valid2ndPEM, err := ioutil.ReadFile("testdata/okta-ca.pem") + if err != nil { + t.Fatal(err) + } + + // copy helper, avoid messing with the byte slice among different cases + c := func(bs []byte) []byte { + return append([]byte(nil), bs...) + } + + tests := []struct { + name string + caData []byte + wantErr bool + }{ + { + name: "one valid PEM entry", + caData: c(validPEM), + }, + { + name: "one valid PEM entry with trailing newline", + caData: append(c(validPEM), []byte("\n")...), + }, + { + name: "one valid PEM entry with trailing spaces", + caData: append(c(validPEM), []byte(" ")...), + }, + { + name: "one valid PEM entry with two trailing newlines", + caData: append(c(validPEM), []byte("\n\n")...), + }, + { + name: "two valid PEM entries", + caData: append(c(validPEM), c(valid2ndPEM)...), + }, + { + name: "two valid PEM entries with newline in between", + caData: append(append(c(validPEM), []byte("\n")...), c(valid2ndPEM)...), + }, + { + name: "two valid PEM entries with trailing newline", + caData: append(c(valid2ndPEM), append(c(validPEM), []byte("\n")...)...), + }, + { + name: "empty", + caData: []byte{}, + wantErr: true, + }, + { + name: "one valid PEM entry with trailing data", + caData: append(c(validPEM), []byte("yaddayadda")...), + wantErr: true, + }, + { + name: "one valid PEM entry with bad data before", + caData: append([]byte("yaddayadda"), c(validPEM)...), + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + c := Config{ + CAData: tc.caData, + UsernameAttr: "user", + EmailAttr: "email", + RedirectURI: "http://127.0.0.1:5556/dex/callback", + SSOURL: "http://foo.bar/", + } + _, err := (&c).Open("samltest", logger) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + } else if err != nil { + t.Errorf("expected no error, got %v", err) + } + }) + } +} + const ( defaultSSOIssuer = "http://www.okta.com/exk91cb99lKkKSYoy0h7" defaultRedirectURI = "http://localhost:5556/dex/callback"