Add support for refresh tokens for openshift connector.
Signed-off-by: Daniel Haus <dhaus@redhat.com>
This commit is contained in:
		| @@ -21,6 +21,11 @@ import ( | ||||
| 	"github.com/dexidp/dex/storage/kubernetes/k8sapi" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	wellKnownURLPath = "/.well-known/oauth-authorization-server" | ||||
| 	usersURLPath     = "/apis/user.openshift.io/v1/users/~" | ||||
| ) | ||||
|  | ||||
| // Config holds configuration options for OpenShift login | ||||
| type Config struct { | ||||
| 	Issuer       string   `json:"issuer"` | ||||
| @@ -33,6 +38,7 @@ type Config struct { | ||||
| } | ||||
|  | ||||
| var _ connector.CallbackConnector = (*openshiftConnector)(nil) | ||||
| var _ connector.RefreshConnector = (*openshiftConnector)(nil) | ||||
|  | ||||
| type openshiftConnector struct { | ||||
| 	apiURL       string | ||||
| @@ -61,7 +67,7 @@ type user struct { | ||||
| func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { | ||||
| 	ctx, cancel := context.WithCancel(context.Background()) | ||||
|  | ||||
| 	wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + "/.well-known/oauth-authorization-server" | ||||
| 	wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath | ||||
| 	req, err := http.NewRequest(http.MethodGet, wellKnownURL, nil) | ||||
|  | ||||
| 	openshiftConnector := openshiftConnector{ | ||||
| @@ -154,8 +160,23 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) | ||||
| 		return identity, fmt.Errorf("oidc: failed to get token: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	client := c.oauth2Config.Client(ctx, token) | ||||
| 	return c.identity(ctx, s, token) | ||||
| } | ||||
|  | ||||
| func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, oldID connector.Identity) (connector.Identity, error) { | ||||
| 	var token oauth2.Token | ||||
| 	err := json.Unmarshal(oldID.ConnectorData, &token) | ||||
| 	if err != nil { | ||||
| 		return connector.Identity{}, fmt.Errorf("parsing token: %w", err) | ||||
| 	} | ||||
| 	if c.httpClient != nil { | ||||
| 		ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient) | ||||
| 	} | ||||
| 	return c.identity(ctx, s, &token) | ||||
| } | ||||
|  | ||||
| func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) { | ||||
| 	client := c.oauth2Config.Client(ctx, token) | ||||
| 	user, err := c.user(ctx, client) | ||||
| 	if err != nil { | ||||
| 		return identity, fmt.Errorf("openshift: get user: %v", err) | ||||
| @@ -177,12 +198,20 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request) | ||||
| 		Groups:            user.Groups, | ||||
| 	} | ||||
|  | ||||
| 	if s.OfflineAccess { | ||||
| 		connData, err := json.Marshal(token) | ||||
| 		if err != nil { | ||||
| 			return identity, fmt.Errorf("marshal connector data: %v", err) | ||||
| 		} | ||||
| 		identity.ConnectorData = connData | ||||
| 	} | ||||
|  | ||||
| 	return identity, nil | ||||
| } | ||||
|  | ||||
| // user function returns the OpenShift user associated with the authenticated user | ||||
| func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u user, err error) { | ||||
| 	url := c.apiURL + "/apis/user.openshift.io/v1/users/~" | ||||
| 	url := c.apiURL + usersURLPath | ||||
|  | ||||
| 	req, err := http.NewRequest("GET", url, nil) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -9,6 +9,7 @@ import ( | ||||
| 	"net/url" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/sirupsen/logrus" | ||||
| 	"golang.org/x/oauth2" | ||||
| @@ -184,6 +185,78 @@ func TestCallbackIdentity(t *testing.T) { | ||||
| 	expectEquals(t, identity.Groups[0], "users") | ||||
| } | ||||
|  | ||||
| func TestRefreshIdentity(t *testing.T) { | ||||
| 	s := newTestServer(map[string]interface{}{ | ||||
| 		usersURLPath: user{ | ||||
| 			ObjectMeta: k8sapi.ObjectMeta{ | ||||
| 				Name: "jdoe", | ||||
| 				UID:  "12345", | ||||
| 			}, | ||||
| 			FullName: "John Doe", | ||||
| 			Groups:   []string{"users"}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer s.Close() | ||||
|  | ||||
| 	h, err := newHTTPClient(true, "") | ||||
| 	expectNil(t, err) | ||||
|  | ||||
| 	oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ | ||||
| 		Endpoint: oauth2.Endpoint{ | ||||
| 			AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL), | ||||
| 			TokenURL: fmt.Sprintf("%s/oauth/token", s.URL), | ||||
| 		}, | ||||
| 	}} | ||||
|  | ||||
| 	data, err := json.Marshal(oauth2.Token{AccessToken: "fFAGRNJru1FTz70BzhT3Zg"}) | ||||
| 	expectNil(t, err) | ||||
|  | ||||
| 	oldID := connector.Identity{ConnectorData: data} | ||||
|  | ||||
| 	identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID) | ||||
|  | ||||
| 	expectNil(t, err) | ||||
| 	expectEquals(t, identity.UserID, "12345") | ||||
| 	expectEquals(t, identity.Username, "jdoe") | ||||
| 	expectEquals(t, identity.PreferredUsername, "jdoe") | ||||
| 	expectEquals(t, identity.Email, "jdoe") | ||||
| 	expectEquals(t, len(identity.Groups), 1) | ||||
| 	expectEquals(t, identity.Groups[0], "users") | ||||
| } | ||||
|  | ||||
| func TestRefreshIdentityFailure(t *testing.T) { | ||||
| 	s := newTestServer(map[string]interface{}{ | ||||
| 		usersURLPath: user{ | ||||
| 			ObjectMeta: k8sapi.ObjectMeta{ | ||||
| 				Name: "jdoe", | ||||
| 				UID:  "12345", | ||||
| 			}, | ||||
| 			FullName: "John Doe", | ||||
| 			Groups:   []string{"users"}, | ||||
| 		}, | ||||
| 	}) | ||||
| 	defer s.Close() | ||||
|  | ||||
| 	h, err := newHTTPClient(true, "") | ||||
| 	expectNil(t, err) | ||||
|  | ||||
| 	oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{ | ||||
| 		Endpoint: oauth2.Endpoint{ | ||||
| 			AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL), | ||||
| 			TokenURL: fmt.Sprintf("%s/oauth/token", s.URL), | ||||
| 		}, | ||||
| 	}} | ||||
|  | ||||
| 	data, err := json.Marshal(oauth2.Token{AccessToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", Expiry: time.Now().Add(-time.Hour)}) | ||||
| 	expectNil(t, err) | ||||
|  | ||||
| 	oldID := connector.Identity{ConnectorData: data} | ||||
|  | ||||
| 	identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID) | ||||
| 	expectNotNil(t, err) | ||||
| 	expectEquals(t, connector.Identity{}, identity) | ||||
| } | ||||
|  | ||||
| func newTestServer(responses map[string]interface{}) *httptest.Server { | ||||
| 	var s *httptest.Server | ||||
| 	s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||
| @@ -216,3 +289,9 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) { | ||||
| 		t.Errorf("Expected %+v to equal %+v", a, b) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func expectNotNil(t *testing.T, a interface{}) { | ||||
| 	if a == nil { | ||||
| 		t.Errorf("Expected %+v to not equal nil", a) | ||||
| 	} | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user