Merge pull request #793 from rithujohn191/token-revocation
storage: Add OfflineSession object to backend storage.
This commit is contained in:
commit
53e383670a
@ -682,6 +682,75 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
|
|||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deleteToken determines if we need to delete the newly created refresh token
|
||||||
|
// due to a failure in updating/creating the OfflineSession object for the
|
||||||
|
// corresponding user.
|
||||||
|
var deleteToken bool
|
||||||
|
defer func() {
|
||||||
|
if deleteToken {
|
||||||
|
// Delete newly created refresh token from storage.
|
||||||
|
if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
|
||||||
|
s.logger.Errorf("failed to delete refresh token: %v", err)
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tokenRef := storage.RefreshTokenRef{
|
||||||
|
ID: refresh.ID,
|
||||||
|
ClientID: refresh.ClientID,
|
||||||
|
CreatedAt: refresh.CreatedAt,
|
||||||
|
LastUsed: refresh.LastUsed,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to retrieve an existing OfflineSession object for the corresponding user.
|
||||||
|
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
|
||||||
|
if err != storage.ErrNotFound {
|
||||||
|
s.logger.Errorf("failed to get offline session: %v", err)
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
deleteToken = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offlineSessions := storage.OfflineSessions{
|
||||||
|
UserID: refresh.Claims.UserID,
|
||||||
|
ConnID: refresh.ConnectorID,
|
||||||
|
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||||
|
}
|
||||||
|
offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef
|
||||||
|
|
||||||
|
// Create a new OfflineSession object for the user and add a reference object for
|
||||||
|
// the newly recieved refreshtoken.
|
||||||
|
if err := s.storage.CreateOfflineSessions(offlineSessions); err != nil {
|
||||||
|
s.logger.Errorf("failed to create offline session: %v", err)
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
deleteToken = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
|
||||||
|
// Delete old refresh token from storage.
|
||||||
|
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
|
||||||
|
s.logger.Errorf("failed to delete refresh token: %v", err)
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
deleteToken = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update existing OfflineSession obj with new RefreshTokenRef.
|
||||||
|
if err := s.storage.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
||||||
|
old.Refresh[tokenRef.ClientID] = &tokenRef
|
||||||
|
return old, nil
|
||||||
|
}); err != nil {
|
||||||
|
s.logger.Errorf("failed to update offline session: %v", err)
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
deleteToken = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
|
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
|
||||||
}
|
}
|
||||||
@ -815,6 +884,7 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lastUsed := s.now()
|
||||||
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
updater := func(old storage.RefreshToken) (storage.RefreshToken, error) {
|
||||||
if old.Token != refresh.Token {
|
if old.Token != refresh.Token {
|
||||||
return old, errors.New("refresh token claimed twice")
|
return old, errors.New("refresh token claimed twice")
|
||||||
@ -828,14 +898,31 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
|
|||||||
old.Claims.EmailVerified = ident.EmailVerified
|
old.Claims.EmailVerified = ident.EmailVerified
|
||||||
old.Claims.Groups = ident.Groups
|
old.Claims.Groups = ident.Groups
|
||||||
old.ConnectorData = ident.ConnectorData
|
old.ConnectorData = ident.ConnectorData
|
||||||
old.LastUsed = s.now()
|
old.LastUsed = lastUsed
|
||||||
return old, nil
|
return old, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Update LastUsed time stamp in refresh token reference object
|
||||||
|
// in offline session for the user.
|
||||||
|
if err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
||||||
|
if old.Refresh[refresh.ClientID].ID != refresh.ID {
|
||||||
|
return old, errors.New("refresh token invalid")
|
||||||
|
}
|
||||||
|
old.Refresh[refresh.ClientID].LastUsed = lastUsed
|
||||||
|
return old, nil
|
||||||
|
}); err != nil {
|
||||||
|
s.logger.Errorf("failed to update offline session: %v", err)
|
||||||
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update refresh token in the storage.
|
||||||
if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
|
if err := s.storage.UpdateRefreshToken(refresh.ID, updater); err != nil {
|
||||||
s.logger.Errorf("failed to update refresh token: %v", err)
|
s.logger.Errorf("failed to update refresh token: %v", err)
|
||||||
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -971,3 +971,108 @@ func TestKeyCacher(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type oauth2Client struct {
|
||||||
|
config *oauth2.Config
|
||||||
|
token *oauth2.Token
|
||||||
|
server *httptest.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRefreshTokenFlow tests the refresh token code flow for oauth2. The test verifies
|
||||||
|
// that only valid refresh tokens can be used to refresh an expired token.
|
||||||
|
func TestRefreshTokenFlow(t *testing.T) {
|
||||||
|
state := "state"
|
||||||
|
now := func() time.Time { return time.Now() }
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
httpServer, s := newTestServer(ctx, t, func(c *Config) {
|
||||||
|
c.Now = now
|
||||||
|
})
|
||||||
|
defer httpServer.Close()
|
||||||
|
|
||||||
|
p, err := oidc.NewProvider(ctx, httpServer.URL)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to get provider: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var oauth2Client oauth2Client
|
||||||
|
|
||||||
|
oauth2Client.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/callback" {
|
||||||
|
// User is visiting app first time. Redirect to dex.
|
||||||
|
http.Redirect(w, r, oauth2Client.config.AuthCodeURL(state), http.StatusSeeOther)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// User is at '/callback' so they were just redirected _from_ dex.
|
||||||
|
q := r.URL.Query()
|
||||||
|
|
||||||
|
if errType := q.Get("error"); errType != "" {
|
||||||
|
if desc := q.Get("error_description"); desc != "" {
|
||||||
|
t.Errorf("got error from server %s: %s", errType, desc)
|
||||||
|
} else {
|
||||||
|
t.Errorf("got error from server %s", errType)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grab code, exchange for token.
|
||||||
|
if code := q.Get("code"); code != "" {
|
||||||
|
token, err := oauth2Client.config.Exchange(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to exchange code for token: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
oauth2Client.token = token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure state matches.
|
||||||
|
if gotState := q.Get("state"); gotState != state {
|
||||||
|
t.Errorf("state did not match, want=%q got=%q", state, gotState)
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return
|
||||||
|
}))
|
||||||
|
defer oauth2Client.server.Close()
|
||||||
|
|
||||||
|
// Register the client above with dex.
|
||||||
|
redirectURL := oauth2Client.server.URL + "/callback"
|
||||||
|
client := storage.Client{
|
||||||
|
ID: "testclient",
|
||||||
|
Secret: "testclientsecret",
|
||||||
|
RedirectURIs: []string{redirectURL},
|
||||||
|
}
|
||||||
|
if err := s.storage.CreateClient(client); err != nil {
|
||||||
|
t.Fatalf("failed to create client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth2Client.config = &oauth2.Config{
|
||||||
|
ClientID: client.ID,
|
||||||
|
ClientSecret: client.Secret,
|
||||||
|
Endpoint: p.Endpoint(),
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "email", "offline_access"},
|
||||||
|
RedirectURL: redirectURL,
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
|
||||||
|
t.Fatalf("get failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tok := &oauth2.Token{
|
||||||
|
RefreshToken: oauth2Client.token.RefreshToken,
|
||||||
|
Expiry: time.Now().Add(-time.Hour),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login in again to recieve a new token.
|
||||||
|
if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
|
||||||
|
t.Fatalf("get failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// try to refresh expired token with old refresh token.
|
||||||
|
newToken, err := oauth2Client.config.TokenSource(ctx, tok).Token()
|
||||||
|
if newToken != nil {
|
||||||
|
t.Errorf("Token refreshed with invalid refresh token.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -47,6 +47,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
|
|||||||
{"RefreshTokenCRUD", testRefreshTokenCRUD},
|
{"RefreshTokenCRUD", testRefreshTokenCRUD},
|
||||||
{"PasswordCRUD", testPasswordCRUD},
|
{"PasswordCRUD", testPasswordCRUD},
|
||||||
{"KeysCRUD", testKeysCRUD},
|
{"KeysCRUD", testKeysCRUD},
|
||||||
|
{"OfflineSessionCRUD", testOfflineSessionCRUD},
|
||||||
{"GarbageCollection", testGC},
|
{"GarbageCollection", testGC},
|
||||||
{"TimezoneSupport", testTimezones},
|
{"TimezoneSupport", testTimezones},
|
||||||
})
|
})
|
||||||
@ -340,6 +341,60 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
|
||||||
|
session := storage.OfflineSessions{
|
||||||
|
UserID: "User",
|
||||||
|
ConnID: "Conn",
|
||||||
|
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creating an OfflineSession with an empty Refresh list to ensure that
|
||||||
|
// an empty map is translated as expected by the storage.
|
||||||
|
if err := s.CreateOfflineSessions(session); err != nil {
|
||||||
|
t.Fatalf("create offline session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
|
||||||
|
gr, err := s.GetOfflineSessions(userID, connID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("get offline session: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if diff := pretty.Compare(want, gr); diff != "" {
|
||||||
|
t.Errorf("offline session retrieved from storage did not match: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
getAndCompare("User", "Conn", session)
|
||||||
|
|
||||||
|
id := storage.NewID()
|
||||||
|
tokenRef := storage.RefreshTokenRef{
|
||||||
|
ID: id,
|
||||||
|
ClientID: "client_id",
|
||||||
|
CreatedAt: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
LastUsed: time.Now().UTC().Round(time.Millisecond),
|
||||||
|
}
|
||||||
|
session.Refresh[tokenRef.ClientID] = &tokenRef
|
||||||
|
|
||||||
|
if err := s.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
|
||||||
|
old.Refresh[tokenRef.ClientID] = &tokenRef
|
||||||
|
return old, nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("failed to update offline session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
getAndCompare("User", "Conn", session)
|
||||||
|
|
||||||
|
if err := s.DeleteOfflineSessions(session.UserID, session.ConnID); err != nil {
|
||||||
|
t.Fatalf("failed to delete offline session: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := s.GetOfflineSessions(session.UserID, session.ConnID); err != storage.ErrNotFound {
|
||||||
|
t.Errorf("after deleting offline session expected storage.ErrNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func testKeysCRUD(t *testing.T, s storage.Storage) {
|
func testKeysCRUD(t *testing.T, s storage.Storage) {
|
||||||
updateAndCompare := func(k storage.Keys) {
|
updateAndCompare := func(k storage.Keys) {
|
||||||
err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) {
|
err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) {
|
||||||
|
@ -58,6 +58,12 @@ func (c *client) idToName(s string) string {
|
|||||||
return idToName(s, c.hash)
|
return idToName(s, c.hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// offlineTokenName maps two arbitrary IDs, to a single Kubernetes object name.
|
||||||
|
// This is used when more than one field is used to uniquely identify the object.
|
||||||
|
func (c *client) offlineTokenName(userID string, connID string) string {
|
||||||
|
return offlineTokenName(userID, connID, c.hash)
|
||||||
|
}
|
||||||
|
|
||||||
// Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'.
|
// Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'.
|
||||||
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
|
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
|
||||||
|
|
||||||
@ -65,6 +71,12 @@ func idToName(s string, h func() hash.Hash) string {
|
|||||||
return strings.TrimRight(encoding.EncodeToString(h().Sum([]byte(s))), "=")
|
return strings.TrimRight(encoding.EncodeToString(h().Sum([]byte(s))), "=")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
|
||||||
|
h().Write([]byte(userID))
|
||||||
|
h().Write([]byte(connID))
|
||||||
|
return strings.TrimRight(encoding.EncodeToString(h().Sum(nil)), "=")
|
||||||
|
}
|
||||||
|
|
||||||
func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
|
func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
|
||||||
basePath := "apis/"
|
basePath := "apis/"
|
||||||
if apiVersion == "v1" {
|
if apiVersion == "v1" {
|
||||||
|
@ -21,6 +21,7 @@ const (
|
|||||||
kindRefreshToken = "RefreshToken"
|
kindRefreshToken = "RefreshToken"
|
||||||
kindKeys = "SigningKey"
|
kindKeys = "SigningKey"
|
||||||
kindPassword = "Password"
|
kindPassword = "Password"
|
||||||
|
kindOfflineSessions = "OfflineSessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -30,6 +31,7 @@ const (
|
|||||||
resourceRefreshToken = "refreshtokens"
|
resourceRefreshToken = "refreshtokens"
|
||||||
resourceKeys = "signingkeies" // Kubernetes attempts to pluralize.
|
resourceKeys = "signingkeies" // Kubernetes attempts to pluralize.
|
||||||
resourcePassword = "passwords"
|
resourcePassword = "passwords"
|
||||||
|
resourceOfflineSessions = "offlinesessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config values for the Kubernetes storage type.
|
// Config values for the Kubernetes storage type.
|
||||||
@ -156,6 +158,10 @@ func (cli *client) CreateRefresh(r storage.RefreshToken) error {
|
|||||||
return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
|
return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cli *client) CreateOfflineSessions(o storage.OfflineSessions) error {
|
||||||
|
return cli.post(resourceOfflineSessions, cli.fromStorageOfflineSessions(o))
|
||||||
|
}
|
||||||
|
|
||||||
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||||
var req AuthRequest
|
var req AuthRequest
|
||||||
if err := cli.get(resourceAuthRequest, id, &req); err != nil {
|
if err := cli.get(resourceAuthRequest, id, &req); err != nil {
|
||||||
@ -235,6 +241,25 @@ func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cli *client) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
|
||||||
|
o, err := cli.getOfflineSessions(userID, connID)
|
||||||
|
if err != nil {
|
||||||
|
return storage.OfflineSessions{}, err
|
||||||
|
}
|
||||||
|
return toStorageOfflineSessions(o), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) getOfflineSessions(userID string, connID string) (o OfflineSessions, err error) {
|
||||||
|
name := cli.offlineTokenName(userID, connID)
|
||||||
|
if err = cli.get(resourceOfflineSessions, name, &o); err != nil {
|
||||||
|
return OfflineSessions{}, err
|
||||||
|
}
|
||||||
|
if userID != o.UserID || connID != o.ConnID {
|
||||||
|
return OfflineSessions{}, fmt.Errorf("get offline session: wrong session retrieved")
|
||||||
|
}
|
||||||
|
return o, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (cli *client) ListClients() ([]storage.Client, error) {
|
func (cli *client) ListClients() ([]storage.Client, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
@ -292,6 +317,15 @@ func (cli *client) DeletePassword(email string) error {
|
|||||||
return cli.delete(resourcePassword, p.ObjectMeta.Name)
|
return cli.delete(resourcePassword, p.ObjectMeta.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cli *client) DeleteOfflineSessions(userID string, connID string) error {
|
||||||
|
// Check for hash collition.
|
||||||
|
o, err := cli.getOfflineSessions(userID, connID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return cli.delete(resourceOfflineSessions, o.ObjectMeta.Name)
|
||||||
|
}
|
||||||
|
|
||||||
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||||
r, err := cli.getRefreshToken(id)
|
r, err := cli.getRefreshToken(id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -342,6 +376,22 @@ func (cli *client) UpdatePassword(email string, updater func(old storage.Passwor
|
|||||||
return cli.put(resourcePassword, p.ObjectMeta.Name, newPassword)
|
return cli.put(resourcePassword, p.ObjectMeta.Name, newPassword)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cli *client) UpdateOfflineSessions(userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||||
|
o, err := cli.getOfflineSessions(userID, connID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := updater(toStorageOfflineSessions(o))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newOfflineSessions := cli.fromStorageOfflineSessions(updated)
|
||||||
|
newOfflineSessions.ObjectMeta = o.ObjectMeta
|
||||||
|
return cli.put(resourceOfflineSessions, o.ObjectMeta.Name, newOfflineSessions)
|
||||||
|
}
|
||||||
|
|
||||||
func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
||||||
firstUpdate := false
|
firstUpdate := false
|
||||||
var keys Keys
|
var keys Keys
|
||||||
|
@ -66,6 +66,14 @@ var thirdPartyResources = []k8sapi.ThirdPartyResource{
|
|||||||
Description: "Passwords managed by the OIDC server.",
|
Description: "Passwords managed by the OIDC server.",
|
||||||
Versions: []k8sapi.APIVersion{{Name: "v1"}},
|
Versions: []k8sapi.APIVersion{{Name: "v1"}},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ObjectMeta: k8sapi.ObjectMeta{
|
||||||
|
Name: "offline-sessions.oidc.coreos.com",
|
||||||
|
},
|
||||||
|
TypeMeta: tprMeta,
|
||||||
|
Description: "User sessions with an active refresh token.",
|
||||||
|
Versions: []k8sapi.APIVersion{{Name: "v1"}},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// There will only ever be a single keys resource. Maintain this by setting a
|
// There will only ever be a single keys resource. Maintain this by setting a
|
||||||
@ -465,3 +473,38 @@ func toStorageKeys(keys Keys) storage.Keys {
|
|||||||
NextRotation: keys.NextRotation,
|
NextRotation: keys.NextRotation,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OfflineSessions is a mirrored struct from storage with JSON struct tags and Kubernetes
|
||||||
|
// type metadata.
|
||||||
|
type OfflineSessions struct {
|
||||||
|
k8sapi.TypeMeta `json:",inline"`
|
||||||
|
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
||||||
|
|
||||||
|
UserID string `json:"userID,omitempty"`
|
||||||
|
ConnID string `json:"connID,omitempty"`
|
||||||
|
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
|
||||||
|
return OfflineSessions{
|
||||||
|
TypeMeta: k8sapi.TypeMeta{
|
||||||
|
Kind: kindOfflineSessions,
|
||||||
|
APIVersion: cli.apiVersion,
|
||||||
|
},
|
||||||
|
ObjectMeta: k8sapi.ObjectMeta{
|
||||||
|
Name: cli.offlineTokenName(o.UserID, o.ConnID),
|
||||||
|
Namespace: cli.namespace,
|
||||||
|
},
|
||||||
|
UserID: o.UserID,
|
||||||
|
ConnID: o.ConnID,
|
||||||
|
Refresh: o.Refresh,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
||||||
|
return storage.OfflineSessions{
|
||||||
|
UserID: o.UserID,
|
||||||
|
ConnID: o.ConnID,
|
||||||
|
Refresh: o.Refresh,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -18,6 +18,7 @@ func New(logger logrus.FieldLogger) storage.Storage {
|
|||||||
refreshTokens: make(map[string]storage.RefreshToken),
|
refreshTokens: make(map[string]storage.RefreshToken),
|
||||||
authReqs: make(map[string]storage.AuthRequest),
|
authReqs: make(map[string]storage.AuthRequest),
|
||||||
passwords: make(map[string]storage.Password),
|
passwords: make(map[string]storage.Password),
|
||||||
|
offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -42,12 +43,18 @@ type memStorage struct {
|
|||||||
refreshTokens map[string]storage.RefreshToken
|
refreshTokens map[string]storage.RefreshToken
|
||||||
authReqs map[string]storage.AuthRequest
|
authReqs map[string]storage.AuthRequest
|
||||||
passwords map[string]storage.Password
|
passwords map[string]storage.Password
|
||||||
|
offlineSessions map[offlineSessionID]storage.OfflineSessions
|
||||||
|
|
||||||
keys storage.Keys
|
keys storage.Keys
|
||||||
|
|
||||||
logger logrus.FieldLogger
|
logger logrus.FieldLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type offlineSessionID struct {
|
||||||
|
userID string
|
||||||
|
connID string
|
||||||
|
}
|
||||||
|
|
||||||
func (s *memStorage) tx(f func()) {
|
func (s *memStorage) tx(f func()) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@ -130,6 +137,32 @@ func (s *memStorage) CreatePassword(p storage.Password) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) CreateOfflineSessions(o storage.OfflineSessions) (err error) {
|
||||||
|
id := offlineSessionID{
|
||||||
|
userID: o.UserID,
|
||||||
|
connID: o.ConnID,
|
||||||
|
}
|
||||||
|
s.tx(func() {
|
||||||
|
if _, ok := s.offlineSessions[id]; ok {
|
||||||
|
err = storage.ErrAlreadyExists
|
||||||
|
} else {
|
||||||
|
s.offlineSessions[id] = o
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
|
||||||
|
s.tx(func() {
|
||||||
|
var ok bool
|
||||||
|
if c, ok = s.authCodes[id]; !ok {
|
||||||
|
err = storage.ErrNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (s *memStorage) GetPassword(email string) (p storage.Password, err error) {
|
func (s *memStorage) GetPassword(email string) (p storage.Password, err error) {
|
||||||
email = strings.ToLower(email)
|
email = strings.ToLower(email)
|
||||||
s.tx(func() {
|
s.tx(func() {
|
||||||
@ -156,10 +189,10 @@ func (s *memStorage) GetKeys() (keys storage.Keys, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *memStorage) GetRefresh(token string) (tok storage.RefreshToken, err error) {
|
func (s *memStorage) GetRefresh(id string) (tok storage.RefreshToken, err error) {
|
||||||
s.tx(func() {
|
s.tx(func() {
|
||||||
var ok bool
|
var ok bool
|
||||||
if tok, ok = s.refreshTokens[token]; !ok {
|
if tok, ok = s.refreshTokens[id]; !ok {
|
||||||
err = storage.ErrNotFound
|
err = storage.ErrNotFound
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -178,6 +211,21 @@ func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err err
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) GetOfflineSessions(userID string, connID string) (o storage.OfflineSessions, err error) {
|
||||||
|
id := offlineSessionID{
|
||||||
|
userID: userID,
|
||||||
|
connID: connID,
|
||||||
|
}
|
||||||
|
s.tx(func() {
|
||||||
|
var ok bool
|
||||||
|
if o, ok = s.offlineSessions[id]; !ok {
|
||||||
|
err = storage.ErrNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (s *memStorage) ListClients() (clients []storage.Client, err error) {
|
func (s *memStorage) ListClients() (clients []storage.Client, err error) {
|
||||||
s.tx(func() {
|
s.tx(func() {
|
||||||
for _, client := range s.clients {
|
for _, client := range s.clients {
|
||||||
@ -228,13 +276,13 @@ func (s *memStorage) DeleteClient(id string) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *memStorage) DeleteRefresh(token string) (err error) {
|
func (s *memStorage) DeleteRefresh(id string) (err error) {
|
||||||
s.tx(func() {
|
s.tx(func() {
|
||||||
if _, ok := s.refreshTokens[token]; !ok {
|
if _, ok := s.refreshTokens[id]; !ok {
|
||||||
err = storage.ErrNotFound
|
err = storage.ErrNotFound
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
delete(s.refreshTokens, token)
|
delete(s.refreshTokens, id)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -261,13 +309,17 @@ func (s *memStorage) DeleteAuthRequest(id string) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
|
func (s *memStorage) DeleteOfflineSessions(userID string, connID string) (err error) {
|
||||||
|
id := offlineSessionID{
|
||||||
|
userID: userID,
|
||||||
|
connID: connID,
|
||||||
|
}
|
||||||
s.tx(func() {
|
s.tx(func() {
|
||||||
var ok bool
|
if _, ok := s.offlineSessions[id]; !ok {
|
||||||
if c, ok = s.authCodes[id]; !ok {
|
|
||||||
err = storage.ErrNotFound
|
err = storage.ErrNotFound
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
delete(s.offlineSessions, id)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -338,3 +390,21 @@ func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.Refres
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *memStorage) UpdateOfflineSessions(userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) {
|
||||||
|
id := offlineSessionID{
|
||||||
|
userID: userID,
|
||||||
|
connID: connID,
|
||||||
|
}
|
||||||
|
s.tx(func() {
|
||||||
|
r, ok := s.offlineSessions[id]
|
||||||
|
if !ok {
|
||||||
|
err = storage.ErrNotFound
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r, err = updater(r); err == nil {
|
||||||
|
s.offlineSessions[id] = r
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -624,6 +624,75 @@ func scanPassword(s scanner) (p storage.Password, err error) {
|
|||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||||
|
_, err := c.Exec(`
|
||||||
|
insert into offline_session (
|
||||||
|
user_id, conn_id, refresh
|
||||||
|
)
|
||||||
|
values (
|
||||||
|
$1, $2, $3
|
||||||
|
);
|
||||||
|
`,
|
||||||
|
s.UserID, s.ConnID, encoder(s.Refresh),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert offline session: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||||
|
return c.ExecTx(func(tx *trans) error {
|
||||||
|
s, err := getOfflineSessions(tx, userID, connID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newSession, err := updater(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = tx.Exec(`
|
||||||
|
update offline_session
|
||||||
|
set
|
||||||
|
refresh = $1
|
||||||
|
where user_id = $2 AND conn_id = $3;
|
||||||
|
`,
|
||||||
|
encoder(newSession.Refresh), s.UserID, s.ConnID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update offline session: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
|
||||||
|
return getOfflineSessions(c, userID, connID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
|
||||||
|
return scanOfflineSessions(q.QueryRow(`
|
||||||
|
select
|
||||||
|
user_id, conn_id, refresh
|
||||||
|
from offline_session
|
||||||
|
where user_id = $1 AND conn_id = $2;
|
||||||
|
`, userID, connID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
|
||||||
|
err = s.Scan(
|
||||||
|
&o.UserID, &o.ConnID, decoder(&o.Refresh),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return o, storage.ErrNotFound
|
||||||
|
}
|
||||||
|
return o, fmt.Errorf("select offline session: %v", err)
|
||||||
|
}
|
||||||
|
return o, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) }
|
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) }
|
||||||
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) }
|
func (c *conn) DeleteAuthCode(id string) error { return c.delete("auth_code", "id", id) }
|
||||||
func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) }
|
func (c *conn) DeleteClient(id string) error { return c.delete("client", "id", id) }
|
||||||
@ -632,6 +701,24 @@ func (c *conn) DeletePassword(email string) error {
|
|||||||
return c.delete("password", "email", strings.ToLower(email))
|
return c.delete("password", "email", strings.ToLower(email))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
|
||||||
|
result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For now mandate that the driver implements RowsAffected. If we ever need to support
|
||||||
|
// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("rows affected: %v", err)
|
||||||
|
}
|
||||||
|
if n < 1 {
|
||||||
|
return storage.ErrNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Do NOT call directly. Does not escape table.
|
// Do NOT call directly. Does not escape table.
|
||||||
func (c *conn) delete(table, field, id string) error {
|
func (c *conn) delete(table, field, id string) error {
|
||||||
result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id)
|
result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id)
|
||||||
|
@ -153,6 +153,7 @@ var migrations = []migration{
|
|||||||
signing_key_pub bytea not null, -- JSON object
|
signing_key_pub bytea not null, -- JSON object
|
||||||
next_rotation timestamptz not null
|
next_rotation timestamptz not null
|
||||||
);
|
);
|
||||||
|
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -165,4 +166,14 @@ var migrations = []migration{
|
|||||||
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
|
add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
|
||||||
`,
|
`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
stmt: `
|
||||||
|
create table offline_session (
|
||||||
|
user_id text not null,
|
||||||
|
conn_id text not null,
|
||||||
|
refresh bytea not null,
|
||||||
|
PRIMARY KEY (user_id, conn_id)
|
||||||
|
);
|
||||||
|
`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,7 @@ type Storage interface {
|
|||||||
CreateAuthCode(c AuthCode) error
|
CreateAuthCode(c AuthCode) error
|
||||||
CreateRefresh(r RefreshToken) error
|
CreateRefresh(r RefreshToken) error
|
||||||
CreatePassword(p Password) error
|
CreatePassword(p Password) error
|
||||||
|
CreateOfflineSessions(s OfflineSessions) error
|
||||||
|
|
||||||
// TODO(ericchiang): return (T, bool, error) so we can indicate not found
|
// TODO(ericchiang): return (T, bool, error) so we can indicate not found
|
||||||
// requests that way instead of using ErrNotFound.
|
// requests that way instead of using ErrNotFound.
|
||||||
@ -61,6 +62,7 @@ type Storage interface {
|
|||||||
GetKeys() (Keys, error)
|
GetKeys() (Keys, error)
|
||||||
GetRefresh(id string) (RefreshToken, error)
|
GetRefresh(id string) (RefreshToken, error)
|
||||||
GetPassword(email string) (Password, error)
|
GetPassword(email string) (Password, error)
|
||||||
|
GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
|
||||||
|
|
||||||
ListClients() ([]Client, error)
|
ListClients() ([]Client, error)
|
||||||
ListRefreshTokens() ([]RefreshToken, error)
|
ListRefreshTokens() ([]RefreshToken, error)
|
||||||
@ -72,6 +74,7 @@ type Storage interface {
|
|||||||
DeleteClient(id string) error
|
DeleteClient(id string) error
|
||||||
DeleteRefresh(id string) error
|
DeleteRefresh(id string) error
|
||||||
DeletePassword(email string) error
|
DeletePassword(email string) error
|
||||||
|
DeleteOfflineSessions(userID string, connID string) error
|
||||||
|
|
||||||
// Update methods take a function for updating an object then performs that update within
|
// Update methods take a function for updating an object then performs that update within
|
||||||
// a transaction. "updater" functions may be called multiple times by a single update call.
|
// a transaction. "updater" functions may be called multiple times by a single update call.
|
||||||
@ -92,6 +95,7 @@ type Storage interface {
|
|||||||
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
|
UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
|
||||||
UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
|
UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
|
||||||
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
UpdatePassword(email string, updater func(p Password) (Password, error)) error
|
||||||
|
UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
|
||||||
|
|
||||||
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
|
// GarbageCollect deletes all expired AuthCodes and AuthRequests.
|
||||||
GarbageCollect(now time.Time) (GCResult, error)
|
GarbageCollect(now time.Time) (GCResult, error)
|
||||||
@ -241,6 +245,30 @@ type RefreshToken struct {
|
|||||||
Nonce string
|
Nonce string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshTokenRef is a reference object that contains metadata about refresh tokens.
|
||||||
|
type RefreshTokenRef struct {
|
||||||
|
ID string
|
||||||
|
|
||||||
|
// Client the refresh token is valid for.
|
||||||
|
ClientID string
|
||||||
|
|
||||||
|
CreatedAt time.Time
|
||||||
|
LastUsed time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// OfflineSessions objects are sessions pertaining to users with refresh tokens.
|
||||||
|
type OfflineSessions struct {
|
||||||
|
// UserID of an end user who has logged in to the server.
|
||||||
|
UserID string
|
||||||
|
|
||||||
|
// The ID of the connector used to login the user.
|
||||||
|
ConnID string
|
||||||
|
|
||||||
|
// Refresh is a hash table of refresh token reference objects
|
||||||
|
// indexed by the ClientID of the refresh token.
|
||||||
|
Refresh map[string]*RefreshTokenRef
|
||||||
|
}
|
||||||
|
|
||||||
// Password is an email to password mapping managed by the storage.
|
// Password is an email to password mapping managed by the storage.
|
||||||
type Password struct {
|
type Password struct {
|
||||||
// Email and identifying name of the password. Emails are assumed to be valid and
|
// Email and identifying name of the password. Emails are assumed to be valid and
|
||||||
|
Reference in New Issue
Block a user