Merge pull request #2342 from dhaus67/refresh-token-openshift-connector
Add support for RefreshConnector for openshift connector.
This commit is contained in:
		| @@ -21,6 +21,11 @@ import ( | |||||||
| 	"github.com/dexidp/dex/storage/kubernetes/k8sapi" | 	"github.com/dexidp/dex/storage/kubernetes/k8sapi" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	wellKnownURLPath = "/.well-known/oauth-authorization-server" | ||||||
|  | 	usersURLPath     = "/apis/user.openshift.io/v1/users/~" | ||||||
|  | ) | ||||||
|  |  | ||||||
| // Config holds configuration options for OpenShift login | // Config holds configuration options for OpenShift login | ||||||
| type Config struct { | type Config struct { | ||||||
| 	Issuer       string   `json:"issuer"` | 	Issuer       string   `json:"issuer"` | ||||||
| @@ -32,7 +37,10 @@ type Config struct { | |||||||
| 	RootCA       string   `json:"rootCA"` | 	RootCA       string   `json:"rootCA"` | ||||||
| } | } | ||||||
|  |  | ||||||
| var _ connector.CallbackConnector = (*openshiftConnector)(nil) | var ( | ||||||
|  | 	_ connector.CallbackConnector = (*openshiftConnector)(nil) | ||||||
|  | 	_ connector.RefreshConnector  = (*openshiftConnector)(nil) | ||||||
|  | ) | ||||||
|  |  | ||||||
| type openshiftConnector struct { | type openshiftConnector struct { | ||||||
| 	apiURL       string | 	apiURL       string | ||||||
| @@ -61,7 +69,7 @@ type user struct { | |||||||
| func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { | func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { | ||||||
| 	ctx, cancel := context.WithCancel(context.Background()) | 	ctx, cancel := context.WithCancel(context.Background()) | ||||||
|  |  | ||||||
| 	wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + "/.well-known/oauth-authorization-server" | 	wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath | ||||||
| 	req, err := http.NewRequest(http.MethodGet, wellKnownURL, nil) | 	req, err := http.NewRequest(http.MethodGet, wellKnownURL, nil) | ||||||
|  |  | ||||||
| 	openshiftConnector := openshiftConnector{ | 	openshiftConnector := openshiftConnector{ | ||||||
| @@ -154,8 +162,23 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) | |||||||
| 		return identity, fmt.Errorf("oidc: failed to get token: %v", err) | 		return identity, fmt.Errorf("oidc: failed to get token: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	client := c.oauth2Config.Client(ctx, token) | 	return c.identity(ctx, s, token) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, oldID connector.Identity) (connector.Identity, error) { | ||||||
|  | 	var token oauth2.Token | ||||||
|  | 	err := json.Unmarshal(oldID.ConnectorData, &token) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return connector.Identity{}, fmt.Errorf("parsing token: %w", err) | ||||||
|  | 	} | ||||||
|  | 	if c.httpClient != nil { | ||||||
|  | 		ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient) | ||||||
|  | 	} | ||||||
|  | 	return c.identity(ctx, s, &token) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) { | ||||||
|  | 	client := c.oauth2Config.Client(ctx, token) | ||||||
| 	user, err := c.user(ctx, client) | 	user, err := c.user(ctx, client) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return identity, fmt.Errorf("openshift: get user: %v", err) | 		return identity, fmt.Errorf("openshift: get user: %v", err) | ||||||
| @@ -177,12 +200,20 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) | |||||||
| 		Groups:            user.Groups, | 		Groups:            user.Groups, | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	if s.OfflineAccess { | ||||||
|  | 		connData, err := json.Marshal(token) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return identity, fmt.Errorf("marshal connector data: %v", err) | ||||||
|  | 		} | ||||||
|  | 		identity.ConnectorData = connData | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	return identity, nil | 	return identity, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // user function returns the OpenShift user associated with the authenticated user | // user function returns the OpenShift user associated with the authenticated user | ||||||
| func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u user, err error) { | func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u user, err error) { | ||||||
| 	url := c.apiURL + "/apis/user.openshift.io/v1/users/~" | 	url := c.apiURL + usersURLPath | ||||||
|  |  | ||||||
| 	req, err := http.NewRequest("GET", url, nil) | 	req, err := http.NewRequest("GET", url, nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ import ( | |||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"reflect" | 	"reflect" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/sirupsen/logrus" | 	"github.com/sirupsen/logrus" | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
| @@ -184,6 +185,78 @@ func TestCallbackIdentity(t *testing.T) { | |||||||
| 	expectEquals(t, identity.Groups[0], "users") | 	expectEquals(t, identity.Groups[0], "users") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestRefreshIdentity(t *testing.T) { | ||||||
|  | 	s := newTestServer(map[string]interface{}{ | ||||||
|  | 		usersURLPath: user{ | ||||||
|  | 			ObjectMeta: k8sapi.ObjectMeta{ | ||||||
|  | 				Name: "jdoe", | ||||||
|  | 				UID:  "12345", | ||||||
|  | 			}, | ||||||
|  | 			FullName: "John Doe", | ||||||
|  | 			Groups:   []string{"users"}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer s.Close() | ||||||
|  |  | ||||||
|  | 	h, err := newHTTPClient(true, "") | ||||||
|  | 	expectNil(t, err) | ||||||
|  |  | ||||||
|  | 	oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ | ||||||
|  | 		Endpoint: oauth2.Endpoint{ | ||||||
|  | 			AuthURL:  fmt.Sprintf("%s/oauth/authorize", s.URL), | ||||||
|  | 			TokenURL: fmt.Sprintf("%s/oauth/token", s.URL), | ||||||
|  | 		}, | ||||||
|  | 	}} | ||||||
|  |  | ||||||
|  | 	data, err := json.Marshal(oauth2.Token{AccessToken: "fFAGRNJru1FTz70BzhT3Zg"}) | ||||||
|  | 	expectNil(t, err) | ||||||
|  |  | ||||||
|  | 	oldID := connector.Identity{ConnectorData: data} | ||||||
|  |  | ||||||
|  | 	identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID) | ||||||
|  |  | ||||||
|  | 	expectNil(t, err) | ||||||
|  | 	expectEquals(t, identity.UserID, "12345") | ||||||
|  | 	expectEquals(t, identity.Username, "jdoe") | ||||||
|  | 	expectEquals(t, identity.PreferredUsername, "jdoe") | ||||||
|  | 	expectEquals(t, identity.Email, "jdoe") | ||||||
|  | 	expectEquals(t, len(identity.Groups), 1) | ||||||
|  | 	expectEquals(t, identity.Groups[0], "users") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRefreshIdentityFailure(t *testing.T) { | ||||||
|  | 	s := newTestServer(map[string]interface{}{ | ||||||
|  | 		usersURLPath: user{ | ||||||
|  | 			ObjectMeta: k8sapi.ObjectMeta{ | ||||||
|  | 				Name: "jdoe", | ||||||
|  | 				UID:  "12345", | ||||||
|  | 			}, | ||||||
|  | 			FullName: "John Doe", | ||||||
|  | 			Groups:   []string{"users"}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer s.Close() | ||||||
|  |  | ||||||
|  | 	h, err := newHTTPClient(true, "") | ||||||
|  | 	expectNil(t, err) | ||||||
|  |  | ||||||
|  | 	oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ | ||||||
|  | 		Endpoint: oauth2.Endpoint{ | ||||||
|  | 			AuthURL:  fmt.Sprintf("%s/oauth/authorize", s.URL), | ||||||
|  | 			TokenURL: fmt.Sprintf("%s/oauth/token", s.URL), | ||||||
|  | 		}, | ||||||
|  | 	}} | ||||||
|  |  | ||||||
|  | 	data, err := json.Marshal(oauth2.Token{AccessToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", Expiry: time.Now().Add(-time.Hour)}) | ||||||
|  | 	expectNil(t, err) | ||||||
|  |  | ||||||
|  | 	oldID := connector.Identity{ConnectorData: data} | ||||||
|  |  | ||||||
|  | 	identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID) | ||||||
|  | 	expectNotNil(t, err) | ||||||
|  | 	expectEquals(t, connector.Identity{}, identity) | ||||||
|  | } | ||||||
|  |  | ||||||
| func newTestServer(responses map[string]interface{}) *httptest.Server { | func newTestServer(responses map[string]interface{}) *httptest.Server { | ||||||
| 	var s *httptest.Server | 	var s *httptest.Server | ||||||
| 	s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| @@ -216,3 +289,9 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) { | |||||||
| 		t.Errorf("Expected %+v to equal %+v", a, b) | 		t.Errorf("Expected %+v to equal %+v", a, b) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func expectNotNil(t *testing.T, a interface{}) { | ||||||
|  | 	if a == nil { | ||||||
|  | 		t.Errorf("Expected %+v to not equal nil", a) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user