storage/sql: add a SQL storage implementation
This change adds support for SQLite3, and Postgres.
This commit is contained in:
		
							
								
								
									
										113
									
								
								storage/sql/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								storage/sql/config.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"strconv" | ||||
|  | ||||
| 	"github.com/coreos/dex/storage" | ||||
| ) | ||||
|  | ||||
| // SQLite3 options for creating an SQL db. | ||||
| type SQLite3 struct { | ||||
| 	// File to | ||||
| 	File string `yaml:"file"` | ||||
| } | ||||
|  | ||||
| // Open creates a new storage implementation backed by SQLite3 | ||||
| func (s *SQLite3) Open() (storage.Storage, error) { | ||||
| 	return s.open() | ||||
| } | ||||
|  | ||||
| func (s *SQLite3) open() (*conn, error) { | ||||
| 	db, err := sql.Open("sqlite3", s.File) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if s.File == ":memory:" { | ||||
| 		// sqlite3 uses file locks to coordinate concurrent access. In memory | ||||
| 		// doesn't support this, so limit the number of connections to 1. | ||||
| 		db.SetMaxOpenConns(1) | ||||
| 	} | ||||
| 	c := &conn{db, flavorSQLite3} | ||||
| 	if _, err := c.migrate(); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to perform migrations: %v", err) | ||||
| 	} | ||||
| 	return c, nil | ||||
| } | ||||
|  | ||||
| const ( | ||||
| 	sslDisable    = "disable" | ||||
| 	sslRequire    = "require" | ||||
| 	sslVerifyCA   = "verify-ca" | ||||
| 	sslVerifyFull = "verify-full" | ||||
| ) | ||||
|  | ||||
| // PostgresSSL represents SSL options for Postgres databases. | ||||
| type PostgresSSL struct { | ||||
| 	Mode   string | ||||
| 	CAFile string | ||||
| 	// Files for client auth. | ||||
| 	KeyFile  string | ||||
| 	CertFile string | ||||
| } | ||||
|  | ||||
| // Postgres options for creating an SQL db. | ||||
| type Postgres struct { | ||||
| 	Database string | ||||
| 	User     string | ||||
| 	Password string | ||||
| 	Host     string | ||||
|  | ||||
| 	SSL PostgresSSL `json:"ssl" yaml:"ssl"` | ||||
|  | ||||
| 	ConnectionTimeout int // Seconds | ||||
| } | ||||
|  | ||||
| // Open creates a new storage implementation backed by Postgres. | ||||
| func (p *Postgres) Open() (storage.Storage, error) { | ||||
| 	return p.open() | ||||
| } | ||||
|  | ||||
| func (p *Postgres) open() (*conn, error) { | ||||
| 	v := url.Values{} | ||||
| 	set := func(key, val string) { | ||||
| 		if val != "" { | ||||
| 			v.Set(key, val) | ||||
| 		} | ||||
| 	} | ||||
| 	set("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) | ||||
| 	set("sslkey", p.SSL.KeyFile) | ||||
| 	set("sslcert", p.SSL.CertFile) | ||||
| 	set("sslrootcert", p.SSL.CAFile) | ||||
| 	if p.SSL.Mode == "" { | ||||
| 		// Assume the strictest mode if unspecified. | ||||
| 		p.SSL.Mode = sslVerifyFull | ||||
| 	} | ||||
| 	set("sslmode", p.SSL.Mode) | ||||
|  | ||||
| 	u := url.URL{ | ||||
| 		Scheme:   "postgres", | ||||
| 		Host:     p.Host, | ||||
| 		Path:     "/" + p.Database, | ||||
| 		RawQuery: v.Encode(), | ||||
| 	} | ||||
|  | ||||
| 	if p.User != "" { | ||||
| 		if p.Password != "" { | ||||
| 			u.User = url.UserPassword(p.User, p.Password) | ||||
| 		} else { | ||||
| 			u.User = url.User(p.User) | ||||
| 		} | ||||
| 	} | ||||
| 	db, err := sql.Open("postgres", u.String()) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	c := &conn{db, flavorPostgres} | ||||
| 	if _, err := c.migrate(); err != nil { | ||||
| 		return nil, fmt.Errorf("failed to perform migrations: %v", err) | ||||
| 	} | ||||
| 	return c, nil | ||||
| } | ||||
							
								
								
									
										1
									
								
								storage/sql/config_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								storage/sql/config_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| package sql | ||||
							
								
								
									
										487
									
								
								storage/sql/crud.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										487
									
								
								storage/sql/crud.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,487 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"database/sql/driver" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
|  | ||||
| 	"github.com/coreos/dex/storage" | ||||
| ) | ||||
|  | ||||
| // TODO(ericchiang): The update, insert, and select methods queries are all | ||||
| // very repetivite. Consider creating them programatically. | ||||
|  | ||||
| // keysRowID is the ID of the only row we expect to populate the "keys" table. | ||||
| const keysRowID = "keys" | ||||
|  | ||||
| // encoder wraps the underlying value in a JSON marshaler which is automatically | ||||
| // called by the database/sql package. | ||||
| // | ||||
| //		s := []string{"planes", "bears"} | ||||
| //		err := db.Exec(`insert into t1 (id, things) values (1, $1)`, encoder(s)) | ||||
| //		if err != nil { | ||||
| //			// handle error | ||||
| //		} | ||||
| // | ||||
| //		var r []byte | ||||
| //		err = db.QueryRow(`select things from t1 where id = 1;`).Scan(&r) | ||||
| //		if err != nil { | ||||
| //			// handle error | ||||
| //		} | ||||
| //		fmt.Printf("%s\n", r) // ["planes","bears"] | ||||
| // | ||||
| func encoder(i interface{}) driver.Valuer { | ||||
| 	return jsonEncoder{i} | ||||
| } | ||||
|  | ||||
| // decoder wraps the underlying value in a JSON unmarshaler which can then be passed | ||||
| // to a database Scan() method. | ||||
| func decoder(i interface{}) sql.Scanner { | ||||
| 	return jsonDecoder{i} | ||||
| } | ||||
|  | ||||
| type jsonEncoder struct { | ||||
| 	i interface{} | ||||
| } | ||||
|  | ||||
| func (j jsonEncoder) Value() (driver.Value, error) { | ||||
| 	b, err := json.Marshal(j.i) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("marshal: %v", err) | ||||
| 	} | ||||
| 	return b, nil | ||||
| } | ||||
|  | ||||
| type jsonDecoder struct { | ||||
| 	i interface{} | ||||
| } | ||||
|  | ||||
| func (j jsonDecoder) Scan(dest interface{}) error { | ||||
| 	if dest == nil { | ||||
| 		return errors.New("nil value") | ||||
| 	} | ||||
| 	b, ok := dest.([]byte) | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("expected []byte got %T", dest) | ||||
| 	} | ||||
| 	if err := json.Unmarshal(b, &j.i); err != nil { | ||||
| 		return fmt.Errorf("unmarshal: %v", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // Abstract conn vs trans. | ||||
| type querier interface { | ||||
| 	QueryRow(query string, args ...interface{}) *sql.Row | ||||
| } | ||||
|  | ||||
| // Abstract row vs rows. | ||||
| type scanner interface { | ||||
| 	Scan(dest ...interface{}) error | ||||
| } | ||||
|  | ||||
| func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { | ||||
| 	_, err := c.Exec(` | ||||
| 		insert into auth_request ( | ||||
| 			id, client_id, response_types, scopes, redirect_uri, nonce, state, | ||||
| 			force_approval_prompt, logged_in, | ||||
| 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||
| 			claims_groups, | ||||
| 			connector_id, connector_data, | ||||
| 			expiry | ||||
| 		) | ||||
| 		values ( | ||||
| 			$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 | ||||
| 		); | ||||
| 	`, | ||||
| 		a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, | ||||
| 		a.ForceApprovalPrompt, a.LoggedIn, | ||||
| 		a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, | ||||
| 		encoder(a.Claims.Groups), | ||||
| 		a.ConnectorID, a.ConnectorData, | ||||
| 		a.Expiry, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("insert auth request: %v", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { | ||||
| 	return c.ExecTx(func(tx *trans) error { | ||||
| 		r, err := getAuthRequest(tx, id) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		a, err := updater(r) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		_, err = tx.Exec(` | ||||
| 			update auth_request | ||||
| 			set | ||||
| 				client_id = $1, response_types = $2, scopes = $3, redirect_uri = $4, | ||||
| 				nonce = $5, state = $6, force_approval_prompt = $7, logged_in = $8, | ||||
| 				claims_user_id = $9, claims_username = $10, claims_email = $11, | ||||
| 				claims_email_verified = $12, | ||||
| 				claims_groups = $13, | ||||
| 				connector_id = $14, connector_data = $15, | ||||
| 				expiry = $16 | ||||
| 			where id = $17; | ||||
| 		`, | ||||
| 			a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, | ||||
| 			a.ForceApprovalPrompt, a.LoggedIn, | ||||
| 			a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, | ||||
| 			encoder(a.Claims.Groups), | ||||
| 			a.ConnectorID, a.ConnectorData, | ||||
| 			a.Expiry, a.ID, | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("update auth request: %v", err) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}) | ||||
|  | ||||
| } | ||||
|  | ||||
| func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) { | ||||
| 	return getAuthRequest(c, id) | ||||
| } | ||||
|  | ||||
| func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { | ||||
| 	err = q.QueryRow(` | ||||
| 		select  | ||||
| 			id, client_id, response_types, scopes, redirect_uri, nonce, state, | ||||
| 			force_approval_prompt, logged_in, | ||||
| 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||
| 			claims_groups, | ||||
| 			connector_id, connector_data, expiry | ||||
| 		from auth_request where id = $1; | ||||
| 	`, id).Scan( | ||||
| 		&a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, | ||||
| 		&a.ForceApprovalPrompt, &a.LoggedIn, | ||||
| 		&a.Claims.UserID, &a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, | ||||
| 		decoder(&a.Claims.Groups), | ||||
| 		&a.ConnectorID, &a.ConnectorData, &a.Expiry, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		if err == sql.ErrNoRows { | ||||
| 			return a, storage.ErrNotFound | ||||
| 		} | ||||
| 		return a, fmt.Errorf("select auth request: %v", err) | ||||
| 	} | ||||
| 	return a, nil | ||||
| } | ||||
|  | ||||
| func (c *conn) CreateAuthCode(a storage.AuthCode) error { | ||||
| 	_, err := c.Exec(` | ||||
| 		insert into auth_code ( | ||||
| 			id, client_id, scopes, nonce, redirect_uri, | ||||
| 			claims_user_id, claims_username, | ||||
| 			claims_email, claims_email_verified, claims_groups, | ||||
| 			connector_id, connector_data, | ||||
| 			expiry | ||||
| 		) | ||||
| 		values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13); | ||||
| 	`, | ||||
| 		a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID, | ||||
| 		a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), | ||||
| 		a.ConnectorID, a.ConnectorData, a.Expiry, | ||||
| 	) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { | ||||
| 	err = c.QueryRow(` | ||||
| 		select | ||||
| 			id, client_id, scopes, nonce, redirect_uri, | ||||
| 			claims_user_id, claims_username, | ||||
| 			claims_email, claims_email_verified, claims_groups, | ||||
| 			connector_id, connector_data, | ||||
| 			expiry | ||||
| 		from auth_code where id = $1; | ||||
| 	`, id).Scan( | ||||
| 		&a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID, | ||||
| 		&a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups), | ||||
| 		&a.ConnectorID, &a.ConnectorData, &a.Expiry, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		if err == sql.ErrNoRows { | ||||
| 			return a, storage.ErrNotFound | ||||
| 		} | ||||
| 		return a, fmt.Errorf("select auth code: %v", err) | ||||
| 	} | ||||
| 	return a, nil | ||||
| } | ||||
|  | ||||
| func (c *conn) CreateRefresh(r storage.RefreshToken) error { | ||||
| 	_, err := c.Exec(` | ||||
| 		insert into refresh_token ( | ||||
| 			id, client_id, scopes, nonce, | ||||
| 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||
| 			claims_groups, | ||||
| 			connector_id, connector_data | ||||
| 		) | ||||
| 		values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); | ||||
| 	`, | ||||
| 		r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce, | ||||
| 		r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, | ||||
| 		encoder(r.Claims.Groups), | ||||
| 		r.ConnectorID, r.ConnectorData, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("insert refresh_token: %v", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { | ||||
| 	return scanRefresh(c.QueryRow(` | ||||
| 		select | ||||
| 			id, client_id, scopes, nonce, | ||||
| 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||
| 			claims_groups, | ||||
| 			connector_id, connector_data | ||||
| 		from refresh_token where id = $1; | ||||
| 	`, id)) | ||||
| } | ||||
|  | ||||
| func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { | ||||
| 	rows, err := c.Query(` | ||||
| 		select | ||||
| 			id, client_id, scopes, nonce, | ||||
| 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||
| 			claims_groups, | ||||
| 			connector_id, connector_data | ||||
| 		from refresh_token; | ||||
| 	`) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("query: %v", err) | ||||
| 	} | ||||
| 	var tokens []storage.RefreshToken | ||||
| 	for rows.Next() { | ||||
| 		r, err := scanRefresh(rows) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		tokens = append(tokens, r) | ||||
| 	} | ||||
| 	if err := rows.Err(); err != nil { | ||||
| 		return nil, fmt.Errorf("scan: %v", err) | ||||
| 	} | ||||
| 	return tokens, nil | ||||
| } | ||||
|  | ||||
| func scanRefresh(s scanner) (r storage.RefreshToken, err error) { | ||||
| 	err = s.Scan( | ||||
| 		&r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce, | ||||
| 		&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified, | ||||
| 		decoder(&r.Claims.Groups), | ||||
| 		&r.ConnectorID, &r.ConnectorData, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		if err == sql.ErrNoRows { | ||||
| 			return r, storage.ErrNotFound | ||||
| 		} | ||||
| 		return r, fmt.Errorf("scan refresh_token: %v", err) | ||||
| 	} | ||||
| 	return r, nil | ||||
| } | ||||
|  | ||||
| func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { | ||||
| 	return c.ExecTx(func(tx *trans) error { | ||||
| 		firstUpdate := false | ||||
| 		// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL | ||||
| 		// server. Test this, and consider adding a COUNT() command beforehand. | ||||
| 		old, err := getKeys(tx) | ||||
| 		if err != nil { | ||||
| 			if err != storage.ErrNotFound { | ||||
| 				return fmt.Errorf("get keys: %v", err) | ||||
| 			} | ||||
| 			firstUpdate = true | ||||
| 			old = storage.Keys{} | ||||
| 		} | ||||
|  | ||||
| 		nk, err := updater(old) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		if firstUpdate { | ||||
| 			_, err = tx.Exec(` | ||||
| 				insert into keys ( | ||||
| 					id, verification_keys, signing_key, signing_key_pub, next_rotation | ||||
| 				) | ||||
| 				values ($1, $2, $3, $4, $5); | ||||
| 			`, | ||||
| 				keysRowID, encoder(nk.VerificationKeys), encoder(nk.SigningKey), | ||||
| 				encoder(nk.SigningKeyPub), nk.NextRotation, | ||||
| 			) | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf("insert: %v", err) | ||||
| 			} | ||||
| 		} else { | ||||
| 			_, err = tx.Exec(` | ||||
| 				update keys | ||||
| 				set  | ||||
| 				    verification_keys = $1, | ||||
| 					signing_key = $2, | ||||
| 					singing_key_pub = $3, | ||||
| 					next_rotation = $4 | ||||
| 				where id = $5; | ||||
| 			`, | ||||
| 				encoder(nk.VerificationKeys), encoder(nk.SigningKey), | ||||
| 				encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID, | ||||
| 			) | ||||
| 			if err != nil { | ||||
| 				return fmt.Errorf("update: %v", err) | ||||
| 			} | ||||
| 		} | ||||
| 		return nil | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (c *conn) GetKeys() (keys storage.Keys, err error) { | ||||
| 	return getKeys(c) | ||||
| } | ||||
|  | ||||
| func getKeys(q querier) (keys storage.Keys, err error) { | ||||
| 	err = q.QueryRow(` | ||||
| 		select | ||||
| 			verification_keys, signing_key, signing_key_pub, next_rotation | ||||
| 		from keys | ||||
| 		where id=$q | ||||
| 	`, keysRowID).Scan( | ||||
| 		decoder(&keys.VerificationKeys), decoder(&keys.SigningKey), | ||||
| 		decoder(&keys.SigningKeyPub), &keys.NextRotation, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		if err == sql.ErrNoRows { | ||||
| 			return keys, storage.ErrNotFound | ||||
| 		} | ||||
| 		return keys, fmt.Errorf("query keys: %v", err) | ||||
| 	} | ||||
| 	return keys, nil | ||||
| } | ||||
|  | ||||
| func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { | ||||
| 	return c.ExecTx(func(tx *trans) error { | ||||
| 		cli, err := getClient(tx, id) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		nc, err := updater(cli) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		_, err = tx.Exec(` | ||||
| 			update client | ||||
| 			set | ||||
| 				secret = $1, | ||||
| 				redirect_uris = $2, | ||||
| 				trusted_peers = $3, | ||||
| 				public = $4, | ||||
| 				name = $5, | ||||
| 				logo_url = $6 | ||||
| 			where id = $7; | ||||
| 		`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id, | ||||
| 		) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("update client: %v", err) | ||||
| 		} | ||||
| 		return nil | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (c *conn) CreateClient(cli storage.Client) error { | ||||
| 	_, err := c.Exec(` | ||||
| 		insert into client ( | ||||
| 			id, secret, redirect_uris, trusted_peers, public, name, logo_url | ||||
| 		) | ||||
| 		values ($1, $2, $3, $4, $5, $6, $7); | ||||
| 	`, | ||||
| 		cli.ID, cli.Secret, encoder(cli.RedirectURIs), encoder(cli.TrustedPeers), | ||||
| 		cli.Public, cli.Name, cli.LogoURL, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("insert client: %v", err) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func getClient(q querier, id string) (storage.Client, error) { | ||||
| 	return scanClient(q.QueryRow(` | ||||
| 		select | ||||
| 			id, secret, redirect_uris, trusted_peers, public, name, logo_url | ||||
| 	    from client where id = $1; | ||||
| 	`, id)) | ||||
| } | ||||
|  | ||||
| func (c *conn) GetClient(id string) (storage.Client, error) { | ||||
| 	return getClient(c, id) | ||||
| } | ||||
|  | ||||
| func (c *conn) ListClients() ([]storage.Client, error) { | ||||
| 	rows, err := c.Query(` | ||||
| 		select | ||||
| 			id, secret, redirect_uris, trusted_peers, public, name, logo_url | ||||
| 		from client; | ||||
| 	`) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	var clients []storage.Client | ||||
| 	for rows.Next() { | ||||
| 		cli, err := scanClient(rows) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		clients = append(clients, cli) | ||||
| 	} | ||||
| 	if err := rows.Err(); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return clients, nil | ||||
| } | ||||
|  | ||||
| func scanClient(s scanner) (cli storage.Client, err error) { | ||||
| 	err = s.Scan( | ||||
| 		&cli.ID, &cli.Secret, decoder(&cli.RedirectURIs), decoder(&cli.TrustedPeers), | ||||
| 		&cli.Public, &cli.Name, &cli.LogoURL, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		if err == sql.ErrNoRows { | ||||
| 			return cli, storage.ErrNotFound | ||||
| 		} | ||||
| 		return cli, fmt.Errorf("get client: %v", err) | ||||
| 	} | ||||
| 	return cli, nil | ||||
| } | ||||
|  | ||||
| func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", id) } | ||||
| func (c *conn) DeleteAuthCode(id string) error    { return c.delete("auth_code", id) } | ||||
| func (c *conn) DeleteClient(id string) error      { return c.delete("client", id) } | ||||
| func (c *conn) DeleteRefresh(id string) error     { return c.delete("refresh_token", id) } | ||||
|  | ||||
| // Do NOT call directly. Does not escape table. | ||||
| func (c *conn) delete(table, id string) error { | ||||
| 	result, err := c.Exec(`delete from `+table+` where id = $1`, id) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("delete %s: %v", table, id) | ||||
| 	} | ||||
|  | ||||
| 	// For now mandate that the driver implements RowsAffected. If we ever need to support | ||||
| 	// a driver that doesn't implement this, we can run this in a transaction with a get beforehand. | ||||
| 	n, err := result.RowsAffected() | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("rows affected: %v", err) | ||||
| 	} | ||||
| 	if n < 1 { | ||||
| 		return storage.ErrNotFound | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										55
									
								
								storage/sql/crud_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								storage/sql/crud_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestDecoder(t *testing.T) { | ||||
| 	db, err := sql.Open("sqlite3", ":memory:") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	if _, err := db.Exec(`create table foo ( id integer primary key, bar blob );`); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if _, err := db.Exec(`insert into foo ( id, bar ) values (1, ?);`, []byte(`["a", "b"]`)); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	var got []string | ||||
| 	if err := db.QueryRow(`select bar from foo where id = 1;`).Scan(decoder(&got)); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	want := []string{"a", "b"} | ||||
| 	if !reflect.DeepEqual(got, want) { | ||||
| 		t.Errorf("wanted %q got %q", want, got) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestEncoder(t *testing.T) { | ||||
| 	db, err := sql.Open("sqlite3", ":memory:") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	if _, err := db.Exec(`create table foo ( id integer primary key, bar blob );`); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	put := []string{"a", "b"} | ||||
| 	if _, err := db.Exec(`insert into foo ( id, bar ) values (1, ?)`, encoder(put)); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	var got []byte | ||||
| 	if err := db.QueryRow(`select bar from foo where id = 1;`).Scan(&got); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	want := []byte(`["a","b"]`) | ||||
| 	if !reflect.DeepEqual(got, want) { | ||||
| 		t.Errorf("wanted %q got %q", want, got) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										24
									
								
								storage/sql/gc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								storage/sql/gc.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type gc struct { | ||||
| 	now  func() time.Time | ||||
| 	conn *conn | ||||
| } | ||||
|  | ||||
| var tablesWithGC = []string{"auth_request", "auth_code"} | ||||
|  | ||||
| func (gc gc) run() error { | ||||
| 	for _, table := range tablesWithGC { | ||||
| 		_, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now()) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("gc %s: %v", table, err) | ||||
| 		} | ||||
| 		// TODO(ericchiang): when we have levelled logging print how many rows were gc'd | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										53
									
								
								storage/sql/gc_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								storage/sql/gc_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/coreos/dex/storage" | ||||
| ) | ||||
|  | ||||
| func TestGC(t *testing.T) { | ||||
| 	// TODO(ericchiang): Add a GarbageCollect method to the storage interface so | ||||
| 	// we can write conformance tests instead of directly testing each implementation. | ||||
| 	s := &SQLite3{":memory:"} | ||||
| 	conn, err := s.open() | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer conn.Close() | ||||
|  | ||||
| 	clock := time.Now() | ||||
| 	now := func() time.Time { return clock } | ||||
|  | ||||
| 	runGC := (gc{now, conn}).run | ||||
|  | ||||
| 	a := storage.AuthRequest{ | ||||
| 		ID:     storage.NewID(), | ||||
| 		Expiry: now().Add(time.Second), | ||||
| 	} | ||||
|  | ||||
| 	if err := conn.CreateAuthRequest(a); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if err := runGC(); err != nil { | ||||
| 		t.Errorf("gc failed: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if _, err := conn.GetAuthRequest(a.ID); err != nil { | ||||
| 		t.Errorf("failed to get auth request after gc: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	clock = clock.Add(time.Minute) | ||||
|  | ||||
| 	if err := runGC(); err != nil { | ||||
| 		t.Errorf("gc failed: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	if _, err := conn.GetAuthRequest(a.ID); err == nil { | ||||
| 		t.Errorf("expected error after gc'ing auth request: %v", err) | ||||
| 	} else if err != storage.ErrNotFound { | ||||
| 		t.Errorf("expected error storage.NotFound got: %v", err) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										151
									
								
								storage/sql/migrate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										151
									
								
								storage/sql/migrate.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,151 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| ) | ||||
|  | ||||
| func (c *conn) migrate() (int, error) { | ||||
| 	_, err := c.Exec(` | ||||
| 		create table if not exists migrations ( | ||||
| 			num integer not null, | ||||
| 			at timestamp not null | ||||
| 		); | ||||
| 	`) | ||||
| 	if err != nil { | ||||
| 		return 0, fmt.Errorf("creating migration table: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	i := 0 | ||||
| 	done := false | ||||
| 	for { | ||||
| 		err := c.ExecTx(func(tx *trans) error { | ||||
| 			// Within a transaction, perform a single migration. | ||||
| 			var ( | ||||
| 				num sql.NullInt64 | ||||
| 				n   int | ||||
| 			) | ||||
| 			if err := tx.QueryRow(`select max(num) from migrations;`).Scan(&num); err != nil { | ||||
| 				return fmt.Errorf("select max migration: %v", err) | ||||
| 			} | ||||
| 			if num.Valid { | ||||
| 				n = int(num.Int64) | ||||
| 			} | ||||
| 			if n >= len(migrations) { | ||||
| 				done = true | ||||
| 				return nil | ||||
| 			} | ||||
|  | ||||
| 			migrationNum := n + 1 | ||||
| 			m := migrations[n] | ||||
| 			if _, err := tx.Exec(m.stmt); err != nil { | ||||
| 				return fmt.Errorf("migration %d failed: %v", migrationNum, err) | ||||
| 			} | ||||
|  | ||||
| 			q := `insert into migrations (num, at) values ($1, now());` | ||||
| 			if _, err := tx.Exec(q, migrationNum); err != nil { | ||||
| 				return fmt.Errorf("update migration table: %v", err) | ||||
| 			} | ||||
| 			return nil | ||||
| 		}) | ||||
| 		if err != nil { | ||||
| 			return i, err | ||||
| 		} | ||||
| 		if done { | ||||
| 			break | ||||
| 		} | ||||
| 		i++ | ||||
| 	} | ||||
|  | ||||
| 	return i, nil | ||||
| } | ||||
|  | ||||
| type migration struct { | ||||
| 	stmt string | ||||
| 	// TODO(ericchiang): consider adding additional fields like "forDrivers" | ||||
| } | ||||
|  | ||||
| // All SQL flavors share migration strategies. | ||||
| var migrations = []migration{ | ||||
| 	{ | ||||
| 		stmt: ` | ||||
| 			create table client ( | ||||
| 				id text not null primary key, | ||||
| 				secret text not null, | ||||
| 				redirect_uris bytea not null, -- JSON array of strings | ||||
| 				trusted_peers bytea not null, -- JSON array of strings | ||||
| 				public boolean not null, | ||||
| 				name text not null, | ||||
| 				logo_url text not null | ||||
| 			); | ||||
| 		 | ||||
| 			create table auth_request ( | ||||
| 				id text not null primary key, | ||||
| 				client_id text not null, | ||||
| 				response_types bytea not null, -- JSON array of strings | ||||
| 				scopes bytea not null,         -- JSON array of strings | ||||
| 				redirect_uri text not null, | ||||
| 				nonce text not null, | ||||
| 				state text not null, | ||||
| 				force_approval_prompt boolean not null, | ||||
| 		 | ||||
| 				logged_in boolean not null, | ||||
| 		 | ||||
| 				claims_user_id text not null, | ||||
| 				claims_username text not null, | ||||
| 				claims_email text not null, | ||||
| 				claims_email_verified boolean not null, | ||||
| 				claims_groups bytea not null, -- JSON array of strings | ||||
| 		 | ||||
| 				connector_id text not null, | ||||
| 				connector_data bytea, | ||||
| 		 | ||||
| 				expiry timestamp not null | ||||
| 			); | ||||
| 		 | ||||
| 			create table auth_code ( | ||||
| 				id text not null primary key, | ||||
| 				client_id text not null, | ||||
| 				scopes bytea not null, -- JSON array of strings | ||||
| 				nonce text not null, | ||||
| 				redirect_uri text not null, | ||||
| 		 | ||||
| 				claims_user_id text not null, | ||||
| 				claims_username text not null, | ||||
| 				claims_email text not null, | ||||
| 				claims_email_verified boolean not null, | ||||
| 				claims_groups bytea not null, -- JSON array of strings | ||||
| 		 | ||||
| 				connector_id text not null, | ||||
| 				connector_data bytea, | ||||
| 		 | ||||
| 				expiry timestamp not null | ||||
| 			); | ||||
| 		 | ||||
| 			create table refresh_token ( | ||||
| 				id text not null primary key, | ||||
| 				client_id text not null, | ||||
| 				scopes bytea not null, -- JSON array of strings | ||||
| 				nonce text not null, | ||||
| 		 | ||||
| 				claims_user_id text not null, | ||||
| 				claims_username text not null, | ||||
| 				claims_email text not null, | ||||
| 				claims_email_verified boolean not null, | ||||
| 				claims_groups bytea not null, -- JSON array of strings | ||||
| 		 | ||||
| 				connector_id text not null, | ||||
| 				connector_data bytea | ||||
| 			); | ||||
| 		 | ||||
| 			-- keys is a weird table because we only ever expect there to be a single row | ||||
| 			create table keys ( | ||||
| 				id text not null primary key, | ||||
| 				verification_keys bytea not null, -- JSON array | ||||
| 				signing_key bytea not null,       -- JSON object | ||||
| 				signing_key_pub bytea not null,   -- JSON object | ||||
| 				next_rotation timestamp not null | ||||
| 			); | ||||
| 		`, | ||||
| 	}, | ||||
| } | ||||
							
								
								
									
										25
									
								
								storage/sql/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								storage/sql/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestMigrate(t *testing.T) { | ||||
| 	db, err := sql.Open("sqlite3", ":memory:") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	c := &conn{db, flavorSQLite3} | ||||
| 	for _, want := range []int{len(migrations), 0} { | ||||
| 		got, err := c.migrate() | ||||
| 		if err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 		if got != want { | ||||
| 			t.Errorf("expected %d migrations, got %d", want, got) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										152
									
								
								storage/sql/sql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								storage/sql/sql.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,152 @@ | ||||
| // Package sql provides SQL implementations of the storage interface. | ||||
| package sql | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"regexp" | ||||
|  | ||||
| 	"github.com/cockroachdb/cockroach-go/crdb" | ||||
|  | ||||
| 	// import third party drivers | ||||
| 	_ "github.com/go-sql-driver/mysql" | ||||
| 	_ "github.com/lib/pq" | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| ) | ||||
|  | ||||
| // flavor represents a specific SQL implementation, and is used to translate query strings | ||||
| // between different drivers. Flavors shouldn't aim to translate all possible SQL statements, | ||||
| // only the specific queries used by the SQL storages. | ||||
| type flavor struct { | ||||
| 	queryReplacers []replacer | ||||
|  | ||||
| 	// Optional function to create and finish a transaction. This is mainly for | ||||
| 	// cockroachdb support which requires special retry logic provided by their | ||||
| 	// client package. | ||||
| 	// | ||||
| 	// This will be nil for most flavors. | ||||
| 	// | ||||
| 	// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44 | ||||
| 	executeTx func(db *sql.DB, fn func(*sql.Tx) error) error | ||||
| } | ||||
|  | ||||
| // A regexp with a replacement string. | ||||
| type replacer struct { | ||||
| 	re   *regexp.Regexp | ||||
| 	with string | ||||
| } | ||||
|  | ||||
| // Match a postgres query binds. E.g. "$1", "$12", etc. | ||||
| var bindRegexp = regexp.MustCompile(`\$\d+`) | ||||
|  | ||||
| func matchLiteral(s string) *regexp.Regexp { | ||||
| 	return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`) | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	// The "github.com/lib/pq" driver is the default flavor. All others are | ||||
| 	// translations of this. | ||||
| 	flavorPostgres = flavor{} | ||||
|  | ||||
| 	flavorSQLite3 = flavor{ | ||||
| 		queryReplacers: []replacer{ | ||||
| 			{bindRegexp, "?"}, | ||||
| 			// Translate for booleans to integers. | ||||
| 			{matchLiteral("true"), "1"}, | ||||
| 			{matchLiteral("false"), "0"}, | ||||
| 			{matchLiteral("boolean"), "integer"}, | ||||
| 			// Translate other types. | ||||
| 			{matchLiteral("bytea"), "blob"}, | ||||
| 			// {matchLiteral("timestamp"), "integer"}, | ||||
| 			// SQLite doesn't have a "now()" method, replace with "date('now')" | ||||
| 			{regexp.MustCompile(`\bnow\(\)`), "date('now')"}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// Incomplete. | ||||
| 	flavorMySQL = flavor{ | ||||
| 		queryReplacers: []replacer{ | ||||
| 			{bindRegexp, "?"}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// Not tested. | ||||
| 	flavorCockroach = flavor{ | ||||
| 		executeTx: crdb.ExecuteTx, | ||||
| 	} | ||||
| ) | ||||
|  | ||||
| func (f flavor) translate(query string) string { | ||||
| 	// TODO(ericchiang): Heavy cashing. | ||||
| 	for _, r := range f.queryReplacers { | ||||
| 		query = r.re.ReplaceAllString(query, r.with) | ||||
| 	} | ||||
| 	return query | ||||
| } | ||||
|  | ||||
| // conn is the main database connection. | ||||
| type conn struct { | ||||
| 	db     *sql.DB | ||||
| 	flavor flavor | ||||
| } | ||||
|  | ||||
| func (c *conn) Close() error { | ||||
| 	return c.db.Close() | ||||
| } | ||||
|  | ||||
| // conn implements the same method signatures as encoding/sql.DB. | ||||
|  | ||||
| func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||
| 	query = c.flavor.translate(query) | ||||
| 	return c.db.Exec(query, args...) | ||||
| } | ||||
|  | ||||
| func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	query = c.flavor.translate(query) | ||||
| 	return c.db.Query(query, args...) | ||||
| } | ||||
|  | ||||
| func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row { | ||||
| 	query = c.flavor.translate(query) | ||||
| 	return c.db.QueryRow(query, args...) | ||||
| } | ||||
|  | ||||
| // ExecTx runs a method which operates on a transaction. | ||||
| func (c *conn) ExecTx(fn func(tx *trans) error) error { | ||||
| 	if c.flavor.executeTx != nil { | ||||
| 		return c.flavor.executeTx(c.db, func(sqlTx *sql.Tx) error { | ||||
| 			return fn(&trans{sqlTx, c}) | ||||
| 		}) | ||||
| 	} | ||||
|  | ||||
| 	sqlTx, err := c.db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if err := fn(&trans{sqlTx, c}); err != nil { | ||||
| 		sqlTx.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
| 	return sqlTx.Commit() | ||||
| } | ||||
|  | ||||
| type trans struct { | ||||
| 	tx *sql.Tx | ||||
| 	c  *conn | ||||
| } | ||||
|  | ||||
| // trans implements the same method signatures as encoding/sql.Tx. | ||||
|  | ||||
| func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||
| 	query = t.c.flavor.translate(query) | ||||
| 	return t.tx.Exec(query, args...) | ||||
| } | ||||
|  | ||||
| func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) { | ||||
| 	query = t.c.flavor.translate(query) | ||||
| 	return t.tx.Query(query, args...) | ||||
| } | ||||
|  | ||||
| func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row { | ||||
| 	query = t.c.flavor.translate(query) | ||||
| 	return t.tx.QueryRow(query, args...) | ||||
| } | ||||
							
								
								
									
										55
									
								
								storage/sql/sql_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								storage/sql/sql_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| package sql | ||||
|  | ||||
| import "testing" | ||||
|  | ||||
| func TestTranslate(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		testCase string | ||||
| 		flavor   flavor | ||||
| 		query    string | ||||
| 		exp      string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			"sqlite3 query bind replacement", | ||||
| 			flavorSQLite3, | ||||
| 			`select foo from bar where foo.zam = $1;`, | ||||
| 			`select foo from bar where foo.zam = ?;`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"sqlite3 query bind replacement at newline", | ||||
| 			flavorSQLite3, | ||||
| 			`select foo from bar where foo.zam = $1`, | ||||
| 			`select foo from bar where foo.zam = ?`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"sqlite3 query true", | ||||
| 			flavorSQLite3, | ||||
| 			`select foo from bar where foo.zam = true`, | ||||
| 			`select foo from bar where foo.zam = 1`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"sqlite3 query false", | ||||
| 			flavorSQLite3, | ||||
| 			`select foo from bar where foo.zam = false`, | ||||
| 			`select foo from bar where foo.zam = 0`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"sqlite3 bytea", | ||||
| 			flavorSQLite3, | ||||
| 			`"connector_data" bytea not null,`, | ||||
| 			`"connector_data" blob not null,`, | ||||
| 		}, | ||||
| 		{ | ||||
| 			"sqlite3 now", | ||||
| 			flavorSQLite3, | ||||
| 			`now(),`, | ||||
| 			`date('now'),`, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range tests { | ||||
| 		if got := tc.flavor.translate(tc.query); got != tc.exp { | ||||
| 			t.Errorf("%s: want=%q, got=%q", tc.testCase, tc.exp, got) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user