Merge pull request #1180 from JoelSpeed/refresh-tokens

Implement refreshing with Google
This commit is contained in:
Nándor István Krácser 2019-11-19 17:39:23 +01:00 committed by GitHub
commit b1e98d8590
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 172 additions and 58 deletions

View File

@ -3,12 +3,14 @@ package oidc
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"time"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -60,6 +62,11 @@ var brokenAuthHeaderDomains = []string{
"oktapreview.com", "oktapreview.com",
} }
// connectorData stores information for sessions authenticated by this connector
type connectorData struct {
RefreshToken []byte
}
// Detect auth header provider issues for known providers. This lets users // Detect auth header provider issues for known providers. This lets users
// avoid having to explicitly set "basicAuthUnsupported" in their config. // avoid having to explicitly set "basicAuthUnsupported" in their config.
// //
@ -167,14 +174,19 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
} }
var opts []oauth2.AuthCodeOption
if len(c.hostedDomains) > 0 { if len(c.hostedDomains) > 0 {
preferredDomain := c.hostedDomains[0] preferredDomain := c.hostedDomains[0]
if len(c.hostedDomains) > 1 { if len(c.hostedDomains) > 1 {
preferredDomain = "*" preferredDomain = "*"
} }
return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil opts = append(opts, oauth2.SetAuthURLParam("hd", preferredDomain))
} }
return c.oauth2Config.AuthCodeURL(state), nil
if s.OfflineAccess {
opts = append(opts, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
}
return c.oauth2Config.AuthCodeURL(state, opts...), nil
} }
type oauth2Error struct { type oauth2Error struct {
@ -199,11 +211,35 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
return identity, fmt.Errorf("oidc: failed to get token: %v", err) return identity, fmt.Errorf("oidc: failed to get token: %v", err)
} }
return c.createIdentity(r.Context(), identity, token)
}
// Refresh is used to refresh a session with the refresh token provided by the IdP
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
cd := connectorData{}
err := json.Unmarshal(identity.ConnectorData, &cd)
if err != nil {
return identity, fmt.Errorf("oidc: failed to unmarshal connector data: %v", err)
}
t := &oauth2.Token{
RefreshToken: string(cd.RefreshToken),
Expiry: time.Now().Add(-time.Hour),
}
token, err := c.oauth2Config.TokenSource(ctx, t).Token()
if err != nil {
return identity, fmt.Errorf("oidc: failed to get refresh token: %v", err)
}
return c.createIdentity(ctx, identity, token)
}
func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.Identity, token *oauth2.Token) (connector.Identity, error) {
rawIDToken, ok := token.Extra("id_token").(string) rawIDToken, ok := token.Extra("id_token").(string)
if !ok { if !ok {
return identity, errors.New("oidc: no id_token in token response") return identity, errors.New("oidc: no id_token in token response")
} }
idToken, err := c.verifier.Verify(r.Context(), rawIDToken) idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil { if err != nil {
return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err)
} }
@ -215,7 +251,7 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
// We immediately want to run getUserInfo if configured before we validate the claims // We immediately want to run getUserInfo if configured before we validate the claims
if c.getUserInfo { if c.getUserInfo {
userInfo, err := c.provider.UserInfo(r.Context(), oauth2.StaticTokenSource(token)) userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
if err != nil { if err != nil {
return identity, fmt.Errorf("oidc: error loading userinfo: %v", err) return identity, fmt.Errorf("oidc: error loading userinfo: %v", err)
} }
@ -260,11 +296,21 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
} }
} }
cd := connectorData{
RefreshToken: []byte(token.RefreshToken),
}
connData, err := json.Marshal(&cd)
if err != nil {
return identity, fmt.Errorf("oidc: failed to encode connector data: %v", err)
}
identity = connector.Identity{ identity = connector.Identity{
UserID: idToken.Subject, UserID: idToken.Subject,
Username: name, Username: name,
Email: email, Email: email,
EmailVerified: emailVerified, EmailVerified: emailVerified,
ConnectorData: connData,
} }
if c.userIDKey != "" { if c.userIDKey != "" {
@ -277,8 +323,3 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
return identity, nil return identity, nil
} }
// Refresh is implemented for backwards compatibility, even though it's a no-op.
func (c *oidcConnector) Refresh(ctx context.Context, s connector.Scopes, identity connector.Identity) (connector.Identity, error) {
return identity, nil
}

View File

@ -505,7 +505,45 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.Auth
s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q", s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q",
authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups) authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups)
return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil returnURL := path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID
_, ok := conn.(connector.RefreshConnector)
if !ok {
return returnURL, nil
}
// Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return "", err
}
offlineSessions := storage.OfflineSessions{
UserID: identity.UserID,
ConnID: authReq.ConnectorID,
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: identity.ConnectorData,
}
// Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err)
return "", err
}
} else {
// Update existing OfflineSession obj with new RefreshTokenRef.
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
if len(identity.ConnectorData) > 0 {
old.ConnectorData = identity.ConnectorData
}
return old, nil
}); err != nil {
s.logger.Errorf("failed to update offline session: %v", err)
return "", err
}
}
return returnURL, nil
} }
func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
@ -962,6 +1000,19 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
scopes = requestedScopes scopes = requestedScopes
} }
var connectorData []byte
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err)
return
}
} else if len(refresh.ConnectorData) > 0 {
// Use the old connector data if it exists, should be deleted once used
connectorData = refresh.ConnectorData
} else {
connectorData = session.ConnectorData
}
conn, err := s.getConnector(refresh.ConnectorID) conn, err := s.getConnector(refresh.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err)
@ -975,7 +1026,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Email: refresh.Claims.Email, Email: refresh.Claims.Email,
EmailVerified: refresh.Claims.EmailVerified, EmailVerified: refresh.Claims.EmailVerified,
Groups: refresh.Claims.Groups, Groups: refresh.Claims.Groups,
ConnectorData: refresh.ConnectorData, ConnectorData: connectorData,
} }
// Can the connector refresh the identity? If so, attempt to refresh the data // Can the connector refresh the identity? If so, attempt to refresh the data
@ -1041,8 +1092,10 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
old.Claims.Email = ident.Email old.Claims.Email = ident.Email
old.Claims.EmailVerified = ident.EmailVerified old.Claims.EmailVerified = ident.EmailVerified
old.Claims.Groups = ident.Groups old.Claims.Groups = ident.Groups
old.ConnectorData = ident.ConnectorData
old.LastUsed = lastUsed old.LastUsed = lastUsed
// ConnectorData has been moved to OfflineSession
old.ConnectorData = []byte{}
return old, nil return old, nil
} }
@ -1053,6 +1106,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
return old, errors.New("refresh token invalid") return old, errors.New("refresh token invalid")
} }
old.Refresh[refresh.ClientID].LastUsed = lastUsed old.Refresh[refresh.ClientID].LastUsed = lastUsed
old.ConnectorData = ident.ConnectorData
return old, nil return old, nil
}); err != nil { }); err != nil {
s.logger.Errorf("failed to update offline session: %v", err) s.logger.Errorf("failed to update offline session: %v", err)

View File

@ -515,9 +515,10 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
userID1 := storage.NewID() userID1 := storage.NewID()
session1 := storage.OfflineSessions{ session1 := storage.OfflineSessions{
UserID: userID1, UserID: userID1,
ConnID: "Conn1", ConnID: "Conn1",
Refresh: make(map[string]*storage.RefreshTokenRef), Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
} }
// Creating an OfflineSession with an empty Refresh list to ensure that // Creating an OfflineSession with an empty Refresh list to ensure that
@ -532,9 +533,10 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
userID2 := storage.NewID() userID2 := storage.NewID()
session2 := storage.OfflineSessions{ session2 := storage.OfflineSessions{
UserID: userID2, UserID: userID2,
ConnID: "Conn2", ConnID: "Conn2",
Refresh: make(map[string]*storage.RefreshTokenRef), Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
} }
if err := s.CreateOfflineSessions(session2); err != nil { if err := s.CreateOfflineSessions(session2); err != nil {

View File

@ -188,24 +188,27 @@ type Keys struct {
// OfflineSessions is a mirrored struct from storage with JSON struct tags // OfflineSessions is a mirrored struct from storage with JSON struct tags
type OfflineSessions struct { type OfflineSessions struct {
UserID string `json:"user_id,omitempty"` UserID string `json:"user_id,omitempty"`
ConnID string `json:"conn_id,omitempty"` ConnID string `json:"conn_id,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
ConnectorData []byte `json:"connectorData,omitempty"`
} }
func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
return OfflineSessions{ return OfflineSessions{
UserID: o.UserID, UserID: o.UserID,
ConnID: o.ConnID, ConnID: o.ConnID,
Refresh: o.Refresh, Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
} }
} }
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
s := storage.OfflineSessions{ s := storage.OfflineSessions{
UserID: o.UserID, UserID: o.UserID,
ConnID: o.ConnID, ConnID: o.ConnID,
Refresh: o.Refresh, Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
} }
if s.Refresh == nil { if s.Refresh == nil {
// Server code assumes this will be non-nil. // Server code assumes this will be non-nil.

View File

@ -552,9 +552,10 @@ type OfflineSessions struct {
k8sapi.TypeMeta `json:",inline"` k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"` k8sapi.ObjectMeta `json:"metadata,omitempty"`
UserID string `json:"userID,omitempty"` UserID string `json:"userID,omitempty"`
ConnID string `json:"connID,omitempty"` ConnID string `json:"connID,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
ConnectorData []byte `json:"connectorData,omitempty"`
} }
func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
@ -567,17 +568,19 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline
Name: cli.offlineTokenName(o.UserID, o.ConnID), Name: cli.offlineTokenName(o.UserID, o.ConnID),
Namespace: cli.namespace, Namespace: cli.namespace,
}, },
UserID: o.UserID, UserID: o.UserID,
ConnID: o.ConnID, ConnID: o.ConnID,
Refresh: o.Refresh, Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
} }
} }
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
s := storage.OfflineSessions{ s := storage.OfflineSessions{
UserID: o.UserID, UserID: o.UserID,
ConnID: o.ConnID, ConnID: o.ConnID,
Refresh: o.Refresh, Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
} }
if s.Refresh == nil { if s.Refresh == nil {
// Server code assumes this will be non-nil. // Server code assumes this will be non-nil.

View File

@ -108,7 +108,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
insert into auth_request ( insert into auth_request (
id, client_id, response_types, scopes, redirect_uri, nonce, state, id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in, force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups, claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, connector_id, connector_data,
expiry expiry
@ -178,7 +178,7 @@ func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
err = q.QueryRow(` err = q.QueryRow(`
select select
id, client_id, response_types, scopes, redirect_uri, nonce, state, id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in, force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_preferred_username, claims_user_id, claims_username, claims_preferred_username,
@ -299,7 +299,7 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
claims_email_verified = $8, claims_email_verified = $8,
claims_groups = $9, claims_groups = $9,
connector_id = $10, connector_id = $10,
connector_data = $11, connector_data = $11,
token = $12, token = $12,
created_at = $13, created_at = $13,
last_used = $14 last_used = $14
@ -417,7 +417,7 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
} else { } else {
_, err = tx.Exec(` _, err = tx.Exec(`
update keys update keys
set set
verification_keys = $1, verification_keys = $1,
signing_key = $2, signing_key = $2,
signing_key_pub = $3, signing_key_pub = $3,
@ -655,13 +655,13 @@ func scanPassword(s scanner) (p storage.Password, err error) {
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
_, err := c.Exec(` _, err := c.Exec(`
insert into offline_session ( insert into offline_session (
user_id, conn_id, refresh user_id, conn_id, refresh, connector_data
) )
values ( values (
$1, $2, $3 $1, $2, $3, $4
); );
`, `,
s.UserID, s.ConnID, encoder(s.Refresh), s.UserID, s.ConnID, encoder(s.Refresh), s.ConnectorData,
) )
if err != nil { if err != nil {
if c.alreadyExistsCheck(err) { if c.alreadyExistsCheck(err) {
@ -686,10 +686,11 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
_, err = tx.Exec(` _, err = tx.Exec(`
update offline_session update offline_session
set set
refresh = $1 refresh = $1,
where user_id = $2 AND conn_id = $3; connector_data = $2
where user_id = $3 AND conn_id = $4;
`, `,
encoder(newSession.Refresh), s.UserID, s.ConnID, encoder(newSession.Refresh), newSession.ConnectorData, s.UserID, s.ConnID,
) )
if err != nil { if err != nil {
return fmt.Errorf("update offline session: %v", err) return fmt.Errorf("update offline session: %v", err)
@ -705,7 +706,7 @@ func (c *conn) GetOfflineSessions(userID string, connID string) (storage.Offline
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
return scanOfflineSessions(q.QueryRow(` return scanOfflineSessions(q.QueryRow(`
select select
user_id, conn_id, refresh user_id, conn_id, refresh, connector_data
from offline_session from offline_session
where user_id = $1 AND conn_id = $2; where user_id = $1 AND conn_id = $2;
`, userID, connID)) `, userID, connID))
@ -713,7 +714,7 @@ func getOfflineSessions(q querier, userID string, connID string) (storage.Offlin
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
err = s.Scan( err = s.Scan(
&o.UserID, &o.ConnID, decoder(&o.Refresh), &o.UserID, &o.ConnID, decoder(&o.Refresh), &o.ConnectorData,
) )
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
@ -757,7 +758,7 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
} }
_, err = tx.Exec(` _, err = tx.Exec(`
update connector update connector
set set
type = $1, type = $1,
name = $2, name = $2,
resource_version = $3, resource_version = $3,

View File

@ -90,18 +90,18 @@ var migrations = []migration{
nonce text not null, nonce text not null,
state text not null, state text not null,
force_approval_prompt boolean not null, force_approval_prompt boolean not null,
logged_in boolean not null, logged_in boolean not null,
claims_user_id text not null, claims_user_id text not null,
claims_username text not null, claims_username text not null,
claims_email text not null, claims_email text not null,
claims_email_verified boolean not null, claims_email_verified boolean not null,
claims_groups bytea not null, -- JSON array of strings claims_groups bytea not null, -- JSON array of strings
connector_id text not null, connector_id text not null,
connector_data bytea, connector_data bytea,
expiry timestamptz not null expiry timestamptz not null
);`, );`,
` `
@ -111,16 +111,16 @@ var migrations = []migration{
scopes bytea not null, -- JSON array of strings scopes bytea not null, -- JSON array of strings
nonce text not null, nonce text not null,
redirect_uri text not null, redirect_uri text not null,
claims_user_id text not null, claims_user_id text not null,
claims_username text not null, claims_username text not null,
claims_email text not null, claims_email text not null,
claims_email_verified boolean not null, claims_email_verified boolean not null,
claims_groups bytea not null, -- JSON array of strings claims_groups bytea not null, -- JSON array of strings
connector_id text not null, connector_id text not null,
connector_data bytea, connector_data bytea,
expiry timestamptz not null expiry timestamptz not null
);`, );`,
` `
@ -129,13 +129,13 @@ var migrations = []migration{
client_id text not null, client_id text not null,
scopes bytea not null, -- JSON array of strings scopes bytea not null, -- JSON array of strings
nonce text not null, nonce text not null,
claims_user_id text not null, claims_user_id text not null,
claims_username text not null, claims_username text not null,
claims_email text not null, claims_email text not null,
claims_email_verified boolean not null, claims_email_verified boolean not null,
claims_groups bytea not null, -- JSON array of strings claims_groups bytea not null, -- JSON array of strings
connector_id text not null, connector_id text not null,
connector_data bytea connector_data bytea
);`, );`,
@ -202,4 +202,11 @@ var migrations = []migration{
add column claims_preferred_username text not null default '';`, add column claims_preferred_username text not null default '';`,
}, },
}, },
{
stmts: []string{`
alter table offline_session
add column connector_data bytea;
`,
},
},
} }

View File

@ -273,6 +273,9 @@ type OfflineSessions struct {
// Refresh is a hash table of refresh token reference objects // Refresh is a hash table of refresh token reference objects
// indexed by the ClientID of the refresh token. // indexed by the ClientID of the refresh token.
Refresh map[string]*RefreshTokenRef Refresh map[string]*RefreshTokenRef
// Authentication data provided by an upstream source.
ConnectorData []byte
} }
// Password is an email to password mapping managed by the storage. // Password is an email to password mapping managed by the storage.