From d31f6eabd467dd2b9ddb1aae184b7c67f44f5dd2 Mon Sep 17 00:00:00 2001 From: Andrew Block Date: Thu, 26 Dec 2019 20:32:12 -0600 Subject: [PATCH] Corrected logic in group verification --- connector/openshift/openshift.go | 14 ++++++++------ connector/openshift/openshift_test.go | 24 +++++++++++++++++++++--- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/connector/openshift/openshift.go b/connector/openshift/openshift.go index e1974694..6ac5d044 100644 --- a/connector/openshift/openshift.go +++ b/connector/openshift/openshift.go @@ -165,10 +165,12 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) return identity, fmt.Errorf("openshift: get user: %v", err) } - validGroups := validateRequiredGroups(user.Groups, c.groups) + if len(c.groups) > 0 { + validGroups := validateAllowedGroups(user.Groups, c.groups) - if !validGroups { - return identity, fmt.Errorf("openshift: user %q is not in any of the required groups", user.Name) + if !validGroups { + return identity, fmt.Errorf("openshift: user %q is not in any of the required groups", user.Name) + } } identity = connector.Identity{ @@ -211,10 +213,10 @@ func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u u return u, err } -func validateRequiredGroups(userGroups, requiredGroups []string) bool { - matchingGroups := groups.Filter(userGroups, requiredGroups) +func validateAllowedGroups(userGroups, allowedGroups []string) bool { + matchingGroups := groups.Filter(userGroups, allowedGroups) - return len(requiredGroups) == len(matchingGroups) + return len(matchingGroups) != 0 } // newHTTPClient returns a new HTTP client diff --git a/connector/openshift/openshift_test.go b/connector/openshift/openshift_test.go index 2ed50150..316af60a 100644 --- a/connector/openshift/openshift_test.go +++ b/connector/openshift/openshift_test.go @@ -83,11 +83,29 @@ func TestGetUser(t *testing.T) { expectEquals(t, len(u.Groups), 1) } -func TestVerifyGroupFn(t *testing.T) { - requiredGroups := []string{"users"} +func TestVerifySingleGroupFn(t *testing.T) { + allowedGroups := []string{"users"} groupMembership := []string{"users", "org1"} - validGroupMembership := validateRequiredGroups(groupMembership, requiredGroups) + validGroupMembership := validateAllowedGroups(groupMembership, allowedGroups) + + expectEquals(t, validGroupMembership, true) +} + +func TestVerifySingleGroupFailureFn(t *testing.T) { + allowedGroups := []string{"admins"} + groupMembership := []string{"users"} + + validGroupMembership := validateAllowedGroups(groupMembership, allowedGroups) + + expectEquals(t, validGroupMembership, false) +} + +func TestVerifyMultipleGroupFn(t *testing.T) { + allowedGroups := []string{"users", "admins"} + groupMembership := []string{"users", "org1"} + + validGroupMembership := validateAllowedGroups(groupMembership, allowedGroups) expectEquals(t, validGroupMembership, true) }