Use GitLab's refresh_token during Refresh. (#2352)
Signed-off-by: Daniel Haus <dhaus@redhat.com>
This commit is contained in:
		| @@ -9,6 +9,7 @@ import ( | |||||||
| 	"io" | 	"io" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
| 	"golang.org/x/oauth2" | 	"golang.org/x/oauth2" | ||||||
|  |  | ||||||
| @@ -61,8 +62,9 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) | |||||||
| } | } | ||||||
|  |  | ||||||
| type connectorData struct { | type connectorData struct { | ||||||
| 	// GitLab's OAuth2 tokens never expire. We don't need a refresh token. | 	// Support GitLab's Access Tokens and Refresh tokens. | ||||||
| 	AccessToken string `json:"accessToken"` | 	AccessToken  string `json:"accessToken"` | ||||||
|  | 	RefreshToken string `json:"refreshToken"` | ||||||
| } | } | ||||||
|  |  | ||||||
| var ( | var ( | ||||||
| @@ -135,6 +137,11 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i | |||||||
| 		return identity, fmt.Errorf("gitlab: failed to get token: %v", err) | 		return identity, fmt.Errorf("gitlab: failed to get token: %v", err) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
|  | 	return c.identity(ctx, s, token) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *gitlabConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) { | ||||||
|  | 	oauth2Config := c.oauth2Config(s) | ||||||
| 	client := oauth2Config.Client(ctx, token) | 	client := oauth2Config.Client(ctx, token) | ||||||
|  |  | ||||||
| 	user, err := c.user(ctx, client) | 	user, err := c.user(ctx, client) | ||||||
| @@ -146,6 +153,7 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i | |||||||
| 	if username == "" { | 	if username == "" { | ||||||
| 		username = user.Email | 		username = user.Email | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	identity = connector.Identity{ | 	identity = connector.Identity{ | ||||||
| 		UserID:            strconv.Itoa(user.ID), | 		UserID:            strconv.Itoa(user.ID), | ||||||
| 		Username:          username, | 		Username:          username, | ||||||
| @@ -166,10 +174,10 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if s.OfflineAccess { | 	if s.OfflineAccess { | ||||||
| 		data := connectorData{AccessToken: token.AccessToken} | 		data := connectorData{RefreshToken: token.RefreshToken, AccessToken: token.AccessToken} | ||||||
| 		connData, err := json.Marshal(data) | 		connData, err := json.Marshal(data) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return identity, fmt.Errorf("marshal connector data: %v", err) | 			return identity, fmt.Errorf("gitlab: marshal connector data: %v", err) | ||||||
| 		} | 		} | ||||||
| 		identity.ConnectorData = connData | 		identity.ConnectorData = connData | ||||||
| 	} | 	} | ||||||
| @@ -178,37 +186,39 @@ func (c *gitlabConnector) HandleCallback(s connector.Scopes, r *http.Request) (i | |||||||
| } | } | ||||||
|  |  | ||||||
| func (c *gitlabConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) { | func (c *gitlabConnector) Refresh(ctx context.Context, s connector.Scopes, ident connector.Identity) (connector.Identity, error) { | ||||||
| 	if len(ident.ConnectorData) == 0 { |  | ||||||
| 		return ident, errors.New("no upstream access token found") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	var data connectorData | 	var data connectorData | ||||||
| 	if err := json.Unmarshal(ident.ConnectorData, &data); err != nil { | 	if err := json.Unmarshal(ident.ConnectorData, &data); err != nil { | ||||||
| 		return ident, fmt.Errorf("gitlab: unmarshal access token: %v", err) | 		return ident, fmt.Errorf("gitlab: unmarshal connector data: %v", err) | ||||||
|  | 	} | ||||||
|  | 	oauth2Config := c.oauth2Config(s) | ||||||
|  |  | ||||||
|  | 	if c.httpClient != nil { | ||||||
|  | 		ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	client := c.oauth2Config(s).Client(ctx, &oauth2.Token{AccessToken: data.AccessToken}) | 	switch { | ||||||
| 	user, err := c.user(ctx, client) | 	case data.RefreshToken != "": | ||||||
| 	if err != nil { | 		{ | ||||||
| 		return ident, fmt.Errorf("gitlab: get user: %v", err) | 			t := &oauth2.Token{ | ||||||
| 	} | 				RefreshToken: data.RefreshToken, | ||||||
|  | 				Expiry:       time.Now().Add(-time.Hour), | ||||||
| 	username := user.Name | 			} | ||||||
| 	if username == "" { | 			token, err := oauth2Config.TokenSource(ctx, t).Token() | ||||||
| 		username = user.Email | 			if err != nil { | ||||||
| 	} | 				return ident, fmt.Errorf("gitlab: failed to get refresh token: %v", err) | ||||||
| 	ident.Username = username | 			} | ||||||
| 	ident.PreferredUsername = user.Username | 			return c.identity(ctx, s, token) | ||||||
| 	ident.Email = user.Email |  | ||||||
|  |  | ||||||
| 	if c.groupsRequired(s.Groups) { |  | ||||||
| 		groups, err := c.getGroups(ctx, client, s.Groups, user.Username) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return ident, fmt.Errorf("gitlab: get groups: %v", err) |  | ||||||
| 		} | 		} | ||||||
| 		ident.Groups = groups | 	case data.AccessToken != "": | ||||||
|  | 		{ | ||||||
|  | 			token := &oauth2.Token{ | ||||||
|  | 				AccessToken: data.AccessToken, | ||||||
|  | 			} | ||||||
|  | 			return c.identity(ctx, s, token) | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
|  | 		return ident, errors.New("no refresh or access token found") | ||||||
| 	} | 	} | ||||||
| 	return ident, nil |  | ||||||
| } | } | ||||||
|  |  | ||||||
| func (c *gitlabConnector) groupsRequired(groupScope bool) bool { | func (c *gitlabConnector) groupsRequired(groupScope bool) bool { | ||||||
|   | |||||||
| @@ -180,6 +180,75 @@ func TestLoginWithTeamNonWhitelisted(t *testing.T) { | |||||||
| 	expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups") | 	expectEquals(t, err.Error(), "gitlab: get groups: gitlab: user \"joebloggs\" is not in any of the required groups") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestRefresh(t *testing.T) { | ||||||
|  | 	s := newTestServer(map[string]interface{}{ | ||||||
|  | 		"/api/v4/user": gitlabUser{Email: "some@email.com", ID: 12345678}, | ||||||
|  | 		"/oauth/token": map[string]interface{}{ | ||||||
|  | 			"access_token":  "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9", | ||||||
|  | 			"refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", | ||||||
|  | 			"expires_in":    "30", | ||||||
|  | 		}, | ||||||
|  | 		"/oauth/userinfo": userInfo{ | ||||||
|  | 			Groups: []string{"team-1"}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer s.Close() | ||||||
|  |  | ||||||
|  | 	hostURL, err := url.Parse(s.URL) | ||||||
|  | 	expectNil(t, err) | ||||||
|  |  | ||||||
|  | 	req, err := http.NewRequest("GET", hostURL.String(), nil) | ||||||
|  | 	expectNil(t, err) | ||||||
|  |  | ||||||
|  | 	c := gitlabConnector{baseURL: s.URL, httpClient: newClient()} | ||||||
|  |  | ||||||
|  | 	expectedConnectorData, err := json.Marshal(connectorData{ | ||||||
|  | 		RefreshToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", | ||||||
|  | 		AccessToken:  "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9", | ||||||
|  | 	}) | ||||||
|  | 	expectNil(t, err) | ||||||
|  |  | ||||||
|  | 	identity, err := c.HandleCallback(connector.Scopes{OfflineAccess: true}, req) | ||||||
|  | 	expectNil(t, err) | ||||||
|  | 	expectEquals(t, identity.Username, "some@email.com") | ||||||
|  | 	expectEquals(t, identity.UserID, "12345678") | ||||||
|  | 	expectEquals(t, identity.ConnectorData, expectedConnectorData) | ||||||
|  |  | ||||||
|  | 	identity, err = c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, identity) | ||||||
|  | 	expectNil(t, err) | ||||||
|  | 	expectEquals(t, identity.Username, "some@email.com") | ||||||
|  | 	expectEquals(t, identity.UserID, "12345678") | ||||||
|  | 	expectEquals(t, identity.ConnectorData, expectedConnectorData) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestRefreshWithEmptyConnectorData(t *testing.T) { | ||||||
|  | 	s := newTestServer(map[string]interface{}{ | ||||||
|  | 		"/api/v4/user": gitlabUser{Email: "some@email.com", ID: 12345678}, | ||||||
|  | 		"/oauth/token": map[string]interface{}{ | ||||||
|  | 			"access_token":  "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9", | ||||||
|  | 			"refresh_token": "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", | ||||||
|  | 			"expires_in":    "30", | ||||||
|  | 		}, | ||||||
|  | 		"/oauth/userinfo": userInfo{ | ||||||
|  | 			Groups: []string{"team-1"}, | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  | 	defer s.Close() | ||||||
|  |  | ||||||
|  | 	emptyConnectorData, err := json.Marshal(connectorData{ | ||||||
|  | 		RefreshToken: "", | ||||||
|  | 		AccessToken:  "", | ||||||
|  | 	}) | ||||||
|  | 	expectNil(t, err) | ||||||
|  |  | ||||||
|  | 	c := gitlabConnector{baseURL: s.URL, httpClient: newClient()} | ||||||
|  | 	emptyIdentity := connector.Identity{ConnectorData: emptyConnectorData} | ||||||
|  |  | ||||||
|  | 	identity, err := c.Refresh(context.Background(), connector.Scopes{OfflineAccess: true}, emptyIdentity) | ||||||
|  | 	expectNotNil(t, err, "Refresh error") | ||||||
|  | 	expectEquals(t, emptyIdentity, identity) | ||||||
|  | } | ||||||
|  |  | ||||||
| func newTestServer(responses map[string]interface{}) *httptest.Server { | func newTestServer(responses map[string]interface{}) *httptest.Server { | ||||||
| 	return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | 	return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
| 		response := responses[r.RequestURI] | 		response := responses[r.RequestURI] | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user