Merge branch 'master' into conformance_tests_improvements
This commit is contained in:
@@ -511,15 +511,15 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
|
||||
|
||||
_, err = s.GetPassword(password1.Email)
|
||||
mustBeErrNotFound(t, "password", err)
|
||||
|
||||
}
|
||||
|
||||
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
|
||||
userID1 := storage.NewID()
|
||||
session1 := storage.OfflineSessions{
|
||||
UserID: userID1,
|
||||
ConnID: "Conn1",
|
||||
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||
UserID: userID1,
|
||||
ConnID: "Conn1",
|
||||
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||
ConnectorData: []byte(`{"some":"data"}`),
|
||||
}
|
||||
|
||||
// Creating an OfflineSession with an empty Refresh list to ensure that
|
||||
@@ -534,9 +534,10 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
|
||||
|
||||
userID2 := storage.NewID()
|
||||
session2 := storage.OfflineSessions{
|
||||
UserID: userID2,
|
||||
ConnID: "Conn2",
|
||||
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||
UserID: userID2,
|
||||
ConnID: "Conn2",
|
||||
Refresh: make(map[string]*storage.RefreshTokenRef),
|
||||
ConnectorData: []byte(`{"some":"data"}`),
|
||||
}
|
||||
|
||||
if err := s.CreateOfflineSessions(session2); err != nil {
|
||||
|
@@ -156,7 +156,7 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
|
||||
return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) {
|
||||
var current RefreshToken
|
||||
if len(currentValue) > 0 {
|
||||
if err := json.Unmarshal([]byte(currentValue), ¤t); err != nil {
|
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
@@ -148,30 +148,33 @@ func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
|
||||
|
||||
// Claims is a mirrored struct from storage with JSON struct tags.
|
||||
type Claims struct {
|
||||
UserID string `json:"userID"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"emailVerified"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
UserID string `json:"userID"`
|
||||
Username string `json:"username"`
|
||||
PreferredUsername string `json:"preferredUsername"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"emailVerified"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func fromStorageClaims(i storage.Claims) Claims {
|
||||
return Claims{
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
PreferredUsername: i.PreferredUsername,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageClaims(i Claims) storage.Claims {
|
||||
return storage.Claims{
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
PreferredUsername: i.PreferredUsername,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,24 +188,27 @@ type Keys struct {
|
||||
|
||||
// OfflineSessions is a mirrored struct from storage with JSON struct tags
|
||||
type OfflineSessions struct {
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
ConnID string `json:"conn_id,omitempty"`
|
||||
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
ConnID string `json:"conn_id,omitempty"`
|
||||
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
|
||||
ConnectorData []byte `json:"connectorData,omitempty"`
|
||||
}
|
||||
|
||||
func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
|
||||
return OfflineSessions{
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
ConnectorData: o.ConnectorData,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
||||
s := storage.OfflineSessions{
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
ConnectorData: o.ConnectorData,
|
||||
}
|
||||
if s.Refresh == nil {
|
||||
// Server code assumes this will be non-nil.
|
||||
|
@@ -55,14 +55,14 @@ type client struct {
|
||||
}
|
||||
|
||||
// idToName maps an arbitrary ID, such as an email or client ID to a Kubernetes object name.
|
||||
func (c *client) idToName(s string) string {
|
||||
return idToName(s, c.hash)
|
||||
func (cli *client) idToName(s string) string {
|
||||
return idToName(s, cli.hash)
|
||||
}
|
||||
|
||||
// offlineTokenName maps two arbitrary IDs, to a single Kubernetes object name.
|
||||
// This is used when more than one field is used to uniquely identify the object.
|
||||
func (c *client) offlineTokenName(userID string, connID string) string {
|
||||
return offlineTokenName(userID, connID, c.hash)
|
||||
func (cli *client) offlineTokenName(userID string, connID string) string {
|
||||
return offlineTokenName(userID, connID, cli.hash)
|
||||
}
|
||||
|
||||
// Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'.
|
||||
@@ -79,7 +79,7 @@ func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
|
||||
return strings.TrimRight(encoding.EncodeToString(hash.Sum(nil)), "=")
|
||||
}
|
||||
|
||||
func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
|
||||
func (cli *client) urlFor(apiVersion, namespace, resource, name string) string {
|
||||
basePath := "apis/"
|
||||
if apiVersion == "v1" {
|
||||
basePath = "api/"
|
||||
@@ -91,10 +91,10 @@ func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
|
||||
} else {
|
||||
p = path.Join(basePath, apiVersion, resource, name)
|
||||
}
|
||||
if strings.HasSuffix(c.baseURL, "/") {
|
||||
return c.baseURL + p
|
||||
if strings.HasSuffix(cli.baseURL, "/") {
|
||||
return cli.baseURL + p
|
||||
}
|
||||
return c.baseURL + "/" + p
|
||||
return cli.baseURL + "/" + p
|
||||
}
|
||||
|
||||
// Define an error interface so we can get at the underlying status code if it's
|
||||
@@ -156,13 +156,13 @@ func closeResp(r *http.Response) {
|
||||
r.Body.Close()
|
||||
}
|
||||
|
||||
func (c *client) get(resource, name string, v interface{}) error {
|
||||
return c.getResource(c.apiVersion, c.namespace, resource, name, v)
|
||||
func (cli *client) get(resource, name string, v interface{}) error {
|
||||
return cli.getResource(cli.apiVersion, cli.namespace, resource, name, v)
|
||||
}
|
||||
|
||||
func (c *client) getResource(apiVersion, namespace, resource, name string, v interface{}) error {
|
||||
url := c.urlFor(apiVersion, namespace, resource, name)
|
||||
resp, err := c.client.Get(url)
|
||||
func (cli *client) getResource(apiVersion, namespace, resource, name string, v interface{}) error {
|
||||
url := cli.urlFor(apiVersion, namespace, resource, name)
|
||||
resp, err := cli.client.Get(url)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -173,22 +173,22 @@ func (c *client) getResource(apiVersion, namespace, resource, name string, v int
|
||||
return json.NewDecoder(resp.Body).Decode(v)
|
||||
}
|
||||
|
||||
func (c *client) list(resource string, v interface{}) error {
|
||||
return c.get(resource, "", v)
|
||||
func (cli *client) list(resource string, v interface{}) error {
|
||||
return cli.get(resource, "", v)
|
||||
}
|
||||
|
||||
func (c *client) post(resource string, v interface{}) error {
|
||||
return c.postResource(c.apiVersion, c.namespace, resource, v)
|
||||
func (cli *client) post(resource string, v interface{}) error {
|
||||
return cli.postResource(cli.apiVersion, cli.namespace, resource, v)
|
||||
}
|
||||
|
||||
func (c *client) postResource(apiVersion, namespace, resource string, v interface{}) error {
|
||||
func (cli *client) postResource(apiVersion, namespace, resource string, v interface{}) error {
|
||||
body, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal object: %v", err)
|
||||
}
|
||||
|
||||
url := c.urlFor(apiVersion, namespace, resource, "")
|
||||
resp, err := c.client.Post(url, "application/json", bytes.NewReader(body))
|
||||
url := cli.urlFor(apiVersion, namespace, resource, "")
|
||||
resp, err := cli.client.Post(url, "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -196,13 +196,13 @@ func (c *client) postResource(apiVersion, namespace, resource string, v interfac
|
||||
return checkHTTPErr(resp, http.StatusCreated)
|
||||
}
|
||||
|
||||
func (c *client) delete(resource, name string) error {
|
||||
url := c.urlFor(c.apiVersion, c.namespace, resource, name)
|
||||
func (cli *client) delete(resource, name string) error {
|
||||
url := cli.urlFor(cli.apiVersion, cli.namespace, resource, name)
|
||||
req, err := http.NewRequest("DELETE", url, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create delete request: %v", err)
|
||||
}
|
||||
resp, err := c.client.Do(req)
|
||||
resp, err := cli.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete request: %v", err)
|
||||
}
|
||||
@@ -210,7 +210,7 @@ func (c *client) delete(resource, name string) error {
|
||||
return checkHTTPErr(resp, http.StatusOK)
|
||||
}
|
||||
|
||||
func (c *client) deleteAll(resource string) error {
|
||||
func (cli *client) deleteAll(resource string) error {
|
||||
var list struct {
|
||||
k8sapi.TypeMeta `json:",inline"`
|
||||
k8sapi.ListMeta `json:"metadata,omitempty"`
|
||||
@@ -219,24 +219,24 @@ func (c *client) deleteAll(resource string) error {
|
||||
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := c.list(resource, &list); err != nil {
|
||||
if err := cli.list(resource, &list); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, item := range list.Items {
|
||||
if err := c.delete(resource, item.Name); err != nil {
|
||||
if err := cli.delete(resource, item.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) put(resource, name string, v interface{}) error {
|
||||
func (cli *client) put(resource, name string, v interface{}) error {
|
||||
body, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal object: %v", err)
|
||||
}
|
||||
|
||||
url := c.urlFor(c.apiVersion, c.namespace, resource, name)
|
||||
url := cli.urlFor(cli.apiVersion, cli.namespace, resource, name)
|
||||
req, err := http.NewRequest("PUT", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create patch request: %v", err)
|
||||
@@ -244,7 +244,7 @@ func (c *client) put(resource, name string, v interface{}) error {
|
||||
|
||||
req.Header.Set("Content-Length", strconv.Itoa(len(body)))
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
resp, err := cli.client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("patch request: %v", err)
|
||||
}
|
||||
|
@@ -43,7 +43,7 @@ type CustomResourceDefinitionNames struct {
|
||||
ListKind string `json:"listKind,omitempty" protobuf:"bytes,5,opt,name=listKind"`
|
||||
}
|
||||
|
||||
// ResourceScope is an enum defining the different scopes availabe to a custom resource
|
||||
// ResourceScope is an enum defining the different scopes available to a custom resource
|
||||
type ResourceScope string
|
||||
|
||||
const (
|
||||
|
@@ -210,30 +210,33 @@ func toStorageClient(c Client) storage.Client {
|
||||
|
||||
// Claims is a mirrored struct from storage with JSON struct tags.
|
||||
type Claims struct {
|
||||
UserID string `json:"userID"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"emailVerified"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
UserID string `json:"userID"`
|
||||
Username string `json:"username"`
|
||||
PreferredUsername string `json:"preferredUsername"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"emailVerified"`
|
||||
Groups []string `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func fromStorageClaims(i storage.Claims) Claims {
|
||||
return Claims{
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
PreferredUsername: i.PreferredUsername,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageClaims(i Claims) storage.Claims {
|
||||
return storage.Claims{
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
UserID: i.UserID,
|
||||
Username: i.Username,
|
||||
PreferredUsername: i.PreferredUsername,
|
||||
Email: i.Email,
|
||||
EmailVerified: i.EmailVerified,
|
||||
Groups: i.Groups,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -549,9 +552,10 @@ type OfflineSessions struct {
|
||||
k8sapi.TypeMeta `json:",inline"`
|
||||
k8sapi.ObjectMeta `json:"metadata,omitempty"`
|
||||
|
||||
UserID string `json:"userID,omitempty"`
|
||||
ConnID string `json:"connID,omitempty"`
|
||||
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
|
||||
UserID string `json:"userID,omitempty"`
|
||||
ConnID string `json:"connID,omitempty"`
|
||||
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
|
||||
ConnectorData []byte `json:"connectorData,omitempty"`
|
||||
}
|
||||
|
||||
func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
|
||||
@@ -564,17 +568,19 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline
|
||||
Name: cli.offlineTokenName(o.UserID, o.ConnID),
|
||||
Namespace: cli.namespace,
|
||||
},
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
ConnectorData: o.ConnectorData,
|
||||
}
|
||||
}
|
||||
|
||||
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
|
||||
s := storage.OfflineSessions{
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
UserID: o.UserID,
|
||||
ConnID: o.ConnID,
|
||||
Refresh: o.Refresh,
|
||||
ConnectorData: o.ConnectorData,
|
||||
}
|
||||
if s.Refresh == nil {
|
||||
// Server code assumes this will be non-nil.
|
||||
|
@@ -29,6 +29,7 @@ const (
|
||||
// MySQL error codes
|
||||
mysqlErrDupEntry = 1062
|
||||
mysqlErrDupEntryWithKeyName = 1586
|
||||
mysqlErrUnknownSysVar = 1193
|
||||
)
|
||||
|
||||
// SQLite3 options for creating an SQL db.
|
||||
@@ -307,6 +308,26 @@ func (s *MySQL) open(logger log.Logger) (*conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = db.Ping()
|
||||
if err != nil {
|
||||
if mysqlErr, ok := err.(*mysql.MySQLError); ok && mysqlErr.Number == mysqlErrUnknownSysVar {
|
||||
logger.Info("reconnecting with MySQL pre-5.7.20 compatibility mode")
|
||||
|
||||
// MySQL 5.7.20 introduced transaction_isolation and deprecated tx_isolation.
|
||||
// MySQL 8.0 doesn't have tx_isolation at all.
|
||||
// https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation
|
||||
delete(cfg.Params, "transaction_isolation")
|
||||
cfg.Params["tx_isolation"] = "'SERIALIZABLE'"
|
||||
|
||||
db, err = sql.Open("mysql", cfg.FormatDSN())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
errCheck := func(err error) bool {
|
||||
sqlErr, ok := err.(*mysql.MySQLError)
|
||||
if !ok {
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -220,12 +221,24 @@ func TestPostgres(t *testing.T) {
|
||||
if host == "" {
|
||||
t.Skipf("test environment variable %q not set, skipping", testPostgresEnv)
|
||||
}
|
||||
|
||||
port := uint64(5432)
|
||||
if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" {
|
||||
var err error
|
||||
|
||||
port, err = strconv.ParseUint(rawPort, 10, 32)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid postgres port %q: %s", rawPort, err)
|
||||
}
|
||||
}
|
||||
|
||||
p := &Postgres{
|
||||
NetworkDB: NetworkDB{
|
||||
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
|
||||
User: getenv("DEX_POSTGRES_USER", "postgres"),
|
||||
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
|
||||
Host: host,
|
||||
Port: uint16(port),
|
||||
ConnectionTimeout: 5,
|
||||
},
|
||||
SSL: SSL{
|
||||
@@ -242,12 +255,24 @@ func TestMySQL(t *testing.T) {
|
||||
if host == "" {
|
||||
t.Skipf("test environment variable %q not set, skipping", testMySQLEnv)
|
||||
}
|
||||
|
||||
port := uint64(3306)
|
||||
if rawPort := os.Getenv("DEX_MYSQL_PORT"); rawPort != "" {
|
||||
var err error
|
||||
|
||||
port, err = strconv.ParseUint(rawPort, 10, 32)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid mysql port %q: %s", rawPort, err)
|
||||
}
|
||||
}
|
||||
|
||||
s := &MySQL{
|
||||
NetworkDB: NetworkDB{
|
||||
Database: getenv("DEX_MYSQL_DATABASE", "mysql"),
|
||||
User: getenv("DEX_MYSQL_USER", "mysql"),
|
||||
Password: getenv("DEX_MYSQL_PASSWORD", ""),
|
||||
Host: host,
|
||||
Port: uint16(port),
|
||||
ConnectionTimeout: 5,
|
||||
},
|
||||
SSL: SSL{
|
||||
|
@@ -108,19 +108,19 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
|
||||
insert into auth_request (
|
||||
id, client_id, response_types, scopes, redirect_uri, nonce, state,
|
||||
force_approval_prompt, logged_in,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
claims_user_id, claims_username, claims_preferred_username,
|
||||
claims_email, claims_email_verified, claims_groups,
|
||||
connector_id, connector_data,
|
||||
expiry
|
||||
)
|
||||
values (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
|
||||
);
|
||||
`,
|
||||
a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
|
||||
a.ForceApprovalPrompt, a.LoggedIn,
|
||||
a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified,
|
||||
encoder(a.Claims.Groups),
|
||||
a.Claims.UserID, a.Claims.Username, a.Claims.PreferredUsername,
|
||||
a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
|
||||
a.ConnectorID, a.ConnectorData,
|
||||
a.Expiry,
|
||||
)
|
||||
@@ -149,16 +149,17 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
|
||||
set
|
||||
client_id = $1, response_types = $2, scopes = $3, redirect_uri = $4,
|
||||
nonce = $5, state = $6, force_approval_prompt = $7, logged_in = $8,
|
||||
claims_user_id = $9, claims_username = $10, claims_email = $11,
|
||||
claims_email_verified = $12,
|
||||
claims_groups = $13,
|
||||
connector_id = $14, connector_data = $15,
|
||||
expiry = $16
|
||||
where id = $17;
|
||||
claims_user_id = $9, claims_username = $10, claims_preferred_username = $11,
|
||||
claims_email = $12, claims_email_verified = $13,
|
||||
claims_groups = $14,
|
||||
connector_id = $15, connector_data = $16,
|
||||
expiry = $17
|
||||
where id = $18;
|
||||
`,
|
||||
a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
|
||||
a.ForceApprovalPrompt, a.LoggedIn,
|
||||
a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified,
|
||||
a.Claims.UserID, a.Claims.Username, a.Claims.PreferredUsername,
|
||||
a.Claims.Email, a.Claims.EmailVerified,
|
||||
encoder(a.Claims.Groups),
|
||||
a.ConnectorID, a.ConnectorData,
|
||||
a.Expiry, r.ID,
|
||||
@@ -168,7 +169,6 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||
@@ -177,17 +177,18 @@ func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
|
||||
|
||||
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
|
||||
err = q.QueryRow(`
|
||||
select
|
||||
select
|
||||
id, client_id, response_types, scopes, redirect_uri, nonce, state,
|
||||
force_approval_prompt, logged_in,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
claims_user_id, claims_username, claims_preferred_username,
|
||||
claims_email, claims_email_verified, claims_groups,
|
||||
connector_id, connector_data, expiry
|
||||
from auth_request where id = $1;
|
||||
`, id).Scan(
|
||||
&a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State,
|
||||
&a.ForceApprovalPrompt, &a.LoggedIn,
|
||||
&a.Claims.UserID, &a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified,
|
||||
&a.Claims.UserID, &a.Claims.Username, &a.Claims.PreferredUsername,
|
||||
&a.Claims.Email, &a.Claims.EmailVerified,
|
||||
decoder(&a.Claims.Groups),
|
||||
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
|
||||
)
|
||||
@@ -204,16 +205,16 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error {
|
||||
_, err := c.Exec(`
|
||||
insert into auth_code (
|
||||
id, client_id, scopes, nonce, redirect_uri,
|
||||
claims_user_id, claims_username,
|
||||
claims_user_id, claims_username, claims_preferred_username,
|
||||
claims_email, claims_email_verified, claims_groups,
|
||||
connector_id, connector_data,
|
||||
expiry
|
||||
)
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13);
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
|
||||
`,
|
||||
a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID,
|
||||
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
|
||||
a.ConnectorID, a.ConnectorData, a.Expiry,
|
||||
a.Claims.Username, a.Claims.PreferredUsername, a.Claims.Email, a.Claims.EmailVerified,
|
||||
encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
@@ -229,15 +230,15 @@ func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
|
||||
err = c.QueryRow(`
|
||||
select
|
||||
id, client_id, scopes, nonce, redirect_uri,
|
||||
claims_user_id, claims_username,
|
||||
claims_user_id, claims_username, claims_preferred_username,
|
||||
claims_email, claims_email_verified, claims_groups,
|
||||
connector_id, connector_data,
|
||||
expiry
|
||||
from auth_code where id = $1;
|
||||
`, id).Scan(
|
||||
&a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID,
|
||||
&a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups),
|
||||
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
|
||||
&a.Claims.Username, &a.Claims.PreferredUsername, &a.Claims.Email, &a.Claims.EmailVerified,
|
||||
decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -252,15 +253,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
|
||||
_, err := c.Exec(`
|
||||
insert into refresh_token (
|
||||
id, client_id, scopes, nonce,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
claims_user_id, claims_username, claims_preferred_username,
|
||||
claims_email, claims_email_verified, claims_groups,
|
||||
connector_id, connector_data,
|
||||
token, created_at, last_used
|
||||
)
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15);
|
||||
`,
|
||||
r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
|
||||
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
|
||||
r.Claims.Email, r.Claims.EmailVerified,
|
||||
encoder(r.Claims.Groups),
|
||||
r.ConnectorID, r.ConnectorData,
|
||||
r.Token, r.CreatedAt, r.LastUsed,
|
||||
@@ -291,19 +293,21 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
|
||||
nonce = $3,
|
||||
claims_user_id = $4,
|
||||
claims_username = $5,
|
||||
claims_email = $6,
|
||||
claims_email_verified = $7,
|
||||
claims_groups = $8,
|
||||
connector_id = $9,
|
||||
connector_data = $10,
|
||||
token = $11,
|
||||
created_at = $12,
|
||||
last_used = $13
|
||||
claims_preferred_username = $6,
|
||||
claims_email = $7,
|
||||
claims_email_verified = $8,
|
||||
claims_groups = $9,
|
||||
connector_id = $10,
|
||||
connector_data = $11,
|
||||
token = $12,
|
||||
created_at = $13,
|
||||
last_used = $14
|
||||
where
|
||||
id = $14
|
||||
id = $15
|
||||
`,
|
||||
r.ClientID, encoder(r.Scopes), r.Nonce,
|
||||
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
|
||||
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
|
||||
r.Claims.Email, r.Claims.EmailVerified,
|
||||
encoder(r.Claims.Groups),
|
||||
r.ConnectorID, r.ConnectorData,
|
||||
r.Token, r.CreatedAt, r.LastUsed, id,
|
||||
@@ -323,7 +327,8 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) {
|
||||
return scanRefresh(q.QueryRow(`
|
||||
select
|
||||
id, client_id, scopes, nonce,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_user_id, claims_username, claims_preferred_username,
|
||||
claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
connector_id, connector_data,
|
||||
token, created_at, last_used
|
||||
@@ -335,8 +340,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||
rows, err := c.Query(`
|
||||
select
|
||||
id, client_id, scopes, nonce,
|
||||
claims_user_id, claims_username, claims_email, claims_email_verified,
|
||||
claims_groups,
|
||||
claims_user_id, claims_username, claims_preferred_username,
|
||||
claims_email, claims_email_verified, claims_groups,
|
||||
connector_id, connector_data,
|
||||
token, created_at, last_used
|
||||
from refresh_token;
|
||||
@@ -361,7 +366,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
|
||||
func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
|
||||
err = s.Scan(
|
||||
&r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
|
||||
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
|
||||
&r.Claims.UserID, &r.Claims.Username, &r.Claims.PreferredUsername,
|
||||
&r.Claims.Email, &r.Claims.EmailVerified,
|
||||
decoder(&r.Claims.Groups),
|
||||
&r.ConnectorID, &r.ConnectorData,
|
||||
&r.Token, &r.CreatedAt, &r.LastUsed,
|
||||
@@ -410,7 +416,7 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
|
||||
} else {
|
||||
_, err = tx.Exec(`
|
||||
update keys
|
||||
set
|
||||
set
|
||||
verification_keys = $1,
|
||||
signing_key = $2,
|
||||
signing_key_pub = $3,
|
||||
@@ -648,13 +654,13 @@ func scanPassword(s scanner) (p storage.Password, err error) {
|
||||
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
|
||||
_, err := c.Exec(`
|
||||
insert into offline_session (
|
||||
user_id, conn_id, refresh
|
||||
user_id, conn_id, refresh, connector_data
|
||||
)
|
||||
values (
|
||||
$1, $2, $3
|
||||
$1, $2, $3, $4
|
||||
);
|
||||
`,
|
||||
s.UserID, s.ConnID, encoder(s.Refresh),
|
||||
s.UserID, s.ConnID, encoder(s.Refresh), s.ConnectorData,
|
||||
)
|
||||
if err != nil {
|
||||
if c.alreadyExistsCheck(err) {
|
||||
@@ -679,10 +685,11 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
|
||||
_, err = tx.Exec(`
|
||||
update offline_session
|
||||
set
|
||||
refresh = $1
|
||||
where user_id = $2 AND conn_id = $3;
|
||||
refresh = $1,
|
||||
connector_data = $2
|
||||
where user_id = $3 AND conn_id = $4;
|
||||
`,
|
||||
encoder(newSession.Refresh), s.UserID, s.ConnID,
|
||||
encoder(newSession.Refresh), newSession.ConnectorData, s.UserID, s.ConnID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update offline session: %v", err)
|
||||
@@ -698,7 +705,7 @@ func (c *conn) GetOfflineSessions(userID string, connID string) (storage.Offline
|
||||
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
|
||||
return scanOfflineSessions(q.QueryRow(`
|
||||
select
|
||||
user_id, conn_id, refresh
|
||||
user_id, conn_id, refresh, connector_data
|
||||
from offline_session
|
||||
where user_id = $1 AND conn_id = $2;
|
||||
`, userID, connID))
|
||||
@@ -706,7 +713,7 @@ func getOfflineSessions(q querier, userID string, connID string) (storage.Offlin
|
||||
|
||||
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
|
||||
err = s.Scan(
|
||||
&o.UserID, &o.ConnID, decoder(&o.Refresh),
|
||||
&o.UserID, &o.ConnID, decoder(&o.Refresh), &o.ConnectorData,
|
||||
)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -750,7 +757,7 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
|
||||
}
|
||||
_, err = tx.Exec(`
|
||||
update connector
|
||||
set
|
||||
set
|
||||
type = $1,
|
||||
name = $2,
|
||||
resource_version = $3,
|
||||
|
@@ -90,18 +90,18 @@ var migrations = []migration{
|
||||
nonce text not null,
|
||||
state text not null,
|
||||
force_approval_prompt boolean not null,
|
||||
|
||||
|
||||
logged_in boolean not null,
|
||||
|
||||
|
||||
claims_user_id text not null,
|
||||
claims_username text not null,
|
||||
claims_email text not null,
|
||||
claims_email_verified boolean not null,
|
||||
claims_groups bytea not null, -- JSON array of strings
|
||||
|
||||
|
||||
connector_id text not null,
|
||||
connector_data bytea,
|
||||
|
||||
|
||||
expiry timestamptz not null
|
||||
);`,
|
||||
`
|
||||
@@ -111,16 +111,16 @@ var migrations = []migration{
|
||||
scopes bytea not null, -- JSON array of strings
|
||||
nonce text not null,
|
||||
redirect_uri text not null,
|
||||
|
||||
|
||||
claims_user_id text not null,
|
||||
claims_username text not null,
|
||||
claims_email text not null,
|
||||
claims_email_verified boolean not null,
|
||||
claims_groups bytea not null, -- JSON array of strings
|
||||
|
||||
|
||||
connector_id text not null,
|
||||
connector_data bytea,
|
||||
|
||||
|
||||
expiry timestamptz not null
|
||||
);`,
|
||||
`
|
||||
@@ -129,13 +129,13 @@ var migrations = []migration{
|
||||
client_id text not null,
|
||||
scopes bytea not null, -- JSON array of strings
|
||||
nonce text not null,
|
||||
|
||||
|
||||
claims_user_id text not null,
|
||||
claims_username text not null,
|
||||
claims_email text not null,
|
||||
claims_email_verified boolean not null,
|
||||
claims_groups bytea not null, -- JSON array of strings
|
||||
|
||||
|
||||
connector_id text not null,
|
||||
connector_data bytea
|
||||
);`,
|
||||
@@ -190,4 +190,23 @@ var migrations = []migration{
|
||||
);`,
|
||||
},
|
||||
},
|
||||
{
|
||||
stmts: []string{`
|
||||
alter table auth_code
|
||||
add column claims_preferred_username text not null default '';`,
|
||||
`
|
||||
alter table auth_request
|
||||
add column claims_preferred_username text not null default '';`,
|
||||
`
|
||||
alter table refresh_token
|
||||
add column claims_preferred_username text not null default '';`,
|
||||
},
|
||||
},
|
||||
{
|
||||
stmts: []string{`
|
||||
alter table offline_session
|
||||
add column connector_data bytea;
|
||||
`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@ package sql
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -12,12 +13,24 @@ func TestPostgresTunables(t *testing.T) {
|
||||
if host == "" {
|
||||
t.Skipf("test environment variable %q not set, skipping", testPostgresEnv)
|
||||
}
|
||||
|
||||
port := uint64(5432)
|
||||
if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" {
|
||||
var err error
|
||||
|
||||
port, err = strconv.ParseUint(rawPort, 10, 32)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid postgres port %q: %s", rawPort, err)
|
||||
}
|
||||
}
|
||||
|
||||
baseCfg := &Postgres{
|
||||
NetworkDB: NetworkDB{
|
||||
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
|
||||
User: getenv("DEX_POSTGRES_USER", "postgres"),
|
||||
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
|
||||
Host: host,
|
||||
Port: uint16(port),
|
||||
},
|
||||
SSL: SSL{
|
||||
Mode: pgSSLDisable, // Postgres container doesn't support SSL.
|
||||
|
@@ -137,10 +137,11 @@ type Client struct {
|
||||
|
||||
// Claims represents the ID Token claims supported by the server.
|
||||
type Claims struct {
|
||||
UserID string
|
||||
Username string
|
||||
Email string
|
||||
EmailVerified bool
|
||||
UserID string
|
||||
Username string
|
||||
PreferredUsername string
|
||||
Email string
|
||||
EmailVerified bool
|
||||
|
||||
Groups []string
|
||||
}
|
||||
@@ -272,6 +273,9 @@ type OfflineSessions struct {
|
||||
// Refresh is a hash table of refresh token reference objects
|
||||
// indexed by the ClientID of the refresh token.
|
||||
Refresh map[string]*RefreshTokenRef
|
||||
|
||||
// Authentication data provided by an upstream source.
|
||||
ConnectorData []byte
|
||||
}
|
||||
|
||||
// Password is an email to password mapping managed by the storage.
|
||||
|
Reference in New Issue
Block a user