diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index b5e075ad..3e405d87 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -3,12 +3,14 @@ package oidc import ( "context" + "encoding/json" "errors" "fmt" "net/http" "net/url" "strings" "sync" + "time" "github.com/coreos/go-oidc" "golang.org/x/oauth2" @@ -60,6 +62,11 @@ var brokenAuthHeaderDomains = []string{ "oktapreview.com", } +// connectorData stores information for sessions authenticated by this connector +type connectorData struct { + RefreshToken []byte +} + // Detect auth header provider issues for known providers. This lets users // avoid having to explicitly set "basicAuthUnsupported" in their config. // @@ -167,14 +174,19 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string) return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) } + var opts []oauth2.AuthCodeOption if len(c.hostedDomains) > 0 { preferredDomain := c.hostedDomains[0] if len(c.hostedDomains) > 1 { preferredDomain = "*" } - return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil + opts = append(opts, oauth2.SetAuthURLParam("hd", preferredDomain)) } - return c.oauth2Config.AuthCodeURL(state), nil + + if s.OfflineAccess { + opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) + } + return c.oauth2Config.AuthCodeURL(state, opts...), nil } type oauth2Error struct { @@ -199,11 +211,35 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide return identity, fmt.Errorf("oidc: failed to get token: %v", err) } + return c.createIdentity(r.Context(), identity, token) +} + +// Refresh is used to refresh a session with the refresh token provided by the IdP +func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) { + cd := connectorData{} + err := json.Unmarshal(identity.ConnectorData, &cd) + if err != nil { + return identity, fmt.Errorf("oidc: failed to unmarshal connector data: %v", err) + } + + t := &oauth2.Token{ + RefreshToken: string(cd.RefreshToken), + Expiry: time.Now().Add(-time.Hour), + } + token, err := c.oauth2Config.TokenSource(ctx, t).Token() + if err != nil { + return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err) + } + + return c.createIdentity(ctx, identity, token) +} + +func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok { return identity, errors.New("oidc: no id_token in token response") } - idToken, err := c.verifier.Verify(r.Context(), rawIDToken) + idToken, err := c.verifier.Verify(ctx, rawIDToken) if err != nil { return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) } @@ -215,7 +251,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide // We immediately want to run getUserInfo if configured before we validate the claims if c.getUserInfo { - userInfo, err := c.provider.UserInfo(r.Context(), oauth2.StaticTokenSource(token)) + userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token)) if err != nil { return identity, fmt.Errorf("oidc: error loading userinfo: %v", err) } @@ -260,11 +296,21 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide } } + cd := connectorData{ + RefreshToken: []byte(token.RefreshToken), + } + + connData, err := json.Marshal(&cd) + if err != nil { + return identity, fmt.Errorf("oidc: failed to encode connector data: %v", err) + } + identity = connector.Identity{ UserID: idToken.Subject, Username: name, Email: email, EmailVerified: emailVerified, + ConnectorData: connData, } if c.userIDKey != "" { @@ -277,8 +323,3 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide return identity, nil } - -// Refresh is implemented for backwards compatibility, even though it's a no-op. -func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) { - return identity, nil -} diff --git a/server/handlers.go b/server/handlers.go index b528918f..0f5b0d23 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -505,7 +505,45 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q", authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups) - return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil + returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID + _, ok := conn.(connector.RefreshConnector) + if !ok { + return returnURL, nil + } + + // Try to retrieve an existing OfflineSession object for the corresponding user. + if session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID); err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + return "", err + } + offlineSessions := storage.OfflineSessions{ + UserID: identity.UserID, + ConnID: authReq.ConnectorID, + Refresh: make(map[string]*storage.RefreshTokenRef), + ConnectorData: identity.ConnectorData, + } + + // Create a new OfflineSession object for the user and add a reference object for + // the newly received refreshtoken. + if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil { + s.logger.Errorf("failed to create offline session: %v", err) + return "", err + } + } else { + // Update existing OfflineSession obj with new RefreshTokenRef. + if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { + if len(identity.ConnectorData) > 0 { + old.ConnectorData = identity.ConnectorData + } + return old, nil + }); err != nil { + s.logger.Errorf("failed to update offline session: %v", err) + return "", err + } + } + + return returnURL, nil } func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { @@ -962,6 +1000,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie scopes = requestedScopes } + var connectorData []byte + if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { + if err != storage.ErrNotFound { + s.logger.Errorf("failed to get offline session: %v", err) + return + } + } else if len(refresh.ConnectorData) > 0 { + // Use the old connector data if it exists, should be deleted once used + connectorData = refresh.ConnectorData + } else { + connectorData = session.ConnectorData + } + conn, err := s.getConnector(refresh.ConnectorID) if err != nil { s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) @@ -975,7 +1026,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie Email: refresh.Claims.Email, EmailVerified: refresh.Claims.EmailVerified, Groups: refresh.Claims.Groups, - ConnectorData: refresh.ConnectorData, + ConnectorData: connectorData, } // Can the connector refresh the identity? If so, attempt to refresh the data @@ -1041,8 +1092,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie old.Claims.Email = ident.Email old.Claims.EmailVerified = ident.EmailVerified old.Claims.Groups = ident.Groups - old.ConnectorData = ident.ConnectorData old.LastUsed = lastUsed + + // ConnectorData has been moved to OfflineSession + old.ConnectorData = []byte{} return old, nil } @@ -1053,6 +1106,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie return old, errors.New("refresh token invalid") } old.Refresh[refresh.ClientID].LastUsed = lastUsed + old.ConnectorData = ident.ConnectorData return old, nil }); err != nil { s.logger.Errorf("failed to update offline session: %v", err) diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index a1399807..9832a7d8 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -515,9 +515,10 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { userID1 := storage.NewID() session1 := storage.OfflineSessions{ - UserID: userID1, - ConnID: "Conn1", - Refresh: make(map[string]*storage.RefreshTokenRef), + UserID: userID1, + ConnID: "Conn1", + Refresh: make(map[string]*storage.RefreshTokenRef), + ConnectorData: []byte(`{"some":"data"}`), } // Creating an OfflineSession with an empty Refresh list to ensure that @@ -532,9 +533,10 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { userID2 := storage.NewID() session2 := storage.OfflineSessions{ - UserID: userID2, - ConnID: "Conn2", - Refresh: make(map[string]*storage.RefreshTokenRef), + UserID: userID2, + ConnID: "Conn2", + Refresh: make(map[string]*storage.RefreshTokenRef), + ConnectorData: []byte(`{"some":"data"}`), } if err := s.CreateOfflineSessions(session2); err != nil { diff --git a/storage/etcd/types.go b/storage/etcd/types.go index 8063c69f..a16eae8e 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -188,24 +188,27 @@ type Keys struct { // OfflineSessions is a mirrored struct from storage with JSON struct tags type OfflineSessions struct { - UserID string `json:"user_id,omitempty"` - ConnID string `json:"conn_id,omitempty"` - Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` + UserID string `json:"user_id,omitempty"` + ConnID string `json:"conn_id,omitempty"` + Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` + ConnectorData []byte `json:"connectorData,omitempty"` } func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { return OfflineSessions{ - UserID: o.UserID, - ConnID: o.ConnID, - Refresh: o.Refresh, + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + ConnectorData: o.ConnectorData, } } func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { s := storage.OfflineSessions{ - UserID: o.UserID, - ConnID: o.ConnID, - Refresh: o.Refresh, + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + ConnectorData: o.ConnectorData, } if s.Refresh == nil { // Server code assumes this will be non-nil. diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index a42238b3..5eda1781 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -552,9 +552,10 @@ type OfflineSessions struct { k8sapi.TypeMeta `json:",inline"` k8sapi.ObjectMeta `json:"metadata,omitempty"` - UserID string `json:"userID,omitempty"` - ConnID string `json:"connID,omitempty"` - Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` + UserID string `json:"userID,omitempty"` + ConnID string `json:"connID,omitempty"` + Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` + ConnectorData []byte `json:"connectorData,omitempty"` } func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { @@ -567,17 +568,19 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline Name: cli.offlineTokenName(o.UserID, o.ConnID), Namespace: cli.namespace, }, - UserID: o.UserID, - ConnID: o.ConnID, - Refresh: o.Refresh, + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + ConnectorData: o.ConnectorData, } } func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { s := storage.OfflineSessions{ - UserID: o.UserID, - ConnID: o.ConnID, - Refresh: o.Refresh, + UserID: o.UserID, + ConnID: o.ConnID, + Refresh: o.Refresh, + ConnectorData: o.ConnectorData, } if s.Refresh == nil { // Server code assumes this will be non-nil. diff --git a/storage/sql/crud.go b/storage/sql/crud.go index e1982928..e96a7b12 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -108,7 +108,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { insert into auth_request ( id, client_id, response_types, scopes, redirect_uri, nonce, state, force_approval_prompt, logged_in, - claims_user_id, claims_username, claims_preferred_username, + claims_user_id, claims_username, claims_preferred_username, claims_email, claims_email_verified, claims_groups, connector_id, connector_data, expiry @@ -178,7 +178,7 @@ func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) { func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { err = q.QueryRow(` - select + select id, client_id, response_types, scopes, redirect_uri, nonce, state, force_approval_prompt, logged_in, claims_user_id, claims_username, claims_preferred_username, @@ -299,7 +299,7 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok claims_email_verified = $8, claims_groups = $9, connector_id = $10, - connector_data = $11, + connector_data = $11, token = $12, created_at = $13, last_used = $14 @@ -417,7 +417,7 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) } else { _, err = tx.Exec(` update keys - set + set verification_keys = $1, signing_key = $2, signing_key_pub = $3, @@ -655,13 +655,13 @@ func scanPassword(s scanner) (p storage.Password, err error) { func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { _, err := c.Exec(` insert into offline_session ( - user_id, conn_id, refresh + user_id, conn_id, refresh, connector_data ) values ( - $1, $2, $3 + $1, $2, $3, $4 ); `, - s.UserID, s.ConnID, encoder(s.Refresh), + s.UserID, s.ConnID, encoder(s.Refresh), s.ConnectorData, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -686,10 +686,11 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( _, err = tx.Exec(` update offline_session set - refresh = $1 - where user_id = $2 AND conn_id = $3; + refresh = $1, + connector_data = $2 + where user_id = $3 AND conn_id = $4; `, - encoder(newSession.Refresh), s.UserID, s.ConnID, + encoder(newSession.Refresh), newSession.ConnectorData, s.UserID, s.ConnID, ) if err != nil { return fmt.Errorf("update offline session: %v", err) @@ -705,7 +706,7 @@ func (c *conn) GetOfflineSessions(userID string, connID string) (storage.Offline func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { return scanOfflineSessions(q.QueryRow(` select - user_id, conn_id, refresh + user_id, conn_id, refresh, connector_data from offline_session where user_id = $1 AND conn_id = $2; `, userID, connID)) @@ -713,7 +714,7 @@ func getOfflineSessions(q querier, userID string, connID string) (storage.Offlin func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { err = s.Scan( - &o.UserID, &o.ConnID, decoder(&o.Refresh), + &o.UserID, &o.ConnID, decoder(&o.Refresh), &o.ConnectorData, ) if err != nil { if err == sql.ErrNoRows { @@ -757,7 +758,7 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto } _, err = tx.Exec(` update connector - set + set type = $1, name = $2, resource_version = $3, diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 0ef62609..5b86bc78 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -90,18 +90,18 @@ var migrations = []migration{ nonce text not null, state text not null, force_approval_prompt boolean not null, - + logged_in boolean not null, - + claims_user_id text not null, claims_username text not null, claims_email text not null, claims_email_verified boolean not null, claims_groups bytea not null, -- JSON array of strings - + connector_id text not null, connector_data bytea, - + expiry timestamptz not null );`, ` @@ -111,16 +111,16 @@ var migrations = []migration{ scopes bytea not null, -- JSON array of strings nonce text not null, redirect_uri text not null, - + claims_user_id text not null, claims_username text not null, claims_email text not null, claims_email_verified boolean not null, claims_groups bytea not null, -- JSON array of strings - + connector_id text not null, connector_data bytea, - + expiry timestamptz not null );`, ` @@ -129,13 +129,13 @@ var migrations = []migration{ client_id text not null, scopes bytea not null, -- JSON array of strings nonce text not null, - + claims_user_id text not null, claims_username text not null, claims_email text not null, claims_email_verified boolean not null, claims_groups bytea not null, -- JSON array of strings - + connector_id text not null, connector_data bytea );`, @@ -202,4 +202,11 @@ var migrations = []migration{ add column claims_preferred_username text not null default '';`, }, }, + { + stmts: []string{` + alter table offline_session + add column connector_data bytea; + `, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index 235f74e0..cb2a7e0c 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -273,6 +273,9 @@ type OfflineSessions struct { // Refresh is a hash table of refresh token reference objects // indexed by the ClientID of the refresh token. Refresh map[string]*RefreshTokenRef + + // Authentication data provided by an upstream source. + ConnectorData []byte } // Password is an email to password mapping managed by the storage.