connectors: refactor filter code into a helper package
I hope I didn't miss any :D Signed-off-by: Stephan Renatus <srenatus@chef.io>
This commit is contained in:
		@@ -6,7 +6,6 @@ import (
 | 
				
			|||||||
	"encoding/json"
 | 
						"encoding/json"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"github.com/dexidp/dex/pkg/log"
 | 
					 | 
				
			||||||
	"io/ioutil"
 | 
						"io/ioutil"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
@@ -16,6 +15,8 @@ import (
 | 
				
			|||||||
	"golang.org/x/oauth2/bitbucket"
 | 
						"golang.org/x/oauth2/bitbucket"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/dexidp/dex/connector"
 | 
						"github.com/dexidp/dex/connector"
 | 
				
			||||||
 | 
						"github.com/dexidp/dex/pkg/groups"
 | 
				
			||||||
 | 
						"github.com/dexidp/dex/pkg/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
@@ -350,7 +351,7 @@ func (b *bitbucketConnector) getGroups(ctx context.Context, client *http.Client,
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(b.teams) > 0 {
 | 
						if len(b.teams) > 0 {
 | 
				
			||||||
		filteredTeams := filterTeams(bitbucketTeams, b.teams)
 | 
							filteredTeams := groups.Filter(bitbucketTeams, b.teams)
 | 
				
			||||||
		if len(filteredTeams) == 0 {
 | 
							if len(filteredTeams) == 0 {
 | 
				
			||||||
			return nil, fmt.Errorf("bitbucket: user %q is not in any of the required teams", userLogin)
 | 
								return nil, fmt.Errorf("bitbucket: user %q is not in any of the required teams", userLogin)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -362,21 +363,6 @@ func (b *bitbucketConnector) getGroups(ctx context.Context, client *http.Client,
 | 
				
			|||||||
	return nil, nil
 | 
						return nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Filter the users' team memberships by 'teams' from config.
 | 
					 | 
				
			||||||
func filterTeams(userTeams, configTeams []string) []string {
 | 
					 | 
				
			||||||
	teams := []string{}
 | 
					 | 
				
			||||||
	teamFilter := make(map[string]struct{})
 | 
					 | 
				
			||||||
	for _, team := range configTeams {
 | 
					 | 
				
			||||||
		teamFilter[team] = struct{}{}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, team := range userTeams {
 | 
					 | 
				
			||||||
		if _, ok := teamFilter[team]; ok {
 | 
					 | 
				
			||||||
			teams = append(teams, team)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return teams
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type team struct {
 | 
					type team struct {
 | 
				
			||||||
	Name string `json:"username"` // The "username" from Bitbucket Cloud is actually the team name here
 | 
						Name string `json:"username"` // The "username" from Bitbucket Cloud is actually the team name here
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,6 +20,7 @@ import (
 | 
				
			|||||||
	"golang.org/x/oauth2/github"
 | 
						"golang.org/x/oauth2/github"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/dexidp/dex/connector"
 | 
						"github.com/dexidp/dex/connector"
 | 
				
			||||||
 | 
						groups_pkg "github.com/dexidp/dex/pkg/groups"
 | 
				
			||||||
	"github.com/dexidp/dex/pkg/log"
 | 
						"github.com/dexidp/dex/pkg/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -375,7 +376,7 @@ func (c *githubConnector) groupsForOrgs(ctx context.Context, client *http.Client
 | 
				
			|||||||
		// 'teams' list in config.
 | 
							// 'teams' list in config.
 | 
				
			||||||
		if len(org.Teams) == 0 {
 | 
							if len(org.Teams) == 0 {
 | 
				
			||||||
			inOrgNoTeams = true
 | 
								inOrgNoTeams = true
 | 
				
			||||||
		} else if teams = filterTeams(teams, org.Teams); len(teams) == 0 {
 | 
							} else if teams = groups_pkg.Filter(teams, org.Teams); len(teams) == 0 {
 | 
				
			||||||
			c.logger.Infof("github: user %q in org %q but no teams", userName, org.Name)
 | 
								c.logger.Infof("github: user %q in org %q but no teams", userName, org.Name)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -466,22 +467,6 @@ func (c *githubConnector) userOrgTeams(ctx context.Context, client *http.Client)
 | 
				
			|||||||
	return groups, nil
 | 
						return groups, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Filter the users' team memberships by 'teams' from config.
 | 
					 | 
				
			||||||
func filterTeams(userTeams, configTeams []string) (teams []string) {
 | 
					 | 
				
			||||||
	teamFilter := make(map[string]struct{})
 | 
					 | 
				
			||||||
	for _, team := range configTeams {
 | 
					 | 
				
			||||||
		if _, ok := teamFilter[team]; !ok {
 | 
					 | 
				
			||||||
			teamFilter[team] = struct{}{}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, team := range userTeams {
 | 
					 | 
				
			||||||
		if _, ok := teamFilter[team]; ok {
 | 
					 | 
				
			||||||
			teams = append(teams, team)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// get creates a "GET `apiURL`" request with context, sends the request using
 | 
					// get creates a "GET `apiURL`" request with context, sends the request using
 | 
				
			||||||
// the client, and decodes the resulting response body into v. A pagination URL
 | 
					// the client, and decodes the resulting response body into v. A pagination URL
 | 
				
			||||||
// is returned if one exists. Any errors encountered when building requests,
 | 
					// is returned if one exists. Any errors encountered when building requests,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,6 +13,7 @@ import (
 | 
				
			|||||||
	"golang.org/x/oauth2"
 | 
						"golang.org/x/oauth2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/dexidp/dex/connector"
 | 
						"github.com/dexidp/dex/connector"
 | 
				
			||||||
 | 
						"github.com/dexidp/dex/pkg/groups"
 | 
				
			||||||
	"github.com/dexidp/dex/pkg/log"
 | 
						"github.com/dexidp/dex/pkg/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -273,7 +274,7 @@ func (c *gitlabConnector) getGroups(ctx context.Context, client *http.Client, gr
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if len(c.groups) > 0 {
 | 
						if len(c.groups) > 0 {
 | 
				
			||||||
		filteredGroups := filterGroups(gitlabGroups, c.groups)
 | 
							filteredGroups := groups.Filter(gitlabGroups, c.groups)
 | 
				
			||||||
		if len(filteredGroups) == 0 {
 | 
							if len(filteredGroups) == 0 {
 | 
				
			||||||
			return nil, fmt.Errorf("gitlab: user %q is not in any of the required groups", userLogin)
 | 
								return nil, fmt.Errorf("gitlab: user %q is not in any of the required groups", userLogin)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -284,18 +285,3 @@ func (c *gitlabConnector) getGroups(ctx context.Context, client *http.Client, gr
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	return nil, nil
 | 
						return nil, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
// Filter the users' group memberships by 'groups' from config.
 | 
					 | 
				
			||||||
func filterGroups(userGroups, configGroups []string) []string {
 | 
					 | 
				
			||||||
	groups := []string{}
 | 
					 | 
				
			||||||
	groupFilter := make(map[string]struct{})
 | 
					 | 
				
			||||||
	for _, group := range configGroups {
 | 
					 | 
				
			||||||
		groupFilter[group] = struct{}{}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	for _, group := range userGroups {
 | 
					 | 
				
			||||||
		if _, ok := groupFilter[group]; ok {
 | 
					 | 
				
			||||||
			groups = append(groups, group)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	return groups
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -15,6 +15,7 @@ import (
 | 
				
			|||||||
	"golang.org/x/oauth2"
 | 
						"golang.org/x/oauth2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/dexidp/dex/connector"
 | 
						"github.com/dexidp/dex/connector"
 | 
				
			||||||
 | 
						groups_pkg "github.com/dexidp/dex/pkg/groups"
 | 
				
			||||||
	"github.com/dexidp/dex/pkg/log"
 | 
						"github.com/dexidp/dex/pkg/log"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -311,22 +312,9 @@ func (c *microsoftConnector) getGroups(ctx context.Context, client *http.Client,
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ensure that the user is in at least one required group
 | 
						// ensure that the user is in at least one required group
 | 
				
			||||||
	isInGroups := false
 | 
						filteredGroups := groups_pkg.Filter(groups, c.groups)
 | 
				
			||||||
	if len(c.groups) > 0 {
 | 
						if len(c.groups) > 0 && len(filteredGroups) == 0 {
 | 
				
			||||||
		gs := make(map[string]struct{})
 | 
							return nil, fmt.Errorf("microsoft: user %v not in any of the required groups", userID)
 | 
				
			||||||
		for _, g := range c.groups {
 | 
					 | 
				
			||||||
			gs[g] = struct{}{}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		for _, g := range groups {
 | 
					 | 
				
			||||||
			if _, ok := gs[g]; ok {
 | 
					 | 
				
			||||||
				isInGroups = true
 | 
					 | 
				
			||||||
				break
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	if len(c.groups) > 0 && !isInGroups {
 | 
					 | 
				
			||||||
		return nil, fmt.Errorf("microsoft: user %v not in required groups", userID)
 | 
					 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										18
									
								
								pkg/groups/groups.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								pkg/groups/groups.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,18 @@
 | 
				
			|||||||
 | 
					// Package groups contains helper functions related to groups
 | 
				
			||||||
 | 
					package groups
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Filter filters out any groups of given that are not in required. Thus it may
 | 
				
			||||||
 | 
					// happen that the resulting slice is empty.
 | 
				
			||||||
 | 
					func Filter(given, required []string) []string {
 | 
				
			||||||
 | 
						groups := []string{}
 | 
				
			||||||
 | 
						groupFilter := make(map[string]struct{})
 | 
				
			||||||
 | 
						for _, group := range required {
 | 
				
			||||||
 | 
							groupFilter[group] = struct{}{}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, group := range given {
 | 
				
			||||||
 | 
							if _, ok := groupFilter[group]; ok {
 | 
				
			||||||
 | 
								groups = append(groups, group)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return groups
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										26
									
								
								pkg/groups/groups_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								pkg/groups/groups_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,26 @@
 | 
				
			|||||||
 | 
					package groups_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/stretchr/testify/assert"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/dexidp/dex/pkg/groups"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestFilter(t *testing.T) {
 | 
				
			||||||
 | 
						cases := map[string]struct {
 | 
				
			||||||
 | 
							given, required, expected []string
 | 
				
			||||||
 | 
						}{
 | 
				
			||||||
 | 
							"nothing given":                 {given: []string{}, required: []string{"ops"}, expected: []string{}},
 | 
				
			||||||
 | 
							"exactly one match":             {given: []string{"foo"}, required: []string{"foo"}, expected: []string{"foo"}},
 | 
				
			||||||
 | 
							"no group of the required ones": {given: []string{"foo", "bar"}, required: []string{"baz"}, expected: []string{}},
 | 
				
			||||||
 | 
							"subset matching":               {given: []string{"foo", "bar", "baz"}, required: []string{"bar", "baz"}, expected: []string{"bar", "baz"}},
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for name, tc := range cases {
 | 
				
			||||||
 | 
							t.Run(name, func(t *testing.T) {
 | 
				
			||||||
 | 
								actual := groups.Filter(tc.given, tc.required)
 | 
				
			||||||
 | 
								assert.ElementsMatch(t, tc.expected, actual)
 | 
				
			||||||
 | 
							})
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user