Add obsolete tokens, resolve conflicts, bump ent

Signed-off-by: m.nabokikh <maksim.nabokikh@flant.com>
This commit is contained in:
m.nabokikh
2021-04-30 18:31:49 +04:00
parent 24fa4def5b
commit 8553309db3
44 changed files with 766 additions and 111 deletions

View File

@@ -10,6 +10,16 @@ import (
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/dexidp/dex/storage/ent/db/authcode"
"github.com/dexidp/dex/storage/ent/db/authrequest"
"github.com/dexidp/dex/storage/ent/db/connector"
"github.com/dexidp/dex/storage/ent/db/devicerequest"
"github.com/dexidp/dex/storage/ent/db/devicetoken"
"github.com/dexidp/dex/storage/ent/db/keys"
"github.com/dexidp/dex/storage/ent/db/oauth2client"
"github.com/dexidp/dex/storage/ent/db/offlinesession"
"github.com/dexidp/dex/storage/ent/db/password"
"github.com/dexidp/dex/storage/ent/db/refreshtoken"
)
// ent aliases to avoid import conflicts in user's code.
@@ -25,36 +35,64 @@ type (
)
// OrderFunc applies an ordering on the sql selector.
type OrderFunc func(*sql.Selector, func(string) bool)
type OrderFunc func(*sql.Selector)
// columnChecker returns a function indicates if the column exists in the given column.
func columnChecker(table string) func(string) error {
checks := map[string]func(string) bool{
authcode.Table: authcode.ValidColumn,
authrequest.Table: authrequest.ValidColumn,
connector.Table: connector.ValidColumn,
devicerequest.Table: devicerequest.ValidColumn,
devicetoken.Table: devicetoken.ValidColumn,
keys.Table: keys.ValidColumn,
oauth2client.Table: oauth2client.ValidColumn,
offlinesession.Table: offlinesession.ValidColumn,
password.Table: password.ValidColumn,
refreshtoken.Table: refreshtoken.ValidColumn,
}
check, ok := checks[table]
if !ok {
return func(string) error {
return fmt.Errorf("unknown table %q", table)
}
}
return func(column string) error {
if !check(column) {
return fmt.Errorf("unknown column %q for table %q", column, table)
}
return nil
}
}
// Asc applies the given fields in ASC order.
func Asc(fields ...string) OrderFunc {
return func(s *sql.Selector, check func(string) bool) {
return func(s *sql.Selector) {
check := columnChecker(s.TableName())
for _, f := range fields {
if check(f) {
s.OrderBy(sql.Asc(f))
} else {
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)})
if err := check(f); err != nil {
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("db: %w", err)})
}
s.OrderBy(sql.Asc(s.C(f)))
}
}
}
// Desc applies the given fields in DESC order.
func Desc(fields ...string) OrderFunc {
return func(s *sql.Selector, check func(string) bool) {
return func(s *sql.Selector) {
check := columnChecker(s.TableName())
for _, f := range fields {
if check(f) {
s.OrderBy(sql.Desc(f))
} else {
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)})
if err := check(f); err != nil {
s.AddError(&ValidationError{Name: f, err: fmt.Errorf("db: %w", err)})
}
s.OrderBy(sql.Desc(s.C(f)))
}
}
}
// AggregateFunc applies an aggregation step on the group-by traversal/selector.
type AggregateFunc func(*sql.Selector, func(string) bool) string
type AggregateFunc func(*sql.Selector) string
// As is a pseudo aggregation function for renaming another other functions with custom names. For example:
//
@@ -63,23 +101,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string
// Scan(ctx, &v)
//
func As(fn AggregateFunc, end string) AggregateFunc {
return func(s *sql.Selector, check func(string) bool) string {
return sql.As(fn(s, check), end)
return func(s *sql.Selector) string {
return sql.As(fn(s), end)
}
}
// Count applies the "count" aggregation function on each group.
func Count() AggregateFunc {
return func(s *sql.Selector, _ func(string) bool) string {
return func(s *sql.Selector) string {
return sql.Count("*")
}
}
// Max applies the "max" aggregation function on the given field of each group.
func Max(field string) AggregateFunc {
return func(s *sql.Selector, check func(string) bool) string {
if !check(field) {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)})
return func(s *sql.Selector) string {
check := columnChecker(s.TableName())
if err := check(field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)})
return ""
}
return sql.Max(s.C(field))
@@ -88,9 +127,10 @@ func Max(field string) AggregateFunc {
// Mean applies the "mean" aggregation function on the given field of each group.
func Mean(field string) AggregateFunc {
return func(s *sql.Selector, check func(string) bool) string {
if !check(field) {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)})
return func(s *sql.Selector) string {
check := columnChecker(s.TableName())
if err := check(field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)})
return ""
}
return sql.Avg(s.C(field))
@@ -99,9 +139,10 @@ func Mean(field string) AggregateFunc {
// Min applies the "min" aggregation function on the given field of each group.
func Min(field string) AggregateFunc {
return func(s *sql.Selector, check func(string) bool) string {
if !check(field) {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)})
return func(s *sql.Selector) string {
check := columnChecker(s.TableName())
if err := check(field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)})
return ""
}
return sql.Min(s.C(field))
@@ -110,9 +151,10 @@ func Min(field string) AggregateFunc {
// Sum applies the "sum" aggregation function on the given field of each group.
func Sum(field string) AggregateFunc {
return func(s *sql.Selector, check func(string) bool) string {
if !check(field) {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)})
return func(s *sql.Selector) string {
check := columnChecker(s.TableName())
if err := check(field); err != nil {
s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)})
return ""
}
return sql.Sum(s.C(field))