feat: Add ent-based sqlite3 storage
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
52
storage/ent/client/authcode.go
Normal file
52
storage/ent/client/authcode.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateAuthCode saves provided auth code into the database.
|
||||
func (d *Database) CreateAuthCode(code storage.AuthCode) error {
|
||||
_, err := d.client.AuthCode.Create().
|
||||
SetID(code.ID).
|
||||
SetClientID(code.ClientID).
|
||||
SetScopes(code.Scopes).
|
||||
SetRedirectURI(code.RedirectURI).
|
||||
SetNonce(code.Nonce).
|
||||
SetClaimsUserID(code.Claims.UserID).
|
||||
SetClaimsEmail(code.Claims.Email).
|
||||
SetClaimsEmailVerified(code.Claims.EmailVerified).
|
||||
SetClaimsUsername(code.Claims.Username).
|
||||
SetClaimsPreferredUsername(code.Claims.PreferredUsername).
|
||||
SetClaimsGroups(code.Claims.Groups).
|
||||
SetCodeChallenge(code.PKCE.CodeChallenge).
|
||||
SetCodeChallengeMethod(code.PKCE.CodeChallengeMethod).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(code.Expiry.UTC()).
|
||||
SetConnectorID(code.ConnectorID).
|
||||
SetConnectorData(code.ConnectorData).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create auth code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAuthCode extracts an auth code from the database by id.
|
||||
func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) {
|
||||
authCode, err := d.client.AuthCode.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.AuthCode{}, convertDBError("get auth code: %w", err)
|
||||
}
|
||||
return toStorageAuthCode(authCode), nil
|
||||
}
|
||||
|
||||
// DeleteAuthCode deletes an auth code from the database by id.
|
||||
func (d *Database) DeleteAuthCode(id string) error {
|
||||
err := d.client.AuthCode.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete auth code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
107
storage/ent/client/authrequest.go
Normal file
107
storage/ent/client/authrequest.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateAuthRequest saves provided auth request into the database.
|
||||
func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error {
|
||||
_, err := d.client.AuthRequest.Create().
|
||||
SetID(authRequest.ID).
|
||||
SetClientID(authRequest.ClientID).
|
||||
SetScopes(authRequest.Scopes).
|
||||
SetResponseTypes(authRequest.ResponseTypes).
|
||||
SetRedirectURI(authRequest.RedirectURI).
|
||||
SetState(authRequest.State).
|
||||
SetNonce(authRequest.Nonce).
|
||||
SetForceApprovalPrompt(authRequest.ForceApprovalPrompt).
|
||||
SetLoggedIn(authRequest.LoggedIn).
|
||||
SetClaimsUserID(authRequest.Claims.UserID).
|
||||
SetClaimsEmail(authRequest.Claims.Email).
|
||||
SetClaimsEmailVerified(authRequest.Claims.EmailVerified).
|
||||
SetClaimsUsername(authRequest.Claims.Username).
|
||||
SetClaimsPreferredUsername(authRequest.Claims.PreferredUsername).
|
||||
SetClaimsGroups(authRequest.Claims.Groups).
|
||||
SetCodeChallenge(authRequest.PKCE.CodeChallenge).
|
||||
SetCodeChallengeMethod(authRequest.PKCE.CodeChallengeMethod).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(authRequest.Expiry.UTC()).
|
||||
SetConnectorID(authRequest.ConnectorID).
|
||||
SetConnectorData(authRequest.ConnectorData).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create auth request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAuthRequest extracts an auth request from the database by id.
|
||||
func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||
authRequest, err := d.client.AuthRequest.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.AuthRequest{}, convertDBError("get auth request: %w", err)
|
||||
}
|
||||
return toStorageAuthRequest(authRequest), nil
|
||||
}
|
||||
|
||||
// DeleteAuthRequest deletes an auth request from the database by id.
|
||||
func (d *Database) DeleteAuthRequest(id string) error {
|
||||
err := d.client.AuthRequest.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete auth request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateAuthRequest changes an auth request by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error {
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update auth request tx: %w", err)
|
||||
}
|
||||
|
||||
authRequest, err := tx.AuthRequest.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return rollback(tx, "update auth request database: %w", err)
|
||||
}
|
||||
|
||||
newAuthRequest, err := updater(toStorageAuthRequest(authRequest))
|
||||
if err != nil {
|
||||
return rollback(tx, "update auth request updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.AuthRequest.UpdateOneID(newAuthRequest.ID).
|
||||
SetClientID(newAuthRequest.ClientID).
|
||||
SetScopes(newAuthRequest.Scopes).
|
||||
SetResponseTypes(newAuthRequest.ResponseTypes).
|
||||
SetRedirectURI(newAuthRequest.RedirectURI).
|
||||
SetState(newAuthRequest.State).
|
||||
SetNonce(newAuthRequest.Nonce).
|
||||
SetForceApprovalPrompt(newAuthRequest.ForceApprovalPrompt).
|
||||
SetLoggedIn(newAuthRequest.LoggedIn).
|
||||
SetClaimsUserID(newAuthRequest.Claims.UserID).
|
||||
SetClaimsEmail(newAuthRequest.Claims.Email).
|
||||
SetClaimsEmailVerified(newAuthRequest.Claims.EmailVerified).
|
||||
SetClaimsUsername(newAuthRequest.Claims.Username).
|
||||
SetClaimsPreferredUsername(newAuthRequest.Claims.PreferredUsername).
|
||||
SetClaimsGroups(newAuthRequest.Claims.Groups).
|
||||
SetCodeChallenge(newAuthRequest.PKCE.CodeChallenge).
|
||||
SetCodeChallengeMethod(newAuthRequest.PKCE.CodeChallengeMethod).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(newAuthRequest.Expiry.UTC()).
|
||||
SetConnectorID(newAuthRequest.ConnectorID).
|
||||
SetConnectorData(newAuthRequest.ConnectorData).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update auth request uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update auth request commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
92
storage/ent/client/client.go
Normal file
92
storage/ent/client/client.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateClient saves provided oauth2 client settings into the database.
|
||||
func (d *Database) CreateClient(client storage.Client) error {
|
||||
_, err := d.client.OAuth2Client.Create().
|
||||
SetID(client.ID).
|
||||
SetName(client.Name).
|
||||
SetSecret(client.Secret).
|
||||
SetPublic(client.Public).
|
||||
SetLogoURL(client.LogoURL).
|
||||
SetRedirectUris(client.RedirectURIs).
|
||||
SetTrustedPeers(client.TrustedPeers).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create oauth2 client: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListClients extracts an array of oauth2 clients from the database.
|
||||
func (d *Database) ListClients() ([]storage.Client, error) {
|
||||
clients, err := d.client.OAuth2Client.Query().All(context.TODO())
|
||||
if err != nil {
|
||||
return nil, convertDBError("list clients: %w", err)
|
||||
}
|
||||
|
||||
storageClients := make([]storage.Client, 0, len(clients))
|
||||
for _, c := range clients {
|
||||
storageClients = append(storageClients, toStorageClient(c))
|
||||
}
|
||||
return storageClients, nil
|
||||
}
|
||||
|
||||
// GetClient extracts an oauth2 client from the database by id.
|
||||
func (d *Database) GetClient(id string) (storage.Client, error) {
|
||||
client, err := d.client.OAuth2Client.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.Client{}, convertDBError("get client: %w", err)
|
||||
}
|
||||
return toStorageClient(client), nil
|
||||
}
|
||||
|
||||
// DeleteClient deletes an oauth2 client from the database by id.
|
||||
func (d *Database) DeleteClient(id string) error {
|
||||
err := d.client.OAuth2Client.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete client: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateClient changes an oauth2 client by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update client tx: %w", err)
|
||||
}
|
||||
|
||||
client, err := tx.OAuth2Client.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return rollback(tx, "update client database: %w", err)
|
||||
}
|
||||
|
||||
newClient, err := updater(toStorageClient(client))
|
||||
if err != nil {
|
||||
return rollback(tx, "update client updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.OAuth2Client.UpdateOneID(newClient.ID).
|
||||
SetName(newClient.Name).
|
||||
SetSecret(newClient.Secret).
|
||||
SetPublic(newClient.Public).
|
||||
SetLogoURL(newClient.LogoURL).
|
||||
SetRedirectUris(newClient.RedirectURIs).
|
||||
SetTrustedPeers(newClient.TrustedPeers).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update client uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update auth request commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
88
storage/ent/client/connector.go
Normal file
88
storage/ent/client/connector.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateConnector saves a connector into the database.
|
||||
func (d *Database) CreateConnector(connector storage.Connector) error {
|
||||
_, err := d.client.Connector.Create().
|
||||
SetID(connector.ID).
|
||||
SetName(connector.Name).
|
||||
SetType(connector.Type).
|
||||
SetResourceVersion(connector.ResourceVersion).
|
||||
SetConfig(connector.Config).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create connector: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListConnectors extracts an array of connectors from the database.
|
||||
func (d *Database) ListConnectors() ([]storage.Connector, error) {
|
||||
connectors, err := d.client.Connector.Query().All(context.TODO())
|
||||
if err != nil {
|
||||
return nil, convertDBError("list connectors: %w", err)
|
||||
}
|
||||
|
||||
storageConnectors := make([]storage.Connector, 0, len(connectors))
|
||||
for _, c := range connectors {
|
||||
storageConnectors = append(storageConnectors, toStorageConnector(c))
|
||||
}
|
||||
return storageConnectors, nil
|
||||
}
|
||||
|
||||
// GetConnector extracts a connector from the database by id.
|
||||
func (d *Database) GetConnector(id string) (storage.Connector, error) {
|
||||
connector, err := d.client.Connector.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.Connector{}, convertDBError("get connector: %w", err)
|
||||
}
|
||||
return toStorageConnector(connector), nil
|
||||
}
|
||||
|
||||
// DeleteConnector deletes a connector from the database by id.
|
||||
func (d *Database) DeleteConnector(id string) error {
|
||||
err := d.client.Connector.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete connector: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateConnector changes a connector by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateConnector(id string, updater func(old storage.Connector) (storage.Connector, error)) error {
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update connector tx: %w", err)
|
||||
}
|
||||
|
||||
connector, err := tx.Connector.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return rollback(tx, "update connector database: %w", err)
|
||||
}
|
||||
|
||||
newConnector, err := updater(toStorageConnector(connector))
|
||||
if err != nil {
|
||||
return rollback(tx, "update connector updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Connector.UpdateOneID(newConnector.ID).
|
||||
SetName(newConnector.Name).
|
||||
SetType(newConnector.Type).
|
||||
SetResourceVersion(newConnector.ResourceVersion).
|
||||
SetConfig(newConnector.Config).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update connector uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update connector commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
36
storage/ent/client/devicerequest.go
Normal file
36
storage/ent/client/devicerequest.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db/devicerequest"
|
||||
)
|
||||
|
||||
// CreateDeviceRequest saves provided device request into the database.
|
||||
func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error {
|
||||
_, err := d.client.DeviceRequest.Create().
|
||||
SetClientID(request.ClientID).
|
||||
SetClientSecret(request.ClientSecret).
|
||||
SetScopes(request.Scopes).
|
||||
SetUserCode(request.UserCode).
|
||||
SetDeviceCode(request.DeviceCode).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(request.Expiry.UTC()).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create device request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDeviceRequest extracts a device request from the database by user code.
|
||||
func (d *Database) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||
deviceRequest, err := d.client.DeviceRequest.Query().
|
||||
Where(devicerequest.UserCode(userCode)).
|
||||
Only(context.TODO())
|
||||
if err != nil {
|
||||
return storage.DeviceRequest{}, convertDBError("get device request: %w", err)
|
||||
}
|
||||
return toStorageDeviceRequest(deviceRequest), nil
|
||||
}
|
||||
76
storage/ent/client/devicetoken.go
Normal file
76
storage/ent/client/devicetoken.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db/devicetoken"
|
||||
)
|
||||
|
||||
// CreateDeviceToken saves provided token into the database.
|
||||
func (d *Database) CreateDeviceToken(token storage.DeviceToken) error {
|
||||
_, err := d.client.DeviceToken.Create().
|
||||
SetDeviceCode(token.DeviceCode).
|
||||
SetToken([]byte(token.Token)).
|
||||
SetPollInterval(token.PollIntervalSeconds).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(token.Expiry.UTC()).
|
||||
SetLastRequest(token.LastRequestTime.UTC()).
|
||||
SetStatus(token.Status).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create device token: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDeviceToken extracts a token from the database by device code.
|
||||
func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||
deviceToken, err := d.client.DeviceToken.Query().
|
||||
Where(devicetoken.DeviceCode(deviceCode)).
|
||||
Only(context.TODO())
|
||||
if err != nil {
|
||||
return storage.DeviceToken{}, convertDBError("get device token: %w", err)
|
||||
}
|
||||
return toStorageDeviceToken(deviceToken), nil
|
||||
}
|
||||
|
||||
// UpdateDeviceToken changes a token by device code using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update device token tx: %w", err)
|
||||
}
|
||||
|
||||
token, err := tx.DeviceToken.Query().
|
||||
Where(devicetoken.DeviceCode(deviceCode)).
|
||||
Only(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update device token database: %w", err)
|
||||
}
|
||||
|
||||
newToken, err := updater(toStorageDeviceToken(token))
|
||||
if err != nil {
|
||||
return rollback(tx, "update device token updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.DeviceToken.Update().
|
||||
Where(devicetoken.DeviceCode(newToken.DeviceCode)).
|
||||
SetDeviceCode(newToken.DeviceCode).
|
||||
SetToken([]byte(newToken.Token)).
|
||||
SetPollInterval(newToken.PollIntervalSeconds).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(newToken.Expiry.UTC()).
|
||||
SetLastRequest(newToken.LastRequestTime.UTC()).
|
||||
SetStatus(newToken.Status).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update device token uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update device token commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
81
storage/ent/client/keys.go
Normal file
81
storage/ent/client/keys.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db"
|
||||
)
|
||||
|
||||
func getKeys(client *db.KeysClient) (storage.Keys, error) {
|
||||
rawKeys, err := client.Get(context.TODO(), keysRowID)
|
||||
if err != nil {
|
||||
return storage.Keys{}, convertDBError("get keys: %w", err)
|
||||
}
|
||||
|
||||
return toStorageKeys(rawKeys), nil
|
||||
}
|
||||
|
||||
// GetKeys returns signing keys, public keys and verification keys from the database.
|
||||
func (d *Database) GetKeys() (storage.Keys, error) {
|
||||
return getKeys(d.client.Keys)
|
||||
}
|
||||
|
||||
// UpdateKeys rotates keys using updater function.
|
||||
func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
||||
firstUpdate := false
|
||||
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update keys tx: %w", err)
|
||||
}
|
||||
|
||||
storageKeys, err := getKeys(tx.Keys)
|
||||
if err != nil {
|
||||
if !errors.Is(err, storage.ErrNotFound) {
|
||||
return rollback(tx, "update keys get: %w", err)
|
||||
}
|
||||
firstUpdate = true
|
||||
}
|
||||
|
||||
newKeys, err := updater(storageKeys)
|
||||
if err != nil {
|
||||
return rollback(tx, "update keys updating: %w", err)
|
||||
}
|
||||
|
||||
// ent doesn't have an upsert support yet
|
||||
// https://github.com/facebook/ent/issues/139
|
||||
if firstUpdate {
|
||||
_, err = tx.Keys.Create().
|
||||
SetID(keysRowID).
|
||||
SetNextRotation(newKeys.NextRotation).
|
||||
SetSigningKey(*newKeys.SigningKey).
|
||||
SetSigningKeyPub(*newKeys.SigningKeyPub).
|
||||
SetVerificationKeys(newKeys.VerificationKeys).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "create keys: %w", err)
|
||||
}
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update keys commit: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = tx.Keys.UpdateOneID(keysRowID).
|
||||
SetNextRotation(newKeys.NextRotation.UTC()).
|
||||
SetSigningKey(*newKeys.SigningKey).
|
||||
SetSigningKeyPub(*newKeys.SigningKeyPub).
|
||||
SetVerificationKeys(newKeys.VerificationKeys).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update keys uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update keys commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
95
storage/ent/client/main.go
Normal file
95
storage/ent/client/main.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hash"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db"
|
||||
"github.com/dexidp/dex/storage/ent/db/authcode"
|
||||
"github.com/dexidp/dex/storage/ent/db/authrequest"
|
||||
"github.com/dexidp/dex/storage/ent/db/devicerequest"
|
||||
"github.com/dexidp/dex/storage/ent/db/devicetoken"
|
||||
"github.com/dexidp/dex/storage/ent/db/migrate"
|
||||
)
|
||||
|
||||
var _ storage.Storage = (*Database)(nil)
|
||||
|
||||
type Database struct {
|
||||
client *db.Client
|
||||
hasher func() hash.Hash
|
||||
}
|
||||
|
||||
// NewDatabase returns new database client with set options.
|
||||
func NewDatabase(opts ...func(*Database)) *Database {
|
||||
database := &Database{}
|
||||
for _, f := range opts {
|
||||
f(database)
|
||||
}
|
||||
return database
|
||||
}
|
||||
|
||||
// WithClient sets client option of a Database object.
|
||||
func WithClient(c *db.Client) func(*Database) {
|
||||
return func(s *Database) {
|
||||
s.client = c
|
||||
}
|
||||
}
|
||||
|
||||
// WithHasher sets client option of a Database object.
|
||||
func WithHasher(h func() hash.Hash) func(*Database) {
|
||||
return func(s *Database) {
|
||||
s.hasher = h
|
||||
}
|
||||
}
|
||||
|
||||
// Schema exposes migration schema to perform migrations.
|
||||
func (d *Database) Schema() *migrate.Schema {
|
||||
return d.client.Schema
|
||||
}
|
||||
|
||||
// Close calls the corresponding method of the ent database client.
|
||||
func (d *Database) Close() error {
|
||||
return d.client.Close()
|
||||
}
|
||||
|
||||
// GarbageCollect removes expired entities from the database.
|
||||
func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) {
|
||||
result := storage.GCResult{}
|
||||
utcNow := now.UTC()
|
||||
|
||||
q, err := d.client.AuthRequest.Delete().
|
||||
Where(authrequest.ExpiryLT(utcNow)).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return result, convertDBError("gc auth request: %w", err)
|
||||
}
|
||||
result.AuthRequests = int64(q)
|
||||
|
||||
q, err = d.client.AuthCode.Delete().
|
||||
Where(authcode.ExpiryLT(utcNow)).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return result, convertDBError("gc auth code: %w", err)
|
||||
}
|
||||
result.AuthCodes = int64(q)
|
||||
|
||||
q, err = d.client.DeviceRequest.Delete().
|
||||
Where(devicerequest.ExpiryLT(utcNow)).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return result, convertDBError("gc device request: %w", err)
|
||||
}
|
||||
result.DeviceRequests = int64(q)
|
||||
|
||||
q, err = d.client.DeviceToken.Delete().
|
||||
Where(devicetoken.ExpiryLT(utcNow)).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return result, convertDBError("gc device token: %w", err)
|
||||
}
|
||||
result.DeviceTokens = int64(q)
|
||||
|
||||
return result, err
|
||||
}
|
||||
93
storage/ent/client/offlinesession.go
Normal file
93
storage/ent/client/offlinesession.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateOfflineSessions saves provided offline session into the database.
|
||||
func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error {
|
||||
encodedRefresh, err := json.Marshal(session.Refresh)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode refresh offline session: %w", err)
|
||||
}
|
||||
|
||||
id := offlineSessionID(session.UserID, session.ConnID, d.hasher)
|
||||
_, err = d.client.OfflineSession.Create().
|
||||
SetID(id).
|
||||
SetUserID(session.UserID).
|
||||
SetConnID(session.ConnID).
|
||||
SetConnectorData(session.ConnectorData).
|
||||
SetRefresh(encodedRefresh).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create offline session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOfflineSessions extracts an offline session from the database by user id and connector id.
|
||||
func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSessions, error) {
|
||||
id := offlineSessionID(userID, connID, d.hasher)
|
||||
|
||||
offlineSession, err := d.client.OfflineSession.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.OfflineSessions{}, convertDBError("get offline session: %w", err)
|
||||
}
|
||||
return toStorageOfflineSession(offlineSession), nil
|
||||
}
|
||||
|
||||
// DeleteOfflineSessions deletes an offline session from the database by user id and connector id.
|
||||
func (d *Database) DeleteOfflineSessions(userID, connID string) error {
|
||||
id := offlineSessionID(userID, connID, d.hasher)
|
||||
|
||||
err := d.client.OfflineSession.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete offline session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePassword changes an offline session by user id and connector id using an updater function.
|
||||
func (d *Database) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||
id := offlineSessionID(userID, connID, d.hasher)
|
||||
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update offline session tx: %w", err)
|
||||
}
|
||||
|
||||
offlineSession, err := tx.OfflineSession.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return rollback(tx, "update offline session database: %w", err)
|
||||
}
|
||||
|
||||
newOfflineSession, err := updater(toStorageOfflineSession(offlineSession))
|
||||
if err != nil {
|
||||
return rollback(tx, "update offline session updating: %w", err)
|
||||
}
|
||||
|
||||
encodedRefresh, err := json.Marshal(newOfflineSession.Refresh)
|
||||
if err != nil {
|
||||
return rollback(tx, "encode refresh offline session: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.OfflineSession.UpdateOneID(id).
|
||||
SetUserID(newOfflineSession.UserID).
|
||||
SetConnID(newOfflineSession.ConnID).
|
||||
SetConnectorData(newOfflineSession.ConnectorData).
|
||||
SetRefresh(encodedRefresh).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update offline session uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update password commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
100
storage/ent/client/password.go
Normal file
100
storage/ent/client/password.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db/password"
|
||||
)
|
||||
|
||||
// CreatePassword saves provided password into the database.
|
||||
func (d *Database) CreatePassword(password storage.Password) error {
|
||||
_, err := d.client.Password.Create().
|
||||
SetEmail(password.Email).
|
||||
SetHash(password.Hash).
|
||||
SetUsername(password.Username).
|
||||
SetUserID(password.UserID).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create password: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPasswords extracts an array of passwords from the database.
|
||||
func (d *Database) ListPasswords() ([]storage.Password, error) {
|
||||
passwords, err := d.client.Password.Query().All(context.TODO())
|
||||
if err != nil {
|
||||
return nil, convertDBError("list passwords: %w", err)
|
||||
}
|
||||
|
||||
storagePasswords := make([]storage.Password, 0, len(passwords))
|
||||
for _, p := range passwords {
|
||||
storagePasswords = append(storagePasswords, toStoragePassword(p))
|
||||
}
|
||||
return storagePasswords, nil
|
||||
}
|
||||
|
||||
// GetPassword extracts a password from the database by email.
|
||||
func (d *Database) GetPassword(email string) (storage.Password, error) {
|
||||
email = strings.ToLower(email)
|
||||
passwordFromStorage, err := d.client.Password.Query().
|
||||
Where(password.Email(email)).
|
||||
Only(context.TODO())
|
||||
if err != nil {
|
||||
return storage.Password{}, convertDBError("get password: %w", err)
|
||||
}
|
||||
return toStoragePassword(passwordFromStorage), nil
|
||||
}
|
||||
|
||||
// DeletePassword deletes a password from the database by email.
|
||||
func (d *Database) DeletePassword(email string) error {
|
||||
email = strings.ToLower(email)
|
||||
_, err := d.client.Password.Delete().
|
||||
Where(password.Email(email)).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete password: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePassword changes a password by email using an updater function and saves it to the database.
|
||||
func (d *Database) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error {
|
||||
email = strings.ToLower(email)
|
||||
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update connector tx: %w", err)
|
||||
}
|
||||
|
||||
passwordToUpdate, err := tx.Password.Query().
|
||||
Where(password.Email(email)).
|
||||
Only(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update password database: %w", err)
|
||||
}
|
||||
|
||||
newPassword, err := updater(toStoragePassword(passwordToUpdate))
|
||||
if err != nil {
|
||||
return rollback(tx, "update password updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Password.Update().
|
||||
Where(password.Email(newPassword.Email)).
|
||||
SetEmail(newPassword.Email).
|
||||
SetHash(newPassword.Hash).
|
||||
SetUsername(newPassword.Username).
|
||||
SetUserID(newPassword.UserID).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update password uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update password commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
109
storage/ent/client/refreshtoken.go
Normal file
109
storage/ent/client/refreshtoken.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateRefresh saves provided refresh token into the database.
|
||||
func (d *Database) CreateRefresh(refresh storage.RefreshToken) error {
|
||||
_, err := d.client.RefreshToken.Create().
|
||||
SetID(refresh.ID).
|
||||
SetClientID(refresh.ClientID).
|
||||
SetScopes(refresh.Scopes).
|
||||
SetNonce(refresh.Nonce).
|
||||
SetClaimsUserID(refresh.Claims.UserID).
|
||||
SetClaimsEmail(refresh.Claims.Email).
|
||||
SetClaimsEmailVerified(refresh.Claims.EmailVerified).
|
||||
SetClaimsUsername(refresh.Claims.Username).
|
||||
SetClaimsPreferredUsername(refresh.Claims.PreferredUsername).
|
||||
SetClaimsGroups(refresh.Claims.Groups).
|
||||
SetConnectorID(refresh.ConnectorID).
|
||||
SetConnectorData(refresh.ConnectorData).
|
||||
SetToken(refresh.Token).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetLastUsed(refresh.LastUsed.UTC()).
|
||||
SetCreatedAt(refresh.CreatedAt.UTC()).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create refresh token: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRefreshTokens extracts an array of refresh tokens from the database.
|
||||
func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||
refreshTokens, err := d.client.RefreshToken.Query().All(context.TODO())
|
||||
if err != nil {
|
||||
return nil, convertDBError("list refresh tokens: %w", err)
|
||||
}
|
||||
|
||||
storageRefreshTokens := make([]storage.RefreshToken, 0, len(refreshTokens))
|
||||
for _, r := range refreshTokens {
|
||||
storageRefreshTokens = append(storageRefreshTokens, toStorageRefreshToken(r))
|
||||
}
|
||||
return storageRefreshTokens, nil
|
||||
}
|
||||
|
||||
// GetRefresh extracts a refresh token from the database by id.
|
||||
func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||
refreshToken, err := d.client.RefreshToken.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.RefreshToken{}, convertDBError("get refresh token: %w", err)
|
||||
}
|
||||
return toStorageRefreshToken(refreshToken), nil
|
||||
}
|
||||
|
||||
// DeleteRefresh deletes a refresh token from the database by id.
|
||||
func (d *Database) DeleteRefresh(id string) error {
|
||||
err := d.client.RefreshToken.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete refresh token: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update refresh token tx: %w", err)
|
||||
}
|
||||
|
||||
token, err := tx.RefreshToken.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return rollback(tx, "update refresh token database: %w", err)
|
||||
}
|
||||
|
||||
newtToken, err := updater(toStorageRefreshToken(token))
|
||||
if err != nil {
|
||||
return rollback(tx, "update refresh token updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.RefreshToken.UpdateOneID(newtToken.ID).
|
||||
SetClientID(newtToken.ClientID).
|
||||
SetScopes(newtToken.Scopes).
|
||||
SetNonce(newtToken.Nonce).
|
||||
SetClaimsUserID(newtToken.Claims.UserID).
|
||||
SetClaimsEmail(newtToken.Claims.Email).
|
||||
SetClaimsEmailVerified(newtToken.Claims.EmailVerified).
|
||||
SetClaimsUsername(newtToken.Claims.Username).
|
||||
SetClaimsPreferredUsername(newtToken.Claims.PreferredUsername).
|
||||
SetClaimsGroups(newtToken.Claims.Groups).
|
||||
SetConnectorID(newtToken.ConnectorID).
|
||||
SetConnectorData(newtToken.ConnectorData).
|
||||
SetToken(newtToken.Token).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetLastUsed(newtToken.LastUsed.UTC()).
|
||||
SetCreatedAt(newtToken.CreatedAt.UTC()).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update refresh token uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update refresh token commit: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
167
storage/ent/client/types.go
Normal file
167
storage/ent/client/types.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db"
|
||||
)
|
||||
|
||||
const keysRowID = "keys"
|
||||
|
||||
func toStorageKeys(keys *db.Keys) storage.Keys {
|
||||
return storage.Keys{
|
||||
SigningKey: &keys.SigningKey,
|
||||
SigningKeyPub: &keys.SigningKeyPub,
|
||||
VerificationKeys: keys.VerificationKeys,
|
||||
NextRotation: keys.NextRotation,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest {
|
||||
return storage.AuthRequest{
|
||||
ID: a.ID,
|
||||
ClientID: a.ClientID,
|
||||
ResponseTypes: a.ResponseTypes,
|
||||
Scopes: a.Scopes,
|
||||
RedirectURI: a.RedirectURI,
|
||||
Nonce: a.Nonce,
|
||||
State: a.State,
|
||||
ForceApprovalPrompt: a.ForceApprovalPrompt,
|
||||
LoggedIn: a.LoggedIn,
|
||||
ConnectorID: a.ConnectorID,
|
||||
ConnectorData: *a.ConnectorData,
|
||||
Expiry: a.Expiry,
|
||||
Claims: storage.Claims{
|
||||
UserID: a.ClaimsUserID,
|
||||
Username: a.ClaimsUsername,
|
||||
PreferredUsername: a.ClaimsPreferredUsername,
|
||||
Email: a.ClaimsEmail,
|
||||
EmailVerified: a.ClaimsEmailVerified,
|
||||
Groups: a.ClaimsGroups,
|
||||
},
|
||||
PKCE: storage.PKCE{
|
||||
CodeChallenge: a.CodeChallenge,
|
||||
CodeChallengeMethod: a.CodeChallengeMethod,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageAuthCode(a *db.AuthCode) storage.AuthCode {
|
||||
return storage.AuthCode{
|
||||
ID: a.ID,
|
||||
ClientID: a.ClientID,
|
||||
Scopes: a.Scopes,
|
||||
RedirectURI: a.RedirectURI,
|
||||
Nonce: a.Nonce,
|
||||
ConnectorID: a.ConnectorID,
|
||||
ConnectorData: *a.ConnectorData,
|
||||
Expiry: a.Expiry,
|
||||
Claims: storage.Claims{
|
||||
UserID: a.ClaimsUserID,
|
||||
Username: a.ClaimsUsername,
|
||||
PreferredUsername: a.ClaimsPreferredUsername,
|
||||
Email: a.ClaimsEmail,
|
||||
EmailVerified: a.ClaimsEmailVerified,
|
||||
Groups: a.ClaimsGroups,
|
||||
},
|
||||
PKCE: storage.PKCE{
|
||||
CodeChallenge: a.CodeChallenge,
|
||||
CodeChallengeMethod: a.CodeChallengeMethod,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageClient(c *db.OAuth2Client) storage.Client {
|
||||
return storage.Client{
|
||||
ID: c.ID,
|
||||
Secret: c.Secret,
|
||||
RedirectURIs: c.RedirectUris,
|
||||
TrustedPeers: c.TrustedPeers,
|
||||
Public: c.Public,
|
||||
Name: c.Name,
|
||||
LogoURL: c.LogoURL,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageConnector(c *db.Connector) storage.Connector {
|
||||
return storage.Connector{
|
||||
ID: c.ID,
|
||||
Type: c.Type,
|
||||
Name: c.Name,
|
||||
Config: c.Config,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageOfflineSession(o *db.OfflineSession) storage.OfflineSessions {
|
||||
s := storage.OfflineSessions{
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
ConnectorData: *o.ConnectorData,
|
||||
}
|
||||
|
||||
if o.Refresh != nil {
|
||||
if err := json.Unmarshal(o.Refresh, &s.Refresh); err != nil {
|
||||
// Correctness of json structure if guaranteed on uploading
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
// Server code assumes this will be non-nil.
|
||||
s.Refresh = make(map[string]*storage.RefreshTokenRef)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func toStorageRefreshToken(r *db.RefreshToken) storage.RefreshToken {
|
||||
return storage.RefreshToken{
|
||||
ID: r.ID,
|
||||
Token: r.Token,
|
||||
CreatedAt: r.CreatedAt,
|
||||
LastUsed: r.LastUsed,
|
||||
ClientID: r.ClientID,
|
||||
ConnectorID: r.ConnectorID,
|
||||
ConnectorData: *r.ConnectorData,
|
||||
Scopes: r.Scopes,
|
||||
Nonce: r.Nonce,
|
||||
Claims: storage.Claims{
|
||||
UserID: r.ClaimsUserID,
|
||||
Username: r.ClaimsUsername,
|
||||
PreferredUsername: r.ClaimsPreferredUsername,
|
||||
Email: r.ClaimsEmail,
|
||||
EmailVerified: r.ClaimsEmailVerified,
|
||||
Groups: r.ClaimsGroups,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func toStoragePassword(p *db.Password) storage.Password {
|
||||
return storage.Password{
|
||||
Email: p.Email,
|
||||
Hash: p.Hash,
|
||||
Username: p.Username,
|
||||
UserID: p.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageDeviceRequest(r *db.DeviceRequest) storage.DeviceRequest {
|
||||
return storage.DeviceRequest{
|
||||
UserCode: strings.ToUpper(r.UserCode),
|
||||
DeviceCode: r.DeviceCode,
|
||||
ClientID: r.ClientID,
|
||||
ClientSecret: r.ClientSecret,
|
||||
Scopes: r.Scopes,
|
||||
Expiry: r.Expiry,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken {
|
||||
return storage.DeviceToken{
|
||||
DeviceCode: t.DeviceCode,
|
||||
Status: t.Status,
|
||||
Token: string(*t.Token),
|
||||
Expiry: t.Expiry,
|
||||
LastRequestTime: t.LastRequest,
|
||||
PollIntervalSeconds: t.PollInterval,
|
||||
}
|
||||
}
|
||||
44
storage/ent/client/utils.go
Normal file
44
storage/ent/client/utils.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db"
|
||||
)
|
||||
|
||||
func rollback(tx *db.Tx, t string, err error) error {
|
||||
rerr := tx.Rollback()
|
||||
err = convertDBError(t, err)
|
||||
|
||||
if rerr == nil {
|
||||
return err
|
||||
}
|
||||
return errors.Wrapf(err, "rolling back transaction: %v", rerr)
|
||||
}
|
||||
|
||||
func convertDBError(t string, err error) error {
|
||||
if db.IsNotFound(err) {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
|
||||
if db.IsConstraintError(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
|
||||
return fmt.Errorf(t, err)
|
||||
}
|
||||
|
||||
// compose hashed id from user and connection id to use it as primary key
|
||||
// ent doesn't support multi-key primary yet
|
||||
// https://github.com/facebook/ent/issues/400
|
||||
func offlineSessionID(userID string, connID string, hasher func() hash.Hash) string {
|
||||
h := hasher()
|
||||
|
||||
h.Write([]byte(userID))
|
||||
h.Write([]byte(connID))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
||||
Reference in New Issue
Block a user