feat: Add ent-based sqlite3 storage
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
parent
674631c9ab
commit
11859166d0
7
Makefile
7
Makefile
@ -26,7 +26,10 @@ PROTOC_VERSION = 3.15.6
|
||||
PROTOC_GEN_GO_VERSION = 1.26.0
|
||||
PROTOC_GEN_GO_GRPC_VERSION = 1.1.0
|
||||
|
||||
build: bin/dex
|
||||
generate:
|
||||
@go generate $(REPO_PATH)/storage/ent/
|
||||
|
||||
build: generate bin/dex
|
||||
|
||||
bin/dex:
|
||||
@mkdir -p bin/
|
||||
@ -42,7 +45,7 @@ bin/example-app:
|
||||
@mkdir -p bin/
|
||||
@cd examples/ && go install -v -ldflags $(LD_FLAGS) $(REPO_PATH)/examples/example-app
|
||||
|
||||
.PHONY: release-binary
|
||||
.PHONY: generate release-binary
|
||||
release-binary:
|
||||
@go build -o /go/bin/dex -v -ldflags $(LD_FLAGS) $(REPO_PATH)/cmd/dex
|
||||
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/dexidp/dex/pkg/log"
|
||||
"github.com/dexidp/dex/server"
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent"
|
||||
"github.com/dexidp/dex/storage/etcd"
|
||||
"github.com/dexidp/dex/storage/kubernetes"
|
||||
"github.com/dexidp/dex/storage/memory"
|
||||
@ -173,13 +174,32 @@ type StorageConfig interface {
|
||||
Open(logger log.Logger) (storage.Storage, error)
|
||||
}
|
||||
|
||||
var (
|
||||
_ StorageConfig = (*etcd.Etcd)(nil)
|
||||
_ StorageConfig = (*kubernetes.Config)(nil)
|
||||
_ StorageConfig = (*memory.Config)(nil)
|
||||
_ StorageConfig = (*sql.SQLite3)(nil)
|
||||
_ StorageConfig = (*sql.Postgres)(nil)
|
||||
_ StorageConfig = (*sql.MySQL)(nil)
|
||||
_ StorageConfig = (*ent.SQLite3)(nil)
|
||||
)
|
||||
|
||||
func getORMBasedSQLiteStorage() StorageConfig {
|
||||
switch os.Getenv("DEX_ENT_ENABLED") {
|
||||
case "true", "yes":
|
||||
return new(ent.SQLite3)
|
||||
default:
|
||||
return new(sql.SQLite3)
|
||||
}
|
||||
}
|
||||
|
||||
var storages = map[string]func() StorageConfig{
|
||||
"etcd": func() StorageConfig { return new(etcd.Etcd) },
|
||||
"kubernetes": func() StorageConfig { return new(kubernetes.Config) },
|
||||
"memory": func() StorageConfig { return new(memory.Config) },
|
||||
"sqlite3": func() StorageConfig { return new(sql.SQLite3) },
|
||||
"postgres": func() StorageConfig { return new(sql.Postgres) },
|
||||
"mysql": func() StorageConfig { return new(sql.MySQL) },
|
||||
"sqlite3": getORMBasedSQLiteStorage,
|
||||
}
|
||||
|
||||
// isExpandEnvEnabled returns if os.ExpandEnv should be used for each storage and connector config.
|
||||
|
3
go.mod
3
go.mod
@ -7,7 +7,8 @@ require (
|
||||
github.com/beevik/etree v1.1.0
|
||||
github.com/coreos/go-oidc/v3 v3.0.0
|
||||
github.com/dexidp/dex/api/v2 v2.0.0
|
||||
github.com/felixge/httpsnoop v1.0.2
|
||||
github.com/facebook/ent v0.5.3
|
||||
github.com/felixge/httpsnoop v1.0.1
|
||||
github.com/ghodss/yaml v1.0.0
|
||||
github.com/go-ldap/ldap/v3 v3.3.0
|
||||
github.com/go-sql-driver/mysql v1.6.0
|
||||
|
2
go.sum
2
go.sum
@ -128,6 +128,8 @@ github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5y
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
|
||||
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
|
||||
github.com/facebook/ent v0.5.3 h1:YT3Sl28n7gGGOkQeYgeJsZmizJ1Iiy7psgkOtEk0aq4=
|
||||
github.com/facebook/ent v0.5.3/go.mod h1:tlWP+qCd3x2EeO7B/EqlJQ4dWu/2IeYFhP/szzDKAi8=
|
||||
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
|
||||
github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/felixge/httpsnoop v1.0.2 h1:+nS9g82KMXccJ/wp0zyRW9ZBHFETmMGtkk+2CTTrW4o=
|
||||
|
52
storage/ent/client/authcode.go
Normal file
52
storage/ent/client/authcode.go
Normal file
@ -0,0 +1,52 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateAuthCode saves provided auth code into the database.
|
||||
func (d *Database) CreateAuthCode(code storage.AuthCode) error {
|
||||
_, err := d.client.AuthCode.Create().
|
||||
SetID(code.ID).
|
||||
SetClientID(code.ClientID).
|
||||
SetScopes(code.Scopes).
|
||||
SetRedirectURI(code.RedirectURI).
|
||||
SetNonce(code.Nonce).
|
||||
SetClaimsUserID(code.Claims.UserID).
|
||||
SetClaimsEmail(code.Claims.Email).
|
||||
SetClaimsEmailVerified(code.Claims.EmailVerified).
|
||||
SetClaimsUsername(code.Claims.Username).
|
||||
SetClaimsPreferredUsername(code.Claims.PreferredUsername).
|
||||
SetClaimsGroups(code.Claims.Groups).
|
||||
SetCodeChallenge(code.PKCE.CodeChallenge).
|
||||
SetCodeChallengeMethod(code.PKCE.CodeChallengeMethod).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(code.Expiry.UTC()).
|
||||
SetConnectorID(code.ConnectorID).
|
||||
SetConnectorData(code.ConnectorData).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create auth code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAuthCode extracts an auth code from the database by id.
|
||||
func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) {
|
||||
authCode, err := d.client.AuthCode.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.AuthCode{}, convertDBError("get auth code: %w", err)
|
||||
}
|
||||
return toStorageAuthCode(authCode), nil
|
||||
}
|
||||
|
||||
// DeleteAuthCode deletes an auth code from the database by id.
|
||||
func (d *Database) DeleteAuthCode(id string) error {
|
||||
err := d.client.AuthCode.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete auth code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
107
storage/ent/client/authrequest.go
Normal file
107
storage/ent/client/authrequest.go
Normal file
@ -0,0 +1,107 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateAuthRequest saves provided auth request into the database.
|
||||
func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error {
|
||||
_, err := d.client.AuthRequest.Create().
|
||||
SetID(authRequest.ID).
|
||||
SetClientID(authRequest.ClientID).
|
||||
SetScopes(authRequest.Scopes).
|
||||
SetResponseTypes(authRequest.ResponseTypes).
|
||||
SetRedirectURI(authRequest.RedirectURI).
|
||||
SetState(authRequest.State).
|
||||
SetNonce(authRequest.Nonce).
|
||||
SetForceApprovalPrompt(authRequest.ForceApprovalPrompt).
|
||||
SetLoggedIn(authRequest.LoggedIn).
|
||||
SetClaimsUserID(authRequest.Claims.UserID).
|
||||
SetClaimsEmail(authRequest.Claims.Email).
|
||||
SetClaimsEmailVerified(authRequest.Claims.EmailVerified).
|
||||
SetClaimsUsername(authRequest.Claims.Username).
|
||||
SetClaimsPreferredUsername(authRequest.Claims.PreferredUsername).
|
||||
SetClaimsGroups(authRequest.Claims.Groups).
|
||||
SetCodeChallenge(authRequest.PKCE.CodeChallenge).
|
||||
SetCodeChallengeMethod(authRequest.PKCE.CodeChallengeMethod).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(authRequest.Expiry.UTC()).
|
||||
SetConnectorID(authRequest.ConnectorID).
|
||||
SetConnectorData(authRequest.ConnectorData).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create auth request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAuthRequest extracts an auth request from the database by id.
|
||||
func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||
authRequest, err := d.client.AuthRequest.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.AuthRequest{}, convertDBError("get auth request: %w", err)
|
||||
}
|
||||
return toStorageAuthRequest(authRequest), nil
|
||||
}
|
||||
|
||||
// DeleteAuthRequest deletes an auth request from the database by id.
|
||||
func (d *Database) DeleteAuthRequest(id string) error {
|
||||
err := d.client.AuthRequest.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete auth request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateAuthRequest changes an auth request by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error {
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return fmt.Errorf("update auth request tx: %w", err)
|
||||
}
|
||||
|
||||
authRequest, err := tx.AuthRequest.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return rollback(tx, "update auth request database: %w", err)
|
||||
}
|
||||
|
||||
newAuthRequest, err := updater(toStorageAuthRequest(authRequest))
|
||||
if err != nil {
|
||||
return rollback(tx, "update auth request updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.AuthRequest.UpdateOneID(newAuthRequest.ID).
|
||||
SetClientID(newAuthRequest.ClientID).
|
||||
SetScopes(newAuthRequest.Scopes).
|
||||
SetResponseTypes(newAuthRequest.ResponseTypes).
|
||||
SetRedirectURI(newAuthRequest.RedirectURI).
|
||||
SetState(newAuthRequest.State).
|
||||
SetNonce(newAuthRequest.Nonce).
|
||||
SetForceApprovalPrompt(newAuthRequest.ForceApprovalPrompt).
|
||||
SetLoggedIn(newAuthRequest.LoggedIn).
|
||||
SetClaimsUserID(newAuthRequest.Claims.UserID).
|
||||
SetClaimsEmail(newAuthRequest.Claims.Email).
|
||||
SetClaimsEmailVerified(newAuthRequest.Claims.EmailVerified).
|
||||
SetClaimsUsername(newAuthRequest.Claims.Username).
|
||||
SetClaimsPreferredUsername(newAuthRequest.Claims.PreferredUsername).
|
||||
SetClaimsGroups(newAuthRequest.Claims.Groups).
|
||||
SetCodeChallenge(newAuthRequest.PKCE.CodeChallenge).
|
||||
SetCodeChallengeMethod(newAuthRequest.PKCE.CodeChallengeMethod).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(newAuthRequest.Expiry.UTC()).
|
||||
SetConnectorID(newAuthRequest.ConnectorID).
|
||||
SetConnectorData(newAuthRequest.ConnectorData).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update auth request uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update auth request commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
92
storage/ent/client/client.go
Normal file
92
storage/ent/client/client.go
Normal file
@ -0,0 +1,92 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateClient saves provided oauth2 client settings into the database.
|
||||
func (d *Database) CreateClient(client storage.Client) error {
|
||||
_, err := d.client.OAuth2Client.Create().
|
||||
SetID(client.ID).
|
||||
SetName(client.Name).
|
||||
SetSecret(client.Secret).
|
||||
SetPublic(client.Public).
|
||||
SetLogoURL(client.LogoURL).
|
||||
SetRedirectUris(client.RedirectURIs).
|
||||
SetTrustedPeers(client.TrustedPeers).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create oauth2 client: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListClients extracts an array of oauth2 clients from the database.
|
||||
func (d *Database) ListClients() ([]storage.Client, error) {
|
||||
clients, err := d.client.OAuth2Client.Query().All(context.TODO())
|
||||
if err != nil {
|
||||
return nil, convertDBError("list clients: %w", err)
|
||||
}
|
||||
|
||||
storageClients := make([]storage.Client, 0, len(clients))
|
||||
for _, c := range clients {
|
||||
storageClients = append(storageClients, toStorageClient(c))
|
||||
}
|
||||
return storageClients, nil
|
||||
}
|
||||
|
||||
// GetClient extracts an oauth2 client from the database by id.
|
||||
func (d *Database) GetClient(id string) (storage.Client, error) {
|
||||
client, err := d.client.OAuth2Client.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.Client{}, convertDBError("get client: %w", err)
|
||||
}
|
||||
return toStorageClient(client), nil
|
||||
}
|
||||
|
||||
// DeleteClient deletes an oauth2 client from the database by id.
|
||||
func (d *Database) DeleteClient(id string) error {
|
||||
err := d.client.OAuth2Client.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete client: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateClient changes an oauth2 client by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update client tx: %w", err)
|
||||
}
|
||||
|
||||
client, err := tx.OAuth2Client.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return rollback(tx, "update client database: %w", err)
|
||||
}
|
||||
|
||||
newClient, err := updater(toStorageClient(client))
|
||||
if err != nil {
|
||||
return rollback(tx, "update client updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.OAuth2Client.UpdateOneID(newClient.ID).
|
||||
SetName(newClient.Name).
|
||||
SetSecret(newClient.Secret).
|
||||
SetPublic(newClient.Public).
|
||||
SetLogoURL(newClient.LogoURL).
|
||||
SetRedirectUris(newClient.RedirectURIs).
|
||||
SetTrustedPeers(newClient.TrustedPeers).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update client uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update auth request commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
88
storage/ent/client/connector.go
Normal file
88
storage/ent/client/connector.go
Normal file
@ -0,0 +1,88 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateConnector saves a connector into the database.
|
||||
func (d *Database) CreateConnector(connector storage.Connector) error {
|
||||
_, err := d.client.Connector.Create().
|
||||
SetID(connector.ID).
|
||||
SetName(connector.Name).
|
||||
SetType(connector.Type).
|
||||
SetResourceVersion(connector.ResourceVersion).
|
||||
SetConfig(connector.Config).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create connector: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListConnectors extracts an array of connectors from the database.
|
||||
func (d *Database) ListConnectors() ([]storage.Connector, error) {
|
||||
connectors, err := d.client.Connector.Query().All(context.TODO())
|
||||
if err != nil {
|
||||
return nil, convertDBError("list connectors: %w", err)
|
||||
}
|
||||
|
||||
storageConnectors := make([]storage.Connector, 0, len(connectors))
|
||||
for _, c := range connectors {
|
||||
storageConnectors = append(storageConnectors, toStorageConnector(c))
|
||||
}
|
||||
return storageConnectors, nil
|
||||
}
|
||||
|
||||
// GetConnector extracts a connector from the database by id.
|
||||
func (d *Database) GetConnector(id string) (storage.Connector, error) {
|
||||
connector, err := d.client.Connector.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.Connector{}, convertDBError("get connector: %w", err)
|
||||
}
|
||||
return toStorageConnector(connector), nil
|
||||
}
|
||||
|
||||
// DeleteConnector deletes a connector from the database by id.
|
||||
func (d *Database) DeleteConnector(id string) error {
|
||||
err := d.client.Connector.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete connector: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateConnector changes a connector by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateConnector(id string, updater func(old storage.Connector) (storage.Connector, error)) error {
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update connector tx: %w", err)
|
||||
}
|
||||
|
||||
connector, err := tx.Connector.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return rollback(tx, "update connector database: %w", err)
|
||||
}
|
||||
|
||||
newConnector, err := updater(toStorageConnector(connector))
|
||||
if err != nil {
|
||||
return rollback(tx, "update connector updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Connector.UpdateOneID(newConnector.ID).
|
||||
SetName(newConnector.Name).
|
||||
SetType(newConnector.Type).
|
||||
SetResourceVersion(newConnector.ResourceVersion).
|
||||
SetConfig(newConnector.Config).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update connector uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update connector commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
36
storage/ent/client/devicerequest.go
Normal file
36
storage/ent/client/devicerequest.go
Normal file
@ -0,0 +1,36 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db/devicerequest"
|
||||
)
|
||||
|
||||
// CreateDeviceRequest saves provided device request into the database.
|
||||
func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error {
|
||||
_, err := d.client.DeviceRequest.Create().
|
||||
SetClientID(request.ClientID).
|
||||
SetClientSecret(request.ClientSecret).
|
||||
SetScopes(request.Scopes).
|
||||
SetUserCode(request.UserCode).
|
||||
SetDeviceCode(request.DeviceCode).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(request.Expiry.UTC()).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create device request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDeviceRequest extracts a device request from the database by user code.
|
||||
func (d *Database) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
|
||||
deviceRequest, err := d.client.DeviceRequest.Query().
|
||||
Where(devicerequest.UserCode(userCode)).
|
||||
Only(context.TODO())
|
||||
if err != nil {
|
||||
return storage.DeviceRequest{}, convertDBError("get device request: %w", err)
|
||||
}
|
||||
return toStorageDeviceRequest(deviceRequest), nil
|
||||
}
|
76
storage/ent/client/devicetoken.go
Normal file
76
storage/ent/client/devicetoken.go
Normal file
@ -0,0 +1,76 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db/devicetoken"
|
||||
)
|
||||
|
||||
// CreateDeviceToken saves provided token into the database.
|
||||
func (d *Database) CreateDeviceToken(token storage.DeviceToken) error {
|
||||
_, err := d.client.DeviceToken.Create().
|
||||
SetDeviceCode(token.DeviceCode).
|
||||
SetToken([]byte(token.Token)).
|
||||
SetPollInterval(token.PollIntervalSeconds).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(token.Expiry.UTC()).
|
||||
SetLastRequest(token.LastRequestTime.UTC()).
|
||||
SetStatus(token.Status).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create device token: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDeviceToken extracts a token from the database by device code.
|
||||
func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
|
||||
deviceToken, err := d.client.DeviceToken.Query().
|
||||
Where(devicetoken.DeviceCode(deviceCode)).
|
||||
Only(context.TODO())
|
||||
if err != nil {
|
||||
return storage.DeviceToken{}, convertDBError("get device token: %w", err)
|
||||
}
|
||||
return toStorageDeviceToken(deviceToken), nil
|
||||
}
|
||||
|
||||
// UpdateDeviceToken changes a token by device code using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update device token tx: %w", err)
|
||||
}
|
||||
|
||||
token, err := tx.DeviceToken.Query().
|
||||
Where(devicetoken.DeviceCode(deviceCode)).
|
||||
Only(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update device token database: %w", err)
|
||||
}
|
||||
|
||||
newToken, err := updater(toStorageDeviceToken(token))
|
||||
if err != nil {
|
||||
return rollback(tx, "update device token updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.DeviceToken.Update().
|
||||
Where(devicetoken.DeviceCode(newToken.DeviceCode)).
|
||||
SetDeviceCode(newToken.DeviceCode).
|
||||
SetToken([]byte(newToken.Token)).
|
||||
SetPollInterval(newToken.PollIntervalSeconds).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(newToken.Expiry.UTC()).
|
||||
SetLastRequest(newToken.LastRequestTime.UTC()).
|
||||
SetStatus(newToken.Status).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update device token uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update device token commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
81
storage/ent/client/keys.go
Normal file
81
storage/ent/client/keys.go
Normal file
@ -0,0 +1,81 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db"
|
||||
)
|
||||
|
||||
func getKeys(client *db.KeysClient) (storage.Keys, error) {
|
||||
rawKeys, err := client.Get(context.TODO(), keysRowID)
|
||||
if err != nil {
|
||||
return storage.Keys{}, convertDBError("get keys: %w", err)
|
||||
}
|
||||
|
||||
return toStorageKeys(rawKeys), nil
|
||||
}
|
||||
|
||||
// GetKeys returns signing keys, public keys and verification keys from the database.
|
||||
func (d *Database) GetKeys() (storage.Keys, error) {
|
||||
return getKeys(d.client.Keys)
|
||||
}
|
||||
|
||||
// UpdateKeys rotates keys using updater function.
|
||||
func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
|
||||
firstUpdate := false
|
||||
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update keys tx: %w", err)
|
||||
}
|
||||
|
||||
storageKeys, err := getKeys(tx.Keys)
|
||||
if err != nil {
|
||||
if !errors.Is(err, storage.ErrNotFound) {
|
||||
return rollback(tx, "update keys get: %w", err)
|
||||
}
|
||||
firstUpdate = true
|
||||
}
|
||||
|
||||
newKeys, err := updater(storageKeys)
|
||||
if err != nil {
|
||||
return rollback(tx, "update keys updating: %w", err)
|
||||
}
|
||||
|
||||
// ent doesn't have an upsert support yet
|
||||
// https://github.com/facebook/ent/issues/139
|
||||
if firstUpdate {
|
||||
_, err = tx.Keys.Create().
|
||||
SetID(keysRowID).
|
||||
SetNextRotation(newKeys.NextRotation).
|
||||
SetSigningKey(*newKeys.SigningKey).
|
||||
SetSigningKeyPub(*newKeys.SigningKeyPub).
|
||||
SetVerificationKeys(newKeys.VerificationKeys).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "create keys: %w", err)
|
||||
}
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update keys commit: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = tx.Keys.UpdateOneID(keysRowID).
|
||||
SetNextRotation(newKeys.NextRotation.UTC()).
|
||||
SetSigningKey(*newKeys.SigningKey).
|
||||
SetSigningKeyPub(*newKeys.SigningKeyPub).
|
||||
SetVerificationKeys(newKeys.VerificationKeys).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update keys uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update keys commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
95
storage/ent/client/main.go
Normal file
95
storage/ent/client/main.go
Normal file
@ -0,0 +1,95 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hash"
|
||||
"time"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db"
|
||||
"github.com/dexidp/dex/storage/ent/db/authcode"
|
||||
"github.com/dexidp/dex/storage/ent/db/authrequest"
|
||||
"github.com/dexidp/dex/storage/ent/db/devicerequest"
|
||||
"github.com/dexidp/dex/storage/ent/db/devicetoken"
|
||||
"github.com/dexidp/dex/storage/ent/db/migrate"
|
||||
)
|
||||
|
||||
var _ storage.Storage = (*Database)(nil)
|
||||
|
||||
type Database struct {
|
||||
client *db.Client
|
||||
hasher func() hash.Hash
|
||||
}
|
||||
|
||||
// NewDatabase returns new database client with set options.
|
||||
func NewDatabase(opts ...func(*Database)) *Database {
|
||||
database := &Database{}
|
||||
for _, f := range opts {
|
||||
f(database)
|
||||
}
|
||||
return database
|
||||
}
|
||||
|
||||
// WithClient sets client option of a Database object.
|
||||
func WithClient(c *db.Client) func(*Database) {
|
||||
return func(s *Database) {
|
||||
s.client = c
|
||||
}
|
||||
}
|
||||
|
||||
// WithHasher sets client option of a Database object.
|
||||
func WithHasher(h func() hash.Hash) func(*Database) {
|
||||
return func(s *Database) {
|
||||
s.hasher = h
|
||||
}
|
||||
}
|
||||
|
||||
// Schema exposes migration schema to perform migrations.
|
||||
func (d *Database) Schema() *migrate.Schema {
|
||||
return d.client.Schema
|
||||
}
|
||||
|
||||
// Close calls the corresponding method of the ent database client.
|
||||
func (d *Database) Close() error {
|
||||
return d.client.Close()
|
||||
}
|
||||
|
||||
// GarbageCollect removes expired entities from the database.
|
||||
func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) {
|
||||
result := storage.GCResult{}
|
||||
utcNow := now.UTC()
|
||||
|
||||
q, err := d.client.AuthRequest.Delete().
|
||||
Where(authrequest.ExpiryLT(utcNow)).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return result, convertDBError("gc auth request: %w", err)
|
||||
}
|
||||
result.AuthRequests = int64(q)
|
||||
|
||||
q, err = d.client.AuthCode.Delete().
|
||||
Where(authcode.ExpiryLT(utcNow)).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return result, convertDBError("gc auth code: %w", err)
|
||||
}
|
||||
result.AuthCodes = int64(q)
|
||||
|
||||
q, err = d.client.DeviceRequest.Delete().
|
||||
Where(devicerequest.ExpiryLT(utcNow)).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return result, convertDBError("gc device request: %w", err)
|
||||
}
|
||||
result.DeviceRequests = int64(q)
|
||||
|
||||
q, err = d.client.DeviceToken.Delete().
|
||||
Where(devicetoken.ExpiryLT(utcNow)).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return result, convertDBError("gc device token: %w", err)
|
||||
}
|
||||
result.DeviceTokens = int64(q)
|
||||
|
||||
return result, err
|
||||
}
|
93
storage/ent/client/offlinesession.go
Normal file
93
storage/ent/client/offlinesession.go
Normal file
@ -0,0 +1,93 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateOfflineSessions saves provided offline session into the database.
|
||||
func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error {
|
||||
encodedRefresh, err := json.Marshal(session.Refresh)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode refresh offline session: %w", err)
|
||||
}
|
||||
|
||||
id := offlineSessionID(session.UserID, session.ConnID, d.hasher)
|
||||
_, err = d.client.OfflineSession.Create().
|
||||
SetID(id).
|
||||
SetUserID(session.UserID).
|
||||
SetConnID(session.ConnID).
|
||||
SetConnectorData(session.ConnectorData).
|
||||
SetRefresh(encodedRefresh).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create offline session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOfflineSessions extracts an offline session from the database by user id and connector id.
|
||||
func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSessions, error) {
|
||||
id := offlineSessionID(userID, connID, d.hasher)
|
||||
|
||||
offlineSession, err := d.client.OfflineSession.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.OfflineSessions{}, convertDBError("get offline session: %w", err)
|
||||
}
|
||||
return toStorageOfflineSession(offlineSession), nil
|
||||
}
|
||||
|
||||
// DeleteOfflineSessions deletes an offline session from the database by user id and connector id.
|
||||
func (d *Database) DeleteOfflineSessions(userID, connID string) error {
|
||||
id := offlineSessionID(userID, connID, d.hasher)
|
||||
|
||||
err := d.client.OfflineSession.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete offline session: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePassword changes an offline session by user id and connector id using an updater function.
|
||||
func (d *Database) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
|
||||
id := offlineSessionID(userID, connID, d.hasher)
|
||||
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update offline session tx: %w", err)
|
||||
}
|
||||
|
||||
offlineSession, err := tx.OfflineSession.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return rollback(tx, "update offline session database: %w", err)
|
||||
}
|
||||
|
||||
newOfflineSession, err := updater(toStorageOfflineSession(offlineSession))
|
||||
if err != nil {
|
||||
return rollback(tx, "update offline session updating: %w", err)
|
||||
}
|
||||
|
||||
encodedRefresh, err := json.Marshal(newOfflineSession.Refresh)
|
||||
if err != nil {
|
||||
return rollback(tx, "encode refresh offline session: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.OfflineSession.UpdateOneID(id).
|
||||
SetUserID(newOfflineSession.UserID).
|
||||
SetConnID(newOfflineSession.ConnID).
|
||||
SetConnectorData(newOfflineSession.ConnectorData).
|
||||
SetRefresh(encodedRefresh).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update offline session uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update password commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
100
storage/ent/client/password.go
Normal file
100
storage/ent/client/password.go
Normal file
@ -0,0 +1,100 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db/password"
|
||||
)
|
||||
|
||||
// CreatePassword saves provided password into the database.
|
||||
func (d *Database) CreatePassword(password storage.Password) error {
|
||||
_, err := d.client.Password.Create().
|
||||
SetEmail(password.Email).
|
||||
SetHash(password.Hash).
|
||||
SetUsername(password.Username).
|
||||
SetUserID(password.UserID).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create password: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPasswords extracts an array of passwords from the database.
|
||||
func (d *Database) ListPasswords() ([]storage.Password, error) {
|
||||
passwords, err := d.client.Password.Query().All(context.TODO())
|
||||
if err != nil {
|
||||
return nil, convertDBError("list passwords: %w", err)
|
||||
}
|
||||
|
||||
storagePasswords := make([]storage.Password, 0, len(passwords))
|
||||
for _, p := range passwords {
|
||||
storagePasswords = append(storagePasswords, toStoragePassword(p))
|
||||
}
|
||||
return storagePasswords, nil
|
||||
}
|
||||
|
||||
// GetPassword extracts a password from the database by email.
|
||||
func (d *Database) GetPassword(email string) (storage.Password, error) {
|
||||
email = strings.ToLower(email)
|
||||
passwordFromStorage, err := d.client.Password.Query().
|
||||
Where(password.Email(email)).
|
||||
Only(context.TODO())
|
||||
if err != nil {
|
||||
return storage.Password{}, convertDBError("get password: %w", err)
|
||||
}
|
||||
return toStoragePassword(passwordFromStorage), nil
|
||||
}
|
||||
|
||||
// DeletePassword deletes a password from the database by email.
|
||||
func (d *Database) DeletePassword(email string) error {
|
||||
email = strings.ToLower(email)
|
||||
_, err := d.client.Password.Delete().
|
||||
Where(password.Email(email)).
|
||||
Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete password: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePassword changes a password by email using an updater function and saves it to the database.
|
||||
func (d *Database) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error {
|
||||
email = strings.ToLower(email)
|
||||
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update connector tx: %w", err)
|
||||
}
|
||||
|
||||
passwordToUpdate, err := tx.Password.Query().
|
||||
Where(password.Email(email)).
|
||||
Only(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update password database: %w", err)
|
||||
}
|
||||
|
||||
newPassword, err := updater(toStoragePassword(passwordToUpdate))
|
||||
if err != nil {
|
||||
return rollback(tx, "update password updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Password.Update().
|
||||
Where(password.Email(newPassword.Email)).
|
||||
SetEmail(newPassword.Email).
|
||||
SetHash(newPassword.Hash).
|
||||
SetUsername(newPassword.Username).
|
||||
SetUserID(newPassword.UserID).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update password uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update password commit: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
109
storage/ent/client/refreshtoken.go
Normal file
109
storage/ent/client/refreshtoken.go
Normal file
@ -0,0 +1,109 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
// CreateRefresh saves provided refresh token into the database.
|
||||
func (d *Database) CreateRefresh(refresh storage.RefreshToken) error {
|
||||
_, err := d.client.RefreshToken.Create().
|
||||
SetID(refresh.ID).
|
||||
SetClientID(refresh.ClientID).
|
||||
SetScopes(refresh.Scopes).
|
||||
SetNonce(refresh.Nonce).
|
||||
SetClaimsUserID(refresh.Claims.UserID).
|
||||
SetClaimsEmail(refresh.Claims.Email).
|
||||
SetClaimsEmailVerified(refresh.Claims.EmailVerified).
|
||||
SetClaimsUsername(refresh.Claims.Username).
|
||||
SetClaimsPreferredUsername(refresh.Claims.PreferredUsername).
|
||||
SetClaimsGroups(refresh.Claims.Groups).
|
||||
SetConnectorID(refresh.ConnectorID).
|
||||
SetConnectorData(refresh.ConnectorData).
|
||||
SetToken(refresh.Token).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetLastUsed(refresh.LastUsed.UTC()).
|
||||
SetCreatedAt(refresh.CreatedAt.UTC()).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("create refresh token: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRefreshTokens extracts an array of refresh tokens from the database.
|
||||
func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||
refreshTokens, err := d.client.RefreshToken.Query().All(context.TODO())
|
||||
if err != nil {
|
||||
return nil, convertDBError("list refresh tokens: %w", err)
|
||||
}
|
||||
|
||||
storageRefreshTokens := make([]storage.RefreshToken, 0, len(refreshTokens))
|
||||
for _, r := range refreshTokens {
|
||||
storageRefreshTokens = append(storageRefreshTokens, toStorageRefreshToken(r))
|
||||
}
|
||||
return storageRefreshTokens, nil
|
||||
}
|
||||
|
||||
// GetRefresh extracts a refresh token from the database by id.
|
||||
func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) {
|
||||
refreshToken, err := d.client.RefreshToken.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return storage.RefreshToken{}, convertDBError("get refresh token: %w", err)
|
||||
}
|
||||
return toStorageRefreshToken(refreshToken), nil
|
||||
}
|
||||
|
||||
// DeleteRefresh deletes a refresh token from the database by id.
|
||||
func (d *Database) DeleteRefresh(id string) error {
|
||||
err := d.client.RefreshToken.DeleteOneID(id).Exec(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("delete refresh token: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
|
||||
tx, err := d.client.Tx(context.TODO())
|
||||
if err != nil {
|
||||
return convertDBError("update refresh token tx: %w", err)
|
||||
}
|
||||
|
||||
token, err := tx.RefreshToken.Get(context.TODO(), id)
|
||||
if err != nil {
|
||||
return rollback(tx, "update refresh token database: %w", err)
|
||||
}
|
||||
|
||||
newtToken, err := updater(toStorageRefreshToken(token))
|
||||
if err != nil {
|
||||
return rollback(tx, "update refresh token updating: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.RefreshToken.UpdateOneID(newtToken.ID).
|
||||
SetClientID(newtToken.ClientID).
|
||||
SetScopes(newtToken.Scopes).
|
||||
SetNonce(newtToken.Nonce).
|
||||
SetClaimsUserID(newtToken.Claims.UserID).
|
||||
SetClaimsEmail(newtToken.Claims.Email).
|
||||
SetClaimsEmailVerified(newtToken.Claims.EmailVerified).
|
||||
SetClaimsUsername(newtToken.Claims.Username).
|
||||
SetClaimsPreferredUsername(newtToken.Claims.PreferredUsername).
|
||||
SetClaimsGroups(newtToken.Claims.Groups).
|
||||
SetConnectorID(newtToken.ConnectorID).
|
||||
SetConnectorData(newtToken.ConnectorData).
|
||||
SetToken(newtToken.Token).
|
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetLastUsed(newtToken.LastUsed.UTC()).
|
||||
SetCreatedAt(newtToken.CreatedAt.UTC()).
|
||||
Save(context.TODO())
|
||||
if err != nil {
|
||||
return rollback(tx, "update refresh token uploading: %w", err)
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return rollback(tx, "update refresh token commit: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
167
storage/ent/client/types.go
Normal file
167
storage/ent/client/types.go
Normal file
@ -0,0 +1,167 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db"
|
||||
)
|
||||
|
||||
const keysRowID = "keys"
|
||||
|
||||
func toStorageKeys(keys *db.Keys) storage.Keys {
|
||||
return storage.Keys{
|
||||
SigningKey: &keys.SigningKey,
|
||||
SigningKeyPub: &keys.SigningKeyPub,
|
||||
VerificationKeys: keys.VerificationKeys,
|
||||
NextRotation: keys.NextRotation,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest {
|
||||
return storage.AuthRequest{
|
||||
ID: a.ID,
|
||||
ClientID: a.ClientID,
|
||||
ResponseTypes: a.ResponseTypes,
|
||||
Scopes: a.Scopes,
|
||||
RedirectURI: a.RedirectURI,
|
||||
Nonce: a.Nonce,
|
||||
State: a.State,
|
||||
ForceApprovalPrompt: a.ForceApprovalPrompt,
|
||||
LoggedIn: a.LoggedIn,
|
||||
ConnectorID: a.ConnectorID,
|
||||
ConnectorData: *a.ConnectorData,
|
||||
Expiry: a.Expiry,
|
||||
Claims: storage.Claims{
|
||||
UserID: a.ClaimsUserID,
|
||||
Username: a.ClaimsUsername,
|
||||
PreferredUsername: a.ClaimsPreferredUsername,
|
||||
Email: a.ClaimsEmail,
|
||||
EmailVerified: a.ClaimsEmailVerified,
|
||||
Groups: a.ClaimsGroups,
|
||||
},
|
||||
PKCE: storage.PKCE{
|
||||
CodeChallenge: a.CodeChallenge,
|
||||
CodeChallengeMethod: a.CodeChallengeMethod,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageAuthCode(a *db.AuthCode) storage.AuthCode {
|
||||
return storage.AuthCode{
|
||||
ID: a.ID,
|
||||
ClientID: a.ClientID,
|
||||
Scopes: a.Scopes,
|
||||
RedirectURI: a.RedirectURI,
|
||||
Nonce: a.Nonce,
|
||||
ConnectorID: a.ConnectorID,
|
||||
ConnectorData: *a.ConnectorData,
|
||||
Expiry: a.Expiry,
|
||||
Claims: storage.Claims{
|
||||
UserID: a.ClaimsUserID,
|
||||
Username: a.ClaimsUsername,
|
||||
PreferredUsername: a.ClaimsPreferredUsername,
|
||||
Email: a.ClaimsEmail,
|
||||
EmailVerified: a.ClaimsEmailVerified,
|
||||
Groups: a.ClaimsGroups,
|
||||
},
|
||||
PKCE: storage.PKCE{
|
||||
CodeChallenge: a.CodeChallenge,
|
||||
CodeChallengeMethod: a.CodeChallengeMethod,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageClient(c *db.OAuth2Client) storage.Client {
|
||||
return storage.Client{
|
||||
ID: c.ID,
|
||||
Secret: c.Secret,
|
||||
RedirectURIs: c.RedirectUris,
|
||||
TrustedPeers: c.TrustedPeers,
|
||||
Public: c.Public,
|
||||
Name: c.Name,
|
||||
LogoURL: c.LogoURL,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageConnector(c *db.Connector) storage.Connector {
|
||||
return storage.Connector{
|
||||
ID: c.ID,
|
||||
Type: c.Type,
|
||||
Name: c.Name,
|
||||
Config: c.Config,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageOfflineSession(o *db.OfflineSession) storage.OfflineSessions {
|
||||
s := storage.OfflineSessions{
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
ConnectorData: *o.ConnectorData,
|
||||
}
|
||||
|
||||
if o.Refresh != nil {
|
||||
if err := json.Unmarshal(o.Refresh, &s.Refresh); err != nil {
|
||||
// Correctness of json structure if guaranteed on uploading
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
// Server code assumes this will be non-nil.
|
||||
s.Refresh = make(map[string]*storage.RefreshTokenRef)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func toStorageRefreshToken(r *db.RefreshToken) storage.RefreshToken {
|
||||
return storage.RefreshToken{
|
||||
ID: r.ID,
|
||||
Token: r.Token,
|
||||
CreatedAt: r.CreatedAt,
|
||||
LastUsed: r.LastUsed,
|
||||
ClientID: r.ClientID,
|
||||
ConnectorID: r.ConnectorID,
|
||||
ConnectorData: *r.ConnectorData,
|
||||
Scopes: r.Scopes,
|
||||
Nonce: r.Nonce,
|
||||
Claims: storage.Claims{
|
||||
UserID: r.ClaimsUserID,
|
||||
Username: r.ClaimsUsername,
|
||||
PreferredUsername: r.ClaimsPreferredUsername,
|
||||
Email: r.ClaimsEmail,
|
||||
EmailVerified: r.ClaimsEmailVerified,
|
||||
Groups: r.ClaimsGroups,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func toStoragePassword(p *db.Password) storage.Password {
|
||||
return storage.Password{
|
||||
Email: p.Email,
|
||||
Hash: p.Hash,
|
||||
Username: p.Username,
|
||||
UserID: p.UserID,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageDeviceRequest(r *db.DeviceRequest) storage.DeviceRequest {
|
||||
return storage.DeviceRequest{
|
||||
UserCode: strings.ToUpper(r.UserCode),
|
||||
DeviceCode: r.DeviceCode,
|
||||
ClientID: r.ClientID,
|
||||
ClientSecret: r.ClientSecret,
|
||||
Scopes: r.Scopes,
|
||||
Expiry: r.Expiry,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken {
|
||||
return storage.DeviceToken{
|
||||
DeviceCode: t.DeviceCode,
|
||||
Status: t.Status,
|
||||
Token: string(*t.Token),
|
||||
Expiry: t.Expiry,
|
||||
LastRequestTime: t.LastRequest,
|
||||
PollIntervalSeconds: t.PollInterval,
|
||||
}
|
||||
}
|
44
storage/ent/client/utils.go
Normal file
44
storage/ent/client/utils.go
Normal file
@ -0,0 +1,44 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/db"
|
||||
)
|
||||
|
||||
func rollback(tx *db.Tx, t string, err error) error {
|
||||
rerr := tx.Rollback()
|
||||
err = convertDBError(t, err)
|
||||
|
||||
if rerr == nil {
|
||||
return err
|
||||
}
|
||||
return errors.Wrapf(err, "rolling back transaction: %v", rerr)
|
||||
}
|
||||
|
||||
func convertDBError(t string, err error) error {
|
||||
if db.IsNotFound(err) {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
|
||||
if db.IsConstraintError(err) {
|
||||
return storage.ErrAlreadyExists
|
||||
}
|
||||
|
||||
return fmt.Errorf(t, err)
|
||||
}
|
||||
|
||||
// compose hashed id from user and connection id to use it as primary key
|
||||
// ent doesn't support multi-key primary yet
|
||||
// https://github.com/facebook/ent/issues/400
|
||||
func offlineSessionID(userID string, connID string, hasher func() hash.Hash) string {
|
||||
h := hasher()
|
||||
|
||||
h.Write([]byte(userID))
|
||||
h.Write([]byte(connID))
|
||||
return fmt.Sprintf("%x", h.Sum(nil))
|
||||
}
|
3
storage/ent/generate.go
Normal file
3
storage/ent/generate.go
Normal file
@ -0,0 +1,3 @@
|
||||
package ent
|
||||
|
||||
//go:generate go run github.com/facebook/ent/cmd/entc generate ./schema --target ./db
|
89
storage/ent/schema/authcode.go
Normal file
89
storage/ent/schema/authcode.go
Normal file
@ -0,0 +1,89 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/facebook/ent"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
)
|
||||
|
||||
/* Original SQL table:
|
||||
create table auth_code
|
||||
(
|
||||
id text not null primary key,
|
||||
client_id text not null,
|
||||
scopes blob not null,
|
||||
nonce text not null,
|
||||
redirect_uri text not null,
|
||||
claims_user_id text not null,
|
||||
claims_username text not null,
|
||||
claims_email text not null,
|
||||
claims_email_verified integer not null,
|
||||
claims_groups blob not null,
|
||||
connector_id text not null,
|
||||
connector_data blob,
|
||||
expiry timestamp not null,
|
||||
claims_preferred_username text default '' not null,
|
||||
code_challenge text default '' not null,
|
||||
code_challenge_method text default '' not null
|
||||
);
|
||||
*/
|
||||
|
||||
// AuthCode holds the schema definition for the AuthCode entity.
|
||||
type AuthCode struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the AuthCode.
|
||||
func (AuthCode) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Text("id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.Text("client_id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.JSON("scopes", []string{}).
|
||||
Optional(),
|
||||
field.Text("nonce").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("redirect_uri").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
|
||||
field.Text("claims_user_id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("claims_username").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("claims_email").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Bool("claims_email_verified"),
|
||||
field.JSON("claims_groups", []string{}).
|
||||
Optional(),
|
||||
field.Text("claims_preferred_username").
|
||||
SchemaType(textSchema).
|
||||
Default(""),
|
||||
|
||||
field.Text("connector_id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Bytes("connector_data").
|
||||
Nillable().
|
||||
Optional(),
|
||||
field.Time("expiry"),
|
||||
field.Text("code_challenge").
|
||||
SchemaType(textSchema).
|
||||
Default(""),
|
||||
field.Text("code_challenge_method").
|
||||
SchemaType(textSchema).
|
||||
Default(""),
|
||||
}
|
||||
}
|
||||
|
||||
// Edges of the AuthCode.
|
||||
func (AuthCode) Edges() []ent.Edge {
|
||||
return []ent.Edge{}
|
||||
}
|
94
storage/ent/schema/authrequest.go
Normal file
94
storage/ent/schema/authrequest.go
Normal file
@ -0,0 +1,94 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/facebook/ent"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
)
|
||||
|
||||
/* Original SQL table:
|
||||
create table auth_request
|
||||
(
|
||||
id text not null primary key,
|
||||
client_id text not null,
|
||||
response_types blob not null,
|
||||
scopes blob not null,
|
||||
redirect_uri text not null,
|
||||
nonce text not null,
|
||||
state text not null,
|
||||
force_approval_prompt integer not null,
|
||||
logged_in integer not null,
|
||||
claims_user_id text not null,
|
||||
claims_username text not null,
|
||||
claims_email text not null,
|
||||
claims_email_verified integer not null,
|
||||
claims_groups blob not null,
|
||||
connector_id text not null,
|
||||
connector_data blob,
|
||||
expiry timestamp not null,
|
||||
claims_preferred_username text default '' not null,
|
||||
code_challenge text default '' not null,
|
||||
code_challenge_method text default '' not null
|
||||
);
|
||||
*/
|
||||
|
||||
// AuthRequest holds the schema definition for the AuthRequest entity.
|
||||
type AuthRequest struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the AuthRequest.
|
||||
func (AuthRequest) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Text("id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.Text("client_id").
|
||||
SchemaType(textSchema),
|
||||
field.JSON("scopes", []string{}).
|
||||
Optional(),
|
||||
field.JSON("response_types", []string{}).
|
||||
Optional(),
|
||||
field.Text("redirect_uri").
|
||||
SchemaType(textSchema),
|
||||
field.Text("nonce").
|
||||
SchemaType(textSchema),
|
||||
field.Text("state").
|
||||
SchemaType(textSchema),
|
||||
|
||||
field.Bool("force_approval_prompt"),
|
||||
field.Bool("logged_in"),
|
||||
|
||||
field.Text("claims_user_id").
|
||||
SchemaType(textSchema),
|
||||
field.Text("claims_username").
|
||||
SchemaType(textSchema),
|
||||
field.Text("claims_email").
|
||||
SchemaType(textSchema),
|
||||
field.Bool("claims_email_verified"),
|
||||
field.JSON("claims_groups", []string{}).
|
||||
Optional(),
|
||||
field.Text("claims_preferred_username").
|
||||
SchemaType(textSchema).
|
||||
Default(""),
|
||||
|
||||
field.Text("connector_id").
|
||||
SchemaType(textSchema),
|
||||
field.Bytes("connector_data").
|
||||
Nillable().
|
||||
Optional(),
|
||||
field.Time("expiry"),
|
||||
|
||||
field.Text("code_challenge").
|
||||
SchemaType(textSchema).
|
||||
Default(""),
|
||||
field.Text("code_challenge_method").
|
||||
SchemaType(textSchema).
|
||||
Default(""),
|
||||
}
|
||||
}
|
||||
|
||||
// Edges of the AuthRequest.
|
||||
func (AuthRequest) Edges() []ent.Edge {
|
||||
return []ent.Edge{}
|
||||
}
|
53
storage/ent/schema/client.go
Normal file
53
storage/ent/schema/client.go
Normal file
@ -0,0 +1,53 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/facebook/ent"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
)
|
||||
|
||||
/* Original SQL table:
|
||||
create table client
|
||||
(
|
||||
id text not null primary key,
|
||||
secret text not null,
|
||||
redirect_uris blob not null,
|
||||
trusted_peers blob not null,
|
||||
public integer not null,
|
||||
name text not null,
|
||||
logo_url text not null
|
||||
);
|
||||
*/
|
||||
|
||||
// OAuth2Client holds the schema definition for the Client entity.
|
||||
type OAuth2Client struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the OAuth2Client.
|
||||
func (OAuth2Client) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Text("id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.Text("secret").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.JSON("redirect_uris", []string{}).
|
||||
Optional(),
|
||||
field.JSON("trusted_peers", []string{}).
|
||||
Optional(),
|
||||
field.Bool("public"),
|
||||
field.Text("name").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("logo_url").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
}
|
||||
}
|
||||
|
||||
// Edges of the OAuth2Client.
|
||||
func (OAuth2Client) Edges() []ent.Edge {
|
||||
return []ent.Edge{}
|
||||
}
|
46
storage/ent/schema/connector.go
Normal file
46
storage/ent/schema/connector.go
Normal file
@ -0,0 +1,46 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/facebook/ent"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
)
|
||||
|
||||
/* Original SQL table:
|
||||
create table connector
|
||||
(
|
||||
id text not null primary key,
|
||||
type text not null,
|
||||
name text not null,
|
||||
resource_version text not null,
|
||||
config blob
|
||||
);
|
||||
*/
|
||||
|
||||
// Connector holds the schema definition for the Client entity.
|
||||
type Connector struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the Connector.
|
||||
func (Connector) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Text("id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.Text("type").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("name").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("resource_version").
|
||||
SchemaType(textSchema),
|
||||
field.Bytes("config"),
|
||||
}
|
||||
}
|
||||
|
||||
// Edges of the Connector.
|
||||
func (Connector) Edges() []ent.Edge {
|
||||
return []ent.Edge{}
|
||||
}
|
50
storage/ent/schema/devicerequest.go
Normal file
50
storage/ent/schema/devicerequest.go
Normal file
@ -0,0 +1,50 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/facebook/ent"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
)
|
||||
|
||||
/* Original SQL table:
|
||||
create table device_request
|
||||
(
|
||||
user_code text not null primary key,
|
||||
device_code text not null,
|
||||
client_id text not null,
|
||||
client_secret text,
|
||||
scopes blob not null,
|
||||
expiry timestamp not null
|
||||
);
|
||||
*/
|
||||
|
||||
// DeviceRequest holds the schema definition for the DeviceRequest entity.
|
||||
type DeviceRequest struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the DeviceRequest.
|
||||
func (DeviceRequest) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Text("user_code").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.Text("device_code").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("client_id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("client_secret").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.JSON("scopes", []string{}).
|
||||
Optional(),
|
||||
field.Time("expiry"),
|
||||
}
|
||||
}
|
||||
|
||||
// Edges of the DeviceRequest.
|
||||
func (DeviceRequest) Edges() []ent.Edge {
|
||||
return []ent.Edge{}
|
||||
}
|
45
storage/ent/schema/devicetoken.go
Normal file
45
storage/ent/schema/devicetoken.go
Normal file
@ -0,0 +1,45 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/facebook/ent"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
)
|
||||
|
||||
/* Original SQL table:
|
||||
create table device_token
|
||||
(
|
||||
device_code text not null primary key,
|
||||
status text not null,
|
||||
token blob,
|
||||
expiry timestamp not null,
|
||||
last_request timestamp not null,
|
||||
poll_interval integer not null
|
||||
);
|
||||
*/
|
||||
|
||||
// DeviceToken holds the schema definition for the DeviceToken entity.
|
||||
type DeviceToken struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the DeviceToken.
|
||||
func (DeviceToken) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Text("device_code").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.Text("status").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Bytes("token").Nillable().Optional(),
|
||||
field.Time("expiry"),
|
||||
field.Time("last_request"),
|
||||
field.Int("poll_interval"),
|
||||
}
|
||||
}
|
||||
|
||||
// Edges of the DeviceToken.
|
||||
func (DeviceToken) Edges() []ent.Edge {
|
||||
return []ent.Edge{}
|
||||
}
|
44
storage/ent/schema/keys.go
Normal file
44
storage/ent/schema/keys.go
Normal file
@ -0,0 +1,44 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/facebook/ent"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
)
|
||||
|
||||
/* Original SQL table:
|
||||
create table keys
|
||||
(
|
||||
id text not null primary key,
|
||||
verification_keys blob not null,
|
||||
signing_key blob not null,
|
||||
signing_key_pub blob not null,
|
||||
next_rotation timestamp not null
|
||||
);
|
||||
*/
|
||||
|
||||
// Keys holds the schema definition for the Keys entity.
|
||||
type Keys struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the Keys.
|
||||
func (Keys) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Text("id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.JSON("verification_keys", []storage.VerificationKey{}),
|
||||
field.JSON("signing_key", jose.JSONWebKey{}),
|
||||
field.JSON("signing_key_pub", jose.JSONWebKey{}),
|
||||
field.Time("next_rotation"),
|
||||
}
|
||||
}
|
||||
|
||||
// Edges of the Keys.
|
||||
func (Keys) Edges() []ent.Edge {
|
||||
return []ent.Edge{}
|
||||
}
|
46
storage/ent/schema/offlinesession.go
Normal file
46
storage/ent/schema/offlinesession.go
Normal file
@ -0,0 +1,46 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/facebook/ent"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
)
|
||||
|
||||
/* Original SQL table:
|
||||
create table offline_session
|
||||
(
|
||||
user_id text not null,
|
||||
conn_id text not null,
|
||||
refresh blob not null,
|
||||
connector_data blob,
|
||||
primary key (user_id, conn_id)
|
||||
);
|
||||
*/
|
||||
|
||||
// OfflineSession holds the schema definition for the OfflineSession entity.
|
||||
type OfflineSession struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the OfflineSession.
|
||||
func (OfflineSession) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
// Using id field here because it's impossible to create multi-key primary yet
|
||||
field.Text("id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.Text("user_id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("conn_id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Bytes("refresh"),
|
||||
field.Bytes("connector_data").Nillable().Optional(),
|
||||
}
|
||||
}
|
||||
|
||||
// Edges of the OfflineSession.
|
||||
func (OfflineSession) Edges() []ent.Edge {
|
||||
return []ent.Edge{}
|
||||
}
|
44
storage/ent/schema/password.go
Normal file
44
storage/ent/schema/password.go
Normal file
@ -0,0 +1,44 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/facebook/ent"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
)
|
||||
|
||||
/* Original SQL table:
|
||||
create table password
|
||||
(
|
||||
email text not null primary key,
|
||||
hash blob not null,
|
||||
username text not null,
|
||||
user_id text not null
|
||||
);
|
||||
*/
|
||||
|
||||
// Password holds the schema definition for the Password entity.
|
||||
type Password struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the Password.
|
||||
func (Password) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Text("email").
|
||||
SchemaType(textSchema).
|
||||
StorageKey("email"). // use email as ID field to make querying easier
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.Bytes("hash"),
|
||||
field.Text("username").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("user_id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
}
|
||||
}
|
||||
|
||||
// Edges of the Password.
|
||||
func (Password) Edges() []ent.Edge {
|
||||
return []ent.Edge{}
|
||||
}
|
89
storage/ent/schema/refreshtoken.go
Normal file
89
storage/ent/schema/refreshtoken.go
Normal file
@ -0,0 +1,89 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/facebook/ent"
|
||||
"github.com/facebook/ent/schema/field"
|
||||
)
|
||||
|
||||
/* Original SQL table:
|
||||
create table refresh_token
|
||||
(
|
||||
id text not null primary key,
|
||||
client_id text not null,
|
||||
scopes blob not null,
|
||||
nonce text not null,
|
||||
claims_user_id text not null,
|
||||
claims_username text not null,
|
||||
claims_email text not null,
|
||||
claims_email_verified integer not null,
|
||||
claims_groups blob not null,
|
||||
connector_id text not null,
|
||||
connector_data blob,
|
||||
token text default '' not null,
|
||||
created_at timestamp default '0001-01-01 00:00:00 UTC' not null,
|
||||
last_used timestamp default '0001-01-01 00:00:00 UTC' not null,
|
||||
claims_preferred_username text default '' not null
|
||||
);
|
||||
*/
|
||||
|
||||
// RefreshToken holds the schema definition for the RefreshToken entity.
|
||||
type RefreshToken struct {
|
||||
ent.Schema
|
||||
}
|
||||
|
||||
// Fields of the RefreshToken.
|
||||
func (RefreshToken) Fields() []ent.Field {
|
||||
return []ent.Field{
|
||||
field.Text("id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty().
|
||||
Unique(),
|
||||
field.Text("client_id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.JSON("scopes", []string{}).
|
||||
Optional(),
|
||||
field.Text("nonce").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
|
||||
field.Text("claims_user_id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("claims_username").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Text("claims_email").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Bool("claims_email_verified"),
|
||||
field.JSON("claims_groups", []string{}).
|
||||
Optional(),
|
||||
field.Text("claims_preferred_username").
|
||||
SchemaType(textSchema).
|
||||
Default(""),
|
||||
|
||||
field.Text("connector_id").
|
||||
SchemaType(textSchema).
|
||||
NotEmpty(),
|
||||
field.Bytes("connector_data").
|
||||
Nillable().
|
||||
Optional(),
|
||||
|
||||
field.Text("token").
|
||||
SchemaType(textSchema).
|
||||
Default(""),
|
||||
|
||||
field.Time("created_at").
|
||||
Default(time.Now),
|
||||
field.Time("last_used").
|
||||
Default(time.Now),
|
||||
}
|
||||
}
|
||||
|
||||
// Edges of the RefreshToken.
|
||||
func (RefreshToken) Edges() []ent.Edge {
|
||||
return []ent.Edge{}
|
||||
}
|
9
storage/ent/schema/types.go
Normal file
9
storage/ent/schema/types.go
Normal file
@ -0,0 +1,9 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"github.com/facebook/ent/dialect"
|
||||
)
|
||||
|
||||
var textSchema = map[string]string{
|
||||
dialect.SQLite: "text",
|
||||
}
|
65
storage/ent/sqlite.go
Normal file
65
storage/ent/sqlite.go
Normal file
@ -0,0 +1,65 @@
|
||||
package ent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"strings"
|
||||
|
||||
"github.com/facebook/ent/dialect/sql"
|
||||
|
||||
// Register sqlite driver.
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
|
||||
"github.com/dexidp/dex/pkg/log"
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/ent/client"
|
||||
"github.com/dexidp/dex/storage/ent/db"
|
||||
)
|
||||
|
||||
// SQLite3 options for creating an SQL db.
|
||||
type SQLite3 struct {
|
||||
File string `json:"file"`
|
||||
}
|
||||
|
||||
// Open always returns a new in sqlite3 storage.
|
||||
func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) {
|
||||
logger.Debug("experimental ent-based storage driver is enabled")
|
||||
|
||||
// Implicitly set foreign_keys pragma to "on" because it is required by ent
|
||||
s.File = addFK(s.File)
|
||||
|
||||
drv, err := sql.Open("sqlite3", s.File)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pool := drv.DB()
|
||||
if s.File == ":memory:" {
|
||||
// sqlite3 uses file locks to coordinate concurrent access. In memory
|
||||
// doesn't support this, so limit the number of connections to 1.
|
||||
pool.SetMaxOpenConns(1)
|
||||
}
|
||||
|
||||
databaseClient := client.NewDatabase(
|
||||
client.WithClient(db.NewClient(db.Driver(drv))),
|
||||
client.WithHasher(sha256.New),
|
||||
)
|
||||
|
||||
if err := databaseClient.Schema().Create(context.TODO()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return databaseClient, nil
|
||||
}
|
||||
|
||||
func addFK(dsn string) string {
|
||||
if strings.Contains(dsn, "_fk") {
|
||||
return dsn
|
||||
}
|
||||
|
||||
delim := "?"
|
||||
if strings.Contains(dsn, "?") {
|
||||
delim = "&"
|
||||
}
|
||||
return dsn + delim + "_fk=1"
|
||||
}
|
31
storage/ent/sqlite_test.go
Normal file
31
storage/ent/sqlite_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package ent
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/dexidp/dex/storage"
|
||||
"github.com/dexidp/dex/storage/conformance"
|
||||
)
|
||||
|
||||
func newStorage() storage.Storage {
|
||||
logger := &logrus.Logger{
|
||||
Out: os.Stderr,
|
||||
Formatter: &logrus.TextFormatter{DisableColors: true},
|
||||
Level: logrus.DebugLevel,
|
||||
}
|
||||
|
||||
cfg := SQLite3{File: ":memory:"}
|
||||
s, err := cfg.Open(logger)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func TestSQLite3(t *testing.T) {
|
||||
conformance.RunTests(t, newStorage)
|
||||
conformance.RunTransactionTests(t, newStorage)
|
||||
}
|
Reference in New Issue
Block a user