storage: Add OfflineSession object to backend storage.
This commit is contained in:
		@@ -47,6 +47,7 @@ func RunTests(t *testing.T, newStorage func() storage.Storage) {
 | 
			
		||||
		{"RefreshTokenCRUD", testRefreshTokenCRUD},
 | 
			
		||||
		{"PasswordCRUD", testPasswordCRUD},
 | 
			
		||||
		{"KeysCRUD", testKeysCRUD},
 | 
			
		||||
		{"OfflineSessionCRUD", testOfflineSessionCRUD},
 | 
			
		||||
		{"GarbageCollection", testGC},
 | 
			
		||||
		{"TimezoneSupport", testTimezones},
 | 
			
		||||
	})
 | 
			
		||||
@@ -340,6 +341,60 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
 | 
			
		||||
	session := storage.OfflineSessions{
 | 
			
		||||
		UserID:  "User",
 | 
			
		||||
		ConnID:  "Conn",
 | 
			
		||||
		Refresh: make(map[string]*storage.RefreshTokenRef),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Creating an OfflineSession with an empty Refresh list to ensure that
 | 
			
		||||
	// an empty map is translated as expected by the storage.
 | 
			
		||||
	if err := s.CreateOfflineSessions(session); err != nil {
 | 
			
		||||
		t.Fatalf("create offline session: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
 | 
			
		||||
		gr, err := s.GetOfflineSessions(userID, connID)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			t.Errorf("get offline session: %v", err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if diff := pretty.Compare(want, gr); diff != "" {
 | 
			
		||||
			t.Errorf("offline session retrieved from storage did not match: %s", diff)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	getAndCompare("User", "Conn", session)
 | 
			
		||||
 | 
			
		||||
	id := storage.NewID()
 | 
			
		||||
	tokenRef := storage.RefreshTokenRef{
 | 
			
		||||
		ID:        id,
 | 
			
		||||
		ClientID:  "client_id",
 | 
			
		||||
		CreatedAt: time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
		LastUsed:  time.Now().UTC().Round(time.Millisecond),
 | 
			
		||||
	}
 | 
			
		||||
	session.Refresh[tokenRef.ClientID] = &tokenRef
 | 
			
		||||
 | 
			
		||||
	if err := s.UpdateOfflineSessions(session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
 | 
			
		||||
		old.Refresh[tokenRef.ClientID] = &tokenRef
 | 
			
		||||
		return old, nil
 | 
			
		||||
	}); err != nil {
 | 
			
		||||
		t.Fatalf("failed to update offline session: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	getAndCompare("User", "Conn", session)
 | 
			
		||||
 | 
			
		||||
	if err := s.DeleteOfflineSessions(session.UserID, session.ConnID); err != nil {
 | 
			
		||||
		t.Fatalf("failed to delete offline session: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if _, err := s.GetOfflineSessions(session.UserID, session.ConnID); err != storage.ErrNotFound {
 | 
			
		||||
		t.Errorf("after deleting offline session expected storage.ErrNotFound, got %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func testKeysCRUD(t *testing.T, s storage.Storage) {
 | 
			
		||||
	updateAndCompare := func(k storage.Keys) {
 | 
			
		||||
		err := s.UpdateKeys(func(oldKeys storage.Keys) (storage.Keys, error) {
 | 
			
		||||
 
 | 
			
		||||
@@ -58,6 +58,12 @@ func (c *client) idToName(s string) string {
 | 
			
		||||
	return idToName(s, c.hash)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// offlineTokenName maps two arbitrary IDs, to a single Kubernetes object name.
 | 
			
		||||
// This is used when more than one field is used to uniquely identify the object.
 | 
			
		||||
func (c *client) offlineTokenName(userID string, connID string) string {
 | 
			
		||||
	return offlineTokenName(userID, connID, c.hash)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'.
 | 
			
		||||
var encoding = base32.NewEncoding("abcdefghijklmnopqrstuvwxyz234567")
 | 
			
		||||
 | 
			
		||||
@@ -65,6 +71,12 @@ func idToName(s string, h func() hash.Hash) string {
 | 
			
		||||
	return strings.TrimRight(encoding.EncodeToString(h().Sum([]byte(s))), "=")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
 | 
			
		||||
	h().Write([]byte(userID))
 | 
			
		||||
	h().Write([]byte(connID))
 | 
			
		||||
	return strings.TrimRight(encoding.EncodeToString(h().Sum(nil)), "=")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
 | 
			
		||||
	basePath := "apis/"
 | 
			
		||||
	if apiVersion == "v1" {
 | 
			
		||||
 
 | 
			
		||||
@@ -15,21 +15,23 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	kindAuthCode     = "AuthCode"
 | 
			
		||||
	kindAuthRequest  = "AuthRequest"
 | 
			
		||||
	kindClient       = "OAuth2Client"
 | 
			
		||||
	kindRefreshToken = "RefreshToken"
 | 
			
		||||
	kindKeys         = "SigningKey"
 | 
			
		||||
	kindPassword     = "Password"
 | 
			
		||||
	kindAuthCode        = "AuthCode"
 | 
			
		||||
	kindAuthRequest     = "AuthRequest"
 | 
			
		||||
	kindClient          = "OAuth2Client"
 | 
			
		||||
	kindRefreshToken    = "RefreshToken"
 | 
			
		||||
	kindKeys            = "SigningKey"
 | 
			
		||||
	kindPassword        = "Password"
 | 
			
		||||
	kindOfflineSessions = "OfflineSessions"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	resourceAuthCode     = "authcodes"
 | 
			
		||||
	resourceAuthRequest  = "authrequests"
 | 
			
		||||
	resourceClient       = "oauth2clients"
 | 
			
		||||
	resourceRefreshToken = "refreshtokens"
 | 
			
		||||
	resourceKeys         = "signingkeies" // Kubernetes attempts to pluralize.
 | 
			
		||||
	resourcePassword     = "passwords"
 | 
			
		||||
	resourceAuthCode        = "authcodes"
 | 
			
		||||
	resourceAuthRequest     = "authrequests"
 | 
			
		||||
	resourceClient          = "oauth2clients"
 | 
			
		||||
	resourceRefreshToken    = "refreshtokens"
 | 
			
		||||
	resourceKeys            = "signingkeies" // Kubernetes attempts to pluralize.
 | 
			
		||||
	resourcePassword        = "passwords"
 | 
			
		||||
	resourceOfflineSessions = "offlinesessions"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Config values for the Kubernetes storage type.
 | 
			
		||||
@@ -156,6 +158,10 @@ func (cli *client) CreateRefresh(r storage.RefreshToken) error {
 | 
			
		||||
	return cli.post(resourceRefreshToken, cli.fromStorageRefreshToken(r))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cli *client) CreateOfflineSessions(o storage.OfflineSessions) error {
 | 
			
		||||
	return cli.post(resourceOfflineSessions, cli.fromStorageOfflineSessions(o))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cli *client) GetAuthRequest(id string) (storage.AuthRequest, error) {
 | 
			
		||||
	var req AuthRequest
 | 
			
		||||
	if err := cli.get(resourceAuthRequest, id, &req); err != nil {
 | 
			
		||||
@@ -235,6 +241,25 @@ func (cli *client) getRefreshToken(id string) (r RefreshToken, err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cli *client) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
 | 
			
		||||
	o, err := cli.getOfflineSessions(userID, connID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.OfflineSessions{}, err
 | 
			
		||||
	}
 | 
			
		||||
	return toStorageOfflineSessions(o), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cli *client) getOfflineSessions(userID string, connID string) (o OfflineSessions, err error) {
 | 
			
		||||
	name := cli.offlineTokenName(userID, connID)
 | 
			
		||||
	if err = cli.get(resourceOfflineSessions, name, &o); err != nil {
 | 
			
		||||
		return OfflineSessions{}, err
 | 
			
		||||
	}
 | 
			
		||||
	if userID != o.UserID || connID != o.ConnID {
 | 
			
		||||
		return OfflineSessions{}, fmt.Errorf("get offline session: wrong session retrieved")
 | 
			
		||||
	}
 | 
			
		||||
	return o, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cli *client) ListClients() ([]storage.Client, error) {
 | 
			
		||||
	return nil, errors.New("not implemented")
 | 
			
		||||
}
 | 
			
		||||
@@ -292,6 +317,15 @@ func (cli *client) DeletePassword(email string) error {
 | 
			
		||||
	return cli.delete(resourcePassword, p.ObjectMeta.Name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cli *client) DeleteOfflineSessions(userID string, connID string) error {
 | 
			
		||||
	// Check for hash collition.
 | 
			
		||||
	o, err := cli.getOfflineSessions(userID, connID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return cli.delete(resourceOfflineSessions, o.ObjectMeta.Name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cli *client) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
 | 
			
		||||
	r, err := cli.getRefreshToken(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -342,6 +376,22 @@ func (cli *client) UpdatePassword(email string, updater func(old storage.Passwor
 | 
			
		||||
	return cli.put(resourcePassword, p.ObjectMeta.Name, newPassword)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cli *client) UpdateOfflineSessions(userID string, connID string, updater func(old storage.OfflineSessions) (storage.OfflineSessions, error)) error {
 | 
			
		||||
	o, err := cli.getOfflineSessions(userID, connID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	updated, err := updater(toStorageOfflineSessions(o))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newOfflineSessions := cli.fromStorageOfflineSessions(updated)
 | 
			
		||||
	newOfflineSessions.ObjectMeta = o.ObjectMeta
 | 
			
		||||
	return cli.put(resourceOfflineSessions, o.ObjectMeta.Name, newOfflineSessions)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
 | 
			
		||||
	firstUpdate := false
 | 
			
		||||
	var keys Keys
 | 
			
		||||
 
 | 
			
		||||
@@ -66,6 +66,14 @@ var thirdPartyResources = []k8sapi.ThirdPartyResource{
 | 
			
		||||
		Description: "Passwords managed by the OIDC server.",
 | 
			
		||||
		Versions:    []k8sapi.APIVersion{{Name: "v1"}},
 | 
			
		||||
	},
 | 
			
		||||
	{
 | 
			
		||||
		ObjectMeta: k8sapi.ObjectMeta{
 | 
			
		||||
			Name: "offline-sessions.oidc.coreos.com",
 | 
			
		||||
		},
 | 
			
		||||
		TypeMeta:    tprMeta,
 | 
			
		||||
		Description: "User sessions with an active refresh token.",
 | 
			
		||||
		Versions:    []k8sapi.APIVersion{{Name: "v1"}},
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// There will only ever be a single keys resource. Maintain this by setting a
 | 
			
		||||
@@ -465,3 +473,38 @@ func toStorageKeys(keys Keys) storage.Keys {
 | 
			
		||||
		NextRotation:     keys.NextRotation,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OfflineSessions is a mirrored struct from storage with JSON struct tags and Kubernetes
 | 
			
		||||
// type metadata.
 | 
			
		||||
type OfflineSessions struct {
 | 
			
		||||
	k8sapi.TypeMeta   `json:",inline"`
 | 
			
		||||
	k8sapi.ObjectMeta `json:"metadata,omitempty"`
 | 
			
		||||
 | 
			
		||||
	UserID  string                              `json:"userID,omitempty"`
 | 
			
		||||
	ConnID  string                              `json:"connID,omitempty"`
 | 
			
		||||
	Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
 | 
			
		||||
	return OfflineSessions{
 | 
			
		||||
		TypeMeta: k8sapi.TypeMeta{
 | 
			
		||||
			Kind:       kindOfflineSessions,
 | 
			
		||||
			APIVersion: cli.apiVersion,
 | 
			
		||||
		},
 | 
			
		||||
		ObjectMeta: k8sapi.ObjectMeta{
 | 
			
		||||
			Name:      cli.offlineTokenName(o.UserID, o.ConnID),
 | 
			
		||||
			Namespace: cli.namespace,
 | 
			
		||||
		},
 | 
			
		||||
		UserID:  o.UserID,
 | 
			
		||||
		ConnID:  o.ConnID,
 | 
			
		||||
		Refresh: o.Refresh,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
 | 
			
		||||
	return storage.OfflineSessions{
 | 
			
		||||
		UserID:  o.UserID,
 | 
			
		||||
		ConnID:  o.ConnID,
 | 
			
		||||
		Refresh: o.Refresh,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -13,12 +13,13 @@ import (
 | 
			
		||||
// New returns an in memory storage.
 | 
			
		||||
func New(logger logrus.FieldLogger) storage.Storage {
 | 
			
		||||
	return &memStorage{
 | 
			
		||||
		clients:       make(map[string]storage.Client),
 | 
			
		||||
		authCodes:     make(map[string]storage.AuthCode),
 | 
			
		||||
		refreshTokens: make(map[string]storage.RefreshToken),
 | 
			
		||||
		authReqs:      make(map[string]storage.AuthRequest),
 | 
			
		||||
		passwords:     make(map[string]storage.Password),
 | 
			
		||||
		logger:        logger,
 | 
			
		||||
		clients:         make(map[string]storage.Client),
 | 
			
		||||
		authCodes:       make(map[string]storage.AuthCode),
 | 
			
		||||
		refreshTokens:   make(map[string]storage.RefreshToken),
 | 
			
		||||
		authReqs:        make(map[string]storage.AuthRequest),
 | 
			
		||||
		passwords:       make(map[string]storage.Password),
 | 
			
		||||
		offlineSessions: make(map[offlineSessionID]storage.OfflineSessions),
 | 
			
		||||
		logger:          logger,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -37,17 +38,23 @@ func (c *Config) Open(logger logrus.FieldLogger) (storage.Storage, error) {
 | 
			
		||||
type memStorage struct {
 | 
			
		||||
	mu sync.Mutex
 | 
			
		||||
 | 
			
		||||
	clients       map[string]storage.Client
 | 
			
		||||
	authCodes     map[string]storage.AuthCode
 | 
			
		||||
	refreshTokens map[string]storage.RefreshToken
 | 
			
		||||
	authReqs      map[string]storage.AuthRequest
 | 
			
		||||
	passwords     map[string]storage.Password
 | 
			
		||||
	clients         map[string]storage.Client
 | 
			
		||||
	authCodes       map[string]storage.AuthCode
 | 
			
		||||
	refreshTokens   map[string]storage.RefreshToken
 | 
			
		||||
	authReqs        map[string]storage.AuthRequest
 | 
			
		||||
	passwords       map[string]storage.Password
 | 
			
		||||
	offlineSessions map[offlineSessionID]storage.OfflineSessions
 | 
			
		||||
 | 
			
		||||
	keys storage.Keys
 | 
			
		||||
 | 
			
		||||
	logger logrus.FieldLogger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type offlineSessionID struct {
 | 
			
		||||
	userID string
 | 
			
		||||
	connID string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *memStorage) tx(f func()) {
 | 
			
		||||
	s.mu.Lock()
 | 
			
		||||
	defer s.mu.Unlock()
 | 
			
		||||
@@ -130,6 +137,32 @@ func (s *memStorage) CreatePassword(p storage.Password) (err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *memStorage) CreateOfflineSessions(o storage.OfflineSessions) (err error) {
 | 
			
		||||
	id := offlineSessionID{
 | 
			
		||||
		userID: o.UserID,
 | 
			
		||||
		connID: o.ConnID,
 | 
			
		||||
	}
 | 
			
		||||
	s.tx(func() {
 | 
			
		||||
		if _, ok := s.offlineSessions[id]; ok {
 | 
			
		||||
			err = storage.ErrAlreadyExists
 | 
			
		||||
		} else {
 | 
			
		||||
			s.offlineSessions[id] = o
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
 | 
			
		||||
	s.tx(func() {
 | 
			
		||||
		var ok bool
 | 
			
		||||
		if c, ok = s.authCodes[id]; !ok {
 | 
			
		||||
			err = storage.ErrNotFound
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *memStorage) GetPassword(email string) (p storage.Password, err error) {
 | 
			
		||||
	email = strings.ToLower(email)
 | 
			
		||||
	s.tx(func() {
 | 
			
		||||
@@ -156,10 +189,10 @@ func (s *memStorage) GetKeys() (keys storage.Keys, err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *memStorage) GetRefresh(token string) (tok storage.RefreshToken, err error) {
 | 
			
		||||
func (s *memStorage) GetRefresh(id string) (tok storage.RefreshToken, err error) {
 | 
			
		||||
	s.tx(func() {
 | 
			
		||||
		var ok bool
 | 
			
		||||
		if tok, ok = s.refreshTokens[token]; !ok {
 | 
			
		||||
		if tok, ok = s.refreshTokens[id]; !ok {
 | 
			
		||||
			err = storage.ErrNotFound
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
@@ -178,6 +211,21 @@ func (s *memStorage) GetAuthRequest(id string) (req storage.AuthRequest, err err
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *memStorage) GetOfflineSessions(userID string, connID string) (o storage.OfflineSessions, err error) {
 | 
			
		||||
	id := offlineSessionID{
 | 
			
		||||
		userID: userID,
 | 
			
		||||
		connID: connID,
 | 
			
		||||
	}
 | 
			
		||||
	s.tx(func() {
 | 
			
		||||
		var ok bool
 | 
			
		||||
		if o, ok = s.offlineSessions[id]; !ok {
 | 
			
		||||
			err = storage.ErrNotFound
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *memStorage) ListClients() (clients []storage.Client, err error) {
 | 
			
		||||
	s.tx(func() {
 | 
			
		||||
		for _, client := range s.clients {
 | 
			
		||||
@@ -228,13 +276,13 @@ func (s *memStorage) DeleteClient(id string) (err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *memStorage) DeleteRefresh(token string) (err error) {
 | 
			
		||||
func (s *memStorage) DeleteRefresh(id string) (err error) {
 | 
			
		||||
	s.tx(func() {
 | 
			
		||||
		if _, ok := s.refreshTokens[token]; !ok {
 | 
			
		||||
		if _, ok := s.refreshTokens[id]; !ok {
 | 
			
		||||
			err = storage.ErrNotFound
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		delete(s.refreshTokens, token)
 | 
			
		||||
		delete(s.refreshTokens, id)
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@@ -261,13 +309,17 @@ func (s *memStorage) DeleteAuthRequest(id string) (err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *memStorage) GetAuthCode(id string) (c storage.AuthCode, err error) {
 | 
			
		||||
func (s *memStorage) DeleteOfflineSessions(userID string, connID string) (err error) {
 | 
			
		||||
	id := offlineSessionID{
 | 
			
		||||
		userID: userID,
 | 
			
		||||
		connID: connID,
 | 
			
		||||
	}
 | 
			
		||||
	s.tx(func() {
 | 
			
		||||
		var ok bool
 | 
			
		||||
		if c, ok = s.authCodes[id]; !ok {
 | 
			
		||||
		if _, ok := s.offlineSessions[id]; !ok {
 | 
			
		||||
			err = storage.ErrNotFound
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		delete(s.offlineSessions, id)
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
@@ -338,3 +390,21 @@ func (s *memStorage) UpdateRefreshToken(id string, updater func(p storage.Refres
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *memStorage) UpdateOfflineSessions(userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) {
 | 
			
		||||
	id := offlineSessionID{
 | 
			
		||||
		userID: userID,
 | 
			
		||||
		connID: connID,
 | 
			
		||||
	}
 | 
			
		||||
	s.tx(func() {
 | 
			
		||||
		r, ok := s.offlineSessions[id]
 | 
			
		||||
		if !ok {
 | 
			
		||||
			err = storage.ErrNotFound
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if r, err = updater(r); err == nil {
 | 
			
		||||
			s.offlineSessions[id] = r
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -624,6 +624,75 @@ func scanPassword(s scanner) (p storage.Password, err error) {
 | 
			
		||||
	return p, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
 | 
			
		||||
	_, err := c.Exec(`
 | 
			
		||||
		insert into offline_session (
 | 
			
		||||
			user_id, conn_id, refresh
 | 
			
		||||
		)
 | 
			
		||||
		values (
 | 
			
		||||
			$1, $2, $3
 | 
			
		||||
		);
 | 
			
		||||
	`,
 | 
			
		||||
		s.UserID, s.ConnID, encoder(s.Refresh),
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("insert offline session: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
 | 
			
		||||
	return c.ExecTx(func(tx *trans) error {
 | 
			
		||||
		s, err := getOfflineSessions(tx, userID, connID)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		newSession, err := updater(s)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		_, err = tx.Exec(`
 | 
			
		||||
			update offline_session
 | 
			
		||||
			set
 | 
			
		||||
				refresh = $1
 | 
			
		||||
			where user_id = $2 AND conn_id = $3;
 | 
			
		||||
		`,
 | 
			
		||||
			encoder(newSession.Refresh), s.UserID, s.ConnID,
 | 
			
		||||
		)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("update offline session: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) {
 | 
			
		||||
	return getOfflineSessions(c, userID, connID)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
 | 
			
		||||
	return scanOfflineSessions(q.QueryRow(`
 | 
			
		||||
		select
 | 
			
		||||
			user_id, conn_id, refresh
 | 
			
		||||
		from offline_session
 | 
			
		||||
		where user_id = $1 AND conn_id = $2;
 | 
			
		||||
		`, userID, connID))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
 | 
			
		||||
	err = s.Scan(
 | 
			
		||||
		&o.UserID, &o.ConnID, decoder(&o.Refresh),
 | 
			
		||||
	)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == sql.ErrNoRows {
 | 
			
		||||
			return o, storage.ErrNotFound
 | 
			
		||||
		}
 | 
			
		||||
		return o, fmt.Errorf("select offline session: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	return o, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", "id", id) }
 | 
			
		||||
func (c *conn) DeleteAuthCode(id string) error    { return c.delete("auth_code", "id", id) }
 | 
			
		||||
func (c *conn) DeleteClient(id string) error      { return c.delete("client", "id", id) }
 | 
			
		||||
@@ -632,6 +701,24 @@ func (c *conn) DeletePassword(email string) error {
 | 
			
		||||
	return c.delete("password", "email", strings.ToLower(email))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *conn) DeleteOfflineSessions(userID string, connID string) error {
 | 
			
		||||
	result, err := c.Exec(`delete from offline_session where user_id = $1 AND conn_id = $2`, userID, connID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("delete offline_session: user_id = %s, conn_id = %s", userID, connID)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// For now mandate that the driver implements RowsAffected. If we ever need to support
 | 
			
		||||
	// a driver that doesn't implement this, we can run this in a transaction with a get beforehand.
 | 
			
		||||
	n, err := result.RowsAffected()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("rows affected: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	if n < 1 {
 | 
			
		||||
		return storage.ErrNotFound
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Do NOT call directly. Does not escape table.
 | 
			
		||||
func (c *conn) delete(table, field, id string) error {
 | 
			
		||||
	result, err := c.Exec(`delete from `+table+` where `+field+` = $1`, id)
 | 
			
		||||
 
 | 
			
		||||
@@ -153,6 +153,7 @@ var migrations = []migration{
 | 
			
		||||
				signing_key_pub bytea not null,   -- JSON object
 | 
			
		||||
				next_rotation timestamptz not null
 | 
			
		||||
			);
 | 
			
		||||
 | 
			
		||||
		`,
 | 
			
		||||
	},
 | 
			
		||||
	{
 | 
			
		||||
@@ -165,4 +166,14 @@ var migrations = []migration{
 | 
			
		||||
				add column last_used timestamptz not null default '0001-01-01 00:00:00 UTC';
 | 
			
		||||
		`,
 | 
			
		||||
	},
 | 
			
		||||
	{
 | 
			
		||||
		stmt: `
 | 
			
		||||
			create table offline_session (
 | 
			
		||||
				user_id text not null,
 | 
			
		||||
				conn_id text not null,
 | 
			
		||||
				refresh bytea not null,
 | 
			
		||||
				PRIMARY KEY (user_id, conn_id)
 | 
			
		||||
			);
 | 
			
		||||
		`,
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -52,6 +52,7 @@ type Storage interface {
 | 
			
		||||
	CreateAuthCode(c AuthCode) error
 | 
			
		||||
	CreateRefresh(r RefreshToken) error
 | 
			
		||||
	CreatePassword(p Password) error
 | 
			
		||||
	CreateOfflineSessions(s OfflineSessions) error
 | 
			
		||||
 | 
			
		||||
	// TODO(ericchiang): return (T, bool, error) so we can indicate not found
 | 
			
		||||
	// requests that way instead of using ErrNotFound.
 | 
			
		||||
@@ -61,6 +62,7 @@ type Storage interface {
 | 
			
		||||
	GetKeys() (Keys, error)
 | 
			
		||||
	GetRefresh(id string) (RefreshToken, error)
 | 
			
		||||
	GetPassword(email string) (Password, error)
 | 
			
		||||
	GetOfflineSessions(userID string, connID string) (OfflineSessions, error)
 | 
			
		||||
 | 
			
		||||
	ListClients() ([]Client, error)
 | 
			
		||||
	ListRefreshTokens() ([]RefreshToken, error)
 | 
			
		||||
@@ -72,6 +74,7 @@ type Storage interface {
 | 
			
		||||
	DeleteClient(id string) error
 | 
			
		||||
	DeleteRefresh(id string) error
 | 
			
		||||
	DeletePassword(email string) error
 | 
			
		||||
	DeleteOfflineSessions(userID string, connID string) error
 | 
			
		||||
 | 
			
		||||
	// Update methods take a function for updating an object then performs that update within
 | 
			
		||||
	// a transaction. "updater" functions may be called multiple times by a single update call.
 | 
			
		||||
@@ -92,6 +95,7 @@ type Storage interface {
 | 
			
		||||
	UpdateAuthRequest(id string, updater func(a AuthRequest) (AuthRequest, error)) error
 | 
			
		||||
	UpdateRefreshToken(id string, updater func(r RefreshToken) (RefreshToken, error)) error
 | 
			
		||||
	UpdatePassword(email string, updater func(p Password) (Password, error)) error
 | 
			
		||||
	UpdateOfflineSessions(userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
 | 
			
		||||
 | 
			
		||||
	// GarbageCollect deletes all expired AuthCodes and AuthRequests.
 | 
			
		||||
	GarbageCollect(now time.Time) (GCResult, error)
 | 
			
		||||
@@ -241,6 +245,30 @@ type RefreshToken struct {
 | 
			
		||||
	Nonce string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RefreshTokenRef is a reference object that contains metadata about refresh tokens.
 | 
			
		||||
type RefreshTokenRef struct {
 | 
			
		||||
	ID string
 | 
			
		||||
 | 
			
		||||
	// Client the refresh token is valid for.
 | 
			
		||||
	ClientID string
 | 
			
		||||
 | 
			
		||||
	CreatedAt time.Time
 | 
			
		||||
	LastUsed  time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OfflineSessions objects are sessions pertaining to users with refresh tokens.
 | 
			
		||||
type OfflineSessions struct {
 | 
			
		||||
	// UserID of an end user who has logged in to the server.
 | 
			
		||||
	UserID string
 | 
			
		||||
 | 
			
		||||
	// The ID of the connector used to login the user.
 | 
			
		||||
	ConnID string
 | 
			
		||||
 | 
			
		||||
	// Refresh is a hash table of refresh token reference objects
 | 
			
		||||
	// indexed by the ClientID of the refresh token.
 | 
			
		||||
	Refresh map[string]*RefreshTokenRef
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Password is an email to password mapping managed by the storage.
 | 
			
		||||
type Password struct {
 | 
			
		||||
	// Email and identifying name of the password. Emails are assumed to be valid and
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user