feat: Add ent-based sqlite3 storage

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
m.nabokikh 2020-12-31 02:07:32 +04:00
parent 674631c9ab
commit 11859166d0
31 changed files with 1878 additions and 4 deletions

View File

@ -26,7 +26,10 @@ PROTOC_VERSION = 3.15.6
PROTOC_GEN_GO_VERSION = 1.26.0
PROTOC_GEN_GO_GRPC_VERSION = 1.1.0
build: bin/dex
generate:
@go generate $(REPO_PATH)/storage/ent/
build: generate bin/dex
bin/dex:
@mkdir -p bin/
@ -42,7 +45,7 @@ bin/example-app:
@mkdir -p bin/
@cd examples/ && go install -v -ldflags $(LD_FLAGS) $(REPO_PATH)/examples/example-app
.PHONY: release-binary
.PHONY: generate release-binary
release-binary:
@go build -o /go/bin/dex -v -ldflags $(LD_FLAGS) $(REPO_PATH)/cmd/dex

View File

@ -13,6 +13,7 @@ import (
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent"
"github.com/dexidp/dex/storage/etcd"
"github.com/dexidp/dex/storage/kubernetes"
"github.com/dexidp/dex/storage/memory"
@ -173,13 +174,32 @@ type StorageConfig interface {
Open(logger log.Logger) (storage.Storage, error)
}
var (
_ StorageConfig = (*etcd.Etcd)(nil)
_ StorageConfig = (*kubernetes.Config)(nil)
_ StorageConfig = (*memory.Config)(nil)
_ StorageConfig = (*sql.SQLite3)(nil)
_ StorageConfig = (*sql.Postgres)(nil)
_ StorageConfig = (*sql.MySQL)(nil)
_ StorageConfig = (*ent.SQLite3)(nil)
)
func getORMBasedSQLiteStorage() StorageConfig {
switch os.Getenv("DEX_ENT_ENABLED") {
case "true", "yes":
return new(ent.SQLite3)
default:
return new(sql.SQLite3)
}
}
var storages = map[string]func() StorageConfig{
"etcd": func() StorageConfig { return new(etcd.Etcd) },
"kubernetes": func() StorageConfig { return new(kubernetes.Config) },
"memory": func() StorageConfig { return new(memory.Config) },
"sqlite3": func() StorageConfig { return new(sql.SQLite3) },
"postgres": func() StorageConfig { return new(sql.Postgres) },
"mysql": func() StorageConfig { return new(sql.MySQL) },
"sqlite3": getORMBasedSQLiteStorage,
}
// isExpandEnvEnabled returns if os.ExpandEnv should be used for each storage and connector config.

3
go.mod
View File

@ -7,7 +7,8 @@ require (
github.com/beevik/etree v1.1.0
github.com/coreos/go-oidc/v3 v3.0.0
github.com/dexidp/dex/api/v2 v2.0.0
github.com/felixge/httpsnoop v1.0.2
github.com/facebook/ent v0.5.3
github.com/felixge/httpsnoop v1.0.1
github.com/ghodss/yaml v1.0.0
github.com/go-ldap/ldap/v3 v3.3.0
github.com/go-sql-driver/mysql v1.6.0

2
go.sum
View File

@ -128,6 +128,8 @@ github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5y
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/facebook/ent v0.5.3 h1:YT3Sl28n7gGGOkQeYgeJsZmizJ1Iiy7psgkOtEk0aq4=
github.com/facebook/ent v0.5.3/go.mod h1:tlWP+qCd3x2EeO7B/EqlJQ4dWu/2IeYFhP/szzDKAi8=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/felixge/httpsnoop v1.0.2 h1:+nS9g82KMXccJ/wp0zyRW9ZBHFETmMGtkk+2CTTrW4o=

View File

@ -0,0 +1,52 @@
package client
import (
"context"
"github.com/dexidp/dex/storage"
)
// CreateAuthCode saves provided auth code into the database.
func (d *Database) CreateAuthCode(code storage.AuthCode) error {
_, err := d.client.AuthCode.Create().
SetID(code.ID).
SetClientID(code.ClientID).
SetScopes(code.Scopes).
SetRedirectURI(code.RedirectURI).
SetNonce(code.Nonce).
SetClaimsUserID(code.Claims.UserID).
SetClaimsEmail(code.Claims.Email).
SetClaimsEmailVerified(code.Claims.EmailVerified).
SetClaimsUsername(code.Claims.Username).
SetClaimsPreferredUsername(code.Claims.PreferredUsername).
SetClaimsGroups(code.Claims.Groups).
SetCodeChallenge(code.PKCE.CodeChallenge).
SetCodeChallengeMethod(code.PKCE.CodeChallengeMethod).
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetExpiry(code.Expiry.UTC()).
SetConnectorID(code.ConnectorID).
SetConnectorData(code.ConnectorData).
Save(context.TODO())
if err != nil {
return convertDBError("create auth code: %w", err)
}
return nil
}
// GetAuthCode extracts an auth code from the database by id.
func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) {
authCode, err := d.client.AuthCode.Get(context.TODO(), id)
if err != nil {
return storage.AuthCode{}, convertDBError("get auth code: %w", err)
}
return toStorageAuthCode(authCode), nil
}
// DeleteAuthCode deletes an auth code from the database by id.
func (d *Database) DeleteAuthCode(id string) error {
err := d.client.AuthCode.DeleteOneID(id).Exec(context.TODO())
if err != nil {
return convertDBError("delete auth code: %w", err)
}
return nil
}

View File

@ -0,0 +1,107 @@
package client
import (
"context"
"fmt"
"github.com/dexidp/dex/storage"
)
// CreateAuthRequest saves provided auth request into the database.
func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error {
_, err := d.client.AuthRequest.Create().
SetID(authRequest.ID).
SetClientID(authRequest.ClientID).
SetScopes(authRequest.Scopes).
SetResponseTypes(authRequest.ResponseTypes).
SetRedirectURI(authRequest.RedirectURI).
SetState(authRequest.State).
SetNonce(authRequest.Nonce).
SetForceApprovalPrompt(authRequest.ForceApprovalPrompt).
SetLoggedIn(authRequest.LoggedIn).
SetClaimsUserID(authRequest.Claims.UserID).
SetClaimsEmail(authRequest.Claims.Email).
SetClaimsEmailVerified(authRequest.Claims.EmailVerified).
SetClaimsUsername(authRequest.Claims.Username).
SetClaimsPreferredUsername(authRequest.Claims.PreferredUsername).
SetClaimsGroups(authRequest.Claims.Groups).
SetCodeChallenge(authRequest.PKCE.CodeChallenge).
SetCodeChallengeMethod(authRequest.PKCE.CodeChallengeMethod).
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetExpiry(authRequest.Expiry.UTC()).
SetConnectorID(authRequest.ConnectorID).
SetConnectorData(authRequest.ConnectorData).
Save(context.TODO())
if err != nil {
return convertDBError("create auth request: %w", err)
}
return nil
}
// GetAuthRequest extracts an auth request from the database by id.
func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) {
authRequest, err := d.client.AuthRequest.Get(context.TODO(), id)
if err != nil {
return storage.AuthRequest{}, convertDBError("get auth request: %w", err)
}
return toStorageAuthRequest(authRequest), nil
}
// DeleteAuthRequest deletes an auth request from the database by id.
func (d *Database) DeleteAuthRequest(id string) error {
err := d.client.AuthRequest.DeleteOneID(id).Exec(context.TODO())
if err != nil {
return convertDBError("delete auth request: %w", err)
}
return nil
}
// UpdateAuthRequest changes an auth request by id using an updater function and saves it to the database.
func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error {
tx, err := d.client.Tx(context.TODO())
if err != nil {
return fmt.Errorf("update auth request tx: %w", err)
}
authRequest, err := tx.AuthRequest.Get(context.TODO(), id)
if err != nil {
return rollback(tx, "update auth request database: %w", err)
}
newAuthRequest, err := updater(toStorageAuthRequest(authRequest))
if err != nil {
return rollback(tx, "update auth request updating: %w", err)
}
_, err = tx.AuthRequest.UpdateOneID(newAuthRequest.ID).
SetClientID(newAuthRequest.ClientID).
SetScopes(newAuthRequest.Scopes).
SetResponseTypes(newAuthRequest.ResponseTypes).
SetRedirectURI(newAuthRequest.RedirectURI).
SetState(newAuthRequest.State).
SetNonce(newAuthRequest.Nonce).
SetForceApprovalPrompt(newAuthRequest.ForceApprovalPrompt).
SetLoggedIn(newAuthRequest.LoggedIn).
SetClaimsUserID(newAuthRequest.Claims.UserID).
SetClaimsEmail(newAuthRequest.Claims.Email).
SetClaimsEmailVerified(newAuthRequest.Claims.EmailVerified).
SetClaimsUsername(newAuthRequest.Claims.Username).
SetClaimsPreferredUsername(newAuthRequest.Claims.PreferredUsername).
SetClaimsGroups(newAuthRequest.Claims.Groups).
SetCodeChallenge(newAuthRequest.PKCE.CodeChallenge).
SetCodeChallengeMethod(newAuthRequest.PKCE.CodeChallengeMethod).
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetExpiry(newAuthRequest.Expiry.UTC()).
SetConnectorID(newAuthRequest.ConnectorID).
SetConnectorData(newAuthRequest.ConnectorData).
Save(context.TODO())
if err != nil {
return rollback(tx, "update auth request uploading: %w", err)
}
if err = tx.Commit(); err != nil {
return rollback(tx, "update auth request commit: %w", err)
}
return nil
}

View File

@ -0,0 +1,92 @@
package client
import (
"context"
"github.com/dexidp/dex/storage"
)
// CreateClient saves provided oauth2 client settings into the database.
func (d *Database) CreateClient(client storage.Client) error {
_, err := d.client.OAuth2Client.Create().
SetID(client.ID).
SetName(client.Name).
SetSecret(client.Secret).
SetPublic(client.Public).
SetLogoURL(client.LogoURL).
SetRedirectUris(client.RedirectURIs).
SetTrustedPeers(client.TrustedPeers).
Save(context.TODO())
if err != nil {
return convertDBError("create oauth2 client: %w", err)
}
return nil
}
// ListClients extracts an array of oauth2 clients from the database.
func (d *Database) ListClients() ([]storage.Client, error) {
clients, err := d.client.OAuth2Client.Query().All(context.TODO())
if err != nil {
return nil, convertDBError("list clients: %w", err)
}
storageClients := make([]storage.Client, 0, len(clients))
for _, c := range clients {
storageClients = append(storageClients, toStorageClient(c))
}
return storageClients, nil
}
// GetClient extracts an oauth2 client from the database by id.
func (d *Database) GetClient(id string) (storage.Client, error) {
client, err := d.client.OAuth2Client.Get(context.TODO(), id)
if err != nil {
return storage.Client{}, convertDBError("get client: %w", err)
}
return toStorageClient(client), nil
}
// DeleteClient deletes an oauth2 client from the database by id.
func (d *Database) DeleteClient(id string) error {
err := d.client.OAuth2Client.DeleteOneID(id).Exec(context.TODO())
if err != nil {
return convertDBError("delete client: %w", err)
}
return nil
}
// UpdateClient changes an oauth2 client by id using an updater function and saves it to the database.
func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
tx, err := d.client.Tx(context.TODO())
if err != nil {
return convertDBError("update client tx: %w", err)
}
client, err := tx.OAuth2Client.Get(context.TODO(), id)
if err != nil {
return rollback(tx, "update client database: %w", err)
}
newClient, err := updater(toStorageClient(client))
if err != nil {
return rollback(tx, "update client updating: %w", err)
}
_, err = tx.OAuth2Client.UpdateOneID(newClient.ID).
SetName(newClient.Name).
SetSecret(newClient.Secret).
SetPublic(newClient.Public).
SetLogoURL(newClient.LogoURL).
SetRedirectUris(newClient.RedirectURIs).
SetTrustedPeers(newClient.TrustedPeers).
Save(context.TODO())
if err != nil {
return rollback(tx, "update client uploading: %w", err)
}
if err = tx.Commit(); err != nil {
return rollback(tx, "update auth request commit: %w", err)
}
return nil
}

View File

@ -0,0 +1,88 @@
package client
import (
"context"
"github.com/dexidp/dex/storage"
)
// CreateConnector saves a connector into the database.
func (d *Database) CreateConnector(connector storage.Connector) error {
_, err := d.client.Connector.Create().
SetID(connector.ID).
SetName(connector.Name).
SetType(connector.Type).
SetResourceVersion(connector.ResourceVersion).
SetConfig(connector.Config).
Save(context.TODO())
if err != nil {
return convertDBError("create connector: %w", err)
}
return nil
}
// ListConnectors extracts an array of connectors from the database.
func (d *Database) ListConnectors() ([]storage.Connector, error) {
connectors, err := d.client.Connector.Query().All(context.TODO())
if err != nil {
return nil, convertDBError("list connectors: %w", err)
}
storageConnectors := make([]storage.Connector, 0, len(connectors))
for _, c := range connectors {
storageConnectors = append(storageConnectors, toStorageConnector(c))
}
return storageConnectors, nil
}
// GetConnector extracts a connector from the database by id.
func (d *Database) GetConnector(id string) (storage.Connector, error) {
connector, err := d.client.Connector.Get(context.TODO(), id)
if err != nil {
return storage.Connector{}, convertDBError("get connector: %w", err)
}
return toStorageConnector(connector), nil
}
// DeleteConnector deletes a connector from the database by id.
func (d *Database) DeleteConnector(id string) error {
err := d.client.Connector.DeleteOneID(id).Exec(context.TODO())
if err != nil {
return convertDBError("delete connector: %w", err)
}
return nil
}
// UpdateConnector changes a connector by id using an updater function and saves it to the database.
func (d *Database) UpdateConnector(id string, updater func(old storage.Connector) (storage.Connector, error)) error {
tx, err := d.client.Tx(context.TODO())
if err != nil {
return convertDBError("update connector tx: %w", err)
}
connector, err := tx.Connector.Get(context.TODO(), id)
if err != nil {
return rollback(tx, "update connector database: %w", err)
}
newConnector, err := updater(toStorageConnector(connector))
if err != nil {
return rollback(tx, "update connector updating: %w", err)
}
_, err = tx.Connector.UpdateOneID(newConnector.ID).
SetName(newConnector.Name).
SetType(newConnector.Type).
SetResourceVersion(newConnector.ResourceVersion).
SetConfig(newConnector.Config).
Save(context.TODO())
if err != nil {
return rollback(tx, "update connector uploading: %w", err)
}
if err = tx.Commit(); err != nil {
return rollback(tx, "update connector commit: %w", err)
}
return nil
}

View File

@ -0,0 +1,36 @@
package client
import (
"context"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/db/devicerequest"
)
// CreateDeviceRequest saves provided device request into the database.
func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error {
_, err := d.client.DeviceRequest.Create().
SetClientID(request.ClientID).
SetClientSecret(request.ClientSecret).
SetScopes(request.Scopes).
SetUserCode(request.UserCode).
SetDeviceCode(request.DeviceCode).
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetExpiry(request.Expiry.UTC()).
Save(context.TODO())
if err != nil {
return convertDBError("create device request: %w", err)
}
return nil
}
// GetDeviceRequest extracts a device request from the database by user code.
func (d *Database) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
deviceRequest, err := d.client.DeviceRequest.Query().
Where(devicerequest.UserCode(userCode)).
Only(context.TODO())
if err != nil {
return storage.DeviceRequest{}, convertDBError("get device request: %w", err)
}
return toStorageDeviceRequest(deviceRequest), nil
}

View File

@ -0,0 +1,76 @@
package client
import (
"context"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/db/devicetoken"
)
// CreateDeviceToken saves provided token into the database.
func (d *Database) CreateDeviceToken(token storage.DeviceToken) error {
_, err := d.client.DeviceToken.Create().
SetDeviceCode(token.DeviceCode).
SetToken([]byte(token.Token)).
SetPollInterval(token.PollIntervalSeconds).
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetExpiry(token.Expiry.UTC()).
SetLastRequest(token.LastRequestTime.UTC()).
SetStatus(token.Status).
Save(context.TODO())
if err != nil {
return convertDBError("create device token: %w", err)
}
return nil
}
// GetDeviceToken extracts a token from the database by device code.
func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
deviceToken, err := d.client.DeviceToken.Query().
Where(devicetoken.DeviceCode(deviceCode)).
Only(context.TODO())
if err != nil {
return storage.DeviceToken{}, convertDBError("get device token: %w", err)
}
return toStorageDeviceToken(deviceToken), nil
}
// UpdateDeviceToken changes a token by device code using an updater function and saves it to the database.
func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
tx, err := d.client.Tx(context.TODO())
if err != nil {
return convertDBError("update device token tx: %w", err)
}
token, err := tx.DeviceToken.Query().
Where(devicetoken.DeviceCode(deviceCode)).
Only(context.TODO())
if err != nil {
return rollback(tx, "update device token database: %w", err)
}
newToken, err := updater(toStorageDeviceToken(token))
if err != nil {
return rollback(tx, "update device token updating: %w", err)
}
_, err = tx.DeviceToken.Update().
Where(devicetoken.DeviceCode(newToken.DeviceCode)).
SetDeviceCode(newToken.DeviceCode).
SetToken([]byte(newToken.Token)).
SetPollInterval(newToken.PollIntervalSeconds).
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetExpiry(newToken.Expiry.UTC()).
SetLastRequest(newToken.LastRequestTime.UTC()).
SetStatus(newToken.Status).
Save(context.TODO())
if err != nil {
return rollback(tx, "update device token uploading: %w", err)
}
if err = tx.Commit(); err != nil {
return rollback(tx, "update device token commit: %w", err)
}
return nil
}

View File

@ -0,0 +1,81 @@
package client
import (
"context"
"errors"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/db"
)
func getKeys(client *db.KeysClient) (storage.Keys, error) {
rawKeys, err := client.Get(context.TODO(), keysRowID)
if err != nil {
return storage.Keys{}, convertDBError("get keys: %w", err)
}
return toStorageKeys(rawKeys), nil
}
// GetKeys returns signing keys, public keys and verification keys from the database.
func (d *Database) GetKeys() (storage.Keys, error) {
return getKeys(d.client.Keys)
}
// UpdateKeys rotates keys using updater function.
func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
firstUpdate := false
tx, err := d.client.Tx(context.TODO())
if err != nil {
return convertDBError("update keys tx: %w", err)
}
storageKeys, err := getKeys(tx.Keys)
if err != nil {
if !errors.Is(err, storage.ErrNotFound) {
return rollback(tx, "update keys get: %w", err)
}
firstUpdate = true
}
newKeys, err := updater(storageKeys)
if err != nil {
return rollback(tx, "update keys updating: %w", err)
}
// ent doesn't have an upsert support yet
// https://github.com/facebook/ent/issues/139
if firstUpdate {
_, err = tx.Keys.Create().
SetID(keysRowID).
SetNextRotation(newKeys.NextRotation).
SetSigningKey(*newKeys.SigningKey).
SetSigningKeyPub(*newKeys.SigningKeyPub).
SetVerificationKeys(newKeys.VerificationKeys).
Save(context.TODO())
if err != nil {
return rollback(tx, "create keys: %w", err)
}
if err = tx.Commit(); err != nil {
return rollback(tx, "update keys commit: %w", err)
}
return nil
}
err = tx.Keys.UpdateOneID(keysRowID).
SetNextRotation(newKeys.NextRotation.UTC()).
SetSigningKey(*newKeys.SigningKey).
SetSigningKeyPub(*newKeys.SigningKeyPub).
SetVerificationKeys(newKeys.VerificationKeys).
Exec(context.TODO())
if err != nil {
return rollback(tx, "update keys uploading: %w", err)
}
if err = tx.Commit(); err != nil {
return rollback(tx, "update keys commit: %w", err)
}
return nil
}

View File

@ -0,0 +1,95 @@
package client
import (
"context"
"hash"
"time"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/db"
"github.com/dexidp/dex/storage/ent/db/authcode"
"github.com/dexidp/dex/storage/ent/db/authrequest"
"github.com/dexidp/dex/storage/ent/db/devicerequest"
"github.com/dexidp/dex/storage/ent/db/devicetoken"
"github.com/dexidp/dex/storage/ent/db/migrate"
)
var _ storage.Storage = (*Database)(nil)
type Database struct {
client *db.Client
hasher func() hash.Hash
}
// NewDatabase returns new database client with set options.
func NewDatabase(opts ...func(*Database)) *Database {
database := &Database{}
for _, f := range opts {
f(database)
}
return database
}
// WithClient sets client option of a Database object.
func WithClient(c *db.Client) func(*Database) {
return func(s *Database) {
s.client = c
}
}
// WithHasher sets client option of a Database object.
func WithHasher(h func() hash.Hash) func(*Database) {
return func(s *Database) {
s.hasher = h
}
}
// Schema exposes migration schema to perform migrations.
func (d *Database) Schema() *migrate.Schema {
return d.client.Schema
}
// Close calls the corresponding method of the ent database client.
func (d *Database) Close() error {
return d.client.Close()
}
// GarbageCollect removes expired entities from the database.
func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) {
result := storage.GCResult{}
utcNow := now.UTC()
q, err := d.client.AuthRequest.Delete().
Where(authrequest.ExpiryLT(utcNow)).
Exec(context.TODO())
if err != nil {
return result, convertDBError("gc auth request: %w", err)
}
result.AuthRequests = int64(q)
q, err = d.client.AuthCode.Delete().
Where(authcode.ExpiryLT(utcNow)).
Exec(context.TODO())
if err != nil {
return result, convertDBError("gc auth code: %w", err)
}
result.AuthCodes = int64(q)
q, err = d.client.DeviceRequest.Delete().
Where(devicerequest.ExpiryLT(utcNow)).
Exec(context.TODO())
if err != nil {
return result, convertDBError("gc device request: %w", err)
}
result.DeviceRequests = int64(q)
q, err = d.client.DeviceToken.Delete().
Where(devicetoken.ExpiryLT(utcNow)).
Exec(context.TODO())
if err != nil {
return result, convertDBError("gc device token: %w", err)
}
result.DeviceTokens = int64(q)
return result, err
}

View File

@ -0,0 +1,93 @@
package client
import (
"context"
"encoding/json"
"fmt"
"github.com/dexidp/dex/storage"
)
// CreateOfflineSessions saves provided offline session into the database.
func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error {
encodedRefresh, err := json.Marshal(session.Refresh)
if err != nil {
return fmt.Errorf("encode refresh offline session: %w", err)
}
id := offlineSessionID(session.UserID, session.ConnID, d.hasher)
_, err = d.client.OfflineSession.Create().
SetID(id).
SetUserID(session.UserID).
SetConnID(session.ConnID).
SetConnectorData(session.ConnectorData).
SetRefresh(encodedRefresh).
Save(context.TODO())
if err != nil {
return convertDBError("create offline session: %w", err)
}
return nil
}
// GetOfflineSessions extracts an offline session from the database by user id and connector id.
func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSessions, error) {
id := offlineSessionID(userID, connID, d.hasher)
offlineSession, err := d.client.OfflineSession.Get(context.TODO(), id)
if err != nil {
return storage.OfflineSessions{}, convertDBError("get offline session: %w", err)
}
return toStorageOfflineSession(offlineSession), nil
}
// DeleteOfflineSessions deletes an offline session from the database by user id and connector id.
func (d *Database) DeleteOfflineSessions(userID, connID string) error {
id := offlineSessionID(userID, connID, d.hasher)
err := d.client.OfflineSession.DeleteOneID(id).Exec(context.TODO())
if err != nil {
return convertDBError("delete offline session: %w", err)
}
return nil
}
// UpdatePassword changes an offline session by user id and connector id using an updater function.
func (d *Database) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
id := offlineSessionID(userID, connID, d.hasher)
tx, err := d.client.Tx(context.TODO())
if err != nil {
return convertDBError("update offline session tx: %w", err)
}
offlineSession, err := tx.OfflineSession.Get(context.TODO(), id)
if err != nil {
return rollback(tx, "update offline session database: %w", err)
}
newOfflineSession, err := updater(toStorageOfflineSession(offlineSession))
if err != nil {
return rollback(tx, "update offline session updating: %w", err)
}
encodedRefresh, err := json.Marshal(newOfflineSession.Refresh)
if err != nil {
return rollback(tx, "encode refresh offline session: %w", err)
}
_, err = tx.OfflineSession.UpdateOneID(id).
SetUserID(newOfflineSession.UserID).
SetConnID(newOfflineSession.ConnID).
SetConnectorData(newOfflineSession.ConnectorData).
SetRefresh(encodedRefresh).
Save(context.TODO())
if err != nil {
return rollback(tx, "update offline session uploading: %w", err)
}
if err = tx.Commit(); err != nil {
return rollback(tx, "update password commit: %w", err)
}
return nil
}

View File

@ -0,0 +1,100 @@
package client
import (
"context"
"strings"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/db/password"
)
// CreatePassword saves provided password into the database.
func (d *Database) CreatePassword(password storage.Password) error {
_, err := d.client.Password.Create().
SetEmail(password.Email).
SetHash(password.Hash).
SetUsername(password.Username).
SetUserID(password.UserID).
Save(context.TODO())
if err != nil {
return convertDBError("create password: %w", err)
}
return nil
}
// ListPasswords extracts an array of passwords from the database.
func (d *Database) ListPasswords() ([]storage.Password, error) {
passwords, err := d.client.Password.Query().All(context.TODO())
if err != nil {
return nil, convertDBError("list passwords: %w", err)
}
storagePasswords := make([]storage.Password, 0, len(passwords))
for _, p := range passwords {
storagePasswords = append(storagePasswords, toStoragePassword(p))
}
return storagePasswords, nil
}
// GetPassword extracts a password from the database by email.
func (d *Database) GetPassword(email string) (storage.Password, error) {
email = strings.ToLower(email)
passwordFromStorage, err := d.client.Password.Query().
Where(password.Email(email)).
Only(context.TODO())
if err != nil {
return storage.Password{}, convertDBError("get password: %w", err)
}
return toStoragePassword(passwordFromStorage), nil
}
// DeletePassword deletes a password from the database by email.
func (d *Database) DeletePassword(email string) error {
email = strings.ToLower(email)
_, err := d.client.Password.Delete().
Where(password.Email(email)).
Exec(context.TODO())
if err != nil {
return convertDBError("delete password: %w", err)
}
return nil
}
// UpdatePassword changes a password by email using an updater function and saves it to the database.
func (d *Database) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error {
email = strings.ToLower(email)
tx, err := d.client.Tx(context.TODO())
if err != nil {
return convertDBError("update connector tx: %w", err)
}
passwordToUpdate, err := tx.Password.Query().
Where(password.Email(email)).
Only(context.TODO())
if err != nil {
return rollback(tx, "update password database: %w", err)
}
newPassword, err := updater(toStoragePassword(passwordToUpdate))
if err != nil {
return rollback(tx, "update password updating: %w", err)
}
_, err = tx.Password.Update().
Where(password.Email(newPassword.Email)).
SetEmail(newPassword.Email).
SetHash(newPassword.Hash).
SetUsername(newPassword.Username).
SetUserID(newPassword.UserID).
Save(context.TODO())
if err != nil {
return rollback(tx, "update password uploading: %w", err)
}
if err = tx.Commit(); err != nil {
return rollback(tx, "update password commit: %w", err)
}
return nil
}

View File

@ -0,0 +1,109 @@
package client
import (
"context"
"github.com/dexidp/dex/storage"
)
// CreateRefresh saves provided refresh token into the database.
func (d *Database) CreateRefresh(refresh storage.RefreshToken) error {
_, err := d.client.RefreshToken.Create().
SetID(refresh.ID).
SetClientID(refresh.ClientID).
SetScopes(refresh.Scopes).
SetNonce(refresh.Nonce).
SetClaimsUserID(refresh.Claims.UserID).
SetClaimsEmail(refresh.Claims.Email).
SetClaimsEmailVerified(refresh.Claims.EmailVerified).
SetClaimsUsername(refresh.Claims.Username).
SetClaimsPreferredUsername(refresh.Claims.PreferredUsername).
SetClaimsGroups(refresh.Claims.Groups).
SetConnectorID(refresh.ConnectorID).
SetConnectorData(refresh.ConnectorData).
SetToken(refresh.Token).
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetLastUsed(refresh.LastUsed.UTC()).
SetCreatedAt(refresh.CreatedAt.UTC()).
Save(context.TODO())
if err != nil {
return convertDBError("create refresh token: %w", err)
}
return nil
}
// ListRefreshTokens extracts an array of refresh tokens from the database.
func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) {
refreshTokens, err := d.client.RefreshToken.Query().All(context.TODO())
if err != nil {
return nil, convertDBError("list refresh tokens: %w", err)
}
storageRefreshTokens := make([]storage.RefreshToken, 0, len(refreshTokens))
for _, r := range refreshTokens {
storageRefreshTokens = append(storageRefreshTokens, toStorageRefreshToken(r))
}
return storageRefreshTokens, nil
}
// GetRefresh extracts a refresh token from the database by id.
func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) {
refreshToken, err := d.client.RefreshToken.Get(context.TODO(), id)
if err != nil {
return storage.RefreshToken{}, convertDBError("get refresh token: %w", err)
}
return toStorageRefreshToken(refreshToken), nil
}
// DeleteRefresh deletes a refresh token from the database by id.
func (d *Database) DeleteRefresh(id string) error {
err := d.client.RefreshToken.DeleteOneID(id).Exec(context.TODO())
if err != nil {
return convertDBError("delete refresh token: %w", err)
}
return nil
}
// UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database.
func (d *Database) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
tx, err := d.client.Tx(context.TODO())
if err != nil {
return convertDBError("update refresh token tx: %w", err)
}
token, err := tx.RefreshToken.Get(context.TODO(), id)
if err != nil {
return rollback(tx, "update refresh token database: %w", err)
}
newtToken, err := updater(toStorageRefreshToken(token))
if err != nil {
return rollback(tx, "update refresh token updating: %w", err)
}
_, err = tx.RefreshToken.UpdateOneID(newtToken.ID).
SetClientID(newtToken.ClientID).
SetScopes(newtToken.Scopes).
SetNonce(newtToken.Nonce).
SetClaimsUserID(newtToken.Claims.UserID).
SetClaimsEmail(newtToken.Claims.Email).
SetClaimsEmailVerified(newtToken.Claims.EmailVerified).
SetClaimsUsername(newtToken.Claims.Username).
SetClaimsPreferredUsername(newtToken.Claims.PreferredUsername).
SetClaimsGroups(newtToken.Claims.Groups).
SetConnectorID(newtToken.ConnectorID).
SetConnectorData(newtToken.ConnectorData).
SetToken(newtToken.Token).
// Save utc time into database because ent doesn't support comparing dates with different timezones
SetLastUsed(newtToken.LastUsed.UTC()).
SetCreatedAt(newtToken.CreatedAt.UTC()).
Save(context.TODO())
if err != nil {
return rollback(tx, "update refresh token uploading: %w", err)
}
if err = tx.Commit(); err != nil {
return rollback(tx, "update refresh token commit: %w", err)
}
return nil
}

167
storage/ent/client/types.go Normal file
View File

@ -0,0 +1,167 @@
package client
import (
"encoding/json"
"strings"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/db"
)
const keysRowID = "keys"
func toStorageKeys(keys *db.Keys) storage.Keys {
return storage.Keys{
SigningKey: &keys.SigningKey,
SigningKeyPub: &keys.SigningKeyPub,
VerificationKeys: keys.VerificationKeys,
NextRotation: keys.NextRotation,
}
}
func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest {
return storage.AuthRequest{
ID: a.ID,
ClientID: a.ClientID,
ResponseTypes: a.ResponseTypes,
Scopes: a.Scopes,
RedirectURI: a.RedirectURI,
Nonce: a.Nonce,
State: a.State,
ForceApprovalPrompt: a.ForceApprovalPrompt,
LoggedIn: a.LoggedIn,
ConnectorID: a.ConnectorID,
ConnectorData: *a.ConnectorData,
Expiry: a.Expiry,
Claims: storage.Claims{
UserID: a.ClaimsUserID,
Username: a.ClaimsUsername,
PreferredUsername: a.ClaimsPreferredUsername,
Email: a.ClaimsEmail,
EmailVerified: a.ClaimsEmailVerified,
Groups: a.ClaimsGroups,
},
PKCE: storage.PKCE{
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
},
}
}
func toStorageAuthCode(a *db.AuthCode) storage.AuthCode {
return storage.AuthCode{
ID: a.ID,
ClientID: a.ClientID,
Scopes: a.Scopes,
RedirectURI: a.RedirectURI,
Nonce: a.Nonce,
ConnectorID: a.ConnectorID,
ConnectorData: *a.ConnectorData,
Expiry: a.Expiry,
Claims: storage.Claims{
UserID: a.ClaimsUserID,
Username: a.ClaimsUsername,
PreferredUsername: a.ClaimsPreferredUsername,
Email: a.ClaimsEmail,
EmailVerified: a.ClaimsEmailVerified,
Groups: a.ClaimsGroups,
},
PKCE: storage.PKCE{
CodeChallenge: a.CodeChallenge,
CodeChallengeMethod: a.CodeChallengeMethod,
},
}
}
func toStorageClient(c *db.OAuth2Client) storage.Client {
return storage.Client{
ID: c.ID,
Secret: c.Secret,
RedirectURIs: c.RedirectUris,
TrustedPeers: c.TrustedPeers,
Public: c.Public,
Name: c.Name,
LogoURL: c.LogoURL,
}
}
func toStorageConnector(c *db.Connector) storage.Connector {
return storage.Connector{
ID: c.ID,
Type: c.Type,
Name: c.Name,
Config: c.Config,
}
}
func toStorageOfflineSession(o *db.OfflineSession) storage.OfflineSessions {
s := storage.OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
ConnectorData: *o.ConnectorData,
}
if o.Refresh != nil {
if err := json.Unmarshal(o.Refresh, &s.Refresh); err != nil {
// Correctness of json structure if guaranteed on uploading
panic(err)
}
} else {
// Server code assumes this will be non-nil.
s.Refresh = make(map[string]*storage.RefreshTokenRef)
}
return s
}
func toStorageRefreshToken(r *db.RefreshToken) storage.RefreshToken {
return storage.RefreshToken{
ID: r.ID,
Token: r.Token,
CreatedAt: r.CreatedAt,
LastUsed: r.LastUsed,
ClientID: r.ClientID,
ConnectorID: r.ConnectorID,
ConnectorData: *r.ConnectorData,
Scopes: r.Scopes,
Nonce: r.Nonce,
Claims: storage.Claims{
UserID: r.ClaimsUserID,
Username: r.ClaimsUsername,
PreferredUsername: r.ClaimsPreferredUsername,
Email: r.ClaimsEmail,
EmailVerified: r.ClaimsEmailVerified,
Groups: r.ClaimsGroups,
},
}
}
func toStoragePassword(p *db.Password) storage.Password {
return storage.Password{
Email: p.Email,
Hash: p.Hash,
Username: p.Username,
UserID: p.UserID,
}
}
func toStorageDeviceRequest(r *db.DeviceRequest) storage.DeviceRequest {
return storage.DeviceRequest{
UserCode: strings.ToUpper(r.UserCode),
DeviceCode: r.DeviceCode,
ClientID: r.ClientID,
ClientSecret: r.ClientSecret,
Scopes: r.Scopes,
Expiry: r.Expiry,
}
}
func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken {
return storage.DeviceToken{
DeviceCode: t.DeviceCode,
Status: t.Status,
Token: string(*t.Token),
Expiry: t.Expiry,
LastRequestTime: t.LastRequest,
PollIntervalSeconds: t.PollInterval,
}
}

View File

@ -0,0 +1,44 @@
package client
import (
"fmt"
"hash"
"github.com/pkg/errors"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/db"
)
func rollback(tx *db.Tx, t string, err error) error {
rerr := tx.Rollback()
err = convertDBError(t, err)
if rerr == nil {
return err
}
return errors.Wrapf(err, "rolling back transaction: %v", rerr)
}
func convertDBError(t string, err error) error {
if db.IsNotFound(err) {
return storage.ErrNotFound
}
if db.IsConstraintError(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf(t, err)
}
// compose hashed id from user and connection id to use it as primary key
// ent doesn't support multi-key primary yet
// https://github.com/facebook/ent/issues/400
func offlineSessionID(userID string, connID string, hasher func() hash.Hash) string {
h := hasher()
h.Write([]byte(userID))
h.Write([]byte(connID))
return fmt.Sprintf("%x", h.Sum(nil))
}

3
storage/ent/generate.go Normal file
View File

@ -0,0 +1,3 @@
package ent
//go:generate go run github.com/facebook/ent/cmd/entc generate ./schema --target ./db

View File

@ -0,0 +1,89 @@
package schema
import (
"github.com/facebook/ent"
"github.com/facebook/ent/schema/field"
)
/* Original SQL table:
create table auth_code
(
id text not null primary key,
client_id text not null,
scopes blob not null,
nonce text not null,
redirect_uri text not null,
claims_user_id text not null,
claims_username text not null,
claims_email text not null,
claims_email_verified integer not null,
claims_groups blob not null,
connector_id text not null,
connector_data blob,
expiry timestamp not null,
claims_preferred_username text default '' not null,
code_challenge text default '' not null,
code_challenge_method text default '' not null
);
*/
// AuthCode holds the schema definition for the AuthCode entity.
type AuthCode struct {
ent.Schema
}
// Fields of the AuthCode.
func (AuthCode) Fields() []ent.Field {
return []ent.Field{
field.Text("id").
SchemaType(textSchema).
NotEmpty().
Unique(),
field.Text("client_id").
SchemaType(textSchema).
NotEmpty(),
field.JSON("scopes", []string{}).
Optional(),
field.Text("nonce").
SchemaType(textSchema).
NotEmpty(),
field.Text("redirect_uri").
SchemaType(textSchema).
NotEmpty(),
field.Text("claims_user_id").
SchemaType(textSchema).
NotEmpty(),
field.Text("claims_username").
SchemaType(textSchema).
NotEmpty(),
field.Text("claims_email").
SchemaType(textSchema).
NotEmpty(),
field.Bool("claims_email_verified"),
field.JSON("claims_groups", []string{}).
Optional(),
field.Text("claims_preferred_username").
SchemaType(textSchema).
Default(""),
field.Text("connector_id").
SchemaType(textSchema).
NotEmpty(),
field.Bytes("connector_data").
Nillable().
Optional(),
field.Time("expiry"),
field.Text("code_challenge").
SchemaType(textSchema).
Default(""),
field.Text("code_challenge_method").
SchemaType(textSchema).
Default(""),
}
}
// Edges of the AuthCode.
func (AuthCode) Edges() []ent.Edge {
return []ent.Edge{}
}

View File

@ -0,0 +1,94 @@
package schema
import (
"github.com/facebook/ent"
"github.com/facebook/ent/schema/field"
)
/* Original SQL table:
create table auth_request
(
id text not null primary key,
client_id text not null,
response_types blob not null,
scopes blob not null,
redirect_uri text not null,
nonce text not null,
state text not null,
force_approval_prompt integer not null,
logged_in integer not null,
claims_user_id text not null,
claims_username text not null,
claims_email text not null,
claims_email_verified integer not null,
claims_groups blob not null,
connector_id text not null,
connector_data blob,
expiry timestamp not null,
claims_preferred_username text default '' not null,
code_challenge text default '' not null,
code_challenge_method text default '' not null
);
*/
// AuthRequest holds the schema definition for the AuthRequest entity.
type AuthRequest struct {
ent.Schema
}
// Fields of the AuthRequest.
func (AuthRequest) Fields() []ent.Field {
return []ent.Field{
field.Text("id").
SchemaType(textSchema).
NotEmpty().
Unique(),
field.Text("client_id").
SchemaType(textSchema),
field.JSON("scopes", []string{}).
Optional(),
field.JSON("response_types", []string{}).
Optional(),
field.Text("redirect_uri").
SchemaType(textSchema),
field.Text("nonce").
SchemaType(textSchema),
field.Text("state").
SchemaType(textSchema),
field.Bool("force_approval_prompt"),
field.Bool("logged_in"),
field.Text("claims_user_id").
SchemaType(textSchema),
field.Text("claims_username").
SchemaType(textSchema),
field.Text("claims_email").
SchemaType(textSchema),
field.Bool("claims_email_verified"),
field.JSON("claims_groups", []string{}).
Optional(),
field.Text("claims_preferred_username").
SchemaType(textSchema).
Default(""),
field.Text("connector_id").
SchemaType(textSchema),
field.Bytes("connector_data").
Nillable().
Optional(),
field.Time("expiry"),
field.Text("code_challenge").
SchemaType(textSchema).
Default(""),
field.Text("code_challenge_method").
SchemaType(textSchema).
Default(""),
}
}
// Edges of the AuthRequest.
func (AuthRequest) Edges() []ent.Edge {
return []ent.Edge{}
}

View File

@ -0,0 +1,53 @@
package schema
import (
"github.com/facebook/ent"
"github.com/facebook/ent/schema/field"
)
/* Original SQL table:
create table client
(
id text not null primary key,
secret text not null,
redirect_uris blob not null,
trusted_peers blob not null,
public integer not null,
name text not null,
logo_url text not null
);
*/
// OAuth2Client holds the schema definition for the Client entity.
type OAuth2Client struct {
ent.Schema
}
// Fields of the OAuth2Client.
func (OAuth2Client) Fields() []ent.Field {
return []ent.Field{
field.Text("id").
SchemaType(textSchema).
NotEmpty().
Unique(),
field.Text("secret").
SchemaType(textSchema).
NotEmpty(),
field.JSON("redirect_uris", []string{}).
Optional(),
field.JSON("trusted_peers", []string{}).
Optional(),
field.Bool("public"),
field.Text("name").
SchemaType(textSchema).
NotEmpty(),
field.Text("logo_url").
SchemaType(textSchema).
NotEmpty(),
}
}
// Edges of the OAuth2Client.
func (OAuth2Client) Edges() []ent.Edge {
return []ent.Edge{}
}

View File

@ -0,0 +1,46 @@
package schema
import (
"github.com/facebook/ent"
"github.com/facebook/ent/schema/field"
)
/* Original SQL table:
create table connector
(
id text not null primary key,
type text not null,
name text not null,
resource_version text not null,
config blob
);
*/
// Connector holds the schema definition for the Client entity.
type Connector struct {
ent.Schema
}
// Fields of the Connector.
func (Connector) Fields() []ent.Field {
return []ent.Field{
field.Text("id").
SchemaType(textSchema).
NotEmpty().
Unique(),
field.Text("type").
SchemaType(textSchema).
NotEmpty(),
field.Text("name").
SchemaType(textSchema).
NotEmpty(),
field.Text("resource_version").
SchemaType(textSchema),
field.Bytes("config"),
}
}
// Edges of the Connector.
func (Connector) Edges() []ent.Edge {
return []ent.Edge{}
}

View File

@ -0,0 +1,50 @@
package schema
import (
"github.com/facebook/ent"
"github.com/facebook/ent/schema/field"
)
/* Original SQL table:
create table device_request
(
user_code text not null primary key,
device_code text not null,
client_id text not null,
client_secret text,
scopes blob not null,
expiry timestamp not null
);
*/
// DeviceRequest holds the schema definition for the DeviceRequest entity.
type DeviceRequest struct {
ent.Schema
}
// Fields of the DeviceRequest.
func (DeviceRequest) Fields() []ent.Field {
return []ent.Field{
field.Text("user_code").
SchemaType(textSchema).
NotEmpty().
Unique(),
field.Text("device_code").
SchemaType(textSchema).
NotEmpty(),
field.Text("client_id").
SchemaType(textSchema).
NotEmpty(),
field.Text("client_secret").
SchemaType(textSchema).
NotEmpty(),
field.JSON("scopes", []string{}).
Optional(),
field.Time("expiry"),
}
}
// Edges of the DeviceRequest.
func (DeviceRequest) Edges() []ent.Edge {
return []ent.Edge{}
}

View File

@ -0,0 +1,45 @@
package schema
import (
"github.com/facebook/ent"
"github.com/facebook/ent/schema/field"
)
/* Original SQL table:
create table device_token
(
device_code text not null primary key,
status text not null,
token blob,
expiry timestamp not null,
last_request timestamp not null,
poll_interval integer not null
);
*/
// DeviceToken holds the schema definition for the DeviceToken entity.
type DeviceToken struct {
ent.Schema
}
// Fields of the DeviceToken.
func (DeviceToken) Fields() []ent.Field {
return []ent.Field{
field.Text("device_code").
SchemaType(textSchema).
NotEmpty().
Unique(),
field.Text("status").
SchemaType(textSchema).
NotEmpty(),
field.Bytes("token").Nillable().Optional(),
field.Time("expiry"),
field.Time("last_request"),
field.Int("poll_interval"),
}
}
// Edges of the DeviceToken.
func (DeviceToken) Edges() []ent.Edge {
return []ent.Edge{}
}

View File

@ -0,0 +1,44 @@
package schema
import (
"github.com/facebook/ent"
"github.com/facebook/ent/schema/field"
"gopkg.in/square/go-jose.v2"
"github.com/dexidp/dex/storage"
)
/* Original SQL table:
create table keys
(
id text not null primary key,
verification_keys blob not null,
signing_key blob not null,
signing_key_pub blob not null,
next_rotation timestamp not null
);
*/
// Keys holds the schema definition for the Keys entity.
type Keys struct {
ent.Schema
}
// Fields of the Keys.
func (Keys) Fields() []ent.Field {
return []ent.Field{
field.Text("id").
SchemaType(textSchema).
NotEmpty().
Unique(),
field.JSON("verification_keys", []storage.VerificationKey{}),
field.JSON("signing_key", jose.JSONWebKey{}),
field.JSON("signing_key_pub", jose.JSONWebKey{}),
field.Time("next_rotation"),
}
}
// Edges of the Keys.
func (Keys) Edges() []ent.Edge {
return []ent.Edge{}
}

View File

@ -0,0 +1,46 @@
package schema
import (
"github.com/facebook/ent"
"github.com/facebook/ent/schema/field"
)
/* Original SQL table:
create table offline_session
(
user_id text not null,
conn_id text not null,
refresh blob not null,
connector_data blob,
primary key (user_id, conn_id)
);
*/
// OfflineSession holds the schema definition for the OfflineSession entity.
type OfflineSession struct {
ent.Schema
}
// Fields of the OfflineSession.
func (OfflineSession) Fields() []ent.Field {
return []ent.Field{
// Using id field here because it's impossible to create multi-key primary yet
field.Text("id").
SchemaType(textSchema).
NotEmpty().
Unique(),
field.Text("user_id").
SchemaType(textSchema).
NotEmpty(),
field.Text("conn_id").
SchemaType(textSchema).
NotEmpty(),
field.Bytes("refresh"),
field.Bytes("connector_data").Nillable().Optional(),
}
}
// Edges of the OfflineSession.
func (OfflineSession) Edges() []ent.Edge {
return []ent.Edge{}
}

View File

@ -0,0 +1,44 @@
package schema
import (
"github.com/facebook/ent"
"github.com/facebook/ent/schema/field"
)
/* Original SQL table:
create table password
(
email text not null primary key,
hash blob not null,
username text not null,
user_id text not null
);
*/
// Password holds the schema definition for the Password entity.
type Password struct {
ent.Schema
}
// Fields of the Password.
func (Password) Fields() []ent.Field {
return []ent.Field{
field.Text("email").
SchemaType(textSchema).
StorageKey("email"). // use email as ID field to make querying easier
NotEmpty().
Unique(),
field.Bytes("hash"),
field.Text("username").
SchemaType(textSchema).
NotEmpty(),
field.Text("user_id").
SchemaType(textSchema).
NotEmpty(),
}
}
// Edges of the Password.
func (Password) Edges() []ent.Edge {
return []ent.Edge{}
}

View File

@ -0,0 +1,89 @@
package schema
import (
"time"
"github.com/facebook/ent"
"github.com/facebook/ent/schema/field"
)
/* Original SQL table:
create table refresh_token
(
id text not null primary key,
client_id text not null,
scopes blob not null,
nonce text not null,
claims_user_id text not null,
claims_username text not null,
claims_email text not null,
claims_email_verified integer not null,
claims_groups blob not null,
connector_id text not null,
connector_data blob,
token text default '' not null,
created_at timestamp default '0001-01-01 00:00:00 UTC' not null,
last_used timestamp default '0001-01-01 00:00:00 UTC' not null,
claims_preferred_username text default '' not null
);
*/
// RefreshToken holds the schema definition for the RefreshToken entity.
type RefreshToken struct {
ent.Schema
}
// Fields of the RefreshToken.
func (RefreshToken) Fields() []ent.Field {
return []ent.Field{
field.Text("id").
SchemaType(textSchema).
NotEmpty().
Unique(),
field.Text("client_id").
SchemaType(textSchema).
NotEmpty(),
field.JSON("scopes", []string{}).
Optional(),
field.Text("nonce").
SchemaType(textSchema).
NotEmpty(),
field.Text("claims_user_id").
SchemaType(textSchema).
NotEmpty(),
field.Text("claims_username").
SchemaType(textSchema).
NotEmpty(),
field.Text("claims_email").
SchemaType(textSchema).
NotEmpty(),
field.Bool("claims_email_verified"),
field.JSON("claims_groups", []string{}).
Optional(),
field.Text("claims_preferred_username").
SchemaType(textSchema).
Default(""),
field.Text("connector_id").
SchemaType(textSchema).
NotEmpty(),
field.Bytes("connector_data").
Nillable().
Optional(),
field.Text("token").
SchemaType(textSchema).
Default(""),
field.Time("created_at").
Default(time.Now),
field.Time("last_used").
Default(time.Now),
}
}
// Edges of the RefreshToken.
func (RefreshToken) Edges() []ent.Edge {
return []ent.Edge{}
}

View File

@ -0,0 +1,9 @@
package schema
import (
"github.com/facebook/ent/dialect"
)
var textSchema = map[string]string{
dialect.SQLite: "text",
}

65
storage/ent/sqlite.go Normal file
View File

@ -0,0 +1,65 @@
package ent
import (
"context"
"crypto/sha256"
"strings"
"github.com/facebook/ent/dialect/sql"
// Register sqlite driver.
_ "github.com/mattn/go-sqlite3"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/client"
"github.com/dexidp/dex/storage/ent/db"
)
// SQLite3 options for creating an SQL db.
type SQLite3 struct {
File string `json:"file"`
}
// Open always returns a new in sqlite3 storage.
func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) {
logger.Debug("experimental ent-based storage driver is enabled")
// Implicitly set foreign_keys pragma to "on" because it is required by ent
s.File = addFK(s.File)
drv, err := sql.Open("sqlite3", s.File)
if err != nil {
return nil, err
}
pool := drv.DB()
if s.File == ":memory:" {
// sqlite3 uses file locks to coordinate concurrent access. In memory
// doesn't support this, so limit the number of connections to 1.
pool.SetMaxOpenConns(1)
}
databaseClient := client.NewDatabase(
client.WithClient(db.NewClient(db.Driver(drv))),
client.WithHasher(sha256.New),
)
if err := databaseClient.Schema().Create(context.TODO()); err != nil {
return nil, err
}
return databaseClient, nil
}
func addFK(dsn string) string {
if strings.Contains(dsn, "_fk") {
return dsn
}
delim := "?"
if strings.Contains(dsn, "?") {
delim = "&"
}
return dsn + delim + "_fk=1"
}

View File

@ -0,0 +1,31 @@
package ent
import (
"os"
"testing"
"github.com/sirupsen/logrus"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/conformance"
)
func newStorage() storage.Storage {
logger := &logrus.Logger{
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
cfg := SQLite3{File: ":memory:"}
s, err := cfg.Open(logger)
if err != nil {
panic(err)
}
return s
}
func TestSQLite3(t *testing.T) {
conformance.RunTests(t, newStorage)
conformance.RunTransactionTests(t, newStorage)
}