Add support for refresh tokens for openshift connector.

Signed-off-by: Daniel Haus <dhaus@redhat.com>
This commit is contained in:
Daniel Haus 2021-11-23 19:39:23 +01:00
parent e00e75b773
commit 6d55fe1c80
No known key found for this signature in database
GPG Key ID: 262B7643F39EB8A9
2 changed files with 111 additions and 3 deletions

View File

@ -21,6 +21,11 @@ import (
"github.com/dexidp/dex/storage/kubernetes/k8sapi" "github.com/dexidp/dex/storage/kubernetes/k8sapi"
) )
const (
wellKnownURLPath = "/.well-known/oauth-authorization-server"
usersURLPath = "/apis/user.openshift.io/v1/users/~"
)
// Config holds configuration options for OpenShift login // Config holds configuration options for OpenShift login
type Config struct { type Config struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
@ -33,6 +38,7 @@ type Config struct {
} }
var _ connector.CallbackConnector = (*openshiftConnector)(nil) var _ connector.CallbackConnector = (*openshiftConnector)(nil)
var _ connector.RefreshConnector = (*openshiftConnector)(nil)
type openshiftConnector struct { type openshiftConnector struct {
apiURL string apiURL string
@ -61,7 +67,7 @@ type user struct {
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + "/.well-known/oauth-authorization-server" wellKnownURL := strings.TrimSuffix(c.Issuer, "/") + wellKnownURLPath
req, err := http.NewRequest(http.MethodGet, wellKnownURL, nil) req, err := http.NewRequest(http.MethodGet, wellKnownURL, nil)
openshiftConnector := openshiftConnector{ openshiftConnector := openshiftConnector{
@ -154,8 +160,23 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request)
return identity, fmt.Errorf("oidc: failed to get token: %v", err) return identity, fmt.Errorf("oidc: failed to get token: %v", err)
} }
client := c.oauth2Config.Client(ctx, token) return c.identity(ctx, s, token)
}
func (c *openshiftConnector) Refresh(ctx context.Context, s connector.Scopes, oldID connector.Identity) (connector.Identity, error) {
var token oauth2.Token
err := json.Unmarshal(oldID.ConnectorData, &token)
if err != nil {
return connector.Identity{}, fmt.Errorf("parsing token: %w", err)
}
if c.httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, c.httpClient)
}
return c.identity(ctx, s, &token)
}
func (c *openshiftConnector) identity(ctx context.Context, s connector.Scopes, token *oauth2.Token) (identity connector.Identity, err error) {
client := c.oauth2Config.Client(ctx, token)
user, err := c.user(ctx, client) user, err := c.user(ctx, client)
if err != nil { if err != nil {
return identity, fmt.Errorf("openshift: get user: %v", err) return identity, fmt.Errorf("openshift: get user: %v", err)
@ -177,12 +198,20 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request)
Groups: user.Groups, Groups: user.Groups,
} }
if s.OfflineAccess {
connData, err := json.Marshal(token)
if err != nil {
return identity, fmt.Errorf("marshal connector data: %v", err)
}
identity.ConnectorData = connData
}
return identity, nil return identity, nil
} }
// user function returns the OpenShift user associated with the authenticated user // user function returns the OpenShift user associated with the authenticated user
func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u user, err error) { func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u user, err error) {
url := c.apiURL + "/apis/user.openshift.io/v1/users/~" url := c.apiURL + usersURLPath
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {

View File

@ -9,6 +9,7 @@ import (
"net/url" "net/url"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -184,6 +185,78 @@ func TestCallbackIdentity(t *testing.T) {
expectEquals(t, identity.Groups[0], "users") expectEquals(t, identity.Groups[0], "users")
} }
func TestRefreshIdentity(t *testing.T) {
s := newTestServer(map[string]interface{}{
usersURLPath: user{
ObjectMeta: k8sapi.ObjectMeta{
Name: "jdoe",
UID: "12345",
},
FullName: "John Doe",
Groups: []string{"users"},
},
})
defer s.Close()
h, err := newHTTPClient(true, "")
expectNil(t, err)
oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL),
TokenURL: fmt.Sprintf("%s/oauth/token", s.URL),
},
}}
data, err := json.Marshal(oauth2.Token{AccessToken: "fFAGRNJru1FTz70BzhT3Zg"})
expectNil(t, err)
oldID := connector.Identity{ConnectorData: data}
identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID)
expectNil(t, err)
expectEquals(t, identity.UserID, "12345")
expectEquals(t, identity.Username, "jdoe")
expectEquals(t, identity.PreferredUsername, "jdoe")
expectEquals(t, identity.Email, "jdoe")
expectEquals(t, len(identity.Groups), 1)
expectEquals(t, identity.Groups[0], "users")
}
func TestRefreshIdentityFailure(t *testing.T) {
s := newTestServer(map[string]interface{}{
usersURLPath: user{
ObjectMeta: k8sapi.ObjectMeta{
Name: "jdoe",
UID: "12345",
},
FullName: "John Doe",
Groups: []string{"users"},
},
})
defer s.Close()
h, err := newHTTPClient(true, "")
expectNil(t, err)
oc := openshiftConnector{apiURL: s.URL, httpClient: h, oauth2Config: &oauth2.Config{
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/oauth/authorize", s.URL),
TokenURL: fmt.Sprintf("%s/oauth/token", s.URL),
},
}}
data, err := json.Marshal(oauth2.Token{AccessToken: "oRzxVjCnohYRHEYEhZshkmakKmoyVoTjfUGC", Expiry: time.Now().Add(-time.Hour)})
expectNil(t, err)
oldID := connector.Identity{ConnectorData: data}
identity, err := oc.Refresh(context.Background(), connector.Scopes{Groups: true}, oldID)
expectNotNil(t, err)
expectEquals(t, connector.Identity{}, identity)
}
func newTestServer(responses map[string]interface{}) *httptest.Server { func newTestServer(responses map[string]interface{}) *httptest.Server {
var s *httptest.Server var s *httptest.Server
s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -216,3 +289,9 @@ func expectEquals(t *testing.T, a interface{}, b interface{}) {
t.Errorf("Expected %+v to equal %+v", a, b) t.Errorf("Expected %+v to equal %+v", a, b)
} }
} }
func expectNotNil(t *testing.T, a interface{}) {
if a == nil {
t.Errorf("Expected %+v to not equal nil", a)
}
}