Merge branch 'master' into conformance_tests_improvements

This commit is contained in:
Nándor István Krácser
2019-12-20 09:56:59 +01:00
committed by GitHub
241 changed files with 46456 additions and 10473 deletions

View File

@@ -511,15 +511,15 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
_, err = s.GetPassword(password1.Email)
mustBeErrNotFound(t, "password", err)
}
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
userID1 := storage.NewID()
session1 := storage.OfflineSessions{
UserID: userID1,
ConnID: "Conn1",
Refresh: make(map[string]*storage.RefreshTokenRef),
UserID: userID1,
ConnID: "Conn1",
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
}
// Creating an OfflineSession with an empty Refresh list to ensure that
@@ -534,9 +534,10 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
userID2 := storage.NewID()
session2 := storage.OfflineSessions{
UserID: userID2,
ConnID: "Conn2",
Refresh: make(map[string]*storage.RefreshTokenRef),
UserID: userID2,
ConnID: "Conn2",
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
}
if err := s.CreateOfflineSessions(session2); err != nil {

View File

@@ -156,7 +156,7 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) {
var current RefreshToken
if len(currentValue) > 0 {
if err := json.Unmarshal([]byte(currentValue), &current); err != nil {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}

View File

@@ -148,30 +148,33 @@ func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken {
// Claims is a mirrored struct from storage with JSON struct tags.
type Claims struct {
UserID string `json:"userID"`
Username string `json:"username"`
Email string `json:"email"`
EmailVerified bool `json:"emailVerified"`
Groups []string `json:"groups,omitempty"`
UserID string `json:"userID"`
Username string `json:"username"`
PreferredUsername string `json:"preferredUsername"`
Email string `json:"email"`
EmailVerified bool `json:"emailVerified"`
Groups []string `json:"groups,omitempty"`
}
func fromStorageClaims(i storage.Claims) Claims {
return Claims{
UserID: i.UserID,
Username: i.Username,
Email: i.Email,
EmailVerified: i.EmailVerified,
Groups: i.Groups,
UserID: i.UserID,
Username: i.Username,
PreferredUsername: i.PreferredUsername,
Email: i.Email,
EmailVerified: i.EmailVerified,
Groups: i.Groups,
}
}
func toStorageClaims(i Claims) storage.Claims {
return storage.Claims{
UserID: i.UserID,
Username: i.Username,
Email: i.Email,
EmailVerified: i.EmailVerified,
Groups: i.Groups,
UserID: i.UserID,
Username: i.Username,
PreferredUsername: i.PreferredUsername,
Email: i.Email,
EmailVerified: i.EmailVerified,
Groups: i.Groups,
}
}
@@ -185,24 +188,27 @@ type Keys struct {
// OfflineSessions is a mirrored struct from storage with JSON struct tags
type OfflineSessions struct {
UserID string `json:"user_id,omitempty"`
ConnID string `json:"conn_id,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
UserID string `json:"user_id,omitempty"`
ConnID string `json:"conn_id,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
ConnectorData []byte `json:"connectorData,omitempty"`
}
func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
return OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
}
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
s := storage.OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
if s.Refresh == nil {
// Server code assumes this will be non-nil.

View File

@@ -55,14 +55,14 @@ type client struct {
}
// idToName maps an arbitrary ID, such as an email or client ID to a Kubernetes object name.
func (c *client) idToName(s string) string {
return idToName(s, c.hash)
func (cli *client) idToName(s string) string {
return idToName(s, cli.hash)
}
// offlineTokenName maps two arbitrary IDs, to a single Kubernetes object name.
// This is used when more than one field is used to uniquely identify the object.
func (c *client) offlineTokenName(userID string, connID string) string {
return offlineTokenName(userID, connID, c.hash)
func (cli *client) offlineTokenName(userID string, connID string) string {
return offlineTokenName(userID, connID, cli.hash)
}
// Kubernetes names must match the regexp '[a-z0-9]([-a-z0-9]*[a-z0-9])?'.
@@ -79,7 +79,7 @@ func offlineTokenName(userID string, connID string, h func() hash.Hash) string {
return strings.TrimRight(encoding.EncodeToString(hash.Sum(nil)), "=")
}
func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
func (cli *client) urlFor(apiVersion, namespace, resource, name string) string {
basePath := "apis/"
if apiVersion == "v1" {
basePath = "api/"
@@ -91,10 +91,10 @@ func (c *client) urlFor(apiVersion, namespace, resource, name string) string {
} else {
p = path.Join(basePath, apiVersion, resource, name)
}
if strings.HasSuffix(c.baseURL, "/") {
return c.baseURL + p
if strings.HasSuffix(cli.baseURL, "/") {
return cli.baseURL + p
}
return c.baseURL + "/" + p
return cli.baseURL + "/" + p
}
// Define an error interface so we can get at the underlying status code if it's
@@ -156,13 +156,13 @@ func closeResp(r *http.Response) {
r.Body.Close()
}
func (c *client) get(resource, name string, v interface{}) error {
return c.getResource(c.apiVersion, c.namespace, resource, name, v)
func (cli *client) get(resource, name string, v interface{}) error {
return cli.getResource(cli.apiVersion, cli.namespace, resource, name, v)
}
func (c *client) getResource(apiVersion, namespace, resource, name string, v interface{}) error {
url := c.urlFor(apiVersion, namespace, resource, name)
resp, err := c.client.Get(url)
func (cli *client) getResource(apiVersion, namespace, resource, name string, v interface{}) error {
url := cli.urlFor(apiVersion, namespace, resource, name)
resp, err := cli.client.Get(url)
if err != nil {
return err
}
@@ -173,22 +173,22 @@ func (c *client) getResource(apiVersion, namespace, resource, name string, v int
return json.NewDecoder(resp.Body).Decode(v)
}
func (c *client) list(resource string, v interface{}) error {
return c.get(resource, "", v)
func (cli *client) list(resource string, v interface{}) error {
return cli.get(resource, "", v)
}
func (c *client) post(resource string, v interface{}) error {
return c.postResource(c.apiVersion, c.namespace, resource, v)
func (cli *client) post(resource string, v interface{}) error {
return cli.postResource(cli.apiVersion, cli.namespace, resource, v)
}
func (c *client) postResource(apiVersion, namespace, resource string, v interface{}) error {
func (cli *client) postResource(apiVersion, namespace, resource string, v interface{}) error {
body, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("marshal object: %v", err)
}
url := c.urlFor(apiVersion, namespace, resource, "")
resp, err := c.client.Post(url, "application/json", bytes.NewReader(body))
url := cli.urlFor(apiVersion, namespace, resource, "")
resp, err := cli.client.Post(url, "application/json", bytes.NewReader(body))
if err != nil {
return err
}
@@ -196,13 +196,13 @@ func (c *client) postResource(apiVersion, namespace, resource string, v interfac
return checkHTTPErr(resp, http.StatusCreated)
}
func (c *client) delete(resource, name string) error {
url := c.urlFor(c.apiVersion, c.namespace, resource, name)
func (cli *client) delete(resource, name string) error {
url := cli.urlFor(cli.apiVersion, cli.namespace, resource, name)
req, err := http.NewRequest("DELETE", url, nil)
if err != nil {
return fmt.Errorf("create delete request: %v", err)
}
resp, err := c.client.Do(req)
resp, err := cli.client.Do(req)
if err != nil {
return fmt.Errorf("delete request: %v", err)
}
@@ -210,7 +210,7 @@ func (c *client) delete(resource, name string) error {
return checkHTTPErr(resp, http.StatusOK)
}
func (c *client) deleteAll(resource string) error {
func (cli *client) deleteAll(resource string) error {
var list struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ListMeta `json:"metadata,omitempty"`
@@ -219,24 +219,24 @@ func (c *client) deleteAll(resource string) error {
k8sapi.ObjectMeta `json:"metadata,omitempty"`
} `json:"items"`
}
if err := c.list(resource, &list); err != nil {
if err := cli.list(resource, &list); err != nil {
return err
}
for _, item := range list.Items {
if err := c.delete(resource, item.Name); err != nil {
if err := cli.delete(resource, item.Name); err != nil {
return err
}
}
return nil
}
func (c *client) put(resource, name string, v interface{}) error {
func (cli *client) put(resource, name string, v interface{}) error {
body, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("marshal object: %v", err)
}
url := c.urlFor(c.apiVersion, c.namespace, resource, name)
url := cli.urlFor(cli.apiVersion, cli.namespace, resource, name)
req, err := http.NewRequest("PUT", url, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("create patch request: %v", err)
@@ -244,7 +244,7 @@ func (c *client) put(resource, name string, v interface{}) error {
req.Header.Set("Content-Length", strconv.Itoa(len(body)))
resp, err := c.client.Do(req)
resp, err := cli.client.Do(req)
if err != nil {
return fmt.Errorf("patch request: %v", err)
}

View File

@@ -43,7 +43,7 @@ type CustomResourceDefinitionNames struct {
ListKind string `json:"listKind,omitempty" protobuf:"bytes,5,opt,name=listKind"`
}
// ResourceScope is an enum defining the different scopes availabe to a custom resource
// ResourceScope is an enum defining the different scopes available to a custom resource
type ResourceScope string
const (

View File

@@ -210,30 +210,33 @@ func toStorageClient(c Client) storage.Client {
// Claims is a mirrored struct from storage with JSON struct tags.
type Claims struct {
UserID string `json:"userID"`
Username string `json:"username"`
Email string `json:"email"`
EmailVerified bool `json:"emailVerified"`
Groups []string `json:"groups,omitempty"`
UserID string `json:"userID"`
Username string `json:"username"`
PreferredUsername string `json:"preferredUsername"`
Email string `json:"email"`
EmailVerified bool `json:"emailVerified"`
Groups []string `json:"groups,omitempty"`
}
func fromStorageClaims(i storage.Claims) Claims {
return Claims{
UserID: i.UserID,
Username: i.Username,
Email: i.Email,
EmailVerified: i.EmailVerified,
Groups: i.Groups,
UserID: i.UserID,
Username: i.Username,
PreferredUsername: i.PreferredUsername,
Email: i.Email,
EmailVerified: i.EmailVerified,
Groups: i.Groups,
}
}
func toStorageClaims(i Claims) storage.Claims {
return storage.Claims{
UserID: i.UserID,
Username: i.Username,
Email: i.Email,
EmailVerified: i.EmailVerified,
Groups: i.Groups,
UserID: i.UserID,
Username: i.Username,
PreferredUsername: i.PreferredUsername,
Email: i.Email,
EmailVerified: i.EmailVerified,
Groups: i.Groups,
}
}
@@ -549,9 +552,10 @@ type OfflineSessions struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
UserID string `json:"userID,omitempty"`
ConnID string `json:"connID,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
UserID string `json:"userID,omitempty"`
ConnID string `json:"connID,omitempty"`
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"`
ConnectorData []byte `json:"connectorData,omitempty"`
}
func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions {
@@ -564,17 +568,19 @@ func (cli *client) fromStorageOfflineSessions(o storage.OfflineSessions) Offline
Name: cli.offlineTokenName(o.UserID, o.ConnID),
Namespace: cli.namespace,
},
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
}
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
s := storage.OfflineSessions{
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
UserID: o.UserID,
ConnID: o.ConnID,
Refresh: o.Refresh,
ConnectorData: o.ConnectorData,
}
if s.Refresh == nil {
// Server code assumes this will be non-nil.

View File

@@ -29,6 +29,7 @@ const (
// MySQL error codes
mysqlErrDupEntry = 1062
mysqlErrDupEntryWithKeyName = 1586
mysqlErrUnknownSysVar = 1193
)
// SQLite3 options for creating an SQL db.
@@ -307,6 +308,26 @@ func (s *MySQL) open(logger log.Logger) (*conn, error) {
return nil, err
}
err = db.Ping()
if err != nil {
if mysqlErr, ok := err.(*mysql.MySQLError); ok && mysqlErr.Number == mysqlErrUnknownSysVar {
logger.Info("reconnecting with MySQL pre-5.7.20 compatibility mode")
// MySQL 5.7.20 introduced transaction_isolation and deprecated tx_isolation.
// MySQL 8.0 doesn't have tx_isolation at all.
// https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_transaction_isolation
delete(cfg.Params, "transaction_isolation")
cfg.Params["tx_isolation"] = "'SERIALIZABLE'"
db, err = sql.Open("mysql", cfg.FormatDSN())
if err != nil {
return nil, err
}
} else {
return nil, err
}
}
errCheck := func(err error) bool {
sqlErr, ok := err.(*mysql.MySQLError)
if !ok {

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"os"
"runtime"
"strconv"
"testing"
"time"
@@ -220,12 +221,24 @@ func TestPostgres(t *testing.T) {
if host == "" {
t.Skipf("test environment variable %q not set, skipping", testPostgresEnv)
}
port := uint64(5432)
if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" {
var err error
port, err = strconv.ParseUint(rawPort, 10, 32)
if err != nil {
t.Fatalf("invalid postgres port %q: %s", rawPort, err)
}
}
p := &Postgres{
NetworkDB: NetworkDB{
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
User: getenv("DEX_POSTGRES_USER", "postgres"),
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
Host: host,
Port: uint16(port),
ConnectionTimeout: 5,
},
SSL: SSL{
@@ -242,12 +255,24 @@ func TestMySQL(t *testing.T) {
if host == "" {
t.Skipf("test environment variable %q not set, skipping", testMySQLEnv)
}
port := uint64(3306)
if rawPort := os.Getenv("DEX_MYSQL_PORT"); rawPort != "" {
var err error
port, err = strconv.ParseUint(rawPort, 10, 32)
if err != nil {
t.Fatalf("invalid mysql port %q: %s", rawPort, err)
}
}
s := &MySQL{
NetworkDB: NetworkDB{
Database: getenv("DEX_MYSQL_DATABASE", "mysql"),
User: getenv("DEX_MYSQL_USER", "mysql"),
Password: getenv("DEX_MYSQL_PASSWORD", ""),
Host: host,
Port: uint16(port),
ConnectionTimeout: 5,
},
SSL: SSL{

View File

@@ -108,19 +108,19 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error {
insert into auth_request (
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
expiry
)
values (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
);
`,
a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
a.ForceApprovalPrompt, a.LoggedIn,
a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups),
a.Claims.UserID, a.Claims.Username, a.Claims.PreferredUsername,
a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData,
a.Expiry,
)
@@ -149,16 +149,17 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
set
client_id = $1, response_types = $2, scopes = $3, redirect_uri = $4,
nonce = $5, state = $6, force_approval_prompt = $7, logged_in = $8,
claims_user_id = $9, claims_username = $10, claims_email = $11,
claims_email_verified = $12,
claims_groups = $13,
connector_id = $14, connector_data = $15,
expiry = $16
where id = $17;
claims_user_id = $9, claims_username = $10, claims_preferred_username = $11,
claims_email = $12, claims_email_verified = $13,
claims_groups = $14,
connector_id = $15, connector_data = $16,
expiry = $17
where id = $18;
`,
a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State,
a.ForceApprovalPrompt, a.LoggedIn,
a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified,
a.Claims.UserID, a.Claims.Username, a.Claims.PreferredUsername,
a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData,
a.Expiry, r.ID,
@@ -168,7 +169,6 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest)
}
return nil
})
}
func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
@@ -177,17 +177,18 @@ func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) {
func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) {
err = q.QueryRow(`
select
select
id, client_id, response_types, scopes, redirect_uri, nonce, state,
force_approval_prompt, logged_in,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data, expiry
from auth_request where id = $1;
`, id).Scan(
&a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State,
&a.ForceApprovalPrompt, &a.LoggedIn,
&a.Claims.UserID, &a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified,
&a.Claims.UserID, &a.Claims.Username, &a.Claims.PreferredUsername,
&a.Claims.Email, &a.Claims.EmailVerified,
decoder(&a.Claims.Groups),
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
@@ -204,16 +205,16 @@ func (c *conn) CreateAuthCode(a storage.AuthCode) error {
_, err := c.Exec(`
insert into auth_code (
id, client_id, scopes, nonce, redirect_uri,
claims_user_id, claims_username,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
expiry
)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13);
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
`,
a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID,
a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups),
a.ConnectorID, a.ConnectorData, a.Expiry,
a.Claims.Username, a.Claims.PreferredUsername, a.Claims.Email, a.Claims.EmailVerified,
encoder(a.Claims.Groups), a.ConnectorID, a.ConnectorData, a.Expiry,
)
if err != nil {
@@ -229,15 +230,15 @@ func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) {
err = c.QueryRow(`
select
id, client_id, scopes, nonce, redirect_uri,
claims_user_id, claims_username,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
expiry
from auth_code where id = $1;
`, id).Scan(
&a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID,
&a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups),
&a.ConnectorID, &a.ConnectorData, &a.Expiry,
&a.Claims.Username, &a.Claims.PreferredUsername, &a.Claims.Email, &a.Claims.EmailVerified,
decoder(&a.Claims.Groups), &a.ConnectorID, &a.ConnectorData, &a.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
@@ -252,15 +253,16 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error {
_, err := c.Exec(`
insert into refresh_token (
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
token, created_at, last_used
)
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14);
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15);
`,
r.ID, r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed,
@@ -291,19 +293,21 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok
nonce = $3,
claims_user_id = $4,
claims_username = $5,
claims_email = $6,
claims_email_verified = $7,
claims_groups = $8,
connector_id = $9,
connector_data = $10,
token = $11,
created_at = $12,
last_used = $13
claims_preferred_username = $6,
claims_email = $7,
claims_email_verified = $8,
claims_groups = $9,
connector_id = $10,
connector_data = $11,
token = $12,
created_at = $13,
last_used = $14
where
id = $14
id = $15
`,
r.ClientID, encoder(r.Scopes), r.Nonce,
r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified,
r.Claims.UserID, r.Claims.Username, r.Claims.PreferredUsername,
r.Claims.Email, r.Claims.EmailVerified,
encoder(r.Claims.Groups),
r.ConnectorID, r.ConnectorData,
r.Token, r.CreatedAt, r.LastUsed, id,
@@ -323,7 +327,8 @@ func getRefresh(q querier, id string) (storage.RefreshToken, error) {
return scanRefresh(q.QueryRow(`
select
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified,
claims_groups,
connector_id, connector_data,
token, created_at, last_used
@@ -335,8 +340,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
rows, err := c.Query(`
select
id, client_id, scopes, nonce,
claims_user_id, claims_username, claims_email, claims_email_verified,
claims_groups,
claims_user_id, claims_username, claims_preferred_username,
claims_email, claims_email_verified, claims_groups,
connector_id, connector_data,
token, created_at, last_used
from refresh_token;
@@ -361,7 +366,8 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) {
func scanRefresh(s scanner) (r storage.RefreshToken, err error) {
err = s.Scan(
&r.ID, &r.ClientID, decoder(&r.Scopes), &r.Nonce,
&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified,
&r.Claims.UserID, &r.Claims.Username, &r.Claims.PreferredUsername,
&r.Claims.Email, &r.Claims.EmailVerified,
decoder(&r.Claims.Groups),
&r.ConnectorID, &r.ConnectorData,
&r.Token, &r.CreatedAt, &r.LastUsed,
@@ -410,7 +416,7 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error))
} else {
_, err = tx.Exec(`
update keys
set
set
verification_keys = $1,
signing_key = $2,
signing_key_pub = $3,
@@ -648,13 +654,13 @@ func scanPassword(s scanner) (p storage.Password, err error) {
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error {
_, err := c.Exec(`
insert into offline_session (
user_id, conn_id, refresh
user_id, conn_id, refresh, connector_data
)
values (
$1, $2, $3
$1, $2, $3, $4
);
`,
s.UserID, s.ConnID, encoder(s.Refresh),
s.UserID, s.ConnID, encoder(s.Refresh), s.ConnectorData,
)
if err != nil {
if c.alreadyExistsCheck(err) {
@@ -679,10 +685,11 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(
_, err = tx.Exec(`
update offline_session
set
refresh = $1
where user_id = $2 AND conn_id = $3;
refresh = $1,
connector_data = $2
where user_id = $3 AND conn_id = $4;
`,
encoder(newSession.Refresh), s.UserID, s.ConnID,
encoder(newSession.Refresh), newSession.ConnectorData, s.UserID, s.ConnID,
)
if err != nil {
return fmt.Errorf("update offline session: %v", err)
@@ -698,7 +705,7 @@ func (c *conn) GetOfflineSessions(userID string, connID string) (storage.Offline
func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) {
return scanOfflineSessions(q.QueryRow(`
select
user_id, conn_id, refresh
user_id, conn_id, refresh, connector_data
from offline_session
where user_id = $1 AND conn_id = $2;
`, userID, connID))
@@ -706,7 +713,7 @@ func getOfflineSessions(q querier, userID string, connID string) (storage.Offlin
func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) {
err = s.Scan(
&o.UserID, &o.ConnID, decoder(&o.Refresh),
&o.UserID, &o.ConnID, decoder(&o.Refresh), &o.ConnectorData,
)
if err != nil {
if err == sql.ErrNoRows {
@@ -750,7 +757,7 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto
}
_, err = tx.Exec(`
update connector
set
set
type = $1,
name = $2,
resource_version = $3,

View File

@@ -90,18 +90,18 @@ var migrations = []migration{
nonce text not null,
state text not null,
force_approval_prompt boolean not null,
logged_in boolean not null,
claims_user_id text not null,
claims_username text not null,
claims_email text not null,
claims_email_verified boolean not null,
claims_groups bytea not null, -- JSON array of strings
connector_id text not null,
connector_data bytea,
expiry timestamptz not null
);`,
`
@@ -111,16 +111,16 @@ var migrations = []migration{
scopes bytea not null, -- JSON array of strings
nonce text not null,
redirect_uri text not null,
claims_user_id text not null,
claims_username text not null,
claims_email text not null,
claims_email_verified boolean not null,
claims_groups bytea not null, -- JSON array of strings
connector_id text not null,
connector_data bytea,
expiry timestamptz not null
);`,
`
@@ -129,13 +129,13 @@ var migrations = []migration{
client_id text not null,
scopes bytea not null, -- JSON array of strings
nonce text not null,
claims_user_id text not null,
claims_username text not null,
claims_email text not null,
claims_email_verified boolean not null,
claims_groups bytea not null, -- JSON array of strings
connector_id text not null,
connector_data bytea
);`,
@@ -190,4 +190,23 @@ var migrations = []migration{
);`,
},
},
{
stmts: []string{`
alter table auth_code
add column claims_preferred_username text not null default '';`,
`
alter table auth_request
add column claims_preferred_username text not null default '';`,
`
alter table refresh_token
add column claims_preferred_username text not null default '';`,
},
},
{
stmts: []string{`
alter table offline_session
add column connector_data bytea;
`,
},
},
}

View File

@@ -4,6 +4,7 @@ package sql
import (
"os"
"strconv"
"testing"
)
@@ -12,12 +13,24 @@ func TestPostgresTunables(t *testing.T) {
if host == "" {
t.Skipf("test environment variable %q not set, skipping", testPostgresEnv)
}
port := uint64(5432)
if rawPort := os.Getenv("DEX_POSTGRES_PORT"); rawPort != "" {
var err error
port, err = strconv.ParseUint(rawPort, 10, 32)
if err != nil {
t.Fatalf("invalid postgres port %q: %s", rawPort, err)
}
}
baseCfg := &Postgres{
NetworkDB: NetworkDB{
Database: getenv("DEX_POSTGRES_DATABASE", "postgres"),
User: getenv("DEX_POSTGRES_USER", "postgres"),
Password: getenv("DEX_POSTGRES_PASSWORD", "postgres"),
Host: host,
Port: uint16(port),
},
SSL: SSL{
Mode: pgSSLDisable, // Postgres container doesn't support SSL.

View File

@@ -137,10 +137,11 @@ type Client struct {
// Claims represents the ID Token claims supported by the server.
type Claims struct {
UserID string
Username string
Email string
EmailVerified bool
UserID string
Username string
PreferredUsername string
Email string
EmailVerified bool
Groups []string
}
@@ -272,6 +273,9 @@ type OfflineSessions struct {
// Refresh is a hash table of refresh token reference objects
// indexed by the ClientID of the refresh token.
Refresh map[string]*RefreshTokenRef
// Authentication data provided by an upstream source.
ConnectorData []byte
}
// Password is an email to password mapping managed by the storage.