feat: Add ent-based sqlite3 storage
Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
		
							
								
								
									
										52
									
								
								storage/ent/client/authcode.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								storage/ent/client/authcode.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreateAuthCode saves provided auth code into the database.
 | 
			
		||||
func (d *Database) CreateAuthCode(code storage.AuthCode) error {
 | 
			
		||||
	_, err := d.client.AuthCode.Create().
 | 
			
		||||
		SetID(code.ID).
 | 
			
		||||
		SetClientID(code.ClientID).
 | 
			
		||||
		SetScopes(code.Scopes).
 | 
			
		||||
		SetRedirectURI(code.RedirectURI).
 | 
			
		||||
		SetNonce(code.Nonce).
 | 
			
		||||
		SetClaimsUserID(code.Claims.UserID).
 | 
			
		||||
		SetClaimsEmail(code.Claims.Email).
 | 
			
		||||
		SetClaimsEmailVerified(code.Claims.EmailVerified).
 | 
			
		||||
		SetClaimsUsername(code.Claims.Username).
 | 
			
		||||
		SetClaimsPreferredUsername(code.Claims.PreferredUsername).
 | 
			
		||||
		SetClaimsGroups(code.Claims.Groups).
 | 
			
		||||
		SetCodeChallenge(code.PKCE.CodeChallenge).
 | 
			
		||||
		SetCodeChallengeMethod(code.PKCE.CodeChallengeMethod).
 | 
			
		||||
		// Save utc time into database because ent doesn't support comparing dates with different timezones
 | 
			
		||||
		SetExpiry(code.Expiry.UTC()).
 | 
			
		||||
		SetConnectorID(code.ConnectorID).
 | 
			
		||||
		SetConnectorData(code.ConnectorData).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("create auth code: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAuthCode extracts an auth code from the database by id.
 | 
			
		||||
func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) {
 | 
			
		||||
	authCode, err := d.client.AuthCode.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.AuthCode{}, convertDBError("get auth code: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return toStorageAuthCode(authCode), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteAuthCode deletes an auth code from the database by id.
 | 
			
		||||
func (d *Database) DeleteAuthCode(id string) error {
 | 
			
		||||
	err := d.client.AuthCode.DeleteOneID(id).Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("delete auth code: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										107
									
								
								storage/ent/client/authrequest.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								storage/ent/client/authrequest.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,107 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreateAuthRequest saves provided auth request into the database.
 | 
			
		||||
func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error {
 | 
			
		||||
	_, err := d.client.AuthRequest.Create().
 | 
			
		||||
		SetID(authRequest.ID).
 | 
			
		||||
		SetClientID(authRequest.ClientID).
 | 
			
		||||
		SetScopes(authRequest.Scopes).
 | 
			
		||||
		SetResponseTypes(authRequest.ResponseTypes).
 | 
			
		||||
		SetRedirectURI(authRequest.RedirectURI).
 | 
			
		||||
		SetState(authRequest.State).
 | 
			
		||||
		SetNonce(authRequest.Nonce).
 | 
			
		||||
		SetForceApprovalPrompt(authRequest.ForceApprovalPrompt).
 | 
			
		||||
		SetLoggedIn(authRequest.LoggedIn).
 | 
			
		||||
		SetClaimsUserID(authRequest.Claims.UserID).
 | 
			
		||||
		SetClaimsEmail(authRequest.Claims.Email).
 | 
			
		||||
		SetClaimsEmailVerified(authRequest.Claims.EmailVerified).
 | 
			
		||||
		SetClaimsUsername(authRequest.Claims.Username).
 | 
			
		||||
		SetClaimsPreferredUsername(authRequest.Claims.PreferredUsername).
 | 
			
		||||
		SetClaimsGroups(authRequest.Claims.Groups).
 | 
			
		||||
		SetCodeChallenge(authRequest.PKCE.CodeChallenge).
 | 
			
		||||
		SetCodeChallengeMethod(authRequest.PKCE.CodeChallengeMethod).
 | 
			
		||||
		// Save utc time into database because ent doesn't support comparing dates with different timezones
 | 
			
		||||
		SetExpiry(authRequest.Expiry.UTC()).
 | 
			
		||||
		SetConnectorID(authRequest.ConnectorID).
 | 
			
		||||
		SetConnectorData(authRequest.ConnectorData).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("create auth request: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAuthRequest extracts an auth request from the database by id.
 | 
			
		||||
func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) {
 | 
			
		||||
	authRequest, err := d.client.AuthRequest.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.AuthRequest{}, convertDBError("get auth request: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return toStorageAuthRequest(authRequest), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteAuthRequest deletes an auth request from the database by id.
 | 
			
		||||
func (d *Database) DeleteAuthRequest(id string) error {
 | 
			
		||||
	err := d.client.AuthRequest.DeleteOneID(id).Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("delete auth request: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateAuthRequest changes an auth request by id using an updater function and saves it to the database.
 | 
			
		||||
func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error {
 | 
			
		||||
	tx, err := d.client.Tx(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("update auth request tx: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	authRequest, err := tx.AuthRequest.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update auth request database: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newAuthRequest, err := updater(toStorageAuthRequest(authRequest))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update auth request updating: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = tx.AuthRequest.UpdateOneID(newAuthRequest.ID).
 | 
			
		||||
		SetClientID(newAuthRequest.ClientID).
 | 
			
		||||
		SetScopes(newAuthRequest.Scopes).
 | 
			
		||||
		SetResponseTypes(newAuthRequest.ResponseTypes).
 | 
			
		||||
		SetRedirectURI(newAuthRequest.RedirectURI).
 | 
			
		||||
		SetState(newAuthRequest.State).
 | 
			
		||||
		SetNonce(newAuthRequest.Nonce).
 | 
			
		||||
		SetForceApprovalPrompt(newAuthRequest.ForceApprovalPrompt).
 | 
			
		||||
		SetLoggedIn(newAuthRequest.LoggedIn).
 | 
			
		||||
		SetClaimsUserID(newAuthRequest.Claims.UserID).
 | 
			
		||||
		SetClaimsEmail(newAuthRequest.Claims.Email).
 | 
			
		||||
		SetClaimsEmailVerified(newAuthRequest.Claims.EmailVerified).
 | 
			
		||||
		SetClaimsUsername(newAuthRequest.Claims.Username).
 | 
			
		||||
		SetClaimsPreferredUsername(newAuthRequest.Claims.PreferredUsername).
 | 
			
		||||
		SetClaimsGroups(newAuthRequest.Claims.Groups).
 | 
			
		||||
		SetCodeChallenge(newAuthRequest.PKCE.CodeChallenge).
 | 
			
		||||
		SetCodeChallengeMethod(newAuthRequest.PKCE.CodeChallengeMethod).
 | 
			
		||||
		// Save utc time into database because ent doesn't support comparing dates with different timezones
 | 
			
		||||
		SetExpiry(newAuthRequest.Expiry.UTC()).
 | 
			
		||||
		SetConnectorID(newAuthRequest.ConnectorID).
 | 
			
		||||
		SetConnectorData(newAuthRequest.ConnectorData).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update auth request uploading: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.Commit(); err != nil {
 | 
			
		||||
		return rollback(tx, "update auth request commit: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										92
									
								
								storage/ent/client/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								storage/ent/client/client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,92 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreateClient saves provided oauth2 client settings into the database.
 | 
			
		||||
func (d *Database) CreateClient(client storage.Client) error {
 | 
			
		||||
	_, err := d.client.OAuth2Client.Create().
 | 
			
		||||
		SetID(client.ID).
 | 
			
		||||
		SetName(client.Name).
 | 
			
		||||
		SetSecret(client.Secret).
 | 
			
		||||
		SetPublic(client.Public).
 | 
			
		||||
		SetLogoURL(client.LogoURL).
 | 
			
		||||
		SetRedirectUris(client.RedirectURIs).
 | 
			
		||||
		SetTrustedPeers(client.TrustedPeers).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("create oauth2 client: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListClients extracts an array of oauth2 clients from the database.
 | 
			
		||||
func (d *Database) ListClients() ([]storage.Client, error) {
 | 
			
		||||
	clients, err := d.client.OAuth2Client.Query().All(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, convertDBError("list clients: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	storageClients := make([]storage.Client, 0, len(clients))
 | 
			
		||||
	for _, c := range clients {
 | 
			
		||||
		storageClients = append(storageClients, toStorageClient(c))
 | 
			
		||||
	}
 | 
			
		||||
	return storageClients, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetClient extracts an oauth2 client from the database by id.
 | 
			
		||||
func (d *Database) GetClient(id string) (storage.Client, error) {
 | 
			
		||||
	client, err := d.client.OAuth2Client.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.Client{}, convertDBError("get client: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return toStorageClient(client), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteClient deletes an oauth2 client from the database by id.
 | 
			
		||||
func (d *Database) DeleteClient(id string) error {
 | 
			
		||||
	err := d.client.OAuth2Client.DeleteOneID(id).Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("delete client: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateClient changes an oauth2 client by id using an updater function and saves it to the database.
 | 
			
		||||
func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error {
 | 
			
		||||
	tx, err := d.client.Tx(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("update client tx: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client, err := tx.OAuth2Client.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update client database: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newClient, err := updater(toStorageClient(client))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update client updating: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = tx.OAuth2Client.UpdateOneID(newClient.ID).
 | 
			
		||||
		SetName(newClient.Name).
 | 
			
		||||
		SetSecret(newClient.Secret).
 | 
			
		||||
		SetPublic(newClient.Public).
 | 
			
		||||
		SetLogoURL(newClient.LogoURL).
 | 
			
		||||
		SetRedirectUris(newClient.RedirectURIs).
 | 
			
		||||
		SetTrustedPeers(newClient.TrustedPeers).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update client uploading: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.Commit(); err != nil {
 | 
			
		||||
		return rollback(tx, "update auth request commit: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										88
									
								
								storage/ent/client/connector.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								storage/ent/client/connector.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,88 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreateConnector saves a connector into the database.
 | 
			
		||||
func (d *Database) CreateConnector(connector storage.Connector) error {
 | 
			
		||||
	_, err := d.client.Connector.Create().
 | 
			
		||||
		SetID(connector.ID).
 | 
			
		||||
		SetName(connector.Name).
 | 
			
		||||
		SetType(connector.Type).
 | 
			
		||||
		SetResourceVersion(connector.ResourceVersion).
 | 
			
		||||
		SetConfig(connector.Config).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("create connector: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListConnectors extracts an array of connectors from the database.
 | 
			
		||||
func (d *Database) ListConnectors() ([]storage.Connector, error) {
 | 
			
		||||
	connectors, err := d.client.Connector.Query().All(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, convertDBError("list connectors: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	storageConnectors := make([]storage.Connector, 0, len(connectors))
 | 
			
		||||
	for _, c := range connectors {
 | 
			
		||||
		storageConnectors = append(storageConnectors, toStorageConnector(c))
 | 
			
		||||
	}
 | 
			
		||||
	return storageConnectors, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetConnector extracts a connector from the database by id.
 | 
			
		||||
func (d *Database) GetConnector(id string) (storage.Connector, error) {
 | 
			
		||||
	connector, err := d.client.Connector.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.Connector{}, convertDBError("get connector: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return toStorageConnector(connector), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteConnector deletes a connector from the database by id.
 | 
			
		||||
func (d *Database) DeleteConnector(id string) error {
 | 
			
		||||
	err := d.client.Connector.DeleteOneID(id).Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("delete connector: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateConnector changes a connector by id using an updater function and saves it to the database.
 | 
			
		||||
func (d *Database) UpdateConnector(id string, updater func(old storage.Connector) (storage.Connector, error)) error {
 | 
			
		||||
	tx, err := d.client.Tx(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("update connector tx: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	connector, err := tx.Connector.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update connector database: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newConnector, err := updater(toStorageConnector(connector))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update connector updating: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = tx.Connector.UpdateOneID(newConnector.ID).
 | 
			
		||||
		SetName(newConnector.Name).
 | 
			
		||||
		SetType(newConnector.Type).
 | 
			
		||||
		SetResourceVersion(newConnector.ResourceVersion).
 | 
			
		||||
		SetConfig(newConnector.Config).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update connector uploading: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.Commit(); err != nil {
 | 
			
		||||
		return rollback(tx, "update connector commit: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										36
									
								
								storage/ent/client/devicerequest.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								storage/ent/client/devicerequest.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db/devicerequest"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreateDeviceRequest saves provided device request into the database.
 | 
			
		||||
func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error {
 | 
			
		||||
	_, err := d.client.DeviceRequest.Create().
 | 
			
		||||
		SetClientID(request.ClientID).
 | 
			
		||||
		SetClientSecret(request.ClientSecret).
 | 
			
		||||
		SetScopes(request.Scopes).
 | 
			
		||||
		SetUserCode(request.UserCode).
 | 
			
		||||
		SetDeviceCode(request.DeviceCode).
 | 
			
		||||
		// Save utc time into database because ent doesn't support comparing dates with different timezones
 | 
			
		||||
		SetExpiry(request.Expiry.UTC()).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("create device request: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetDeviceRequest extracts a device request from the database by user code.
 | 
			
		||||
func (d *Database) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) {
 | 
			
		||||
	deviceRequest, err := d.client.DeviceRequest.Query().
 | 
			
		||||
		Where(devicerequest.UserCode(userCode)).
 | 
			
		||||
		Only(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.DeviceRequest{}, convertDBError("get device request: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return toStorageDeviceRequest(deviceRequest), nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										76
									
								
								storage/ent/client/devicetoken.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								storage/ent/client/devicetoken.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,76 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db/devicetoken"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreateDeviceToken saves provided token into the database.
 | 
			
		||||
func (d *Database) CreateDeviceToken(token storage.DeviceToken) error {
 | 
			
		||||
	_, err := d.client.DeviceToken.Create().
 | 
			
		||||
		SetDeviceCode(token.DeviceCode).
 | 
			
		||||
		SetToken([]byte(token.Token)).
 | 
			
		||||
		SetPollInterval(token.PollIntervalSeconds).
 | 
			
		||||
		// Save utc time into database because ent doesn't support comparing dates with different timezones
 | 
			
		||||
		SetExpiry(token.Expiry.UTC()).
 | 
			
		||||
		SetLastRequest(token.LastRequestTime.UTC()).
 | 
			
		||||
		SetStatus(token.Status).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("create device token: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetDeviceToken extracts a token from the database by device code.
 | 
			
		||||
func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) {
 | 
			
		||||
	deviceToken, err := d.client.DeviceToken.Query().
 | 
			
		||||
		Where(devicetoken.DeviceCode(deviceCode)).
 | 
			
		||||
		Only(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.DeviceToken{}, convertDBError("get device token: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return toStorageDeviceToken(deviceToken), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateDeviceToken changes a token by device code using an updater function and saves it to the database.
 | 
			
		||||
func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error {
 | 
			
		||||
	tx, err := d.client.Tx(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("update device token tx: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token, err := tx.DeviceToken.Query().
 | 
			
		||||
		Where(devicetoken.DeviceCode(deviceCode)).
 | 
			
		||||
		Only(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update device token database: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newToken, err := updater(toStorageDeviceToken(token))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update device token updating: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = tx.DeviceToken.Update().
 | 
			
		||||
		Where(devicetoken.DeviceCode(newToken.DeviceCode)).
 | 
			
		||||
		SetDeviceCode(newToken.DeviceCode).
 | 
			
		||||
		SetToken([]byte(newToken.Token)).
 | 
			
		||||
		SetPollInterval(newToken.PollIntervalSeconds).
 | 
			
		||||
		// Save utc time into database because ent doesn't support comparing dates with different timezones
 | 
			
		||||
		SetExpiry(newToken.Expiry.UTC()).
 | 
			
		||||
		SetLastRequest(newToken.LastRequestTime.UTC()).
 | 
			
		||||
		SetStatus(newToken.Status).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update device token uploading: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.Commit(); err != nil {
 | 
			
		||||
		return rollback(tx, "update device token commit: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										81
									
								
								storage/ent/client/keys.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								storage/ent/client/keys.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,81 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getKeys(client *db.KeysClient) (storage.Keys, error) {
 | 
			
		||||
	rawKeys, err := client.Get(context.TODO(), keysRowID)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.Keys{}, convertDBError("get keys: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return toStorageKeys(rawKeys), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetKeys returns signing keys, public keys and verification keys from the database.
 | 
			
		||||
func (d *Database) GetKeys() (storage.Keys, error) {
 | 
			
		||||
	return getKeys(d.client.Keys)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateKeys rotates keys using updater function.
 | 
			
		||||
func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error {
 | 
			
		||||
	firstUpdate := false
 | 
			
		||||
 | 
			
		||||
	tx, err := d.client.Tx(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("update keys tx: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	storageKeys, err := getKeys(tx.Keys)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if !errors.Is(err, storage.ErrNotFound) {
 | 
			
		||||
			return rollback(tx, "update keys get: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
		firstUpdate = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newKeys, err := updater(storageKeys)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update keys updating: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ent doesn't have an upsert support yet
 | 
			
		||||
	// https://github.com/facebook/ent/issues/139
 | 
			
		||||
	if firstUpdate {
 | 
			
		||||
		_, err = tx.Keys.Create().
 | 
			
		||||
			SetID(keysRowID).
 | 
			
		||||
			SetNextRotation(newKeys.NextRotation).
 | 
			
		||||
			SetSigningKey(*newKeys.SigningKey).
 | 
			
		||||
			SetSigningKeyPub(*newKeys.SigningKeyPub).
 | 
			
		||||
			SetVerificationKeys(newKeys.VerificationKeys).
 | 
			
		||||
			Save(context.TODO())
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return rollback(tx, "create keys: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
		if err = tx.Commit(); err != nil {
 | 
			
		||||
			return rollback(tx, "update keys commit: %w", err)
 | 
			
		||||
		}
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = tx.Keys.UpdateOneID(keysRowID).
 | 
			
		||||
		SetNextRotation(newKeys.NextRotation.UTC()).
 | 
			
		||||
		SetSigningKey(*newKeys.SigningKey).
 | 
			
		||||
		SetSigningKeyPub(*newKeys.SigningKeyPub).
 | 
			
		||||
		SetVerificationKeys(newKeys.VerificationKeys).
 | 
			
		||||
		Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update keys uploading: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.Commit(); err != nil {
 | 
			
		||||
		return rollback(tx, "update keys commit: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										95
									
								
								storage/ent/client/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								storage/ent/client/main.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,95 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"hash"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db/authcode"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db/authrequest"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db/devicerequest"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db/devicetoken"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db/migrate"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ storage.Storage = (*Database)(nil)
 | 
			
		||||
 | 
			
		||||
type Database struct {
 | 
			
		||||
	client *db.Client
 | 
			
		||||
	hasher func() hash.Hash
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewDatabase returns new database client with set options.
 | 
			
		||||
func NewDatabase(opts ...func(*Database)) *Database {
 | 
			
		||||
	database := &Database{}
 | 
			
		||||
	for _, f := range opts {
 | 
			
		||||
		f(database)
 | 
			
		||||
	}
 | 
			
		||||
	return database
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithClient sets client option of a Database object.
 | 
			
		||||
func WithClient(c *db.Client) func(*Database) {
 | 
			
		||||
	return func(s *Database) {
 | 
			
		||||
		s.client = c
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithHasher sets client option of a Database object.
 | 
			
		||||
func WithHasher(h func() hash.Hash) func(*Database) {
 | 
			
		||||
	return func(s *Database) {
 | 
			
		||||
		s.hasher = h
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Schema exposes migration schema to perform migrations.
 | 
			
		||||
func (d *Database) Schema() *migrate.Schema {
 | 
			
		||||
	return d.client.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close calls the corresponding method of the ent database client.
 | 
			
		||||
func (d *Database) Close() error {
 | 
			
		||||
	return d.client.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GarbageCollect removes expired entities from the database.
 | 
			
		||||
func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) {
 | 
			
		||||
	result := storage.GCResult{}
 | 
			
		||||
	utcNow := now.UTC()
 | 
			
		||||
 | 
			
		||||
	q, err := d.client.AuthRequest.Delete().
 | 
			
		||||
		Where(authrequest.ExpiryLT(utcNow)).
 | 
			
		||||
		Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return result, convertDBError("gc auth request: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	result.AuthRequests = int64(q)
 | 
			
		||||
 | 
			
		||||
	q, err = d.client.AuthCode.Delete().
 | 
			
		||||
		Where(authcode.ExpiryLT(utcNow)).
 | 
			
		||||
		Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return result, convertDBError("gc auth code: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	result.AuthCodes = int64(q)
 | 
			
		||||
 | 
			
		||||
	q, err = d.client.DeviceRequest.Delete().
 | 
			
		||||
		Where(devicerequest.ExpiryLT(utcNow)).
 | 
			
		||||
		Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return result, convertDBError("gc device request: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	result.DeviceRequests = int64(q)
 | 
			
		||||
 | 
			
		||||
	q, err = d.client.DeviceToken.Delete().
 | 
			
		||||
		Where(devicetoken.ExpiryLT(utcNow)).
 | 
			
		||||
		Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return result, convertDBError("gc device token: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	result.DeviceTokens = int64(q)
 | 
			
		||||
 | 
			
		||||
	return result, err
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										93
									
								
								storage/ent/client/offlinesession.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								storage/ent/client/offlinesession.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,93 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreateOfflineSessions saves provided offline session into the database.
 | 
			
		||||
func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error {
 | 
			
		||||
	encodedRefresh, err := json.Marshal(session.Refresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("encode refresh offline session: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	id := offlineSessionID(session.UserID, session.ConnID, d.hasher)
 | 
			
		||||
	_, err = d.client.OfflineSession.Create().
 | 
			
		||||
		SetID(id).
 | 
			
		||||
		SetUserID(session.UserID).
 | 
			
		||||
		SetConnID(session.ConnID).
 | 
			
		||||
		SetConnectorData(session.ConnectorData).
 | 
			
		||||
		SetRefresh(encodedRefresh).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("create offline session: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetOfflineSessions extracts an offline session from the database by user id and connector id.
 | 
			
		||||
func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSessions, error) {
 | 
			
		||||
	id := offlineSessionID(userID, connID, d.hasher)
 | 
			
		||||
 | 
			
		||||
	offlineSession, err := d.client.OfflineSession.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.OfflineSessions{}, convertDBError("get offline session: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return toStorageOfflineSession(offlineSession), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteOfflineSessions deletes an offline session from the database by user id and connector id.
 | 
			
		||||
func (d *Database) DeleteOfflineSessions(userID, connID string) error {
 | 
			
		||||
	id := offlineSessionID(userID, connID, d.hasher)
 | 
			
		||||
 | 
			
		||||
	err := d.client.OfflineSession.DeleteOneID(id).Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("delete offline session: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdatePassword changes an offline session by user id and connector id using an updater function.
 | 
			
		||||
func (d *Database) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error {
 | 
			
		||||
	id := offlineSessionID(userID, connID, d.hasher)
 | 
			
		||||
 | 
			
		||||
	tx, err := d.client.Tx(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("update offline session tx: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	offlineSession, err := tx.OfflineSession.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update offline session database: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newOfflineSession, err := updater(toStorageOfflineSession(offlineSession))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update offline session updating: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	encodedRefresh, err := json.Marshal(newOfflineSession.Refresh)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "encode refresh offline session: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = tx.OfflineSession.UpdateOneID(id).
 | 
			
		||||
		SetUserID(newOfflineSession.UserID).
 | 
			
		||||
		SetConnID(newOfflineSession.ConnID).
 | 
			
		||||
		SetConnectorData(newOfflineSession.ConnectorData).
 | 
			
		||||
		SetRefresh(encodedRefresh).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update offline session uploading: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.Commit(); err != nil {
 | 
			
		||||
		return rollback(tx, "update password commit: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										100
									
								
								storage/ent/client/password.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								storage/ent/client/password.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,100 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db/password"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreatePassword saves provided password into the database.
 | 
			
		||||
func (d *Database) CreatePassword(password storage.Password) error {
 | 
			
		||||
	_, err := d.client.Password.Create().
 | 
			
		||||
		SetEmail(password.Email).
 | 
			
		||||
		SetHash(password.Hash).
 | 
			
		||||
		SetUsername(password.Username).
 | 
			
		||||
		SetUserID(password.UserID).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("create password: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListPasswords extracts an array of passwords from the database.
 | 
			
		||||
func (d *Database) ListPasswords() ([]storage.Password, error) {
 | 
			
		||||
	passwords, err := d.client.Password.Query().All(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, convertDBError("list passwords: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	storagePasswords := make([]storage.Password, 0, len(passwords))
 | 
			
		||||
	for _, p := range passwords {
 | 
			
		||||
		storagePasswords = append(storagePasswords, toStoragePassword(p))
 | 
			
		||||
	}
 | 
			
		||||
	return storagePasswords, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetPassword extracts a password from the database by email.
 | 
			
		||||
func (d *Database) GetPassword(email string) (storage.Password, error) {
 | 
			
		||||
	email = strings.ToLower(email)
 | 
			
		||||
	passwordFromStorage, err := d.client.Password.Query().
 | 
			
		||||
		Where(password.Email(email)).
 | 
			
		||||
		Only(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.Password{}, convertDBError("get password: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return toStoragePassword(passwordFromStorage), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeletePassword deletes a password from the database by email.
 | 
			
		||||
func (d *Database) DeletePassword(email string) error {
 | 
			
		||||
	email = strings.ToLower(email)
 | 
			
		||||
	_, err := d.client.Password.Delete().
 | 
			
		||||
		Where(password.Email(email)).
 | 
			
		||||
		Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("delete password: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdatePassword changes a password by email using an updater function and saves it to the database.
 | 
			
		||||
func (d *Database) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error {
 | 
			
		||||
	email = strings.ToLower(email)
 | 
			
		||||
 | 
			
		||||
	tx, err := d.client.Tx(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("update connector tx: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	passwordToUpdate, err := tx.Password.Query().
 | 
			
		||||
		Where(password.Email(email)).
 | 
			
		||||
		Only(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update password database: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newPassword, err := updater(toStoragePassword(passwordToUpdate))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update password updating: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = tx.Password.Update().
 | 
			
		||||
		Where(password.Email(newPassword.Email)).
 | 
			
		||||
		SetEmail(newPassword.Email).
 | 
			
		||||
		SetHash(newPassword.Hash).
 | 
			
		||||
		SetUsername(newPassword.Username).
 | 
			
		||||
		SetUserID(newPassword.UserID).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update password uploading: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.Commit(); err != nil {
 | 
			
		||||
		return rollback(tx, "update password commit: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										109
									
								
								storage/ent/client/refreshtoken.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										109
									
								
								storage/ent/client/refreshtoken.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,109 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CreateRefresh saves provided refresh token into the database.
 | 
			
		||||
func (d *Database) CreateRefresh(refresh storage.RefreshToken) error {
 | 
			
		||||
	_, err := d.client.RefreshToken.Create().
 | 
			
		||||
		SetID(refresh.ID).
 | 
			
		||||
		SetClientID(refresh.ClientID).
 | 
			
		||||
		SetScopes(refresh.Scopes).
 | 
			
		||||
		SetNonce(refresh.Nonce).
 | 
			
		||||
		SetClaimsUserID(refresh.Claims.UserID).
 | 
			
		||||
		SetClaimsEmail(refresh.Claims.Email).
 | 
			
		||||
		SetClaimsEmailVerified(refresh.Claims.EmailVerified).
 | 
			
		||||
		SetClaimsUsername(refresh.Claims.Username).
 | 
			
		||||
		SetClaimsPreferredUsername(refresh.Claims.PreferredUsername).
 | 
			
		||||
		SetClaimsGroups(refresh.Claims.Groups).
 | 
			
		||||
		SetConnectorID(refresh.ConnectorID).
 | 
			
		||||
		SetConnectorData(refresh.ConnectorData).
 | 
			
		||||
		SetToken(refresh.Token).
 | 
			
		||||
		// Save utc time into database because ent doesn't support comparing dates with different timezones
 | 
			
		||||
		SetLastUsed(refresh.LastUsed.UTC()).
 | 
			
		||||
		SetCreatedAt(refresh.CreatedAt.UTC()).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("create refresh token: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ListRefreshTokens extracts an array of refresh tokens from the database.
 | 
			
		||||
func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) {
 | 
			
		||||
	refreshTokens, err := d.client.RefreshToken.Query().All(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, convertDBError("list refresh tokens: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	storageRefreshTokens := make([]storage.RefreshToken, 0, len(refreshTokens))
 | 
			
		||||
	for _, r := range refreshTokens {
 | 
			
		||||
		storageRefreshTokens = append(storageRefreshTokens, toStorageRefreshToken(r))
 | 
			
		||||
	}
 | 
			
		||||
	return storageRefreshTokens, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetRefresh extracts a refresh token from the database by id.
 | 
			
		||||
func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) {
 | 
			
		||||
	refreshToken, err := d.client.RefreshToken.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return storage.RefreshToken{}, convertDBError("get refresh token: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return toStorageRefreshToken(refreshToken), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DeleteRefresh deletes a refresh token from the database by id.
 | 
			
		||||
func (d *Database) DeleteRefresh(id string) error {
 | 
			
		||||
	err := d.client.RefreshToken.DeleteOneID(id).Exec(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("delete refresh token: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database.
 | 
			
		||||
func (d *Database) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error {
 | 
			
		||||
	tx, err := d.client.Tx(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return convertDBError("update refresh token tx: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	token, err := tx.RefreshToken.Get(context.TODO(), id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update refresh token database: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newtToken, err := updater(toStorageRefreshToken(token))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update refresh token updating: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, err = tx.RefreshToken.UpdateOneID(newtToken.ID).
 | 
			
		||||
		SetClientID(newtToken.ClientID).
 | 
			
		||||
		SetScopes(newtToken.Scopes).
 | 
			
		||||
		SetNonce(newtToken.Nonce).
 | 
			
		||||
		SetClaimsUserID(newtToken.Claims.UserID).
 | 
			
		||||
		SetClaimsEmail(newtToken.Claims.Email).
 | 
			
		||||
		SetClaimsEmailVerified(newtToken.Claims.EmailVerified).
 | 
			
		||||
		SetClaimsUsername(newtToken.Claims.Username).
 | 
			
		||||
		SetClaimsPreferredUsername(newtToken.Claims.PreferredUsername).
 | 
			
		||||
		SetClaimsGroups(newtToken.Claims.Groups).
 | 
			
		||||
		SetConnectorID(newtToken.ConnectorID).
 | 
			
		||||
		SetConnectorData(newtToken.ConnectorData).
 | 
			
		||||
		SetToken(newtToken.Token).
 | 
			
		||||
		// Save utc time into database because ent doesn't support comparing dates with different timezones
 | 
			
		||||
		SetLastUsed(newtToken.LastUsed.UTC()).
 | 
			
		||||
		SetCreatedAt(newtToken.CreatedAt.UTC()).
 | 
			
		||||
		Save(context.TODO())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return rollback(tx, "update refresh token uploading: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err = tx.Commit(); err != nil {
 | 
			
		||||
		return rollback(tx, "update refresh token commit: %w", err)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										167
									
								
								storage/ent/client/types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										167
									
								
								storage/ent/client/types.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,167 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const keysRowID = "keys"
 | 
			
		||||
 | 
			
		||||
func toStorageKeys(keys *db.Keys) storage.Keys {
 | 
			
		||||
	return storage.Keys{
 | 
			
		||||
		SigningKey:       &keys.SigningKey,
 | 
			
		||||
		SigningKeyPub:    &keys.SigningKeyPub,
 | 
			
		||||
		VerificationKeys: keys.VerificationKeys,
 | 
			
		||||
		NextRotation:     keys.NextRotation,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest {
 | 
			
		||||
	return storage.AuthRequest{
 | 
			
		||||
		ID:                  a.ID,
 | 
			
		||||
		ClientID:            a.ClientID,
 | 
			
		||||
		ResponseTypes:       a.ResponseTypes,
 | 
			
		||||
		Scopes:              a.Scopes,
 | 
			
		||||
		RedirectURI:         a.RedirectURI,
 | 
			
		||||
		Nonce:               a.Nonce,
 | 
			
		||||
		State:               a.State,
 | 
			
		||||
		ForceApprovalPrompt: a.ForceApprovalPrompt,
 | 
			
		||||
		LoggedIn:            a.LoggedIn,
 | 
			
		||||
		ConnectorID:         a.ConnectorID,
 | 
			
		||||
		ConnectorData:       *a.ConnectorData,
 | 
			
		||||
		Expiry:              a.Expiry,
 | 
			
		||||
		Claims: storage.Claims{
 | 
			
		||||
			UserID:            a.ClaimsUserID,
 | 
			
		||||
			Username:          a.ClaimsUsername,
 | 
			
		||||
			PreferredUsername: a.ClaimsPreferredUsername,
 | 
			
		||||
			Email:             a.ClaimsEmail,
 | 
			
		||||
			EmailVerified:     a.ClaimsEmailVerified,
 | 
			
		||||
			Groups:            a.ClaimsGroups,
 | 
			
		||||
		},
 | 
			
		||||
		PKCE: storage.PKCE{
 | 
			
		||||
			CodeChallenge:       a.CodeChallenge,
 | 
			
		||||
			CodeChallengeMethod: a.CodeChallengeMethod,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toStorageAuthCode(a *db.AuthCode) storage.AuthCode {
 | 
			
		||||
	return storage.AuthCode{
 | 
			
		||||
		ID:            a.ID,
 | 
			
		||||
		ClientID:      a.ClientID,
 | 
			
		||||
		Scopes:        a.Scopes,
 | 
			
		||||
		RedirectURI:   a.RedirectURI,
 | 
			
		||||
		Nonce:         a.Nonce,
 | 
			
		||||
		ConnectorID:   a.ConnectorID,
 | 
			
		||||
		ConnectorData: *a.ConnectorData,
 | 
			
		||||
		Expiry:        a.Expiry,
 | 
			
		||||
		Claims: storage.Claims{
 | 
			
		||||
			UserID:            a.ClaimsUserID,
 | 
			
		||||
			Username:          a.ClaimsUsername,
 | 
			
		||||
			PreferredUsername: a.ClaimsPreferredUsername,
 | 
			
		||||
			Email:             a.ClaimsEmail,
 | 
			
		||||
			EmailVerified:     a.ClaimsEmailVerified,
 | 
			
		||||
			Groups:            a.ClaimsGroups,
 | 
			
		||||
		},
 | 
			
		||||
		PKCE: storage.PKCE{
 | 
			
		||||
			CodeChallenge:       a.CodeChallenge,
 | 
			
		||||
			CodeChallengeMethod: a.CodeChallengeMethod,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toStorageClient(c *db.OAuth2Client) storage.Client {
 | 
			
		||||
	return storage.Client{
 | 
			
		||||
		ID:           c.ID,
 | 
			
		||||
		Secret:       c.Secret,
 | 
			
		||||
		RedirectURIs: c.RedirectUris,
 | 
			
		||||
		TrustedPeers: c.TrustedPeers,
 | 
			
		||||
		Public:       c.Public,
 | 
			
		||||
		Name:         c.Name,
 | 
			
		||||
		LogoURL:      c.LogoURL,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toStorageConnector(c *db.Connector) storage.Connector {
 | 
			
		||||
	return storage.Connector{
 | 
			
		||||
		ID:     c.ID,
 | 
			
		||||
		Type:   c.Type,
 | 
			
		||||
		Name:   c.Name,
 | 
			
		||||
		Config: c.Config,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toStorageOfflineSession(o *db.OfflineSession) storage.OfflineSessions {
 | 
			
		||||
	s := storage.OfflineSessions{
 | 
			
		||||
		UserID:        o.UserID,
 | 
			
		||||
		ConnID:        o.ConnID,
 | 
			
		||||
		ConnectorData: *o.ConnectorData,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if o.Refresh != nil {
 | 
			
		||||
		if err := json.Unmarshal(o.Refresh, &s.Refresh); err != nil {
 | 
			
		||||
			// Correctness of json structure if guaranteed on uploading
 | 
			
		||||
			panic(err)
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		// Server code assumes this will be non-nil.
 | 
			
		||||
		s.Refresh = make(map[string]*storage.RefreshTokenRef)
 | 
			
		||||
	}
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toStorageRefreshToken(r *db.RefreshToken) storage.RefreshToken {
 | 
			
		||||
	return storage.RefreshToken{
 | 
			
		||||
		ID:            r.ID,
 | 
			
		||||
		Token:         r.Token,
 | 
			
		||||
		CreatedAt:     r.CreatedAt,
 | 
			
		||||
		LastUsed:      r.LastUsed,
 | 
			
		||||
		ClientID:      r.ClientID,
 | 
			
		||||
		ConnectorID:   r.ConnectorID,
 | 
			
		||||
		ConnectorData: *r.ConnectorData,
 | 
			
		||||
		Scopes:        r.Scopes,
 | 
			
		||||
		Nonce:         r.Nonce,
 | 
			
		||||
		Claims: storage.Claims{
 | 
			
		||||
			UserID:            r.ClaimsUserID,
 | 
			
		||||
			Username:          r.ClaimsUsername,
 | 
			
		||||
			PreferredUsername: r.ClaimsPreferredUsername,
 | 
			
		||||
			Email:             r.ClaimsEmail,
 | 
			
		||||
			EmailVerified:     r.ClaimsEmailVerified,
 | 
			
		||||
			Groups:            r.ClaimsGroups,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toStoragePassword(p *db.Password) storage.Password {
 | 
			
		||||
	return storage.Password{
 | 
			
		||||
		Email:    p.Email,
 | 
			
		||||
		Hash:     p.Hash,
 | 
			
		||||
		Username: p.Username,
 | 
			
		||||
		UserID:   p.UserID,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toStorageDeviceRequest(r *db.DeviceRequest) storage.DeviceRequest {
 | 
			
		||||
	return storage.DeviceRequest{
 | 
			
		||||
		UserCode:     strings.ToUpper(r.UserCode),
 | 
			
		||||
		DeviceCode:   r.DeviceCode,
 | 
			
		||||
		ClientID:     r.ClientID,
 | 
			
		||||
		ClientSecret: r.ClientSecret,
 | 
			
		||||
		Scopes:       r.Scopes,
 | 
			
		||||
		Expiry:       r.Expiry,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken {
 | 
			
		||||
	return storage.DeviceToken{
 | 
			
		||||
		DeviceCode:          t.DeviceCode,
 | 
			
		||||
		Status:              t.Status,
 | 
			
		||||
		Token:               string(*t.Token),
 | 
			
		||||
		Expiry:              t.Expiry,
 | 
			
		||||
		LastRequestTime:     t.LastRequest,
 | 
			
		||||
		PollIntervalSeconds: t.PollInterval,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										44
									
								
								storage/ent/client/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								storage/ent/client/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
package client
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"hash"
 | 
			
		||||
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func rollback(tx *db.Tx, t string, err error) error {
 | 
			
		||||
	rerr := tx.Rollback()
 | 
			
		||||
	err = convertDBError(t, err)
 | 
			
		||||
 | 
			
		||||
	if rerr == nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return errors.Wrapf(err, "rolling back transaction: %v", rerr)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func convertDBError(t string, err error) error {
 | 
			
		||||
	if db.IsNotFound(err) {
 | 
			
		||||
		return storage.ErrNotFound
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if db.IsConstraintError(err) {
 | 
			
		||||
		return storage.ErrAlreadyExists
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Errorf(t, err)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// compose hashed id from user and connection id to use it as primary key
 | 
			
		||||
// ent doesn't support multi-key primary yet
 | 
			
		||||
// https://github.com/facebook/ent/issues/400
 | 
			
		||||
func offlineSessionID(userID string, connID string, hasher func() hash.Hash) string {
 | 
			
		||||
	h := hasher()
 | 
			
		||||
 | 
			
		||||
	h.Write([]byte(userID))
 | 
			
		||||
	h.Write([]byte(connID))
 | 
			
		||||
	return fmt.Sprintf("%x", h.Sum(nil))
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										3
									
								
								storage/ent/generate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								storage/ent/generate.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
package ent
 | 
			
		||||
 | 
			
		||||
//go:generate go run github.com/facebook/ent/cmd/entc generate ./schema --target ./db
 | 
			
		||||
							
								
								
									
										89
									
								
								storage/ent/schema/authcode.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								storage/ent/schema/authcode.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,89 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/facebook/ent"
 | 
			
		||||
	"github.com/facebook/ent/schema/field"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Original SQL table:
 | 
			
		||||
create table auth_code
 | 
			
		||||
(
 | 
			
		||||
    id                        text      not null  primary key,
 | 
			
		||||
    client_id                 text      not null,
 | 
			
		||||
    scopes                    blob      not null,
 | 
			
		||||
    nonce                     text      not null,
 | 
			
		||||
    redirect_uri              text      not null,
 | 
			
		||||
    claims_user_id            text      not null,
 | 
			
		||||
    claims_username           text      not null,
 | 
			
		||||
    claims_email              text      not null,
 | 
			
		||||
    claims_email_verified     integer   not null,
 | 
			
		||||
    claims_groups             blob      not null,
 | 
			
		||||
    connector_id              text      not null,
 | 
			
		||||
    connector_data            blob,
 | 
			
		||||
    expiry                    timestamp not null,
 | 
			
		||||
    claims_preferred_username text default '' not null,
 | 
			
		||||
    code_challenge            text default '' not null,
 | 
			
		||||
    code_challenge_method     text default '' not null
 | 
			
		||||
);
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// AuthCode holds the schema definition for the AuthCode entity.
 | 
			
		||||
type AuthCode struct {
 | 
			
		||||
	ent.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fields of the AuthCode.
 | 
			
		||||
func (AuthCode) Fields() []ent.Field {
 | 
			
		||||
	return []ent.Field{
 | 
			
		||||
		field.Text("id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty().
 | 
			
		||||
			Unique(),
 | 
			
		||||
		field.Text("client_id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.JSON("scopes", []string{}).
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.Text("nonce").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("redirect_uri").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
 | 
			
		||||
		field.Text("claims_user_id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("claims_username").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("claims_email").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Bool("claims_email_verified"),
 | 
			
		||||
		field.JSON("claims_groups", []string{}).
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.Text("claims_preferred_username").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			Default(""),
 | 
			
		||||
 | 
			
		||||
		field.Text("connector_id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Bytes("connector_data").
 | 
			
		||||
			Nillable().
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.Time("expiry"),
 | 
			
		||||
		field.Text("code_challenge").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			Default(""),
 | 
			
		||||
		field.Text("code_challenge_method").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			Default(""),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Edges of the AuthCode.
 | 
			
		||||
func (AuthCode) Edges() []ent.Edge {
 | 
			
		||||
	return []ent.Edge{}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										94
									
								
								storage/ent/schema/authrequest.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								storage/ent/schema/authrequest.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,94 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/facebook/ent"
 | 
			
		||||
	"github.com/facebook/ent/schema/field"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Original SQL table:
 | 
			
		||||
create table auth_request
 | 
			
		||||
(
 | 
			
		||||
    id                        text      not null  primary key,
 | 
			
		||||
    client_id                 text      not null,
 | 
			
		||||
    response_types            blob      not null,
 | 
			
		||||
    scopes                    blob      not null,
 | 
			
		||||
    redirect_uri              text      not null,
 | 
			
		||||
    nonce                     text      not null,
 | 
			
		||||
    state                     text      not null,
 | 
			
		||||
    force_approval_prompt     integer   not null,
 | 
			
		||||
    logged_in                 integer   not null,
 | 
			
		||||
    claims_user_id            text      not null,
 | 
			
		||||
    claims_username           text      not null,
 | 
			
		||||
    claims_email              text      not null,
 | 
			
		||||
    claims_email_verified     integer   not null,
 | 
			
		||||
    claims_groups             blob      not null,
 | 
			
		||||
    connector_id              text      not null,
 | 
			
		||||
    connector_data            blob,
 | 
			
		||||
    expiry                    timestamp not null,
 | 
			
		||||
    claims_preferred_username text default '' not null,
 | 
			
		||||
    code_challenge            text default '' not null,
 | 
			
		||||
    code_challenge_method     text default '' not null
 | 
			
		||||
);
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// AuthRequest holds the schema definition for the AuthRequest entity.
 | 
			
		||||
type AuthRequest struct {
 | 
			
		||||
	ent.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fields of the AuthRequest.
 | 
			
		||||
func (AuthRequest) Fields() []ent.Field {
 | 
			
		||||
	return []ent.Field{
 | 
			
		||||
		field.Text("id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty().
 | 
			
		||||
			Unique(),
 | 
			
		||||
		field.Text("client_id").
 | 
			
		||||
			SchemaType(textSchema),
 | 
			
		||||
		field.JSON("scopes", []string{}).
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.JSON("response_types", []string{}).
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.Text("redirect_uri").
 | 
			
		||||
			SchemaType(textSchema),
 | 
			
		||||
		field.Text("nonce").
 | 
			
		||||
			SchemaType(textSchema),
 | 
			
		||||
		field.Text("state").
 | 
			
		||||
			SchemaType(textSchema),
 | 
			
		||||
 | 
			
		||||
		field.Bool("force_approval_prompt"),
 | 
			
		||||
		field.Bool("logged_in"),
 | 
			
		||||
 | 
			
		||||
		field.Text("claims_user_id").
 | 
			
		||||
			SchemaType(textSchema),
 | 
			
		||||
		field.Text("claims_username").
 | 
			
		||||
			SchemaType(textSchema),
 | 
			
		||||
		field.Text("claims_email").
 | 
			
		||||
			SchemaType(textSchema),
 | 
			
		||||
		field.Bool("claims_email_verified"),
 | 
			
		||||
		field.JSON("claims_groups", []string{}).
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.Text("claims_preferred_username").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			Default(""),
 | 
			
		||||
 | 
			
		||||
		field.Text("connector_id").
 | 
			
		||||
			SchemaType(textSchema),
 | 
			
		||||
		field.Bytes("connector_data").
 | 
			
		||||
			Nillable().
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.Time("expiry"),
 | 
			
		||||
 | 
			
		||||
		field.Text("code_challenge").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			Default(""),
 | 
			
		||||
		field.Text("code_challenge_method").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			Default(""),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Edges of the AuthRequest.
 | 
			
		||||
func (AuthRequest) Edges() []ent.Edge {
 | 
			
		||||
	return []ent.Edge{}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										53
									
								
								storage/ent/schema/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								storage/ent/schema/client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,53 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/facebook/ent"
 | 
			
		||||
	"github.com/facebook/ent/schema/field"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Original SQL table:
 | 
			
		||||
create table client
 | 
			
		||||
(
 | 
			
		||||
    id            text    not null  primary key,
 | 
			
		||||
    secret        text    not null,
 | 
			
		||||
    redirect_uris blob    not null,
 | 
			
		||||
    trusted_peers blob    not null,
 | 
			
		||||
    public        integer not null,
 | 
			
		||||
    name          text    not null,
 | 
			
		||||
    logo_url      text    not null
 | 
			
		||||
);
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// OAuth2Client holds the schema definition for the Client entity.
 | 
			
		||||
type OAuth2Client struct {
 | 
			
		||||
	ent.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fields of the OAuth2Client.
 | 
			
		||||
func (OAuth2Client) Fields() []ent.Field {
 | 
			
		||||
	return []ent.Field{
 | 
			
		||||
		field.Text("id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty().
 | 
			
		||||
			Unique(),
 | 
			
		||||
		field.Text("secret").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.JSON("redirect_uris", []string{}).
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.JSON("trusted_peers", []string{}).
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.Bool("public"),
 | 
			
		||||
		field.Text("name").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("logo_url").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Edges of the OAuth2Client.
 | 
			
		||||
func (OAuth2Client) Edges() []ent.Edge {
 | 
			
		||||
	return []ent.Edge{}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										46
									
								
								storage/ent/schema/connector.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								storage/ent/schema/connector.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/facebook/ent"
 | 
			
		||||
	"github.com/facebook/ent/schema/field"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Original SQL table:
 | 
			
		||||
create table connector
 | 
			
		||||
(
 | 
			
		||||
    id               text not null  primary key,
 | 
			
		||||
    type             text not null,
 | 
			
		||||
    name             text not null,
 | 
			
		||||
    resource_version text not null,
 | 
			
		||||
    config           blob
 | 
			
		||||
);
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// Connector holds the schema definition for the Client entity.
 | 
			
		||||
type Connector struct {
 | 
			
		||||
	ent.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fields of the Connector.
 | 
			
		||||
func (Connector) Fields() []ent.Field {
 | 
			
		||||
	return []ent.Field{
 | 
			
		||||
		field.Text("id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty().
 | 
			
		||||
			Unique(),
 | 
			
		||||
		field.Text("type").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("name").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("resource_version").
 | 
			
		||||
			SchemaType(textSchema),
 | 
			
		||||
		field.Bytes("config"),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Edges of the Connector.
 | 
			
		||||
func (Connector) Edges() []ent.Edge {
 | 
			
		||||
	return []ent.Edge{}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										50
									
								
								storage/ent/schema/devicerequest.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								storage/ent/schema/devicerequest.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,50 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/facebook/ent"
 | 
			
		||||
	"github.com/facebook/ent/schema/field"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Original SQL table:
 | 
			
		||||
create table device_request
 | 
			
		||||
(
 | 
			
		||||
    user_code     text      not null  primary key,
 | 
			
		||||
    device_code   text      not null,
 | 
			
		||||
    client_id     text      not null,
 | 
			
		||||
    client_secret text,
 | 
			
		||||
    scopes        blob      not null,
 | 
			
		||||
    expiry        timestamp not null
 | 
			
		||||
);
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// DeviceRequest holds the schema definition for the DeviceRequest entity.
 | 
			
		||||
type DeviceRequest struct {
 | 
			
		||||
	ent.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fields of the DeviceRequest.
 | 
			
		||||
func (DeviceRequest) Fields() []ent.Field {
 | 
			
		||||
	return []ent.Field{
 | 
			
		||||
		field.Text("user_code").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty().
 | 
			
		||||
			Unique(),
 | 
			
		||||
		field.Text("device_code").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("client_id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("client_secret").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.JSON("scopes", []string{}).
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.Time("expiry"),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Edges of the DeviceRequest.
 | 
			
		||||
func (DeviceRequest) Edges() []ent.Edge {
 | 
			
		||||
	return []ent.Edge{}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										45
									
								
								storage/ent/schema/devicetoken.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								storage/ent/schema/devicetoken.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,45 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/facebook/ent"
 | 
			
		||||
	"github.com/facebook/ent/schema/field"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Original SQL table:
 | 
			
		||||
create table device_token
 | 
			
		||||
(
 | 
			
		||||
    device_code   text      not null primary key,
 | 
			
		||||
    status        text      not null,
 | 
			
		||||
    token         blob,
 | 
			
		||||
    expiry        timestamp not null,
 | 
			
		||||
    last_request  timestamp not null,
 | 
			
		||||
    poll_interval integer   not null
 | 
			
		||||
);
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// DeviceToken holds the schema definition for the DeviceToken entity.
 | 
			
		||||
type DeviceToken struct {
 | 
			
		||||
	ent.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fields of the DeviceToken.
 | 
			
		||||
func (DeviceToken) Fields() []ent.Field {
 | 
			
		||||
	return []ent.Field{
 | 
			
		||||
		field.Text("device_code").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty().
 | 
			
		||||
			Unique(),
 | 
			
		||||
		field.Text("status").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Bytes("token").Nillable().Optional(),
 | 
			
		||||
		field.Time("expiry"),
 | 
			
		||||
		field.Time("last_request"),
 | 
			
		||||
		field.Int("poll_interval"),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Edges of the DeviceToken.
 | 
			
		||||
func (DeviceToken) Edges() []ent.Edge {
 | 
			
		||||
	return []ent.Edge{}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										44
									
								
								storage/ent/schema/keys.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								storage/ent/schema/keys.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/facebook/ent"
 | 
			
		||||
	"github.com/facebook/ent/schema/field"
 | 
			
		||||
	"gopkg.in/square/go-jose.v2"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Original SQL table:
 | 
			
		||||
create table keys
 | 
			
		||||
(
 | 
			
		||||
    id                text      not null  primary key,
 | 
			
		||||
    verification_keys blob      not null,
 | 
			
		||||
    signing_key       blob      not null,
 | 
			
		||||
    signing_key_pub   blob      not null,
 | 
			
		||||
    next_rotation     timestamp not null
 | 
			
		||||
);
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// Keys holds the schema definition for the Keys entity.
 | 
			
		||||
type Keys struct {
 | 
			
		||||
	ent.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fields of the Keys.
 | 
			
		||||
func (Keys) Fields() []ent.Field {
 | 
			
		||||
	return []ent.Field{
 | 
			
		||||
		field.Text("id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty().
 | 
			
		||||
			Unique(),
 | 
			
		||||
		field.JSON("verification_keys", []storage.VerificationKey{}),
 | 
			
		||||
		field.JSON("signing_key", jose.JSONWebKey{}),
 | 
			
		||||
		field.JSON("signing_key_pub", jose.JSONWebKey{}),
 | 
			
		||||
		field.Time("next_rotation"),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Edges of the Keys.
 | 
			
		||||
func (Keys) Edges() []ent.Edge {
 | 
			
		||||
	return []ent.Edge{}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										46
									
								
								storage/ent/schema/offlinesession.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								storage/ent/schema/offlinesession.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/facebook/ent"
 | 
			
		||||
	"github.com/facebook/ent/schema/field"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Original SQL table:
 | 
			
		||||
create table offline_session
 | 
			
		||||
(
 | 
			
		||||
    user_id        text not null,
 | 
			
		||||
    conn_id        text not null,
 | 
			
		||||
    refresh        blob not null,
 | 
			
		||||
    connector_data blob,
 | 
			
		||||
    primary key (user_id, conn_id)
 | 
			
		||||
);
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// OfflineSession holds the schema definition for the OfflineSession entity.
 | 
			
		||||
type OfflineSession struct {
 | 
			
		||||
	ent.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fields of the OfflineSession.
 | 
			
		||||
func (OfflineSession) Fields() []ent.Field {
 | 
			
		||||
	return []ent.Field{
 | 
			
		||||
		// Using id field here because it's impossible to create multi-key primary yet
 | 
			
		||||
		field.Text("id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty().
 | 
			
		||||
			Unique(),
 | 
			
		||||
		field.Text("user_id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("conn_id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Bytes("refresh"),
 | 
			
		||||
		field.Bytes("connector_data").Nillable().Optional(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Edges of the OfflineSession.
 | 
			
		||||
func (OfflineSession) Edges() []ent.Edge {
 | 
			
		||||
	return []ent.Edge{}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										44
									
								
								storage/ent/schema/password.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								storage/ent/schema/password.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/facebook/ent"
 | 
			
		||||
	"github.com/facebook/ent/schema/field"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Original SQL table:
 | 
			
		||||
create table password
 | 
			
		||||
(
 | 
			
		||||
    email    text not null  primary key,
 | 
			
		||||
    hash     blob not null,
 | 
			
		||||
    username text not null,
 | 
			
		||||
    user_id  text not null
 | 
			
		||||
);
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// Password holds the schema definition for the Password entity.
 | 
			
		||||
type Password struct {
 | 
			
		||||
	ent.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fields of the Password.
 | 
			
		||||
func (Password) Fields() []ent.Field {
 | 
			
		||||
	return []ent.Field{
 | 
			
		||||
		field.Text("email").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			StorageKey("email"). // use email as ID field to make querying easier
 | 
			
		||||
			NotEmpty().
 | 
			
		||||
			Unique(),
 | 
			
		||||
		field.Bytes("hash"),
 | 
			
		||||
		field.Text("username").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("user_id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Edges of the Password.
 | 
			
		||||
func (Password) Edges() []ent.Edge {
 | 
			
		||||
	return []ent.Edge{}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										89
									
								
								storage/ent/schema/refreshtoken.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								storage/ent/schema/refreshtoken.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,89 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/facebook/ent"
 | 
			
		||||
	"github.com/facebook/ent/schema/field"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
/* Original SQL table:
 | 
			
		||||
create table refresh_token
 | 
			
		||||
(
 | 
			
		||||
    id                        text    not null  primary key,
 | 
			
		||||
    client_id                 text    not null,
 | 
			
		||||
    scopes                    blob    not null,
 | 
			
		||||
    nonce                     text    not null,
 | 
			
		||||
    claims_user_id            text    not null,
 | 
			
		||||
    claims_username           text    not null,
 | 
			
		||||
    claims_email              text    not null,
 | 
			
		||||
    claims_email_verified     integer not null,
 | 
			
		||||
    claims_groups             blob    not null,
 | 
			
		||||
    connector_id              text    not null,
 | 
			
		||||
    connector_data            blob,
 | 
			
		||||
    token                     text      default '' not null,
 | 
			
		||||
    created_at                timestamp default '0001-01-01 00:00:00 UTC' not null,
 | 
			
		||||
    last_used                 timestamp default '0001-01-01 00:00:00 UTC' not null,
 | 
			
		||||
    claims_preferred_username text      default '' not null
 | 
			
		||||
);
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// RefreshToken holds the schema definition for the RefreshToken entity.
 | 
			
		||||
type RefreshToken struct {
 | 
			
		||||
	ent.Schema
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Fields of the RefreshToken.
 | 
			
		||||
func (RefreshToken) Fields() []ent.Field {
 | 
			
		||||
	return []ent.Field{
 | 
			
		||||
		field.Text("id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty().
 | 
			
		||||
			Unique(),
 | 
			
		||||
		field.Text("client_id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.JSON("scopes", []string{}).
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.Text("nonce").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
 | 
			
		||||
		field.Text("claims_user_id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("claims_username").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Text("claims_email").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Bool("claims_email_verified"),
 | 
			
		||||
		field.JSON("claims_groups", []string{}).
 | 
			
		||||
			Optional(),
 | 
			
		||||
		field.Text("claims_preferred_username").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			Default(""),
 | 
			
		||||
 | 
			
		||||
		field.Text("connector_id").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			NotEmpty(),
 | 
			
		||||
		field.Bytes("connector_data").
 | 
			
		||||
			Nillable().
 | 
			
		||||
			Optional(),
 | 
			
		||||
 | 
			
		||||
		field.Text("token").
 | 
			
		||||
			SchemaType(textSchema).
 | 
			
		||||
			Default(""),
 | 
			
		||||
 | 
			
		||||
		field.Time("created_at").
 | 
			
		||||
			Default(time.Now),
 | 
			
		||||
		field.Time("last_used").
 | 
			
		||||
			Default(time.Now),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Edges of the RefreshToken.
 | 
			
		||||
func (RefreshToken) Edges() []ent.Edge {
 | 
			
		||||
	return []ent.Edge{}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										9
									
								
								storage/ent/schema/types.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								storage/ent/schema/types.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
			
		||||
package schema
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/facebook/ent/dialect"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var textSchema = map[string]string{
 | 
			
		||||
	dialect.SQLite: "text",
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										65
									
								
								storage/ent/sqlite.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								storage/ent/sqlite.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,65 @@
 | 
			
		||||
package ent
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/facebook/ent/dialect/sql"
 | 
			
		||||
 | 
			
		||||
	// Register sqlite driver.
 | 
			
		||||
	_ "github.com/mattn/go-sqlite3"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/pkg/log"
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/client"
 | 
			
		||||
	"github.com/dexidp/dex/storage/ent/db"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// SQLite3 options for creating an SQL db.
 | 
			
		||||
type SQLite3 struct {
 | 
			
		||||
	File string `json:"file"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Open always returns a new in sqlite3 storage.
 | 
			
		||||
func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) {
 | 
			
		||||
	logger.Debug("experimental ent-based storage driver is enabled")
 | 
			
		||||
 | 
			
		||||
	// Implicitly set foreign_keys pragma to "on" because it is required by ent
 | 
			
		||||
	s.File = addFK(s.File)
 | 
			
		||||
 | 
			
		||||
	drv, err := sql.Open("sqlite3", s.File)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pool := drv.DB()
 | 
			
		||||
	if s.File == ":memory:" {
 | 
			
		||||
		// sqlite3 uses file locks to coordinate concurrent access. In memory
 | 
			
		||||
		// doesn't support this, so limit the number of connections to 1.
 | 
			
		||||
		pool.SetMaxOpenConns(1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	databaseClient := client.NewDatabase(
 | 
			
		||||
		client.WithClient(db.NewClient(db.Driver(drv))),
 | 
			
		||||
		client.WithHasher(sha256.New),
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if err := databaseClient.Schema().Create(context.TODO()); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return databaseClient, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func addFK(dsn string) string {
 | 
			
		||||
	if strings.Contains(dsn, "_fk") {
 | 
			
		||||
		return dsn
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	delim := "?"
 | 
			
		||||
	if strings.Contains(dsn, "?") {
 | 
			
		||||
		delim = "&"
 | 
			
		||||
	}
 | 
			
		||||
	return dsn + delim + "_fk=1"
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										31
									
								
								storage/ent/sqlite_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								storage/ent/sqlite_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
			
		||||
package ent
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"os"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/sirupsen/logrus"
 | 
			
		||||
 | 
			
		||||
	"github.com/dexidp/dex/storage"
 | 
			
		||||
	"github.com/dexidp/dex/storage/conformance"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func newStorage() storage.Storage {
 | 
			
		||||
	logger := &logrus.Logger{
 | 
			
		||||
		Out:       os.Stderr,
 | 
			
		||||
		Formatter: &logrus.TextFormatter{DisableColors: true},
 | 
			
		||||
		Level:     logrus.DebugLevel,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cfg := SQLite3{File: ":memory:"}
 | 
			
		||||
	s, err := cfg.Open(logger)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return s
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestSQLite3(t *testing.T) {
 | 
			
		||||
	conformance.RunTests(t, newStorage)
 | 
			
		||||
	conformance.RunTransactionTests(t, newStorage)
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user