connectors: refactor filter code into a helper package
I hope I didn't miss any :D Signed-off-by: Stephan Renatus <srenatus@chef.io>
This commit is contained in:
		| @@ -6,7 +6,6 @@ import ( | |||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/dexidp/dex/pkg/log" |  | ||||||
| 	"io/ioutil" | 	"io/ioutil" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"sync" | 	"sync" | ||||||
| @@ -16,6 +15,8 @@ import ( | |||||||
| 	"golang.org/x/oauth2/bitbucket" | 	"golang.org/x/oauth2/bitbucket" | ||||||
|  |  | ||||||
| 	"github.com/dexidp/dex/connector" | 	"github.com/dexidp/dex/connector" | ||||||
|  | 	"github.com/dexidp/dex/pkg/groups" | ||||||
|  | 	"github.com/dexidp/dex/pkg/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| const ( | const ( | ||||||
| @@ -350,7 +351,7 @@ func (b *bitbucketConnector) getGroups(ctx context.Context, client *http.Client, | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if len(b.teams) > 0 { | 	if len(b.teams) > 0 { | ||||||
| 		filteredTeams := filterTeams(bitbucketTeams, b.teams) | 		filteredTeams := groups.Filter(bitbucketTeams, b.teams) | ||||||
| 		if len(filteredTeams) == 0 { | 		if len(filteredTeams) == 0 { | ||||||
| 			return nil, fmt.Errorf("bitbucket: user %q is not in any of the required teams", userLogin) | 			return nil, fmt.Errorf("bitbucket: user %q is not in any of the required teams", userLogin) | ||||||
| 		} | 		} | ||||||
| @@ -362,21 +363,6 @@ func (b *bitbucketConnector) getGroups(ctx context.Context, client *http.Client, | |||||||
| 	return nil, nil | 	return nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Filter the users' team memberships by 'teams' from config. |  | ||||||
| func filterTeams(userTeams, configTeams []string) []string { |  | ||||||
| 	teams := []string{} |  | ||||||
| 	teamFilter := make(map[string]struct{}) |  | ||||||
| 	for _, team := range configTeams { |  | ||||||
| 		teamFilter[team] = struct{}{} |  | ||||||
| 	} |  | ||||||
| 	for _, team := range userTeams { |  | ||||||
| 		if _, ok := teamFilter[team]; ok { |  | ||||||
| 			teams = append(teams, team) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return teams |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type team struct { | type team struct { | ||||||
| 	Name string `json:"username"` // The "username" from Bitbucket Cloud is actually the team name here | 	Name string `json:"username"` // The "username" from Bitbucket Cloud is actually the team name here | ||||||
| } | } | ||||||
|   | |||||||
| @@ -20,6 +20,7 @@ import ( | |||||||
| 	"golang.org/x/oauth2/github" | 	"golang.org/x/oauth2/github" | ||||||
|  |  | ||||||
| 	"github.com/dexidp/dex/connector" | 	"github.com/dexidp/dex/connector" | ||||||
|  | 	groups_pkg "github.com/dexidp/dex/pkg/groups" | ||||||
| 	"github.com/dexidp/dex/pkg/log" | 	"github.com/dexidp/dex/pkg/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -375,7 +376,7 @@ func (c *githubConnector) groupsForOrgs(ctx context.Context, client *http.Client | |||||||
| 		// 'teams' list in config. | 		// 'teams' list in config. | ||||||
| 		if len(org.Teams) == 0 { | 		if len(org.Teams) == 0 { | ||||||
| 			inOrgNoTeams = true | 			inOrgNoTeams = true | ||||||
| 		} else if teams = filterTeams(teams, org.Teams); len(teams) == 0 { | 		} else if teams = groups_pkg.Filter(teams, org.Teams); len(teams) == 0 { | ||||||
| 			c.logger.Infof("github: user %q in org %q but no teams", userName, org.Name) | 			c.logger.Infof("github: user %q in org %q but no teams", userName, org.Name) | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| @@ -466,22 +467,6 @@ func (c *githubConnector) userOrgTeams(ctx context.Context, client *http.Client) | |||||||
| 	return groups, nil | 	return groups, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Filter the users' team memberships by 'teams' from config. |  | ||||||
| func filterTeams(userTeams, configTeams []string) (teams []string) { |  | ||||||
| 	teamFilter := make(map[string]struct{}) |  | ||||||
| 	for _, team := range configTeams { |  | ||||||
| 		if _, ok := teamFilter[team]; !ok { |  | ||||||
| 			teamFilter[team] = struct{}{} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	for _, team := range userTeams { |  | ||||||
| 		if _, ok := teamFilter[team]; ok { |  | ||||||
| 			teams = append(teams, team) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // get creates a "GET `apiURL`" request with context, sends the request using | // get creates a "GET `apiURL`" request with context, sends the request using | ||||||
| // the client, and decodes the resulting response body into v. A pagination URL | // the client, and decodes the resulting response body into v. A pagination URL | ||||||
| // is returned if one exists. Any errors encountered when building requests, | // is returned if one exists. Any errors encountered when building requests, | ||||||
|   | |||||||
| @@ -13,6 +13,7 @@ import ( | |||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
|  |  | ||||||
| 	"github.com/dexidp/dex/connector" | 	"github.com/dexidp/dex/connector" | ||||||
|  | 	"github.com/dexidp/dex/pkg/groups" | ||||||
| 	"github.com/dexidp/dex/pkg/log" | 	"github.com/dexidp/dex/pkg/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -273,7 +274,7 @@ func (c *gitlabConnector) getGroups(ctx context.Context, client *http.Client, gr | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if len(c.groups) > 0 { | 	if len(c.groups) > 0 { | ||||||
| 		filteredGroups := filterGroups(gitlabGroups, c.groups) | 		filteredGroups := groups.Filter(gitlabGroups, c.groups) | ||||||
| 		if len(filteredGroups) == 0 { | 		if len(filteredGroups) == 0 { | ||||||
| 			return nil, fmt.Errorf("gitlab: user %q is not in any of the required groups", userLogin) | 			return nil, fmt.Errorf("gitlab: user %q is not in any of the required groups", userLogin) | ||||||
| 		} | 		} | ||||||
| @@ -284,18 +285,3 @@ func (c *gitlabConnector) getGroups(ctx context.Context, client *http.Client, gr | |||||||
|  |  | ||||||
| 	return nil, nil | 	return nil, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Filter the users' group memberships by 'groups' from config. |  | ||||||
| func filterGroups(userGroups, configGroups []string) []string { |  | ||||||
| 	groups := []string{} |  | ||||||
| 	groupFilter := make(map[string]struct{}) |  | ||||||
| 	for _, group := range configGroups { |  | ||||||
| 		groupFilter[group] = struct{}{} |  | ||||||
| 	} |  | ||||||
| 	for _, group := range userGroups { |  | ||||||
| 		if _, ok := groupFilter[group]; ok { |  | ||||||
| 			groups = append(groups, group) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	return groups |  | ||||||
| } |  | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ import ( | |||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
|  |  | ||||||
| 	"github.com/dexidp/dex/connector" | 	"github.com/dexidp/dex/connector" | ||||||
|  | 	groups_pkg "github.com/dexidp/dex/pkg/groups" | ||||||
| 	"github.com/dexidp/dex/pkg/log" | 	"github.com/dexidp/dex/pkg/log" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -311,22 +312,9 @@ func (c *microsoftConnector) getGroups(ctx context.Context, client *http.Client, | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	// ensure that the user is in at least one required group | 	// ensure that the user is in at least one required group | ||||||
| 	isInGroups := false | 	filteredGroups := groups_pkg.Filter(groups, c.groups) | ||||||
| 	if len(c.groups) > 0 { | 	if len(c.groups) > 0 && len(filteredGroups) == 0 { | ||||||
| 		gs := make(map[string]struct{}) | 		return nil, fmt.Errorf("microsoft: user %v not in any of the required groups", userID) | ||||||
| 		for _, g := range c.groups { |  | ||||||
| 			gs[g] = struct{}{} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		for _, g := range groups { |  | ||||||
| 			if _, ok := gs[g]; ok { |  | ||||||
| 				isInGroups = true |  | ||||||
| 				break |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	if len(c.groups) > 0 && !isInGroups { |  | ||||||
| 		return nil, fmt.Errorf("microsoft: user %v not in required groups", userID) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return | 	return | ||||||
|   | |||||||
							
								
								
									
										18
									
								
								pkg/groups/groups.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								pkg/groups/groups.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | |||||||
|  | // Package groups contains helper functions related to groups | ||||||
|  | package groups | ||||||
|  |  | ||||||
|  | // Filter filters out any groups of given that are not in required. Thus it may | ||||||
|  | // happen that the resulting slice is empty. | ||||||
|  | func Filter(given, required []string) []string { | ||||||
|  | 	groups := []string{} | ||||||
|  | 	groupFilter := make(map[string]struct{}) | ||||||
|  | 	for _, group := range required { | ||||||
|  | 		groupFilter[group] = struct{}{} | ||||||
|  | 	} | ||||||
|  | 	for _, group := range given { | ||||||
|  | 		if _, ok := groupFilter[group]; ok { | ||||||
|  | 			groups = append(groups, group) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return groups | ||||||
|  | } | ||||||
							
								
								
									
										26
									
								
								pkg/groups/groups_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								pkg/groups/groups_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | |||||||
|  | package groups_test | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"github.com/stretchr/testify/assert" | ||||||
|  |  | ||||||
|  | 	"github.com/dexidp/dex/pkg/groups" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestFilter(t *testing.T) { | ||||||
|  | 	cases := map[string]struct { | ||||||
|  | 		given, required, expected []string | ||||||
|  | 	}{ | ||||||
|  | 		"nothing given":                 {given: []string{}, required: []string{"ops"}, expected: []string{}}, | ||||||
|  | 		"exactly one match":             {given: []string{"foo"}, required: []string{"foo"}, expected: []string{"foo"}}, | ||||||
|  | 		"no group of the required ones": {given: []string{"foo", "bar"}, required: []string{"baz"}, expected: []string{}}, | ||||||
|  | 		"subset matching":               {given: []string{"foo", "bar", "baz"}, required: []string{"bar", "baz"}, expected: []string{"bar", "baz"}}, | ||||||
|  | 	} | ||||||
|  | 	for name, tc := range cases { | ||||||
|  | 		t.Run(name, func(t *testing.T) { | ||||||
|  | 			actual := groups.Filter(tc.given, tc.required) | ||||||
|  | 			assert.ElementsMatch(t, tc.expected, actual) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user