From 8553309db385268066470e1ed5250dce3a2f544a Mon Sep 17 00:00:00 2001 From: "m.nabokikh" Date: Fri, 30 Apr 2021 18:31:49 +0400 Subject: [PATCH] Add obsolete tokens, resolve conflicts, bump ent Signed-off-by: m.nabokikh --- go.mod | 4 +- go.sum | 17 ++- storage/ent/client/refreshtoken.go | 2 + storage/ent/client/types.go | 1 + storage/ent/db/authcode.go | 8 +- storage/ent/db/authcode_query.go | 17 ++- storage/ent/db/authcode_update.go | 20 ++++ storage/ent/db/authrequest.go | 8 +- storage/ent/db/authrequest_query.go | 17 ++- storage/ent/db/authrequest_update.go | 20 ++++ storage/ent/db/client.go | 40 +++++-- storage/ent/db/connector.go | 4 +- storage/ent/db/connector_query.go | 17 ++- storage/ent/db/connector_update.go | 20 ++++ storage/ent/db/devicerequest.go | 8 +- storage/ent/db/devicerequest_query.go | 17 ++- storage/ent/db/devicerequest_update.go | 20 ++++ storage/ent/db/devicetoken.go | 8 +- storage/ent/db/devicetoken_query.go | 17 ++- storage/ent/db/devicetoken_update.go | 20 ++++ storage/ent/db/ent.go | 96 +++++++++++----- storage/ent/db/keys.go | 6 +- storage/ent/db/keys_query.go | 17 ++- storage/ent/db/keys_update.go | 20 ++++ storage/ent/db/migrate/schema.go | 1 + storage/ent/db/mutation.go | 56 +++++++++- storage/ent/db/oauth2client.go | 6 +- storage/ent/db/oauth2client_query.go | 17 ++- storage/ent/db/oauth2client_update.go | 20 ++++ storage/ent/db/offlinesession.go | 4 +- storage/ent/db/offlinesession_query.go | 17 ++- storage/ent/db/offlinesession_update.go | 20 ++++ storage/ent/db/password.go | 6 +- storage/ent/db/password_query.go | 17 ++- storage/ent/db/password_update.go | 20 ++++ storage/ent/db/refreshtoken.go | 20 +++- storage/ent/db/refreshtoken/refreshtoken.go | 5 + storage/ent/db/refreshtoken/where.go | 118 ++++++++++++++++++++ storage/ent/db/refreshtoken_create.go | 29 +++++ storage/ent/db/refreshtoken_query.go | 17 ++- storage/ent/db/refreshtoken_update.go | 62 ++++++++++ storage/ent/db/runtime.go | 8 +- storage/ent/db/runtime/runtime.go | 4 +- storage/ent/schema/refreshtoken.go | 6 +- 44 files changed, 766 insertions(+), 111 deletions(-) diff --git a/go.mod b/go.mod index ec276810..f5e01257 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,12 @@ module github.com/dexidp/dex go 1.16 require ( - entgo.io/ent v0.7.0 + entgo.io/ent v0.8.0 github.com/AppsFlyer/go-sundheit v0.3.1 github.com/beevik/etree v1.1.0 github.com/coreos/go-oidc/v3 v3.0.0 github.com/dexidp/dex/api/v2 v2.0.0 - github.com/felixge/httpsnoop v1.0.1 + github.com/felixge/httpsnoop v1.0.2 github.com/ghodss/yaml v1.0.0 github.com/go-ldap/ldap/v3 v3.3.0 github.com/go-sql-driver/mysql v1.6.0 diff --git a/go.sum b/go.sum index baf9c97d..f80698ee 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +entgo.io/ent v0.8.0 h1:xirrW//1oda7pp0bz+XssSOv4/C3nmgYQOxjIfljFt8= +entgo.io/ent v0.8.0/go.mod h1:KNjsukat/NJi6zJh1utwRadsbGOZsBbAZNDxkW7tMCc= github.com/AppsFlyer/go-sundheit v0.3.1 h1:Zqnr3wV3WQmXonc234k9XZAoV2KHUHw3osR5k2iHQZE= github.com/AppsFlyer/go-sundheit v0.3.1/go.mod h1:iZ8zWMS7idcvmqewf5mEymWWgoOiG/0WD4+aeh+heX4= github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c h1:/IBSNwUN8+eKzUzbJPqhK839ygXJ82sde8x3ogr6R28= @@ -45,6 +47,8 @@ github.com/Azure/go-ntlmssp v0.0.0-20200615164410-66371956d46c/go.mod h1:chxPXzS github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= @@ -128,8 +132,6 @@ github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5y github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/facebook/ent v0.5.3 h1:YT3Sl28n7gGGOkQeYgeJsZmizJ1Iiy7psgkOtEk0aq4= -github.com/facebook/ent v0.5.3/go.mod h1:tlWP+qCd3x2EeO7B/EqlJQ4dWu/2IeYFhP/szzDKAi8= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/felixge/httpsnoop v1.0.2 h1:+nS9g82KMXccJ/wp0zyRW9ZBHFETmMGtkk+2CTTrW4o= @@ -143,6 +145,7 @@ github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-asn1-ber/asn1-ber v1.5.1 h1:pDbRAunXzIUXfx4CB2QJFv5IuPiuoW+sWvr/Us009o8= github.com/go-asn1-ber/asn1-ber v1.5.1/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= +github.com/go-bindata/go-bindata v1.0.1-0.20190711162640-ee3c2418e368/go.mod h1:7xCgX1lzlrXPHkfvn3EhumqHkmSlzt8at9q7v0ax19c= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -154,7 +157,9 @@ github.com/go-ldap/ldap/v3 v3.3.0/go.mod h1:iYS1MdmrmceOJ1QOTnRXrIs7i3kloqtmGQjR github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-openapi/inflect v0.19.0/go.mod h1:lHpZVlpIQqLyKwJ4N+YSc9hchQy/i12fJykb83CRBH4= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-sql-driver/mysql v1.5.1-0.20200311113236-681ffa848bae/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -232,6 +237,8 @@ github.com/google/pprof v0.0.0-20210226084205-cbba55b83ad5/go.mod h1:kpwsk12EmLe github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= +github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= @@ -283,6 +290,7 @@ github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1: github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod h1:qj24IKcXYK6Iy9ceXlo3Tc+vtHo9lIhSX5JddghvEPo= +github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= github.com/jonboulle/clockwork v0.2.0 h1:J2SLSdy7HgElq8ekSl2Mxh6vrRNFxqbXGenYH2I02Vs= @@ -311,6 +319,7 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.1 h1:6VXZrLU0jHBYyAqrSPa+MgPfnSvTPuMgK+k0o5kVFWo= github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= @@ -323,6 +332,8 @@ github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaO github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.4/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= @@ -336,6 +347,7 @@ github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS4 github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0QubkSMEySY= github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -355,6 +367,7 @@ github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA= github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/olekukonko/tablewriter v0.0.0-20170122224234-a0225b3f23b5/go.mod h1:vsDQFd/mU46D+Z4whnwzcISnGGzXWMclvtLoiIKAKIo= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/onsi/ginkgo v1.4.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0 h1:WSHQ+IS43OoUrWtD1/bbclrwK8TTH5hzp+umCiuxHgs= diff --git a/storage/ent/client/refreshtoken.go b/storage/ent/client/refreshtoken.go index 90f3c6ae..0b90233d 100644 --- a/storage/ent/client/refreshtoken.go +++ b/storage/ent/client/refreshtoken.go @@ -22,6 +22,7 @@ func (d *Database) CreateRefresh(refresh storage.RefreshToken) error { SetConnectorID(refresh.ConnectorID). SetConnectorData(refresh.ConnectorData). SetToken(refresh.Token). + SetObsoleteToken(refresh.ObsoleteToken). // Save utc time into database because ent doesn't support comparing dates with different timezones SetLastUsed(refresh.LastUsed.UTC()). SetCreatedAt(refresh.CreatedAt.UTC()). @@ -94,6 +95,7 @@ func (d *Database) UpdateRefreshToken(id string, updater func(old storage.Refres SetConnectorID(newtToken.ConnectorID). SetConnectorData(newtToken.ConnectorData). SetToken(newtToken.Token). + SetObsoleteToken(newtToken.ObsoleteToken). // Save utc time into database because ent doesn't support comparing dates with different timezones SetLastUsed(newtToken.LastUsed.UTC()). SetCreatedAt(newtToken.CreatedAt.UTC()). diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index 388ef3e5..57f1c0a7 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -117,6 +117,7 @@ func toStorageRefreshToken(r *db.RefreshToken) storage.RefreshToken { return storage.RefreshToken{ ID: r.ID, Token: r.Token, + ObsoleteToken: r.ObsoleteToken, CreatedAt: r.CreatedAt, LastUsed: r.LastUsed, ClientID: r.ClientID, diff --git a/storage/ent/db/authcode.go b/storage/ent/db/authcode.go index 6b177880..29b5e4f5 100644 --- a/storage/ent/db/authcode.go +++ b/storage/ent/db/authcode.go @@ -55,13 +55,13 @@ func (*AuthCode) scanValues(columns []string) ([]interface{}, error) { for i := range columns { switch columns[i] { case authcode.FieldScopes, authcode.FieldClaimsGroups, authcode.FieldConnectorData: - values[i] = &[]byte{} + values[i] = new([]byte) case authcode.FieldClaimsEmailVerified: - values[i] = &sql.NullBool{} + values[i] = new(sql.NullBool) case authcode.FieldID, authcode.FieldClientID, authcode.FieldNonce, authcode.FieldRedirectURI, authcode.FieldClaimsUserID, authcode.FieldClaimsUsername, authcode.FieldClaimsEmail, authcode.FieldClaimsPreferredUsername, authcode.FieldConnectorID, authcode.FieldCodeChallenge, authcode.FieldCodeChallengeMethod: - values[i] = &sql.NullString{} + values[i] = new(sql.NullString) case authcode.FieldExpiry: - values[i] = &sql.NullTime{} + values[i] = new(sql.NullTime) default: return nil, fmt.Errorf("unexpected column %q for type AuthCode", columns[i]) } diff --git a/storage/ent/db/authcode_query.go b/storage/ent/db/authcode_query.go index 3d2010c0..96b6a485 100644 --- a/storage/ent/db/authcode_query.go +++ b/storage/ent/db/authcode_query.go @@ -20,6 +20,7 @@ type AuthCodeQuery struct { config limit *int offset *int + unique *bool order []OrderFunc fields []string predicates []predicate.AuthCode @@ -46,6 +47,13 @@ func (acq *AuthCodeQuery) Offset(offset int) *AuthCodeQuery { return acq } +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (acq *AuthCodeQuery) Unique(unique bool) *AuthCodeQuery { + acq.unique = &unique + return acq +} + // Order adds an order step to the query. func (acq *AuthCodeQuery) Order(o ...OrderFunc) *AuthCodeQuery { acq.order = append(acq.order, o...) @@ -352,6 +360,9 @@ func (acq *AuthCodeQuery) querySpec() *sqlgraph.QuerySpec { From: acq.sql, Unique: true, } + if unique := acq.unique; unique != nil { + _spec.Unique = *unique + } if fields := acq.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, authcode.FieldID) @@ -377,7 +388,7 @@ func (acq *AuthCodeQuery) querySpec() *sqlgraph.QuerySpec { if ps := acq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, authcode.ValidColumn) + ps[i](selector) } } } @@ -396,7 +407,7 @@ func (acq *AuthCodeQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range acq.order { - p(selector, authcode.ValidColumn) + p(selector) } if offset := acq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -662,7 +673,7 @@ func (acgb *AuthCodeGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(acgb.fields)+len(acgb.fns)) columns = append(columns, acgb.fields...) for _, fn := range acgb.fns { - columns = append(columns, fn(selector, authcode.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(acgb.fields...) } diff --git a/storage/ent/db/authcode_update.go b/storage/ent/db/authcode_update.go index 86eb8c5f..08374bd3 100644 --- a/storage/ent/db/authcode_update.go +++ b/storage/ent/db/authcode_update.go @@ -416,6 +416,7 @@ func (acu *AuthCodeUpdate) sqlSave(ctx context.Context) (n int, err error) { // AuthCodeUpdateOne is the builder for updating a single AuthCode entity. type AuthCodeUpdateOne struct { config + fields []string hooks []Hook mutation *AuthCodeMutation } @@ -557,6 +558,13 @@ func (acuo *AuthCodeUpdateOne) Mutation() *AuthCodeMutation { return acuo.mutation } +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (acuo *AuthCodeUpdateOne) Select(field string, fields ...string) *AuthCodeUpdateOne { + acuo.fields = append([]string{field}, fields...) + return acuo +} + // Save executes the query and returns the updated AuthCode entity. func (acuo *AuthCodeUpdateOne) Save(ctx context.Context) (*AuthCode, error) { var ( @@ -670,6 +678,18 @@ func (acuo *AuthCodeUpdateOne) sqlSave(ctx context.Context) (_node *AuthCode, er return nil, &ValidationError{Name: "ID", err: fmt.Errorf("missing AuthCode.ID for update")} } _spec.Node.ID.Value = id + if fields := acuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authcode.FieldID) + for _, f := range fields { + if !authcode.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != authcode.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } if ps := acuo.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { diff --git a/storage/ent/db/authrequest.go b/storage/ent/db/authrequest.go index 669185fe..ed64d9f6 100644 --- a/storage/ent/db/authrequest.go +++ b/storage/ent/db/authrequest.go @@ -63,13 +63,13 @@ func (*AuthRequest) scanValues(columns []string) ([]interface{}, error) { for i := range columns { switch columns[i] { case authrequest.FieldScopes, authrequest.FieldResponseTypes, authrequest.FieldClaimsGroups, authrequest.FieldConnectorData: - values[i] = &[]byte{} + values[i] = new([]byte) case authrequest.FieldForceApprovalPrompt, authrequest.FieldLoggedIn, authrequest.FieldClaimsEmailVerified: - values[i] = &sql.NullBool{} + values[i] = new(sql.NullBool) case authrequest.FieldID, authrequest.FieldClientID, authrequest.FieldRedirectURI, authrequest.FieldNonce, authrequest.FieldState, authrequest.FieldClaimsUserID, authrequest.FieldClaimsUsername, authrequest.FieldClaimsEmail, authrequest.FieldClaimsPreferredUsername, authrequest.FieldConnectorID, authrequest.FieldCodeChallenge, authrequest.FieldCodeChallengeMethod: - values[i] = &sql.NullString{} + values[i] = new(sql.NullString) case authrequest.FieldExpiry: - values[i] = &sql.NullTime{} + values[i] = new(sql.NullTime) default: return nil, fmt.Errorf("unexpected column %q for type AuthRequest", columns[i]) } diff --git a/storage/ent/db/authrequest_query.go b/storage/ent/db/authrequest_query.go index 4c4573ea..b55861cf 100644 --- a/storage/ent/db/authrequest_query.go +++ b/storage/ent/db/authrequest_query.go @@ -20,6 +20,7 @@ type AuthRequestQuery struct { config limit *int offset *int + unique *bool order []OrderFunc fields []string predicates []predicate.AuthRequest @@ -46,6 +47,13 @@ func (arq *AuthRequestQuery) Offset(offset int) *AuthRequestQuery { return arq } +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (arq *AuthRequestQuery) Unique(unique bool) *AuthRequestQuery { + arq.unique = &unique + return arq +} + // Order adds an order step to the query. func (arq *AuthRequestQuery) Order(o ...OrderFunc) *AuthRequestQuery { arq.order = append(arq.order, o...) @@ -352,6 +360,9 @@ func (arq *AuthRequestQuery) querySpec() *sqlgraph.QuerySpec { From: arq.sql, Unique: true, } + if unique := arq.unique; unique != nil { + _spec.Unique = *unique + } if fields := arq.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, authrequest.FieldID) @@ -377,7 +388,7 @@ func (arq *AuthRequestQuery) querySpec() *sqlgraph.QuerySpec { if ps := arq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, authrequest.ValidColumn) + ps[i](selector) } } } @@ -396,7 +407,7 @@ func (arq *AuthRequestQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range arq.order { - p(selector, authrequest.ValidColumn) + p(selector) } if offset := arq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -662,7 +673,7 @@ func (argb *AuthRequestGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(argb.fields)+len(argb.fns)) columns = append(columns, argb.fields...) for _, fn := range argb.fns { - columns = append(columns, fn(selector, authrequest.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(argb.fields...) } diff --git a/storage/ent/db/authrequest_update.go b/storage/ent/db/authrequest_update.go index 87057605..2d3f8594 100644 --- a/storage/ent/db/authrequest_update.go +++ b/storage/ent/db/authrequest_update.go @@ -434,6 +434,7 @@ func (aru *AuthRequestUpdate) sqlSave(ctx context.Context) (n int, err error) { // AuthRequestUpdateOne is the builder for updating a single AuthRequest entity. type AuthRequestUpdateOne struct { config + fields []string hooks []Hook mutation *AuthRequestMutation } @@ -605,6 +606,13 @@ func (aruo *AuthRequestUpdateOne) Mutation() *AuthRequestMutation { return aruo.mutation } +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (aruo *AuthRequestUpdateOne) Select(field string, fields ...string) *AuthRequestUpdateOne { + aruo.fields = append([]string{field}, fields...) + return aruo +} + // Save executes the query and returns the updated AuthRequest entity. func (aruo *AuthRequestUpdateOne) Save(ctx context.Context) (*AuthRequest, error) { var ( @@ -672,6 +680,18 @@ func (aruo *AuthRequestUpdateOne) sqlSave(ctx context.Context) (_node *AuthReque return nil, &ValidationError{Name: "ID", err: fmt.Errorf("missing AuthRequest.ID for update")} } _spec.Node.ID.Value = id + if fields := aruo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authrequest.FieldID) + for _, f := range fields { + if !authrequest.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != authrequest.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } if ps := aruo.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { diff --git a/storage/ent/db/client.go b/storage/ent/db/client.go index 1cd21708..a27286a0 100644 --- a/storage/ent/db/client.go +++ b/storage/ent/db/client.go @@ -250,7 +250,9 @@ func (c *AuthCodeClient) DeleteOneID(id string) *AuthCodeDeleteOne { // Query returns a query builder for AuthCode. func (c *AuthCodeClient) Query() *AuthCodeQuery { - return &AuthCodeQuery{config: c.config} + return &AuthCodeQuery{ + config: c.config, + } } // Get returns a AuthCode entity by its id. @@ -338,7 +340,9 @@ func (c *AuthRequestClient) DeleteOneID(id string) *AuthRequestDeleteOne { // Query returns a query builder for AuthRequest. func (c *AuthRequestClient) Query() *AuthRequestQuery { - return &AuthRequestQuery{config: c.config} + return &AuthRequestQuery{ + config: c.config, + } } // Get returns a AuthRequest entity by its id. @@ -426,7 +430,9 @@ func (c *ConnectorClient) DeleteOneID(id string) *ConnectorDeleteOne { // Query returns a query builder for Connector. func (c *ConnectorClient) Query() *ConnectorQuery { - return &ConnectorQuery{config: c.config} + return &ConnectorQuery{ + config: c.config, + } } // Get returns a Connector entity by its id. @@ -514,7 +520,9 @@ func (c *DeviceRequestClient) DeleteOneID(id int) *DeviceRequestDeleteOne { // Query returns a query builder for DeviceRequest. func (c *DeviceRequestClient) Query() *DeviceRequestQuery { - return &DeviceRequestQuery{config: c.config} + return &DeviceRequestQuery{ + config: c.config, + } } // Get returns a DeviceRequest entity by its id. @@ -602,7 +610,9 @@ func (c *DeviceTokenClient) DeleteOneID(id int) *DeviceTokenDeleteOne { // Query returns a query builder for DeviceToken. func (c *DeviceTokenClient) Query() *DeviceTokenQuery { - return &DeviceTokenQuery{config: c.config} + return &DeviceTokenQuery{ + config: c.config, + } } // Get returns a DeviceToken entity by its id. @@ -690,7 +700,9 @@ func (c *KeysClient) DeleteOneID(id string) *KeysDeleteOne { // Query returns a query builder for Keys. func (c *KeysClient) Query() *KeysQuery { - return &KeysQuery{config: c.config} + return &KeysQuery{ + config: c.config, + } } // Get returns a Keys entity by its id. @@ -778,7 +790,9 @@ func (c *OAuth2ClientClient) DeleteOneID(id string) *OAuth2ClientDeleteOne { // Query returns a query builder for OAuth2Client. func (c *OAuth2ClientClient) Query() *OAuth2ClientQuery { - return &OAuth2ClientQuery{config: c.config} + return &OAuth2ClientQuery{ + config: c.config, + } } // Get returns a OAuth2Client entity by its id. @@ -866,7 +880,9 @@ func (c *OfflineSessionClient) DeleteOneID(id string) *OfflineSessionDeleteOne { // Query returns a query builder for OfflineSession. func (c *OfflineSessionClient) Query() *OfflineSessionQuery { - return &OfflineSessionQuery{config: c.config} + return &OfflineSessionQuery{ + config: c.config, + } } // Get returns a OfflineSession entity by its id. @@ -954,7 +970,9 @@ func (c *PasswordClient) DeleteOneID(id int) *PasswordDeleteOne { // Query returns a query builder for Password. func (c *PasswordClient) Query() *PasswordQuery { - return &PasswordQuery{config: c.config} + return &PasswordQuery{ + config: c.config, + } } // Get returns a Password entity by its id. @@ -1042,7 +1060,9 @@ func (c *RefreshTokenClient) DeleteOneID(id string) *RefreshTokenDeleteOne { // Query returns a query builder for RefreshToken. func (c *RefreshTokenClient) Query() *RefreshTokenQuery { - return &RefreshTokenQuery{config: c.config} + return &RefreshTokenQuery{ + config: c.config, + } } // Get returns a RefreshToken entity by its id. diff --git a/storage/ent/db/connector.go b/storage/ent/db/connector.go index 94614c44..3bcb7ee5 100644 --- a/storage/ent/db/connector.go +++ b/storage/ent/db/connector.go @@ -31,9 +31,9 @@ func (*Connector) scanValues(columns []string) ([]interface{}, error) { for i := range columns { switch columns[i] { case connector.FieldConfig: - values[i] = &[]byte{} + values[i] = new([]byte) case connector.FieldID, connector.FieldType, connector.FieldName, connector.FieldResourceVersion: - values[i] = &sql.NullString{} + values[i] = new(sql.NullString) default: return nil, fmt.Errorf("unexpected column %q for type Connector", columns[i]) } diff --git a/storage/ent/db/connector_query.go b/storage/ent/db/connector_query.go index 89d0acfb..2b4c7872 100644 --- a/storage/ent/db/connector_query.go +++ b/storage/ent/db/connector_query.go @@ -20,6 +20,7 @@ type ConnectorQuery struct { config limit *int offset *int + unique *bool order []OrderFunc fields []string predicates []predicate.Connector @@ -46,6 +47,13 @@ func (cq *ConnectorQuery) Offset(offset int) *ConnectorQuery { return cq } +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (cq *ConnectorQuery) Unique(unique bool) *ConnectorQuery { + cq.unique = &unique + return cq +} + // Order adds an order step to the query. func (cq *ConnectorQuery) Order(o ...OrderFunc) *ConnectorQuery { cq.order = append(cq.order, o...) @@ -352,6 +360,9 @@ func (cq *ConnectorQuery) querySpec() *sqlgraph.QuerySpec { From: cq.sql, Unique: true, } + if unique := cq.unique; unique != nil { + _spec.Unique = *unique + } if fields := cq.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, connector.FieldID) @@ -377,7 +388,7 @@ func (cq *ConnectorQuery) querySpec() *sqlgraph.QuerySpec { if ps := cq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, connector.ValidColumn) + ps[i](selector) } } } @@ -396,7 +407,7 @@ func (cq *ConnectorQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range cq.order { - p(selector, connector.ValidColumn) + p(selector) } if offset := cq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -662,7 +673,7 @@ func (cgb *ConnectorGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(cgb.fields)+len(cgb.fns)) columns = append(columns, cgb.fields...) for _, fn := range cgb.fns { - columns = append(columns, fn(selector, connector.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(cgb.fields...) } diff --git a/storage/ent/db/connector_update.go b/storage/ent/db/connector_update.go index 5e53ea0b..90c972e4 100644 --- a/storage/ent/db/connector_update.go +++ b/storage/ent/db/connector_update.go @@ -187,6 +187,7 @@ func (cu *ConnectorUpdate) sqlSave(ctx context.Context) (n int, err error) { // ConnectorUpdateOne is the builder for updating a single Connector entity. type ConnectorUpdateOne struct { config + fields []string hooks []Hook mutation *ConnectorMutation } @@ -220,6 +221,13 @@ func (cuo *ConnectorUpdateOne) Mutation() *ConnectorMutation { return cuo.mutation } +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (cuo *ConnectorUpdateOne) Select(field string, fields ...string) *ConnectorUpdateOne { + cuo.fields = append([]string{field}, fields...) + return cuo +} + // Save executes the query and returns the updated Connector entity. func (cuo *ConnectorUpdateOne) Save(ctx context.Context) (*Connector, error) { var ( @@ -308,6 +316,18 @@ func (cuo *ConnectorUpdateOne) sqlSave(ctx context.Context) (_node *Connector, e return nil, &ValidationError{Name: "ID", err: fmt.Errorf("missing Connector.ID for update")} } _spec.Node.ID.Value = id + if fields := cuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, connector.FieldID) + for _, f := range fields { + if !connector.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != connector.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } if ps := cuo.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { diff --git a/storage/ent/db/devicerequest.go b/storage/ent/db/devicerequest.go index d66435d5..d50a7c83 100644 --- a/storage/ent/db/devicerequest.go +++ b/storage/ent/db/devicerequest.go @@ -37,13 +37,13 @@ func (*DeviceRequest) scanValues(columns []string) ([]interface{}, error) { for i := range columns { switch columns[i] { case devicerequest.FieldScopes: - values[i] = &[]byte{} + values[i] = new([]byte) case devicerequest.FieldID: - values[i] = &sql.NullInt64{} + values[i] = new(sql.NullInt64) case devicerequest.FieldUserCode, devicerequest.FieldDeviceCode, devicerequest.FieldClientID, devicerequest.FieldClientSecret: - values[i] = &sql.NullString{} + values[i] = new(sql.NullString) case devicerequest.FieldExpiry: - values[i] = &sql.NullTime{} + values[i] = new(sql.NullTime) default: return nil, fmt.Errorf("unexpected column %q for type DeviceRequest", columns[i]) } diff --git a/storage/ent/db/devicerequest_query.go b/storage/ent/db/devicerequest_query.go index 520812bd..08c76871 100644 --- a/storage/ent/db/devicerequest_query.go +++ b/storage/ent/db/devicerequest_query.go @@ -20,6 +20,7 @@ type DeviceRequestQuery struct { config limit *int offset *int + unique *bool order []OrderFunc fields []string predicates []predicate.DeviceRequest @@ -46,6 +47,13 @@ func (drq *DeviceRequestQuery) Offset(offset int) *DeviceRequestQuery { return drq } +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (drq *DeviceRequestQuery) Unique(unique bool) *DeviceRequestQuery { + drq.unique = &unique + return drq +} + // Order adds an order step to the query. func (drq *DeviceRequestQuery) Order(o ...OrderFunc) *DeviceRequestQuery { drq.order = append(drq.order, o...) @@ -352,6 +360,9 @@ func (drq *DeviceRequestQuery) querySpec() *sqlgraph.QuerySpec { From: drq.sql, Unique: true, } + if unique := drq.unique; unique != nil { + _spec.Unique = *unique + } if fields := drq.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, devicerequest.FieldID) @@ -377,7 +388,7 @@ func (drq *DeviceRequestQuery) querySpec() *sqlgraph.QuerySpec { if ps := drq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, devicerequest.ValidColumn) + ps[i](selector) } } } @@ -396,7 +407,7 @@ func (drq *DeviceRequestQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range drq.order { - p(selector, devicerequest.ValidColumn) + p(selector) } if offset := drq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -662,7 +673,7 @@ func (drgb *DeviceRequestGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(drgb.fields)+len(drgb.fns)) columns = append(columns, drgb.fields...) for _, fn := range drgb.fns { - columns = append(columns, fn(selector, devicerequest.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(drgb.fields...) } diff --git a/storage/ent/db/devicerequest_update.go b/storage/ent/db/devicerequest_update.go index aba544e1..d71ca0ed 100644 --- a/storage/ent/db/devicerequest_update.go +++ b/storage/ent/db/devicerequest_update.go @@ -236,6 +236,7 @@ func (dru *DeviceRequestUpdate) sqlSave(ctx context.Context) (n int, err error) // DeviceRequestUpdateOne is the builder for updating a single DeviceRequest entity. type DeviceRequestUpdateOne struct { config + fields []string hooks []Hook mutation *DeviceRequestMutation } @@ -287,6 +288,13 @@ func (druo *DeviceRequestUpdateOne) Mutation() *DeviceRequestMutation { return druo.mutation } +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (druo *DeviceRequestUpdateOne) Select(field string, fields ...string) *DeviceRequestUpdateOne { + druo.fields = append([]string{field}, fields...) + return druo +} + // Save executes the query and returns the updated DeviceRequest entity. func (druo *DeviceRequestUpdateOne) Save(ctx context.Context) (*DeviceRequest, error) { var ( @@ -385,6 +393,18 @@ func (druo *DeviceRequestUpdateOne) sqlSave(ctx context.Context) (_node *DeviceR return nil, &ValidationError{Name: "ID", err: fmt.Errorf("missing DeviceRequest.ID for update")} } _spec.Node.ID.Value = id + if fields := druo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, devicerequest.FieldID) + for _, f := range fields { + if !devicerequest.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != devicerequest.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } if ps := druo.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { diff --git a/storage/ent/db/devicetoken.go b/storage/ent/db/devicetoken.go index 6a88b5d5..1731d1f0 100644 --- a/storage/ent/db/devicetoken.go +++ b/storage/ent/db/devicetoken.go @@ -36,13 +36,13 @@ func (*DeviceToken) scanValues(columns []string) ([]interface{}, error) { for i := range columns { switch columns[i] { case devicetoken.FieldToken: - values[i] = &[]byte{} + values[i] = new([]byte) case devicetoken.FieldID, devicetoken.FieldPollInterval: - values[i] = &sql.NullInt64{} + values[i] = new(sql.NullInt64) case devicetoken.FieldDeviceCode, devicetoken.FieldStatus: - values[i] = &sql.NullString{} + values[i] = new(sql.NullString) case devicetoken.FieldExpiry, devicetoken.FieldLastRequest: - values[i] = &sql.NullTime{} + values[i] = new(sql.NullTime) default: return nil, fmt.Errorf("unexpected column %q for type DeviceToken", columns[i]) } diff --git a/storage/ent/db/devicetoken_query.go b/storage/ent/db/devicetoken_query.go index df399481..e085440d 100644 --- a/storage/ent/db/devicetoken_query.go +++ b/storage/ent/db/devicetoken_query.go @@ -20,6 +20,7 @@ type DeviceTokenQuery struct { config limit *int offset *int + unique *bool order []OrderFunc fields []string predicates []predicate.DeviceToken @@ -46,6 +47,13 @@ func (dtq *DeviceTokenQuery) Offset(offset int) *DeviceTokenQuery { return dtq } +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (dtq *DeviceTokenQuery) Unique(unique bool) *DeviceTokenQuery { + dtq.unique = &unique + return dtq +} + // Order adds an order step to the query. func (dtq *DeviceTokenQuery) Order(o ...OrderFunc) *DeviceTokenQuery { dtq.order = append(dtq.order, o...) @@ -352,6 +360,9 @@ func (dtq *DeviceTokenQuery) querySpec() *sqlgraph.QuerySpec { From: dtq.sql, Unique: true, } + if unique := dtq.unique; unique != nil { + _spec.Unique = *unique + } if fields := dtq.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, devicetoken.FieldID) @@ -377,7 +388,7 @@ func (dtq *DeviceTokenQuery) querySpec() *sqlgraph.QuerySpec { if ps := dtq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, devicetoken.ValidColumn) + ps[i](selector) } } } @@ -396,7 +407,7 @@ func (dtq *DeviceTokenQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range dtq.order { - p(selector, devicetoken.ValidColumn) + p(selector) } if offset := dtq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -662,7 +673,7 @@ func (dtgb *DeviceTokenGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(dtgb.fields)+len(dtgb.fns)) columns = append(columns, dtgb.fields...) for _, fn := range dtgb.fns { - columns = append(columns, fn(selector, devicetoken.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(dtgb.fields...) } diff --git a/storage/ent/db/devicetoken_update.go b/storage/ent/db/devicetoken_update.go index 123b6227..51a4efe0 100644 --- a/storage/ent/db/devicetoken_update.go +++ b/storage/ent/db/devicetoken_update.go @@ -240,6 +240,7 @@ func (dtu *DeviceTokenUpdate) sqlSave(ctx context.Context) (n int, err error) { // DeviceTokenUpdateOne is the builder for updating a single DeviceToken entity. type DeviceTokenUpdateOne struct { config + fields []string hooks []Hook mutation *DeviceTokenMutation } @@ -298,6 +299,13 @@ func (dtuo *DeviceTokenUpdateOne) Mutation() *DeviceTokenMutation { return dtuo.mutation } +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (dtuo *DeviceTokenUpdateOne) Select(field string, fields ...string) *DeviceTokenUpdateOne { + dtuo.fields = append([]string{field}, fields...) + return dtuo +} + // Save executes the query and returns the updated DeviceToken entity. func (dtuo *DeviceTokenUpdateOne) Save(ctx context.Context) (*DeviceToken, error) { var ( @@ -386,6 +394,18 @@ func (dtuo *DeviceTokenUpdateOne) sqlSave(ctx context.Context) (_node *DeviceTok return nil, &ValidationError{Name: "ID", err: fmt.Errorf("missing DeviceToken.ID for update")} } _spec.Node.ID.Value = id + if fields := dtuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, devicetoken.FieldID) + for _, f := range fields { + if !devicetoken.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != devicetoken.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } if ps := dtuo.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { diff --git a/storage/ent/db/ent.go b/storage/ent/db/ent.go index b66677b1..d84e721d 100644 --- a/storage/ent/db/ent.go +++ b/storage/ent/db/ent.go @@ -10,6 +10,16 @@ import ( "entgo.io/ent/dialect" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" + "github.com/dexidp/dex/storage/ent/db/authcode" + "github.com/dexidp/dex/storage/ent/db/authrequest" + "github.com/dexidp/dex/storage/ent/db/connector" + "github.com/dexidp/dex/storage/ent/db/devicerequest" + "github.com/dexidp/dex/storage/ent/db/devicetoken" + "github.com/dexidp/dex/storage/ent/db/keys" + "github.com/dexidp/dex/storage/ent/db/oauth2client" + "github.com/dexidp/dex/storage/ent/db/offlinesession" + "github.com/dexidp/dex/storage/ent/db/password" + "github.com/dexidp/dex/storage/ent/db/refreshtoken" ) // ent aliases to avoid import conflicts in user's code. @@ -25,36 +35,64 @@ type ( ) // OrderFunc applies an ordering on the sql selector. -type OrderFunc func(*sql.Selector, func(string) bool) +type OrderFunc func(*sql.Selector) + +// columnChecker returns a function indicates if the column exists in the given column. +func columnChecker(table string) func(string) error { + checks := map[string]func(string) bool{ + authcode.Table: authcode.ValidColumn, + authrequest.Table: authrequest.ValidColumn, + connector.Table: connector.ValidColumn, + devicerequest.Table: devicerequest.ValidColumn, + devicetoken.Table: devicetoken.ValidColumn, + keys.Table: keys.ValidColumn, + oauth2client.Table: oauth2client.ValidColumn, + offlinesession.Table: offlinesession.ValidColumn, + password.Table: password.ValidColumn, + refreshtoken.Table: refreshtoken.ValidColumn, + } + check, ok := checks[table] + if !ok { + return func(string) error { + return fmt.Errorf("unknown table %q", table) + } + } + return func(column string) error { + if !check(column) { + return fmt.Errorf("unknown column %q for table %q", column, table) + } + return nil + } +} // Asc applies the given fields in ASC order. func Asc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Asc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("db: %w", err)}) } + s.OrderBy(sql.Asc(s.C(f))) } } } // Desc applies the given fields in DESC order. func Desc(fields ...string) OrderFunc { - return func(s *sql.Selector, check func(string) bool) { + return func(s *sql.Selector) { + check := columnChecker(s.TableName()) for _, f := range fields { - if check(f) { - s.OrderBy(sql.Desc(f)) - } else { - s.AddError(&ValidationError{Name: f, err: fmt.Errorf("invalid field %q for ordering", f)}) + if err := check(f); err != nil { + s.AddError(&ValidationError{Name: f, err: fmt.Errorf("db: %w", err)}) } + s.OrderBy(sql.Desc(s.C(f))) } } } // AggregateFunc applies an aggregation step on the group-by traversal/selector. -type AggregateFunc func(*sql.Selector, func(string) bool) string +type AggregateFunc func(*sql.Selector) string // As is a pseudo aggregation function for renaming another other functions with custom names. For example: // @@ -63,23 +101,24 @@ type AggregateFunc func(*sql.Selector, func(string) bool) string // Scan(ctx, &v) // func As(fn AggregateFunc, end string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - return sql.As(fn(s, check), end) + return func(s *sql.Selector) string { + return sql.As(fn(s), end) } } // Count applies the "count" aggregation function on each group. func Count() AggregateFunc { - return func(s *sql.Selector, _ func(string) bool) string { + return func(s *sql.Selector) string { return sql.Count("*") } } // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) return "" } return sql.Max(s.C(field)) @@ -88,9 +127,10 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) return "" } return sql.Avg(s.C(field)) @@ -99,9 +139,10 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) return "" } return sql.Min(s.C(field)) @@ -110,9 +151,10 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { - return func(s *sql.Selector, check func(string) bool) string { - if !check(field) { - s.AddError(&ValidationError{Name: field, err: fmt.Errorf("invalid field %q for grouping", field)}) + return func(s *sql.Selector) string { + check := columnChecker(s.TableName()) + if err := check(field); err != nil { + s.AddError(&ValidationError{Name: field, err: fmt.Errorf("db: %w", err)}) return "" } return sql.Sum(s.C(field)) diff --git a/storage/ent/db/keys.go b/storage/ent/db/keys.go index 790e4a9a..d0312c92 100644 --- a/storage/ent/db/keys.go +++ b/storage/ent/db/keys.go @@ -35,11 +35,11 @@ func (*Keys) scanValues(columns []string) ([]interface{}, error) { for i := range columns { switch columns[i] { case keys.FieldVerificationKeys, keys.FieldSigningKey, keys.FieldSigningKeyPub: - values[i] = &[]byte{} + values[i] = new([]byte) case keys.FieldID: - values[i] = &sql.NullString{} + values[i] = new(sql.NullString) case keys.FieldNextRotation: - values[i] = &sql.NullTime{} + values[i] = new(sql.NullTime) default: return nil, fmt.Errorf("unexpected column %q for type Keys", columns[i]) } diff --git a/storage/ent/db/keys_query.go b/storage/ent/db/keys_query.go index 34f30d04..6d6b00f9 100644 --- a/storage/ent/db/keys_query.go +++ b/storage/ent/db/keys_query.go @@ -20,6 +20,7 @@ type KeysQuery struct { config limit *int offset *int + unique *bool order []OrderFunc fields []string predicates []predicate.Keys @@ -46,6 +47,13 @@ func (kq *KeysQuery) Offset(offset int) *KeysQuery { return kq } +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (kq *KeysQuery) Unique(unique bool) *KeysQuery { + kq.unique = &unique + return kq +} + // Order adds an order step to the query. func (kq *KeysQuery) Order(o ...OrderFunc) *KeysQuery { kq.order = append(kq.order, o...) @@ -352,6 +360,9 @@ func (kq *KeysQuery) querySpec() *sqlgraph.QuerySpec { From: kq.sql, Unique: true, } + if unique := kq.unique; unique != nil { + _spec.Unique = *unique + } if fields := kq.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, keys.FieldID) @@ -377,7 +388,7 @@ func (kq *KeysQuery) querySpec() *sqlgraph.QuerySpec { if ps := kq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, keys.ValidColumn) + ps[i](selector) } } } @@ -396,7 +407,7 @@ func (kq *KeysQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range kq.order { - p(selector, keys.ValidColumn) + p(selector) } if offset := kq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -662,7 +673,7 @@ func (kgb *KeysGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(kgb.fields)+len(kgb.fns)) columns = append(columns, kgb.fields...) for _, fn := range kgb.fns { - columns = append(columns, fn(selector, keys.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(kgb.fields...) } diff --git a/storage/ent/db/keys_update.go b/storage/ent/db/keys_update.go index e82c631c..8bc0ed3e 100644 --- a/storage/ent/db/keys_update.go +++ b/storage/ent/db/keys_update.go @@ -169,6 +169,7 @@ func (ku *KeysUpdate) sqlSave(ctx context.Context) (n int, err error) { // KeysUpdateOne is the builder for updating a single Keys entity. type KeysUpdateOne struct { config + fields []string hooks []Hook mutation *KeysMutation } @@ -202,6 +203,13 @@ func (kuo *KeysUpdateOne) Mutation() *KeysMutation { return kuo.mutation } +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (kuo *KeysUpdateOne) Select(field string, fields ...string) *KeysUpdateOne { + kuo.fields = append([]string{field}, fields...) + return kuo +} + // Save executes the query and returns the updated Keys entity. func (kuo *KeysUpdateOne) Save(ctx context.Context) (*Keys, error) { var ( @@ -269,6 +277,18 @@ func (kuo *KeysUpdateOne) sqlSave(ctx context.Context) (_node *Keys, err error) return nil, &ValidationError{Name: "ID", err: fmt.Errorf("missing Keys.ID for update")} } _spec.Node.ID.Value = id + if fields := kuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, keys.FieldID) + for _, f := range fields { + if !keys.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != keys.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } if ps := kuo.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index 87874137..d5b1f535 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -190,6 +190,7 @@ var ( {Name: "connector_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"sqlite3": "text"}}, {Name: "connector_data", Type: field.TypeBytes, Nullable: true}, {Name: "token", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, + {Name: "obsolete_token", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"sqlite3": "text"}}, {Name: "created_at", Type: field.TypeTime}, {Name: "last_used", Type: field.TypeTime}, } diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index ac72af61..7ccab3f2 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -6151,6 +6151,7 @@ type RefreshTokenMutation struct { connector_id *string connector_data *[]byte token *string + obsolete_token *string created_at *time.Time last_used *time.Time clearedFields map[string]struct{} @@ -6715,6 +6716,42 @@ func (m *RefreshTokenMutation) ResetToken() { m.token = nil } +// SetObsoleteToken sets the "obsolete_token" field. +func (m *RefreshTokenMutation) SetObsoleteToken(s string) { + m.obsolete_token = &s +} + +// ObsoleteToken returns the value of the "obsolete_token" field in the mutation. +func (m *RefreshTokenMutation) ObsoleteToken() (r string, exists bool) { + v := m.obsolete_token + if v == nil { + return + } + return *v, true +} + +// OldObsoleteToken returns the old "obsolete_token" field's value of the RefreshToken entity. +// If the RefreshToken object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *RefreshTokenMutation) OldObsoleteToken(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, fmt.Errorf("OldObsoleteToken is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, fmt.Errorf("OldObsoleteToken requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldObsoleteToken: %w", err) + } + return oldValue.ObsoleteToken, nil +} + +// ResetObsoleteToken resets all changes to the "obsolete_token" field. +func (m *RefreshTokenMutation) ResetObsoleteToken() { + m.obsolete_token = nil +} + // SetCreatedAt sets the "created_at" field. func (m *RefreshTokenMutation) SetCreatedAt(t time.Time) { m.created_at = &t @@ -6801,7 +6838,7 @@ func (m *RefreshTokenMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *RefreshTokenMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 15) if m.client_id != nil { fields = append(fields, refreshtoken.FieldClientID) } @@ -6838,6 +6875,9 @@ func (m *RefreshTokenMutation) Fields() []string { if m.token != nil { fields = append(fields, refreshtoken.FieldToken) } + if m.obsolete_token != nil { + fields = append(fields, refreshtoken.FieldObsoleteToken) + } if m.created_at != nil { fields = append(fields, refreshtoken.FieldCreatedAt) } @@ -6876,6 +6916,8 @@ func (m *RefreshTokenMutation) Field(name string) (ent.Value, bool) { return m.ConnectorData() case refreshtoken.FieldToken: return m.Token() + case refreshtoken.FieldObsoleteToken: + return m.ObsoleteToken() case refreshtoken.FieldCreatedAt: return m.CreatedAt() case refreshtoken.FieldLastUsed: @@ -6913,6 +6955,8 @@ func (m *RefreshTokenMutation) OldField(ctx context.Context, name string) (ent.V return m.OldConnectorData(ctx) case refreshtoken.FieldToken: return m.OldToken(ctx) + case refreshtoken.FieldObsoleteToken: + return m.OldObsoleteToken(ctx) case refreshtoken.FieldCreatedAt: return m.OldCreatedAt(ctx) case refreshtoken.FieldLastUsed: @@ -7010,6 +7054,13 @@ func (m *RefreshTokenMutation) SetField(name string, value ent.Value) error { } m.SetToken(v) return nil + case refreshtoken.FieldObsoleteToken: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetObsoleteToken(v) + return nil case refreshtoken.FieldCreatedAt: v, ok := value.(time.Time) if !ok { @@ -7130,6 +7181,9 @@ func (m *RefreshTokenMutation) ResetField(name string) error { case refreshtoken.FieldToken: m.ResetToken() return nil + case refreshtoken.FieldObsoleteToken: + m.ResetObsoleteToken() + return nil case refreshtoken.FieldCreatedAt: m.ResetCreatedAt() return nil diff --git a/storage/ent/db/oauth2client.go b/storage/ent/db/oauth2client.go index c2a79713..687a6e69 100644 --- a/storage/ent/db/oauth2client.go +++ b/storage/ent/db/oauth2client.go @@ -36,11 +36,11 @@ func (*OAuth2Client) scanValues(columns []string) ([]interface{}, error) { for i := range columns { switch columns[i] { case oauth2client.FieldRedirectUris, oauth2client.FieldTrustedPeers: - values[i] = &[]byte{} + values[i] = new([]byte) case oauth2client.FieldPublic: - values[i] = &sql.NullBool{} + values[i] = new(sql.NullBool) case oauth2client.FieldID, oauth2client.FieldSecret, oauth2client.FieldName, oauth2client.FieldLogoURL: - values[i] = &sql.NullString{} + values[i] = new(sql.NullString) default: return nil, fmt.Errorf("unexpected column %q for type OAuth2Client", columns[i]) } diff --git a/storage/ent/db/oauth2client_query.go b/storage/ent/db/oauth2client_query.go index ea9a3f33..558542f1 100644 --- a/storage/ent/db/oauth2client_query.go +++ b/storage/ent/db/oauth2client_query.go @@ -20,6 +20,7 @@ type OAuth2ClientQuery struct { config limit *int offset *int + unique *bool order []OrderFunc fields []string predicates []predicate.OAuth2Client @@ -46,6 +47,13 @@ func (oq *OAuth2ClientQuery) Offset(offset int) *OAuth2ClientQuery { return oq } +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (oq *OAuth2ClientQuery) Unique(unique bool) *OAuth2ClientQuery { + oq.unique = &unique + return oq +} + // Order adds an order step to the query. func (oq *OAuth2ClientQuery) Order(o ...OrderFunc) *OAuth2ClientQuery { oq.order = append(oq.order, o...) @@ -352,6 +360,9 @@ func (oq *OAuth2ClientQuery) querySpec() *sqlgraph.QuerySpec { From: oq.sql, Unique: true, } + if unique := oq.unique; unique != nil { + _spec.Unique = *unique + } if fields := oq.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, oauth2client.FieldID) @@ -377,7 +388,7 @@ func (oq *OAuth2ClientQuery) querySpec() *sqlgraph.QuerySpec { if ps := oq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, oauth2client.ValidColumn) + ps[i](selector) } } } @@ -396,7 +407,7 @@ func (oq *OAuth2ClientQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range oq.order { - p(selector, oauth2client.ValidColumn) + p(selector) } if offset := oq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -662,7 +673,7 @@ func (ogb *OAuth2ClientGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(ogb.fields)+len(ogb.fns)) columns = append(columns, ogb.fields...) for _, fn := range ogb.fns { - columns = append(columns, fn(selector, oauth2client.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(ogb.fields...) } diff --git a/storage/ent/db/oauth2client_update.go b/storage/ent/db/oauth2client_update.go index f1bfd681..32982418 100644 --- a/storage/ent/db/oauth2client_update.go +++ b/storage/ent/db/oauth2client_update.go @@ -242,6 +242,7 @@ func (ou *OAuth2ClientUpdate) sqlSave(ctx context.Context) (n int, err error) { // OAuth2ClientUpdateOne is the builder for updating a single OAuth2Client entity. type OAuth2ClientUpdateOne struct { config + fields []string hooks []Hook mutation *OAuth2ClientMutation } @@ -299,6 +300,13 @@ func (ouo *OAuth2ClientUpdateOne) Mutation() *OAuth2ClientMutation { return ouo.mutation } +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (ouo *OAuth2ClientUpdateOne) Select(field string, fields ...string) *OAuth2ClientUpdateOne { + ouo.fields = append([]string{field}, fields...) + return ouo +} + // Save executes the query and returns the updated OAuth2Client entity. func (ouo *OAuth2ClientUpdateOne) Save(ctx context.Context) (*OAuth2Client, error) { var ( @@ -392,6 +400,18 @@ func (ouo *OAuth2ClientUpdateOne) sqlSave(ctx context.Context) (_node *OAuth2Cli return nil, &ValidationError{Name: "ID", err: fmt.Errorf("missing OAuth2Client.ID for update")} } _spec.Node.ID.Value = id + if fields := ouo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, oauth2client.FieldID) + for _, f := range fields { + if !oauth2client.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != oauth2client.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } if ps := ouo.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { diff --git a/storage/ent/db/offlinesession.go b/storage/ent/db/offlinesession.go index 8eba0202..6bfa5d2a 100644 --- a/storage/ent/db/offlinesession.go +++ b/storage/ent/db/offlinesession.go @@ -31,9 +31,9 @@ func (*OfflineSession) scanValues(columns []string) ([]interface{}, error) { for i := range columns { switch columns[i] { case offlinesession.FieldRefresh, offlinesession.FieldConnectorData: - values[i] = &[]byte{} + values[i] = new([]byte) case offlinesession.FieldID, offlinesession.FieldUserID, offlinesession.FieldConnID: - values[i] = &sql.NullString{} + values[i] = new(sql.NullString) default: return nil, fmt.Errorf("unexpected column %q for type OfflineSession", columns[i]) } diff --git a/storage/ent/db/offlinesession_query.go b/storage/ent/db/offlinesession_query.go index a790188c..a4fbe1fd 100644 --- a/storage/ent/db/offlinesession_query.go +++ b/storage/ent/db/offlinesession_query.go @@ -20,6 +20,7 @@ type OfflineSessionQuery struct { config limit *int offset *int + unique *bool order []OrderFunc fields []string predicates []predicate.OfflineSession @@ -46,6 +47,13 @@ func (osq *OfflineSessionQuery) Offset(offset int) *OfflineSessionQuery { return osq } +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (osq *OfflineSessionQuery) Unique(unique bool) *OfflineSessionQuery { + osq.unique = &unique + return osq +} + // Order adds an order step to the query. func (osq *OfflineSessionQuery) Order(o ...OrderFunc) *OfflineSessionQuery { osq.order = append(osq.order, o...) @@ -352,6 +360,9 @@ func (osq *OfflineSessionQuery) querySpec() *sqlgraph.QuerySpec { From: osq.sql, Unique: true, } + if unique := osq.unique; unique != nil { + _spec.Unique = *unique + } if fields := osq.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, offlinesession.FieldID) @@ -377,7 +388,7 @@ func (osq *OfflineSessionQuery) querySpec() *sqlgraph.QuerySpec { if ps := osq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, offlinesession.ValidColumn) + ps[i](selector) } } } @@ -396,7 +407,7 @@ func (osq *OfflineSessionQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range osq.order { - p(selector, offlinesession.ValidColumn) + p(selector) } if offset := osq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -662,7 +673,7 @@ func (osgb *OfflineSessionGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(osgb.fields)+len(osgb.fns)) columns = append(columns, osgb.fields...) for _, fn := range osgb.fns { - columns = append(columns, fn(selector, offlinesession.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(osgb.fields...) } diff --git a/storage/ent/db/offlinesession_update.go b/storage/ent/db/offlinesession_update.go index 14141b47..d6edd522 100644 --- a/storage/ent/db/offlinesession_update.go +++ b/storage/ent/db/offlinesession_update.go @@ -199,6 +199,7 @@ func (osu *OfflineSessionUpdate) sqlSave(ctx context.Context) (n int, err error) // OfflineSessionUpdateOne is the builder for updating a single OfflineSession entity. type OfflineSessionUpdateOne struct { config + fields []string hooks []Hook mutation *OfflineSessionMutation } @@ -238,6 +239,13 @@ func (osuo *OfflineSessionUpdateOne) Mutation() *OfflineSessionMutation { return osuo.mutation } +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (osuo *OfflineSessionUpdateOne) Select(field string, fields ...string) *OfflineSessionUpdateOne { + osuo.fields = append([]string{field}, fields...) + return osuo +} + // Save executes the query and returns the updated OfflineSession entity. func (osuo *OfflineSessionUpdateOne) Save(ctx context.Context) (*OfflineSession, error) { var ( @@ -326,6 +334,18 @@ func (osuo *OfflineSessionUpdateOne) sqlSave(ctx context.Context) (_node *Offlin return nil, &ValidationError{Name: "ID", err: fmt.Errorf("missing OfflineSession.ID for update")} } _spec.Node.ID.Value = id + if fields := osuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, offlinesession.FieldID) + for _, f := range fields { + if !offlinesession.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != offlinesession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } if ps := osuo.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { diff --git a/storage/ent/db/password.go b/storage/ent/db/password.go index 702f33e0..9a1043d6 100644 --- a/storage/ent/db/password.go +++ b/storage/ent/db/password.go @@ -31,11 +31,11 @@ func (*Password) scanValues(columns []string) ([]interface{}, error) { for i := range columns { switch columns[i] { case password.FieldHash: - values[i] = &[]byte{} + values[i] = new([]byte) case password.FieldID: - values[i] = &sql.NullInt64{} + values[i] = new(sql.NullInt64) case password.FieldEmail, password.FieldUsername, password.FieldUserID: - values[i] = &sql.NullString{} + values[i] = new(sql.NullString) default: return nil, fmt.Errorf("unexpected column %q for type Password", columns[i]) } diff --git a/storage/ent/db/password_query.go b/storage/ent/db/password_query.go index aafc820d..8bfe9a83 100644 --- a/storage/ent/db/password_query.go +++ b/storage/ent/db/password_query.go @@ -20,6 +20,7 @@ type PasswordQuery struct { config limit *int offset *int + unique *bool order []OrderFunc fields []string predicates []predicate.Password @@ -46,6 +47,13 @@ func (pq *PasswordQuery) Offset(offset int) *PasswordQuery { return pq } +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (pq *PasswordQuery) Unique(unique bool) *PasswordQuery { + pq.unique = &unique + return pq +} + // Order adds an order step to the query. func (pq *PasswordQuery) Order(o ...OrderFunc) *PasswordQuery { pq.order = append(pq.order, o...) @@ -352,6 +360,9 @@ func (pq *PasswordQuery) querySpec() *sqlgraph.QuerySpec { From: pq.sql, Unique: true, } + if unique := pq.unique; unique != nil { + _spec.Unique = *unique + } if fields := pq.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, password.FieldID) @@ -377,7 +388,7 @@ func (pq *PasswordQuery) querySpec() *sqlgraph.QuerySpec { if ps := pq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, password.ValidColumn) + ps[i](selector) } } } @@ -396,7 +407,7 @@ func (pq *PasswordQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range pq.order { - p(selector, password.ValidColumn) + p(selector) } if offset := pq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -662,7 +673,7 @@ func (pgb *PasswordGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(pgb.fields)+len(pgb.fns)) columns = append(columns, pgb.fields...) for _, fn := range pgb.fns { - columns = append(columns, fn(selector, password.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(pgb.fields...) } diff --git a/storage/ent/db/password_update.go b/storage/ent/db/password_update.go index f89e2334..0eb1cb61 100644 --- a/storage/ent/db/password_update.go +++ b/storage/ent/db/password_update.go @@ -192,6 +192,7 @@ func (pu *PasswordUpdate) sqlSave(ctx context.Context) (n int, err error) { // PasswordUpdateOne is the builder for updating a single Password entity. type PasswordUpdateOne struct { config + fields []string hooks []Hook mutation *PasswordMutation } @@ -225,6 +226,13 @@ func (puo *PasswordUpdateOne) Mutation() *PasswordMutation { return puo.mutation } +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (puo *PasswordUpdateOne) Select(field string, fields ...string) *PasswordUpdateOne { + puo.fields = append([]string{field}, fields...) + return puo +} + // Save executes the query and returns the updated Password entity. func (puo *PasswordUpdateOne) Save(ctx context.Context) (*Password, error) { var ( @@ -318,6 +326,18 @@ func (puo *PasswordUpdateOne) sqlSave(ctx context.Context) (_node *Password, err return nil, &ValidationError{Name: "ID", err: fmt.Errorf("missing Password.ID for update")} } _spec.Node.ID.Value = id + if fields := puo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, password.FieldID) + for _, f := range fields { + if !password.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != password.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } if ps := puo.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { diff --git a/storage/ent/db/refreshtoken.go b/storage/ent/db/refreshtoken.go index 057be8d6..7e527079 100644 --- a/storage/ent/db/refreshtoken.go +++ b/storage/ent/db/refreshtoken.go @@ -41,6 +41,8 @@ type RefreshToken struct { ConnectorData *[]byte `json:"connector_data,omitempty"` // Token holds the value of the "token" field. Token string `json:"token,omitempty"` + // ObsoleteToken holds the value of the "obsolete_token" field. + ObsoleteToken string `json:"obsolete_token,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` // LastUsed holds the value of the "last_used" field. @@ -53,13 +55,13 @@ func (*RefreshToken) scanValues(columns []string) ([]interface{}, error) { for i := range columns { switch columns[i] { case refreshtoken.FieldScopes, refreshtoken.FieldClaimsGroups, refreshtoken.FieldConnectorData: - values[i] = &[]byte{} + values[i] = new([]byte) case refreshtoken.FieldClaimsEmailVerified: - values[i] = &sql.NullBool{} - case refreshtoken.FieldID, refreshtoken.FieldClientID, refreshtoken.FieldNonce, refreshtoken.FieldClaimsUserID, refreshtoken.FieldClaimsUsername, refreshtoken.FieldClaimsEmail, refreshtoken.FieldClaimsPreferredUsername, refreshtoken.FieldConnectorID, refreshtoken.FieldToken: - values[i] = &sql.NullString{} + values[i] = new(sql.NullBool) + case refreshtoken.FieldID, refreshtoken.FieldClientID, refreshtoken.FieldNonce, refreshtoken.FieldClaimsUserID, refreshtoken.FieldClaimsUsername, refreshtoken.FieldClaimsEmail, refreshtoken.FieldClaimsPreferredUsername, refreshtoken.FieldConnectorID, refreshtoken.FieldToken, refreshtoken.FieldObsoleteToken: + values[i] = new(sql.NullString) case refreshtoken.FieldCreatedAt, refreshtoken.FieldLastUsed: - values[i] = &sql.NullTime{} + values[i] = new(sql.NullTime) default: return nil, fmt.Errorf("unexpected column %q for type RefreshToken", columns[i]) } @@ -159,6 +161,12 @@ func (rt *RefreshToken) assignValues(columns []string, values []interface{}) err } else if value.Valid { rt.Token = value.String } + case refreshtoken.FieldObsoleteToken: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field obsolete_token", values[i]) + } else if value.Valid { + rt.ObsoleteToken = value.String + } case refreshtoken.FieldCreatedAt: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field created_at", values[i]) @@ -225,6 +233,8 @@ func (rt *RefreshToken) String() string { } builder.WriteString(", token=") builder.WriteString(rt.Token) + builder.WriteString(", obsolete_token=") + builder.WriteString(rt.ObsoleteToken) builder.WriteString(", created_at=") builder.WriteString(rt.CreatedAt.Format(time.ANSIC)) builder.WriteString(", last_used=") diff --git a/storage/ent/db/refreshtoken/refreshtoken.go b/storage/ent/db/refreshtoken/refreshtoken.go index 0e28ef67..38efcc22 100644 --- a/storage/ent/db/refreshtoken/refreshtoken.go +++ b/storage/ent/db/refreshtoken/refreshtoken.go @@ -35,6 +35,8 @@ const ( FieldConnectorData = "connector_data" // FieldToken holds the string denoting the token field in the database. FieldToken = "token" + // FieldObsoleteToken holds the string denoting the obsolete_token field in the database. + FieldObsoleteToken = "obsolete_token" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" // FieldLastUsed holds the string denoting the last_used field in the database. @@ -58,6 +60,7 @@ var Columns = []string{ FieldConnectorID, FieldConnectorData, FieldToken, + FieldObsoleteToken, FieldCreatedAt, FieldLastUsed, } @@ -89,6 +92,8 @@ var ( ConnectorIDValidator func(string) error // DefaultToken holds the default value on creation for the "token" field. DefaultToken string + // DefaultObsoleteToken holds the default value on creation for the "obsolete_token" field. + DefaultObsoleteToken string // DefaultCreatedAt holds the default value on creation for the "created_at" field. DefaultCreatedAt func() time.Time // DefaultLastUsed holds the default value on creation for the "last_used" field. diff --git a/storage/ent/db/refreshtoken/where.go b/storage/ent/db/refreshtoken/where.go index 0acde92f..43a46093 100644 --- a/storage/ent/db/refreshtoken/where.go +++ b/storage/ent/db/refreshtoken/where.go @@ -162,6 +162,13 @@ func Token(v string) predicate.RefreshToken { }) } +// ObsoleteToken applies equality check predicate on the "obsolete_token" field. It's identical to ObsoleteTokenEQ. +func ObsoleteToken(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldObsoleteToken), v)) + }) +} + // CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. func CreatedAt(v time.Time) predicate.RefreshToken { return predicate.RefreshToken(func(s *sql.Selector) { @@ -1196,6 +1203,117 @@ func TokenContainsFold(v string) predicate.RefreshToken { }) } +// ObsoleteTokenEQ applies the EQ predicate on the "obsolete_token" field. +func ObsoleteTokenEQ(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.EQ(s.C(FieldObsoleteToken), v)) + }) +} + +// ObsoleteTokenNEQ applies the NEQ predicate on the "obsolete_token" field. +func ObsoleteTokenNEQ(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.NEQ(s.C(FieldObsoleteToken), v)) + }) +} + +// ObsoleteTokenIn applies the In predicate on the "obsolete_token" field. +func ObsoleteTokenIn(vs ...string) predicate.RefreshToken { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.RefreshToken(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.In(s.C(FieldObsoleteToken), v...)) + }) +} + +// ObsoleteTokenNotIn applies the NotIn predicate on the "obsolete_token" field. +func ObsoleteTokenNotIn(vs ...string) predicate.RefreshToken { + v := make([]interface{}, len(vs)) + for i := range v { + v[i] = vs[i] + } + return predicate.RefreshToken(func(s *sql.Selector) { + // if not arguments were provided, append the FALSE constants, + // since we can't apply "IN ()". This will make this predicate falsy. + if len(v) == 0 { + s.Where(sql.False()) + return + } + s.Where(sql.NotIn(s.C(FieldObsoleteToken), v...)) + }) +} + +// ObsoleteTokenGT applies the GT predicate on the "obsolete_token" field. +func ObsoleteTokenGT(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.GT(s.C(FieldObsoleteToken), v)) + }) +} + +// ObsoleteTokenGTE applies the GTE predicate on the "obsolete_token" field. +func ObsoleteTokenGTE(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.GTE(s.C(FieldObsoleteToken), v)) + }) +} + +// ObsoleteTokenLT applies the LT predicate on the "obsolete_token" field. +func ObsoleteTokenLT(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.LT(s.C(FieldObsoleteToken), v)) + }) +} + +// ObsoleteTokenLTE applies the LTE predicate on the "obsolete_token" field. +func ObsoleteTokenLTE(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.LTE(s.C(FieldObsoleteToken), v)) + }) +} + +// ObsoleteTokenContains applies the Contains predicate on the "obsolete_token" field. +func ObsoleteTokenContains(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.Contains(s.C(FieldObsoleteToken), v)) + }) +} + +// ObsoleteTokenHasPrefix applies the HasPrefix predicate on the "obsolete_token" field. +func ObsoleteTokenHasPrefix(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.HasPrefix(s.C(FieldObsoleteToken), v)) + }) +} + +// ObsoleteTokenHasSuffix applies the HasSuffix predicate on the "obsolete_token" field. +func ObsoleteTokenHasSuffix(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.HasSuffix(s.C(FieldObsoleteToken), v)) + }) +} + +// ObsoleteTokenEqualFold applies the EqualFold predicate on the "obsolete_token" field. +func ObsoleteTokenEqualFold(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.EqualFold(s.C(FieldObsoleteToken), v)) + }) +} + +// ObsoleteTokenContainsFold applies the ContainsFold predicate on the "obsolete_token" field. +func ObsoleteTokenContainsFold(v string) predicate.RefreshToken { + return predicate.RefreshToken(func(s *sql.Selector) { + s.Where(sql.ContainsFold(s.C(FieldObsoleteToken), v)) + }) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.RefreshToken { return predicate.RefreshToken(func(s *sql.Selector) { diff --git a/storage/ent/db/refreshtoken_create.go b/storage/ent/db/refreshtoken_create.go index 4035b110..e73f276a 100644 --- a/storage/ent/db/refreshtoken_create.go +++ b/storage/ent/db/refreshtoken_create.go @@ -108,6 +108,20 @@ func (rtc *RefreshTokenCreate) SetNillableToken(s *string) *RefreshTokenCreate { return rtc } +// SetObsoleteToken sets the "obsolete_token" field. +func (rtc *RefreshTokenCreate) SetObsoleteToken(s string) *RefreshTokenCreate { + rtc.mutation.SetObsoleteToken(s) + return rtc +} + +// SetNillableObsoleteToken sets the "obsolete_token" field if the given value is not nil. +func (rtc *RefreshTokenCreate) SetNillableObsoleteToken(s *string) *RefreshTokenCreate { + if s != nil { + rtc.SetObsoleteToken(*s) + } + return rtc +} + // SetCreatedAt sets the "created_at" field. func (rtc *RefreshTokenCreate) SetCreatedAt(t time.Time) *RefreshTokenCreate { rtc.mutation.SetCreatedAt(t) @@ -202,6 +216,10 @@ func (rtc *RefreshTokenCreate) defaults() { v := refreshtoken.DefaultToken rtc.mutation.SetToken(v) } + if _, ok := rtc.mutation.ObsoleteToken(); !ok { + v := refreshtoken.DefaultObsoleteToken + rtc.mutation.SetObsoleteToken(v) + } if _, ok := rtc.mutation.CreatedAt(); !ok { v := refreshtoken.DefaultCreatedAt() rtc.mutation.SetCreatedAt(v) @@ -271,6 +289,9 @@ func (rtc *RefreshTokenCreate) check() error { if _, ok := rtc.mutation.Token(); !ok { return &ValidationError{Name: "token", err: errors.New("db: missing required field \"token\"")} } + if _, ok := rtc.mutation.ObsoleteToken(); !ok { + return &ValidationError{Name: "obsolete_token", err: errors.New("db: missing required field \"obsolete_token\"")} + } if _, ok := rtc.mutation.CreatedAt(); !ok { return &ValidationError{Name: "created_at", err: errors.New("db: missing required field \"created_at\"")} } @@ -407,6 +428,14 @@ func (rtc *RefreshTokenCreate) createSpec() (*RefreshToken, *sqlgraph.CreateSpec }) _node.Token = value } + if value, ok := rtc.mutation.ObsoleteToken(); ok { + _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: refreshtoken.FieldObsoleteToken, + }) + _node.ObsoleteToken = value + } if value, ok := rtc.mutation.CreatedAt(); ok { _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ Type: field.TypeTime, diff --git a/storage/ent/db/refreshtoken_query.go b/storage/ent/db/refreshtoken_query.go index 14fa475d..503e606f 100644 --- a/storage/ent/db/refreshtoken_query.go +++ b/storage/ent/db/refreshtoken_query.go @@ -20,6 +20,7 @@ type RefreshTokenQuery struct { config limit *int offset *int + unique *bool order []OrderFunc fields []string predicates []predicate.RefreshToken @@ -46,6 +47,13 @@ func (rtq *RefreshTokenQuery) Offset(offset int) *RefreshTokenQuery { return rtq } +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (rtq *RefreshTokenQuery) Unique(unique bool) *RefreshTokenQuery { + rtq.unique = &unique + return rtq +} + // Order adds an order step to the query. func (rtq *RefreshTokenQuery) Order(o ...OrderFunc) *RefreshTokenQuery { rtq.order = append(rtq.order, o...) @@ -352,6 +360,9 @@ func (rtq *RefreshTokenQuery) querySpec() *sqlgraph.QuerySpec { From: rtq.sql, Unique: true, } + if unique := rtq.unique; unique != nil { + _spec.Unique = *unique + } if fields := rtq.fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, refreshtoken.FieldID) @@ -377,7 +388,7 @@ func (rtq *RefreshTokenQuery) querySpec() *sqlgraph.QuerySpec { if ps := rtq.order; len(ps) > 0 { _spec.Order = func(selector *sql.Selector) { for i := range ps { - ps[i](selector, refreshtoken.ValidColumn) + ps[i](selector) } } } @@ -396,7 +407,7 @@ func (rtq *RefreshTokenQuery) sqlQuery(ctx context.Context) *sql.Selector { p(selector) } for _, p := range rtq.order { - p(selector, refreshtoken.ValidColumn) + p(selector) } if offset := rtq.offset; offset != nil { // limit is mandatory for offset clause. We start @@ -662,7 +673,7 @@ func (rtgb *RefreshTokenGroupBy) sqlQuery() *sql.Selector { columns := make([]string, 0, len(rtgb.fields)+len(rtgb.fns)) columns = append(columns, rtgb.fields...) for _, fn := range rtgb.fns { - columns = append(columns, fn(selector, refreshtoken.ValidColumn)) + columns = append(columns, fn(selector)) } return selector.Select(columns...).GroupBy(rtgb.fields...) } diff --git a/storage/ent/db/refreshtoken_update.go b/storage/ent/db/refreshtoken_update.go index 5c745192..87ccfcd0 100644 --- a/storage/ent/db/refreshtoken_update.go +++ b/storage/ent/db/refreshtoken_update.go @@ -133,6 +133,20 @@ func (rtu *RefreshTokenUpdate) SetNillableToken(s *string) *RefreshTokenUpdate { return rtu } +// SetObsoleteToken sets the "obsolete_token" field. +func (rtu *RefreshTokenUpdate) SetObsoleteToken(s string) *RefreshTokenUpdate { + rtu.mutation.SetObsoleteToken(s) + return rtu +} + +// SetNillableObsoleteToken sets the "obsolete_token" field if the given value is not nil. +func (rtu *RefreshTokenUpdate) SetNillableObsoleteToken(s *string) *RefreshTokenUpdate { + if s != nil { + rtu.SetObsoleteToken(*s) + } + return rtu +} + // SetCreatedAt sets the "created_at" field. func (rtu *RefreshTokenUpdate) SetCreatedAt(t time.Time) *RefreshTokenUpdate { rtu.mutation.SetCreatedAt(t) @@ -378,6 +392,13 @@ func (rtu *RefreshTokenUpdate) sqlSave(ctx context.Context) (n int, err error) { Column: refreshtoken.FieldToken, }) } + if value, ok := rtu.mutation.ObsoleteToken(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: refreshtoken.FieldObsoleteToken, + }) + } if value, ok := rtu.mutation.CreatedAt(); ok { _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ Type: field.TypeTime, @@ -406,6 +427,7 @@ func (rtu *RefreshTokenUpdate) sqlSave(ctx context.Context) (n int, err error) { // RefreshTokenUpdateOne is the builder for updating a single RefreshToken entity. type RefreshTokenUpdateOne struct { config + fields []string hooks []Hook mutation *RefreshTokenMutation } @@ -516,6 +538,20 @@ func (rtuo *RefreshTokenUpdateOne) SetNillableToken(s *string) *RefreshTokenUpda return rtuo } +// SetObsoleteToken sets the "obsolete_token" field. +func (rtuo *RefreshTokenUpdateOne) SetObsoleteToken(s string) *RefreshTokenUpdateOne { + rtuo.mutation.SetObsoleteToken(s) + return rtuo +} + +// SetNillableObsoleteToken sets the "obsolete_token" field if the given value is not nil. +func (rtuo *RefreshTokenUpdateOne) SetNillableObsoleteToken(s *string) *RefreshTokenUpdateOne { + if s != nil { + rtuo.SetObsoleteToken(*s) + } + return rtuo +} + // SetCreatedAt sets the "created_at" field. func (rtuo *RefreshTokenUpdateOne) SetCreatedAt(t time.Time) *RefreshTokenUpdateOne { rtuo.mutation.SetCreatedAt(t) @@ -549,6 +585,13 @@ func (rtuo *RefreshTokenUpdateOne) Mutation() *RefreshTokenMutation { return rtuo.mutation } +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (rtuo *RefreshTokenUpdateOne) Select(field string, fields ...string) *RefreshTokenUpdateOne { + rtuo.fields = append([]string{field}, fields...) + return rtuo +} + // Save executes the query and returns the updated RefreshToken entity. func (rtuo *RefreshTokenUpdateOne) Save(ctx context.Context) (*RefreshToken, error) { var ( @@ -657,6 +700,18 @@ func (rtuo *RefreshTokenUpdateOne) sqlSave(ctx context.Context) (_node *RefreshT return nil, &ValidationError{Name: "ID", err: fmt.Errorf("missing RefreshToken.ID for update")} } _spec.Node.ID.Value = id + if fields := rtuo.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, refreshtoken.FieldID) + for _, f := range fields { + if !refreshtoken.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != refreshtoken.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } if ps := rtuo.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -766,6 +821,13 @@ func (rtuo *RefreshTokenUpdateOne) sqlSave(ctx context.Context) (_node *RefreshT Column: refreshtoken.FieldToken, }) } + if value, ok := rtuo.mutation.ObsoleteToken(); ok { + _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ + Type: field.TypeString, + Value: value, + Column: refreshtoken.FieldObsoleteToken, + }) + } if value, ok := rtuo.mutation.CreatedAt(); ok { _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ Type: field.TypeTime, diff --git a/storage/ent/db/runtime.go b/storage/ent/db/runtime.go index 2d7249d7..49f4157a 100644 --- a/storage/ent/db/runtime.go +++ b/storage/ent/db/runtime.go @@ -214,12 +214,16 @@ func init() { refreshtokenDescToken := refreshtokenFields[12].Descriptor() // refreshtoken.DefaultToken holds the default value on creation for the token field. refreshtoken.DefaultToken = refreshtokenDescToken.Default.(string) + // refreshtokenDescObsoleteToken is the schema descriptor for obsolete_token field. + refreshtokenDescObsoleteToken := refreshtokenFields[13].Descriptor() + // refreshtoken.DefaultObsoleteToken holds the default value on creation for the obsolete_token field. + refreshtoken.DefaultObsoleteToken = refreshtokenDescObsoleteToken.Default.(string) // refreshtokenDescCreatedAt is the schema descriptor for created_at field. - refreshtokenDescCreatedAt := refreshtokenFields[13].Descriptor() + refreshtokenDescCreatedAt := refreshtokenFields[14].Descriptor() // refreshtoken.DefaultCreatedAt holds the default value on creation for the created_at field. refreshtoken.DefaultCreatedAt = refreshtokenDescCreatedAt.Default.(func() time.Time) // refreshtokenDescLastUsed is the schema descriptor for last_used field. - refreshtokenDescLastUsed := refreshtokenFields[14].Descriptor() + refreshtokenDescLastUsed := refreshtokenFields[15].Descriptor() // refreshtoken.DefaultLastUsed holds the default value on creation for the last_used field. refreshtoken.DefaultLastUsed = refreshtokenDescLastUsed.Default.(func() time.Time) // refreshtokenDescID is the schema descriptor for id field. diff --git a/storage/ent/db/runtime/runtime.go b/storage/ent/db/runtime/runtime.go index 2a1016b1..6f056d2d 100644 --- a/storage/ent/db/runtime/runtime.go +++ b/storage/ent/db/runtime/runtime.go @@ -5,6 +5,6 @@ package runtime // The schema-stitching logic is generated in github.com/dexidp/dex/storage/ent/db/runtime.go const ( - Version = "v0.7.0" // Version of ent codegen. - Sum = "h1:E3EjO0cUL61DvUg5ZEZdxa4yTL+4SuZv0LqBExo8CQA=" // Sum of ent codegen. + Version = "v0.8.0" // Version of ent codegen. + Sum = "h1:xirrW//1oda7pp0bz+XssSOv4/C3nmgYQOxjIfljFt8=" // Sum of ent codegen. ) diff --git a/storage/ent/schema/refreshtoken.go b/storage/ent/schema/refreshtoken.go index 4df128db..00c640d4 100644 --- a/storage/ent/schema/refreshtoken.go +++ b/storage/ent/schema/refreshtoken.go @@ -24,7 +24,8 @@ create table refresh_token token text default '' not null, created_at timestamp default '0001-01-01 00:00:00 UTC' not null, last_used timestamp default '0001-01-01 00:00:00 UTC' not null, - claims_preferred_username text default '' not null + claims_preferred_username text default '' not null, + obsolete_token text default '' ); */ @@ -75,6 +76,9 @@ func (RefreshToken) Fields() []ent.Field { field.Text("token"). SchemaType(textSchema). Default(""), + field.Text("obsolete_token"). + SchemaType(textSchema). + Default(""), field.Time("created_at"). Default(time.Now),