Remove external setting, enable injection of HTTP client to config.
Signed-off-by: Daniel Haus <dhaus@redhat.com>
This commit is contained in:
		| @@ -35,7 +35,6 @@ type Config struct { | |||||||
| 	Groups       []string `json:"groups"` | 	Groups       []string `json:"groups"` | ||||||
| 	InsecureCA   bool     `json:"insecureCA"` | 	InsecureCA   bool     `json:"insecureCA"` | ||||||
| 	RootCA       string   `json:"rootCA"` | 	RootCA       string   `json:"rootCA"` | ||||||
| 	IncludeSystemRootCAs bool     `json:"includeSystemRootCAs"` |  | ||||||
| } | } | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| @@ -54,7 +53,6 @@ type openshiftConnector struct { | |||||||
| 	oauth2Config *oauth2.Config | 	oauth2Config *oauth2.Config | ||||||
| 	insecureCA   bool | 	insecureCA   bool | ||||||
| 	rootCA       string | 	rootCA       string | ||||||
| 	includeSystemRootCAs bool |  | ||||||
| 	groups       []string | 	groups       []string | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -69,6 +67,18 @@ type user struct { | |||||||
| // Open returns a connector which can be used to login users through an upstream | // Open returns a connector which can be used to login users through an upstream | ||||||
| // OpenShift OAuth2 provider. | // OpenShift OAuth2 provider. | ||||||
| func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { | func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { | ||||||
|  | 	httpClient, err := newHTTPClient(c.InsecureCA, c.RootCA) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to create HTTP client: %w", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return c.OpenWithHTTPClient(id, logger, httpClient) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // OpenWithHTTPClient returns a connector which can be used to login users through an upstream | ||||||
|  | // OpenShift OAuth2 provider. It provides the ability to inject a http.Client. | ||||||
|  | func (c *Config) OpenWithHTTPClient(id string, logger log.Logger, | ||||||
|  | 	httpClient *http.Client) (conn connector.Connector, err error) { | ||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
|  |  | ||||||
| 	wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath | 	wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath | ||||||
| @@ -83,13 +93,8 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e | |||||||
| 		logger:       logger, | 		logger:       logger, | ||||||
| 		redirectURI:  c.RedirectURI, | 		redirectURI:  c.RedirectURI, | ||||||
| 		rootCA:       c.RootCA, | 		rootCA:       c.RootCA, | ||||||
| 		includeSystemRootCAs: c.IncludeSystemRootCAs, |  | ||||||
| 		groups:       c.Groups, | 		groups:       c.Groups, | ||||||
| 	} | 		httpClient:   httpClient, | ||||||
|  |  | ||||||
| 	if openshiftConnector.httpClient, err = newHTTPClient(c.InsecureCA, c.RootCA, c.IncludeSystemRootCAs); err != nil { |  | ||||||
| 		cancel() |  | ||||||
| 		return nil, fmt.Errorf("failed to create HTTP client: %v", err) |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	var metadata struct { | 	var metadata struct { | ||||||
| @@ -100,14 +105,14 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e | |||||||
| 	resp, err := openshiftConnector.httpClient.Do(req.WithContext(ctx)) | 	resp, err := openshiftConnector.httpClient.Do(req.WithContext(ctx)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		cancel() | 		cancel() | ||||||
| 		return nil, fmt.Errorf("failed to query OpenShift endpoint %v", err) | 		return nil, fmt.Errorf("failed to query OpenShift endpoint %w", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	defer resp.Body.Close() | 	defer resp.Body.Close() | ||||||
|  |  | ||||||
| 	if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { | 	if err := json.NewDecoder(resp.Body).Decode(&metadata); err != nil { | ||||||
| 		cancel() | 		cancel() | ||||||
| 		return nil, fmt.Errorf("discovery through endpoint %s failed to decode body: %v", | 		return nil, fmt.Errorf("discovery through endpoint %s failed to decode body: %w", | ||||||
| 			wellKnownURL, err) | 			wellKnownURL, err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -131,7 +136,8 @@ func (c *openshiftConnector) Close() error { | |||||||
| // LoginURL returns the URL to redirect the user to login with. | // LoginURL returns the URL to redirect the user to login with. | ||||||
| func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { | func (c *openshiftConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { | ||||||
| 	if c.redirectURI != callbackURL { | 	if c.redirectURI != callbackURL { | ||||||
| 		return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) | 		return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", | ||||||
|  | 			callbackURL, c.redirectURI) | ||||||
| 	} | 	} | ||||||
| 	return c.oauth2Config.AuthCodeURL(state), nil | 	return c.oauth2Config.AuthCodeURL(state), nil | ||||||
| } | } | ||||||
| @@ -149,7 +155,8 @@ func (e *oauth2Error) Error() string { | |||||||
| } | } | ||||||
|  |  | ||||||
| // HandleCallback parses the request and returns the user's identity | // HandleCallback parses the request and returns the user's identity | ||||||
| func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { | func (c *openshiftConnector) HandleCallback(s connector.Scopes, | ||||||
|  | 	r *http.Request) (identity connector.Identity, err error) { | ||||||
| 	q := r.URL.Query() | 	q := r.URL.Query() | ||||||
| 	if errType := q.Get("error"); errType != "" { | 	if errType := q.Get("error"); errType != "" { | ||||||
| 		return identity, &oauth2Error{errType, q.Get("error_description")} | 		return identity, &oauth2Error{errType, q.Get("error_description")} | ||||||
| @@ -168,7 +175,8 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) | |||||||
| 	return c.identity(ctx, s, token) | 	return c.identity(ctx, s, token) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, oldID connector.Identity) (connector.Identity, error) { | func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, | ||||||
|  | 	oldID connector.Identity) (connector.Identity, error) { | ||||||
| 	var token oauth2.Token | 	var token oauth2.Token | ||||||
| 	err := json.Unmarshal(oldID.ConnectorData, &token) | 	err := json.Unmarshal(oldID.ConnectorData, &token) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -180,7 +188,8 @@ func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, ol | |||||||
| 	return c.identity(ctx, s, &token) | 	return c.identity(ctx, s, &token) | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) { | func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, | ||||||
|  | 	token *oauth2.Token) (identity connector.Identity, err error) { | ||||||
| 	client := c.oauth2Config.Client(ctx, token) | 	client := c.oauth2Config.Client(ctx, token) | ||||||
| 	user, err := c.user(ctx, client) | 	user, err := c.user(ctx, client) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -251,21 +260,12 @@ func validateAllowedGroups(userGroups, allowedGroups []string) bool { | |||||||
| } | } | ||||||
|  |  | ||||||
| // newHTTPClient returns a new HTTP client | // newHTTPClient returns a new HTTP client | ||||||
| func newHTTPClient(insecureCA bool, rootCA string, includeSystemRootCAs bool) (*http.Client, error) { | func newHTTPClient(insecureCA bool, rootCA string) (*http.Client, error) { | ||||||
| 	tlsConfig := tls.Config{} | 	tlsConfig := tls.Config{} | ||||||
|  |  | ||||||
| 	if insecureCA { | 	if insecureCA { | ||||||
| 		tlsConfig = tls.Config{InsecureSkipVerify: true} | 		tlsConfig = tls.Config{InsecureSkipVerify: true} | ||||||
| 	} else if rootCA != "" { | 	} else if rootCA != "" { | ||||||
| 		if !includeSystemRootCAs { |  | ||||||
| 		tlsConfig = tls.Config{RootCAs: x509.NewCertPool()} | 		tlsConfig = tls.Config{RootCAs: x509.NewCertPool()} | ||||||
| 		} else { |  | ||||||
| 			systemCAs, err := x509.SystemCertPool() |  | ||||||
| 			if err != nil { |  | ||||||
| 				return nil, fmt.Errorf("failed to read host CA: %w", err) |  | ||||||
| 			} |  | ||||||
| 			tlsConfig = tls.Config{RootCAs: systemCAs} |  | ||||||
| 		} |  | ||||||
| 		rootCABytes, err := os.ReadFile(rootCA) | 		rootCABytes, err := os.ReadFile(rootCA) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, fmt.Errorf("failed to read root-ca: %w", err) | 			return nil, fmt.Errorf("failed to read root-ca: %w", err) | ||||||
|   | |||||||
| @@ -70,7 +70,7 @@ func TestGetUser(t *testing.T) { | |||||||
| 	_, err = http.NewRequest("GET", hostURL.String(), nil) | 	_, err = http.NewRequest("GET", hostURL.String(), nil) | ||||||
| 	expectNil(t, err) | 	expectNil(t, err) | ||||||
|  |  | ||||||
| 	h, err := newHTTPClient(true, "", false) | 	h, err := newHTTPClient(true, "") | ||||||
|  |  | ||||||
| 	expectNil(t, err) | 	expectNil(t, err) | ||||||
|  |  | ||||||
| @@ -128,7 +128,7 @@ func TestVerifyGroup(t *testing.T) { | |||||||
| 	_, err = http.NewRequest("GET", hostURL.String(), nil) | 	_, err = http.NewRequest("GET", hostURL.String(), nil) | ||||||
| 	expectNil(t, err) | 	expectNil(t, err) | ||||||
|  |  | ||||||
| 	h, err := newHTTPClient(true, "", false) | 	h, err := newHTTPClient(true, "") | ||||||
|  |  | ||||||
| 	expectNil(t, err) | 	expectNil(t, err) | ||||||
|  |  | ||||||
| @@ -164,7 +164,7 @@ func TestCallbackIdentity(t *testing.T) { | |||||||
| 	req, err := http.NewRequest("GET", hostURL.String(), nil) | 	req, err := http.NewRequest("GET", hostURL.String(), nil) | ||||||
| 	expectNil(t, err) | 	expectNil(t, err) | ||||||
|  |  | ||||||
| 	h, err := newHTTPClient(true, "", false) | 	h, err := newHTTPClient(true, "") | ||||||
|  |  | ||||||
| 	expectNil(t, err) | 	expectNil(t, err) | ||||||
|  |  | ||||||
| @@ -198,7 +198,7 @@ func TestRefreshIdentity(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	h, err := newHTTPClient(true, "", false) | 	h, err := newHTTPClient(true, "") | ||||||
| 	expectNil(t, err) | 	expectNil(t, err) | ||||||
|  |  | ||||||
| 	oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ | 	oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ | ||||||
| @@ -237,7 +237,7 @@ func TestRefreshIdentityFailure(t *testing.T) { | |||||||
| 	}) | 	}) | ||||||
| 	defer s.Close() | 	defer s.Close() | ||||||
|  |  | ||||||
| 	h, err := newHTTPClient(true, "", false) | 	h, err := newHTTPClient(true, "") | ||||||
| 	expectNil(t, err) | 	expectNil(t, err) | ||||||
|  |  | ||||||
| 	oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ | 	oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user