storage/sql: add a SQL storage implementation
This change adds support for SQLite3, and Postgres.
This commit is contained in:
		
							
								
								
									
										113
									
								
								storage/sql/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								storage/sql/config.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/url" | ||||||
|  | 	"strconv" | ||||||
|  |  | ||||||
|  | 	"github.com/coreos/dex/storage" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // SQLite3 options for creating an SQL db. | ||||||
|  | type SQLite3 struct { | ||||||
|  | 	// File to | ||||||
|  | 	File string `yaml:"file"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Open creates a new storage implementation backed by SQLite3 | ||||||
|  | func (s *SQLite3) Open() (storage.Storage, error) { | ||||||
|  | 	return s.open() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (s *SQLite3) open() (*conn, error) { | ||||||
|  | 	db, err := sql.Open("sqlite3", s.File) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if s.File == ":memory:" { | ||||||
|  | 		// sqlite3 uses file locks to coordinate concurrent access. In memory | ||||||
|  | 		// doesn't support this, so limit the number of connections to 1. | ||||||
|  | 		db.SetMaxOpenConns(1) | ||||||
|  | 	} | ||||||
|  | 	c := &conn{db, flavorSQLite3} | ||||||
|  | 	if _, err := c.migrate(); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to perform migrations: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return c, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	sslDisable    = "disable" | ||||||
|  | 	sslRequire    = "require" | ||||||
|  | 	sslVerifyCA   = "verify-ca" | ||||||
|  | 	sslVerifyFull = "verify-full" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // PostgresSSL represents SSL options for Postgres databases. | ||||||
|  | type PostgresSSL struct { | ||||||
|  | 	Mode   string | ||||||
|  | 	CAFile string | ||||||
|  | 	// Files for client auth. | ||||||
|  | 	KeyFile  string | ||||||
|  | 	CertFile string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Postgres options for creating an SQL db. | ||||||
|  | type Postgres struct { | ||||||
|  | 	Database string | ||||||
|  | 	User     string | ||||||
|  | 	Password string | ||||||
|  | 	Host     string | ||||||
|  |  | ||||||
|  | 	SSL PostgresSSL `json:"ssl" yaml:"ssl"` | ||||||
|  |  | ||||||
|  | 	ConnectionTimeout int // Seconds | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Open creates a new storage implementation backed by Postgres. | ||||||
|  | func (p *Postgres) Open() (storage.Storage, error) { | ||||||
|  | 	return p.open() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *Postgres) open() (*conn, error) { | ||||||
|  | 	v := url.Values{} | ||||||
|  | 	set := func(key, val string) { | ||||||
|  | 		if val != "" { | ||||||
|  | 			v.Set(key, val) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	set("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) | ||||||
|  | 	set("sslkey", p.SSL.KeyFile) | ||||||
|  | 	set("sslcert", p.SSL.CertFile) | ||||||
|  | 	set("sslrootcert", p.SSL.CAFile) | ||||||
|  | 	if p.SSL.Mode == "" { | ||||||
|  | 		// Assume the strictest mode if unspecified. | ||||||
|  | 		p.SSL.Mode = sslVerifyFull | ||||||
|  | 	} | ||||||
|  | 	set("sslmode", p.SSL.Mode) | ||||||
|  |  | ||||||
|  | 	u := url.URL{ | ||||||
|  | 		Scheme:   "postgres", | ||||||
|  | 		Host:     p.Host, | ||||||
|  | 		Path:     "/" + p.Database, | ||||||
|  | 		RawQuery: v.Encode(), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if p.User != "" { | ||||||
|  | 		if p.Password != "" { | ||||||
|  | 			u.User = url.UserPassword(p.User, p.Password) | ||||||
|  | 		} else { | ||||||
|  | 			u.User = url.User(p.User) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	db, err := sql.Open("postgres", u.String()) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	c := &conn{db, flavorPostgres} | ||||||
|  | 	if _, err := c.migrate(); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to perform migrations: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return c, nil | ||||||
|  | } | ||||||
							
								
								
									
										1
									
								
								storage/sql/config_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								storage/sql/config_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | package sql | ||||||
							
								
								
									
										487
									
								
								storage/sql/crud.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										487
									
								
								storage/sql/crud.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,487 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"database/sql/driver" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  |  | ||||||
|  | 	"github.com/coreos/dex/storage" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // TODO(ericchiang): The update, insert, and select methods queries are all | ||||||
|  | // very repetivite. Consider creating them programatically. | ||||||
|  |  | ||||||
|  | // keysRowID is the ID of the only row we expect to populate the "keys" table. | ||||||
|  | const keysRowID = "keys" | ||||||
|  |  | ||||||
|  | // encoder wraps the underlying value in a JSON marshaler which is automatically | ||||||
|  | // called by the database/sql package. | ||||||
|  | // | ||||||
|  | //		s := []string{"planes", "bears"} | ||||||
|  | //		err := db.Exec(`insert into t1 (id, things) values (1, $1)`, encoder(s)) | ||||||
|  | //		if err != nil { | ||||||
|  | //			// handle error | ||||||
|  | //		} | ||||||
|  | // | ||||||
|  | //		var r []byte | ||||||
|  | //		err = db.QueryRow(`select things from t1 where id = 1;`).Scan(&r) | ||||||
|  | //		if err != nil { | ||||||
|  | //			// handle error | ||||||
|  | //		} | ||||||
|  | //		fmt.Printf("%s\n", r) // ["planes","bears"] | ||||||
|  | // | ||||||
|  | func encoder(i interface{}) driver.Valuer { | ||||||
|  | 	return jsonEncoder{i} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // decoder wraps the underlying value in a JSON unmarshaler which can then be passed | ||||||
|  | // to a database Scan() method. | ||||||
|  | func decoder(i interface{}) sql.Scanner { | ||||||
|  | 	return jsonDecoder{i} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type jsonEncoder struct { | ||||||
|  | 	i interface{} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (j jsonEncoder) Value() (driver.Value, error) { | ||||||
|  | 	b, err := json.Marshal(j.i) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("marshal: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return b, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type jsonDecoder struct { | ||||||
|  | 	i interface{} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (j jsonDecoder) Scan(dest interface{}) error { | ||||||
|  | 	if dest == nil { | ||||||
|  | 		return errors.New("nil value") | ||||||
|  | 	} | ||||||
|  | 	b, ok := dest.([]byte) | ||||||
|  | 	if !ok { | ||||||
|  | 		return fmt.Errorf("expected []byte got %T", dest) | ||||||
|  | 	} | ||||||
|  | 	if err := json.Unmarshal(b, &j.i); err != nil { | ||||||
|  | 		return fmt.Errorf("unmarshal: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Abstract conn vs trans. | ||||||
|  | type querier interface { | ||||||
|  | 	QueryRow(query string, args ...interface{}) *sql.Row | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Abstract row vs rows. | ||||||
|  | type scanner interface { | ||||||
|  | 	Scan(dest ...interface{}) error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { | ||||||
|  | 	_, err := c.Exec(` | ||||||
|  | 		insert into auth_request ( | ||||||
|  | 			id, client_id, response_types, scopes, redirect_uri, nonce, state, | ||||||
|  | 			force_approval_prompt, logged_in, | ||||||
|  | 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||||
|  | 			claims_groups, | ||||||
|  | 			connector_id, connector_data, | ||||||
|  | 			expiry | ||||||
|  | 		) | ||||||
|  | 		values ( | ||||||
|  | 			$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17 | ||||||
|  | 		); | ||||||
|  | 	`, | ||||||
|  | 		a.ID, a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, | ||||||
|  | 		a.ForceApprovalPrompt, a.LoggedIn, | ||||||
|  | 		a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, | ||||||
|  | 		encoder(a.Claims.Groups), | ||||||
|  | 		a.ConnectorID, a.ConnectorData, | ||||||
|  | 		a.Expiry, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("insert auth request: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { | ||||||
|  | 	return c.ExecTx(func(tx *trans) error { | ||||||
|  | 		r, err := getAuthRequest(tx, id) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		a, err := updater(r) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		_, err = tx.Exec(` | ||||||
|  | 			update auth_request | ||||||
|  | 			set | ||||||
|  | 				client_id = $1, response_types = $2, scopes = $3, redirect_uri = $4, | ||||||
|  | 				nonce = $5, state = $6, force_approval_prompt = $7, logged_in = $8, | ||||||
|  | 				claims_user_id = $9, claims_username = $10, claims_email = $11, | ||||||
|  | 				claims_email_verified = $12, | ||||||
|  | 				claims_groups = $13, | ||||||
|  | 				connector_id = $14, connector_data = $15, | ||||||
|  | 				expiry = $16 | ||||||
|  | 			where id = $17; | ||||||
|  | 		`, | ||||||
|  | 			a.ClientID, encoder(a.ResponseTypes), encoder(a.Scopes), a.RedirectURI, a.Nonce, a.State, | ||||||
|  | 			a.ForceApprovalPrompt, a.LoggedIn, | ||||||
|  | 			a.Claims.UserID, a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, | ||||||
|  | 			encoder(a.Claims.Groups), | ||||||
|  | 			a.ConnectorID, a.ConnectorData, | ||||||
|  | 			a.Expiry, a.ID, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("update auth request: %v", err) | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) { | ||||||
|  | 	return getAuthRequest(c, id) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { | ||||||
|  | 	err = q.QueryRow(` | ||||||
|  | 		select  | ||||||
|  | 			id, client_id, response_types, scopes, redirect_uri, nonce, state, | ||||||
|  | 			force_approval_prompt, logged_in, | ||||||
|  | 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||||
|  | 			claims_groups, | ||||||
|  | 			connector_id, connector_data, expiry | ||||||
|  | 		from auth_request where id = $1; | ||||||
|  | 	`, id).Scan( | ||||||
|  | 		&a.ID, &a.ClientID, decoder(&a.ResponseTypes), decoder(&a.Scopes), &a.RedirectURI, &a.Nonce, &a.State, | ||||||
|  | 		&a.ForceApprovalPrompt, &a.LoggedIn, | ||||||
|  | 		&a.Claims.UserID, &a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, | ||||||
|  | 		decoder(&a.Claims.Groups), | ||||||
|  | 		&a.ConnectorID, &a.ConnectorData, &a.Expiry, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if err == sql.ErrNoRows { | ||||||
|  | 			return a, storage.ErrNotFound | ||||||
|  | 		} | ||||||
|  | 		return a, fmt.Errorf("select auth request: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return a, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) CreateAuthCode(a storage.AuthCode) error { | ||||||
|  | 	_, err := c.Exec(` | ||||||
|  | 		insert into auth_code ( | ||||||
|  | 			id, client_id, scopes, nonce, redirect_uri, | ||||||
|  | 			claims_user_id, claims_username, | ||||||
|  | 			claims_email, claims_email_verified, claims_groups, | ||||||
|  | 			connector_id, connector_data, | ||||||
|  | 			expiry | ||||||
|  | 		) | ||||||
|  | 		values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13); | ||||||
|  | 	`, | ||||||
|  | 		a.ID, a.ClientID, encoder(a.Scopes), a.Nonce, a.RedirectURI, a.Claims.UserID, | ||||||
|  | 		a.Claims.Username, a.Claims.Email, a.Claims.EmailVerified, encoder(a.Claims.Groups), | ||||||
|  | 		a.ConnectorID, a.ConnectorData, a.Expiry, | ||||||
|  | 	) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { | ||||||
|  | 	err = c.QueryRow(` | ||||||
|  | 		select | ||||||
|  | 			id, client_id, scopes, nonce, redirect_uri, | ||||||
|  | 			claims_user_id, claims_username, | ||||||
|  | 			claims_email, claims_email_verified, claims_groups, | ||||||
|  | 			connector_id, connector_data, | ||||||
|  | 			expiry | ||||||
|  | 		from auth_code where id = $1; | ||||||
|  | 	`, id).Scan( | ||||||
|  | 		&a.ID, &a.ClientID, decoder(&a.Scopes), &a.Nonce, &a.RedirectURI, &a.Claims.UserID, | ||||||
|  | 		&a.Claims.Username, &a.Claims.Email, &a.Claims.EmailVerified, decoder(&a.Claims.Groups), | ||||||
|  | 		&a.ConnectorID, &a.ConnectorData, &a.Expiry, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if err == sql.ErrNoRows { | ||||||
|  | 			return a, storage.ErrNotFound | ||||||
|  | 		} | ||||||
|  | 		return a, fmt.Errorf("select auth code: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return a, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) CreateRefresh(r storage.RefreshToken) error { | ||||||
|  | 	_, err := c.Exec(` | ||||||
|  | 		insert into refresh_token ( | ||||||
|  | 			id, client_id, scopes, nonce, | ||||||
|  | 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||||
|  | 			claims_groups, | ||||||
|  | 			connector_id, connector_data | ||||||
|  | 		) | ||||||
|  | 		values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11); | ||||||
|  | 	`, | ||||||
|  | 		r.RefreshToken, r.ClientID, encoder(r.Scopes), r.Nonce, | ||||||
|  | 		r.Claims.UserID, r.Claims.Username, r.Claims.Email, r.Claims.EmailVerified, | ||||||
|  | 		encoder(r.Claims.Groups), | ||||||
|  | 		r.ConnectorID, r.ConnectorData, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("insert refresh_token: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { | ||||||
|  | 	return scanRefresh(c.QueryRow(` | ||||||
|  | 		select | ||||||
|  | 			id, client_id, scopes, nonce, | ||||||
|  | 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||||
|  | 			claims_groups, | ||||||
|  | 			connector_id, connector_data | ||||||
|  | 		from refresh_token where id = $1; | ||||||
|  | 	`, id)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { | ||||||
|  | 	rows, err := c.Query(` | ||||||
|  | 		select | ||||||
|  | 			id, client_id, scopes, nonce, | ||||||
|  | 			claims_user_id, claims_username, claims_email, claims_email_verified, | ||||||
|  | 			claims_groups, | ||||||
|  | 			connector_id, connector_data | ||||||
|  | 		from refresh_token; | ||||||
|  | 	`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, fmt.Errorf("query: %v", err) | ||||||
|  | 	} | ||||||
|  | 	var tokens []storage.RefreshToken | ||||||
|  | 	for rows.Next() { | ||||||
|  | 		r, err := scanRefresh(rows) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		tokens = append(tokens, r) | ||||||
|  | 	} | ||||||
|  | 	if err := rows.Err(); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("scan: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return tokens, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func scanRefresh(s scanner) (r storage.RefreshToken, err error) { | ||||||
|  | 	err = s.Scan( | ||||||
|  | 		&r.RefreshToken, &r.ClientID, decoder(&r.Scopes), &r.Nonce, | ||||||
|  | 		&r.Claims.UserID, &r.Claims.Username, &r.Claims.Email, &r.Claims.EmailVerified, | ||||||
|  | 		decoder(&r.Claims.Groups), | ||||||
|  | 		&r.ConnectorID, &r.ConnectorData, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if err == sql.ErrNoRows { | ||||||
|  | 			return r, storage.ErrNotFound | ||||||
|  | 		} | ||||||
|  | 		return r, fmt.Errorf("scan refresh_token: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return r, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { | ||||||
|  | 	return c.ExecTx(func(tx *trans) error { | ||||||
|  | 		firstUpdate := false | ||||||
|  | 		// TODO(ericchiang): errors may cause a transaction be rolled back by the SQL | ||||||
|  | 		// server. Test this, and consider adding a COUNT() command beforehand. | ||||||
|  | 		old, err := getKeys(tx) | ||||||
|  | 		if err != nil { | ||||||
|  | 			if err != storage.ErrNotFound { | ||||||
|  | 				return fmt.Errorf("get keys: %v", err) | ||||||
|  | 			} | ||||||
|  | 			firstUpdate = true | ||||||
|  | 			old = storage.Keys{} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		nk, err := updater(old) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if firstUpdate { | ||||||
|  | 			_, err = tx.Exec(` | ||||||
|  | 				insert into keys ( | ||||||
|  | 					id, verification_keys, signing_key, signing_key_pub, next_rotation | ||||||
|  | 				) | ||||||
|  | 				values ($1, $2, $3, $4, $5); | ||||||
|  | 			`, | ||||||
|  | 				keysRowID, encoder(nk.VerificationKeys), encoder(nk.SigningKey), | ||||||
|  | 				encoder(nk.SigningKeyPub), nk.NextRotation, | ||||||
|  | 			) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return fmt.Errorf("insert: %v", err) | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			_, err = tx.Exec(` | ||||||
|  | 				update keys | ||||||
|  | 				set  | ||||||
|  | 				    verification_keys = $1, | ||||||
|  | 					signing_key = $2, | ||||||
|  | 					singing_key_pub = $3, | ||||||
|  | 					next_rotation = $4 | ||||||
|  | 				where id = $5; | ||||||
|  | 			`, | ||||||
|  | 				encoder(nk.VerificationKeys), encoder(nk.SigningKey), | ||||||
|  | 				encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID, | ||||||
|  | 			) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return fmt.Errorf("update: %v", err) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) GetKeys() (keys storage.Keys, err error) { | ||||||
|  | 	return getKeys(c) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getKeys(q querier) (keys storage.Keys, err error) { | ||||||
|  | 	err = q.QueryRow(` | ||||||
|  | 		select | ||||||
|  | 			verification_keys, signing_key, signing_key_pub, next_rotation | ||||||
|  | 		from keys | ||||||
|  | 		where id=$q | ||||||
|  | 	`, keysRowID).Scan( | ||||||
|  | 		decoder(&keys.VerificationKeys), decoder(&keys.SigningKey), | ||||||
|  | 		decoder(&keys.SigningKeyPub), &keys.NextRotation, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if err == sql.ErrNoRows { | ||||||
|  | 			return keys, storage.ErrNotFound | ||||||
|  | 		} | ||||||
|  | 		return keys, fmt.Errorf("query keys: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return keys, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { | ||||||
|  | 	return c.ExecTx(func(tx *trans) error { | ||||||
|  | 		cli, err := getClient(tx, id) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		nc, err := updater(cli) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		_, err = tx.Exec(` | ||||||
|  | 			update client | ||||||
|  | 			set | ||||||
|  | 				secret = $1, | ||||||
|  | 				redirect_uris = $2, | ||||||
|  | 				trusted_peers = $3, | ||||||
|  | 				public = $4, | ||||||
|  | 				name = $5, | ||||||
|  | 				logo_url = $6 | ||||||
|  | 			where id = $7; | ||||||
|  | 		`, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id, | ||||||
|  | 		) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("update client: %v", err) | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) CreateClient(cli storage.Client) error { | ||||||
|  | 	_, err := c.Exec(` | ||||||
|  | 		insert into client ( | ||||||
|  | 			id, secret, redirect_uris, trusted_peers, public, name, logo_url | ||||||
|  | 		) | ||||||
|  | 		values ($1, $2, $3, $4, $5, $6, $7); | ||||||
|  | 	`, | ||||||
|  | 		cli.ID, cli.Secret, encoder(cli.RedirectURIs), encoder(cli.TrustedPeers), | ||||||
|  | 		cli.Public, cli.Name, cli.LogoURL, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("insert client: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func getClient(q querier, id string) (storage.Client, error) { | ||||||
|  | 	return scanClient(q.QueryRow(` | ||||||
|  | 		select | ||||||
|  | 			id, secret, redirect_uris, trusted_peers, public, name, logo_url | ||||||
|  | 	    from client where id = $1; | ||||||
|  | 	`, id)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) GetClient(id string) (storage.Client, error) { | ||||||
|  | 	return getClient(c, id) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) ListClients() ([]storage.Client, error) { | ||||||
|  | 	rows, err := c.Query(` | ||||||
|  | 		select | ||||||
|  | 			id, secret, redirect_uris, trusted_peers, public, name, logo_url | ||||||
|  | 		from client; | ||||||
|  | 	`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	var clients []storage.Client | ||||||
|  | 	for rows.Next() { | ||||||
|  | 		cli, err := scanClient(rows) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		clients = append(clients, cli) | ||||||
|  | 	} | ||||||
|  | 	if err := rows.Err(); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return clients, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func scanClient(s scanner) (cli storage.Client, err error) { | ||||||
|  | 	err = s.Scan( | ||||||
|  | 		&cli.ID, &cli.Secret, decoder(&cli.RedirectURIs), decoder(&cli.TrustedPeers), | ||||||
|  | 		&cli.Public, &cli.Name, &cli.LogoURL, | ||||||
|  | 	) | ||||||
|  | 	if err != nil { | ||||||
|  | 		if err == sql.ErrNoRows { | ||||||
|  | 			return cli, storage.ErrNotFound | ||||||
|  | 		} | ||||||
|  | 		return cli, fmt.Errorf("get client: %v", err) | ||||||
|  | 	} | ||||||
|  | 	return cli, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) DeleteAuthRequest(id string) error { return c.delete("auth_request", id) } | ||||||
|  | func (c *conn) DeleteAuthCode(id string) error    { return c.delete("auth_code", id) } | ||||||
|  | func (c *conn) DeleteClient(id string) error      { return c.delete("client", id) } | ||||||
|  | func (c *conn) DeleteRefresh(id string) error     { return c.delete("refresh_token", id) } | ||||||
|  |  | ||||||
|  | // Do NOT call directly. Does not escape table. | ||||||
|  | func (c *conn) delete(table, id string) error { | ||||||
|  | 	result, err := c.Exec(`delete from `+table+` where id = $1`, id) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("delete %s: %v", table, id) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// For now mandate that the driver implements RowsAffected. If we ever need to support | ||||||
|  | 	// a driver that doesn't implement this, we can run this in a transaction with a get beforehand. | ||||||
|  | 	n, err := result.RowsAffected() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("rows affected: %v", err) | ||||||
|  | 	} | ||||||
|  | 	if n < 1 { | ||||||
|  | 		return storage.ErrNotFound | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										55
									
								
								storage/sql/crud_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								storage/sql/crud_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestDecoder(t *testing.T) { | ||||||
|  | 	db, err := sql.Open("sqlite3", ":memory:") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer db.Close() | ||||||
|  |  | ||||||
|  | 	if _, err := db.Exec(`create table foo ( id integer primary key, bar blob );`); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if _, err := db.Exec(`insert into foo ( id, bar ) values (1, ?);`, []byte(`["a", "b"]`)); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	var got []string | ||||||
|  | 	if err := db.QueryRow(`select bar from foo where id = 1;`).Scan(decoder(&got)); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	want := []string{"a", "b"} | ||||||
|  | 	if !reflect.DeepEqual(got, want) { | ||||||
|  | 		t.Errorf("wanted %q got %q", want, got) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestEncoder(t *testing.T) { | ||||||
|  | 	db, err := sql.Open("sqlite3", ":memory:") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer db.Close() | ||||||
|  |  | ||||||
|  | 	if _, err := db.Exec(`create table foo ( id integer primary key, bar blob );`); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	put := []string{"a", "b"} | ||||||
|  | 	if _, err := db.Exec(`insert into foo ( id, bar ) values (1, ?)`, encoder(put)); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var got []byte | ||||||
|  | 	if err := db.QueryRow(`select bar from foo where id = 1;`).Scan(&got); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	want := []byte(`["a","b"]`) | ||||||
|  | 	if !reflect.DeepEqual(got, want) { | ||||||
|  | 		t.Errorf("wanted %q got %q", want, got) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										24
									
								
								storage/sql/gc.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								storage/sql/gc.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type gc struct { | ||||||
|  | 	now  func() time.Time | ||||||
|  | 	conn *conn | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var tablesWithGC = []string{"auth_request", "auth_code"} | ||||||
|  |  | ||||||
|  | func (gc gc) run() error { | ||||||
|  | 	for _, table := range tablesWithGC { | ||||||
|  | 		_, err := gc.conn.Exec(`delete from `+table+` where expiry < $1`, gc.now()) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return fmt.Errorf("gc %s: %v", table, err) | ||||||
|  | 		} | ||||||
|  | 		// TODO(ericchiang): when we have levelled logging print how many rows were gc'd | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										53
									
								
								storage/sql/gc_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								storage/sql/gc_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/coreos/dex/storage" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestGC(t *testing.T) { | ||||||
|  | 	// TODO(ericchiang): Add a GarbageCollect method to the storage interface so | ||||||
|  | 	// we can write conformance tests instead of directly testing each implementation. | ||||||
|  | 	s := &SQLite3{":memory:"} | ||||||
|  | 	conn, err := s.open() | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer conn.Close() | ||||||
|  |  | ||||||
|  | 	clock := time.Now() | ||||||
|  | 	now := func() time.Time { return clock } | ||||||
|  |  | ||||||
|  | 	runGC := (gc{now, conn}).run | ||||||
|  |  | ||||||
|  | 	a := storage.AuthRequest{ | ||||||
|  | 		ID:     storage.NewID(), | ||||||
|  | 		Expiry: now().Add(time.Second), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := conn.CreateAuthRequest(a); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := runGC(); err != nil { | ||||||
|  | 		t.Errorf("gc failed: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if _, err := conn.GetAuthRequest(a.ID); err != nil { | ||||||
|  | 		t.Errorf("failed to get auth request after gc: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	clock = clock.Add(time.Minute) | ||||||
|  |  | ||||||
|  | 	if err := runGC(); err != nil { | ||||||
|  | 		t.Errorf("gc failed: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if _, err := conn.GetAuthRequest(a.ID); err == nil { | ||||||
|  | 		t.Errorf("expected error after gc'ing auth request: %v", err) | ||||||
|  | 	} else if err != storage.ErrNotFound { | ||||||
|  | 		t.Errorf("expected error storage.NotFound got: %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										151
									
								
								storage/sql/migrate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										151
									
								
								storage/sql/migrate.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,151 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"fmt" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (c *conn) migrate() (int, error) { | ||||||
|  | 	_, err := c.Exec(` | ||||||
|  | 		create table if not exists migrations ( | ||||||
|  | 			num integer not null, | ||||||
|  | 			at timestamp not null | ||||||
|  | 		); | ||||||
|  | 	`) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return 0, fmt.Errorf("creating migration table: %v", err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	i := 0 | ||||||
|  | 	done := false | ||||||
|  | 	for { | ||||||
|  | 		err := c.ExecTx(func(tx *trans) error { | ||||||
|  | 			// Within a transaction, perform a single migration. | ||||||
|  | 			var ( | ||||||
|  | 				num sql.NullInt64 | ||||||
|  | 				n   int | ||||||
|  | 			) | ||||||
|  | 			if err := tx.QueryRow(`select max(num) from migrations;`).Scan(&num); err != nil { | ||||||
|  | 				return fmt.Errorf("select max migration: %v", err) | ||||||
|  | 			} | ||||||
|  | 			if num.Valid { | ||||||
|  | 				n = int(num.Int64) | ||||||
|  | 			} | ||||||
|  | 			if n >= len(migrations) { | ||||||
|  | 				done = true | ||||||
|  | 				return nil | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			migrationNum := n + 1 | ||||||
|  | 			m := migrations[n] | ||||||
|  | 			if _, err := tx.Exec(m.stmt); err != nil { | ||||||
|  | 				return fmt.Errorf("migration %d failed: %v", migrationNum, err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			q := `insert into migrations (num, at) values ($1, now());` | ||||||
|  | 			if _, err := tx.Exec(q, migrationNum); err != nil { | ||||||
|  | 				return fmt.Errorf("update migration table: %v", err) | ||||||
|  | 			} | ||||||
|  | 			return nil | ||||||
|  | 		}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return i, err | ||||||
|  | 		} | ||||||
|  | 		if done { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 		i++ | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return i, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type migration struct { | ||||||
|  | 	stmt string | ||||||
|  | 	// TODO(ericchiang): consider adding additional fields like "forDrivers" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // All SQL flavors share migration strategies. | ||||||
|  | var migrations = []migration{ | ||||||
|  | 	{ | ||||||
|  | 		stmt: ` | ||||||
|  | 			create table client ( | ||||||
|  | 				id text not null primary key, | ||||||
|  | 				secret text not null, | ||||||
|  | 				redirect_uris bytea not null, -- JSON array of strings | ||||||
|  | 				trusted_peers bytea not null, -- JSON array of strings | ||||||
|  | 				public boolean not null, | ||||||
|  | 				name text not null, | ||||||
|  | 				logo_url text not null | ||||||
|  | 			); | ||||||
|  | 		 | ||||||
|  | 			create table auth_request ( | ||||||
|  | 				id text not null primary key, | ||||||
|  | 				client_id text not null, | ||||||
|  | 				response_types bytea not null, -- JSON array of strings | ||||||
|  | 				scopes bytea not null,         -- JSON array of strings | ||||||
|  | 				redirect_uri text not null, | ||||||
|  | 				nonce text not null, | ||||||
|  | 				state text not null, | ||||||
|  | 				force_approval_prompt boolean not null, | ||||||
|  | 		 | ||||||
|  | 				logged_in boolean not null, | ||||||
|  | 		 | ||||||
|  | 				claims_user_id text not null, | ||||||
|  | 				claims_username text not null, | ||||||
|  | 				claims_email text not null, | ||||||
|  | 				claims_email_verified boolean not null, | ||||||
|  | 				claims_groups bytea not null, -- JSON array of strings | ||||||
|  | 		 | ||||||
|  | 				connector_id text not null, | ||||||
|  | 				connector_data bytea, | ||||||
|  | 		 | ||||||
|  | 				expiry timestamp not null | ||||||
|  | 			); | ||||||
|  | 		 | ||||||
|  | 			create table auth_code ( | ||||||
|  | 				id text not null primary key, | ||||||
|  | 				client_id text not null, | ||||||
|  | 				scopes bytea not null, -- JSON array of strings | ||||||
|  | 				nonce text not null, | ||||||
|  | 				redirect_uri text not null, | ||||||
|  | 		 | ||||||
|  | 				claims_user_id text not null, | ||||||
|  | 				claims_username text not null, | ||||||
|  | 				claims_email text not null, | ||||||
|  | 				claims_email_verified boolean not null, | ||||||
|  | 				claims_groups bytea not null, -- JSON array of strings | ||||||
|  | 		 | ||||||
|  | 				connector_id text not null, | ||||||
|  | 				connector_data bytea, | ||||||
|  | 		 | ||||||
|  | 				expiry timestamp not null | ||||||
|  | 			); | ||||||
|  | 		 | ||||||
|  | 			create table refresh_token ( | ||||||
|  | 				id text not null primary key, | ||||||
|  | 				client_id text not null, | ||||||
|  | 				scopes bytea not null, -- JSON array of strings | ||||||
|  | 				nonce text not null, | ||||||
|  | 		 | ||||||
|  | 				claims_user_id text not null, | ||||||
|  | 				claims_username text not null, | ||||||
|  | 				claims_email text not null, | ||||||
|  | 				claims_email_verified boolean not null, | ||||||
|  | 				claims_groups bytea not null, -- JSON array of strings | ||||||
|  | 		 | ||||||
|  | 				connector_id text not null, | ||||||
|  | 				connector_data bytea | ||||||
|  | 			); | ||||||
|  | 		 | ||||||
|  | 			-- keys is a weird table because we only ever expect there to be a single row | ||||||
|  | 			create table keys ( | ||||||
|  | 				id text not null primary key, | ||||||
|  | 				verification_keys bytea not null, -- JSON array | ||||||
|  | 				signing_key bytea not null,       -- JSON object | ||||||
|  | 				signing_key_pub bytea not null,   -- JSON object | ||||||
|  | 				next_rotation timestamp not null | ||||||
|  | 			); | ||||||
|  | 		`, | ||||||
|  | 	}, | ||||||
|  | } | ||||||
							
								
								
									
										25
									
								
								storage/sql/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								storage/sql/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestMigrate(t *testing.T) { | ||||||
|  | 	db, err := sql.Open("sqlite3", ":memory:") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	defer db.Close() | ||||||
|  |  | ||||||
|  | 	c := &conn{db, flavorSQLite3} | ||||||
|  | 	for _, want := range []int{len(migrations), 0} { | ||||||
|  | 		got, err := c.migrate() | ||||||
|  | 		if err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 		if got != want { | ||||||
|  | 			t.Errorf("expected %d migrations, got %d", want, got) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										152
									
								
								storage/sql/sql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								storage/sql/sql.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,152 @@ | |||||||
|  | // Package sql provides SQL implementations of the storage interface. | ||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"regexp" | ||||||
|  |  | ||||||
|  | 	"github.com/cockroachdb/cockroach-go/crdb" | ||||||
|  |  | ||||||
|  | 	// import third party drivers | ||||||
|  | 	_ "github.com/go-sql-driver/mysql" | ||||||
|  | 	_ "github.com/lib/pq" | ||||||
|  | 	_ "github.com/mattn/go-sqlite3" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // flavor represents a specific SQL implementation, and is used to translate query strings | ||||||
|  | // between different drivers. Flavors shouldn't aim to translate all possible SQL statements, | ||||||
|  | // only the specific queries used by the SQL storages. | ||||||
|  | type flavor struct { | ||||||
|  | 	queryReplacers []replacer | ||||||
|  |  | ||||||
|  | 	// Optional function to create and finish a transaction. This is mainly for | ||||||
|  | 	// cockroachdb support which requires special retry logic provided by their | ||||||
|  | 	// client package. | ||||||
|  | 	// | ||||||
|  | 	// This will be nil for most flavors. | ||||||
|  | 	// | ||||||
|  | 	// See: https://github.com/cockroachdb/docs/blob/63761c2e/_includes/app/txn-sample.go#L41-L44 | ||||||
|  | 	executeTx func(db *sql.DB, fn func(*sql.Tx) error) error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // A regexp with a replacement string. | ||||||
|  | type replacer struct { | ||||||
|  | 	re   *regexp.Regexp | ||||||
|  | 	with string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Match a postgres query binds. E.g. "$1", "$12", etc. | ||||||
|  | var bindRegexp = regexp.MustCompile(`\$\d+`) | ||||||
|  |  | ||||||
|  | func matchLiteral(s string) *regexp.Regexp { | ||||||
|  | 	return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	// The "github.com/lib/pq" driver is the default flavor. All others are | ||||||
|  | 	// translations of this. | ||||||
|  | 	flavorPostgres = flavor{} | ||||||
|  |  | ||||||
|  | 	flavorSQLite3 = flavor{ | ||||||
|  | 		queryReplacers: []replacer{ | ||||||
|  | 			{bindRegexp, "?"}, | ||||||
|  | 			// Translate for booleans to integers. | ||||||
|  | 			{matchLiteral("true"), "1"}, | ||||||
|  | 			{matchLiteral("false"), "0"}, | ||||||
|  | 			{matchLiteral("boolean"), "integer"}, | ||||||
|  | 			// Translate other types. | ||||||
|  | 			{matchLiteral("bytea"), "blob"}, | ||||||
|  | 			// {matchLiteral("timestamp"), "integer"}, | ||||||
|  | 			// SQLite doesn't have a "now()" method, replace with "date('now')" | ||||||
|  | 			{regexp.MustCompile(`\bnow\(\)`), "date('now')"}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Incomplete. | ||||||
|  | 	flavorMySQL = flavor{ | ||||||
|  | 		queryReplacers: []replacer{ | ||||||
|  | 			{bindRegexp, "?"}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Not tested. | ||||||
|  | 	flavorCockroach = flavor{ | ||||||
|  | 		executeTx: crdb.ExecuteTx, | ||||||
|  | 	} | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func (f flavor) translate(query string) string { | ||||||
|  | 	// TODO(ericchiang): Heavy cashing. | ||||||
|  | 	for _, r := range f.queryReplacers { | ||||||
|  | 		query = r.re.ReplaceAllString(query, r.with) | ||||||
|  | 	} | ||||||
|  | 	return query | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // conn is the main database connection. | ||||||
|  | type conn struct { | ||||||
|  | 	db     *sql.DB | ||||||
|  | 	flavor flavor | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) Close() error { | ||||||
|  | 	return c.db.Close() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // conn implements the same method signatures as encoding/sql.DB. | ||||||
|  |  | ||||||
|  | func (c *conn) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||||
|  | 	query = c.flavor.translate(query) | ||||||
|  | 	return c.db.Exec(query, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) Query(query string, args ...interface{}) (*sql.Rows, error) { | ||||||
|  | 	query = c.flavor.translate(query) | ||||||
|  | 	return c.db.Query(query, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *conn) QueryRow(query string, args ...interface{}) *sql.Row { | ||||||
|  | 	query = c.flavor.translate(query) | ||||||
|  | 	return c.db.QueryRow(query, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ExecTx runs a method which operates on a transaction. | ||||||
|  | func (c *conn) ExecTx(fn func(tx *trans) error) error { | ||||||
|  | 	if c.flavor.executeTx != nil { | ||||||
|  | 		return c.flavor.executeTx(c.db, func(sqlTx *sql.Tx) error { | ||||||
|  | 			return fn(&trans{sqlTx, c}) | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sqlTx, err := c.db.Begin() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	if err := fn(&trans{sqlTx, c}); err != nil { | ||||||
|  | 		sqlTx.Rollback() | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	return sqlTx.Commit() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type trans struct { | ||||||
|  | 	tx *sql.Tx | ||||||
|  | 	c  *conn | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // trans implements the same method signatures as encoding/sql.Tx. | ||||||
|  |  | ||||||
|  | func (t *trans) Exec(query string, args ...interface{}) (sql.Result, error) { | ||||||
|  | 	query = t.c.flavor.translate(query) | ||||||
|  | 	return t.tx.Exec(query, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *trans) Query(query string, args ...interface{}) (*sql.Rows, error) { | ||||||
|  | 	query = t.c.flavor.translate(query) | ||||||
|  | 	return t.tx.Query(query, args...) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *trans) QueryRow(query string, args ...interface{}) *sql.Row { | ||||||
|  | 	query = t.c.flavor.translate(query) | ||||||
|  | 	return t.tx.QueryRow(query, args...) | ||||||
|  | } | ||||||
							
								
								
									
										55
									
								
								storage/sql/sql_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								storage/sql/sql_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | |||||||
|  | package sql | ||||||
|  |  | ||||||
|  | import "testing" | ||||||
|  |  | ||||||
|  | func TestTranslate(t *testing.T) { | ||||||
|  | 	tests := []struct { | ||||||
|  | 		testCase string | ||||||
|  | 		flavor   flavor | ||||||
|  | 		query    string | ||||||
|  | 		exp      string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			"sqlite3 query bind replacement", | ||||||
|  | 			flavorSQLite3, | ||||||
|  | 			`select foo from bar where foo.zam = $1;`, | ||||||
|  | 			`select foo from bar where foo.zam = ?;`, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"sqlite3 query bind replacement at newline", | ||||||
|  | 			flavorSQLite3, | ||||||
|  | 			`select foo from bar where foo.zam = $1`, | ||||||
|  | 			`select foo from bar where foo.zam = ?`, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"sqlite3 query true", | ||||||
|  | 			flavorSQLite3, | ||||||
|  | 			`select foo from bar where foo.zam = true`, | ||||||
|  | 			`select foo from bar where foo.zam = 1`, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"sqlite3 query false", | ||||||
|  | 			flavorSQLite3, | ||||||
|  | 			`select foo from bar where foo.zam = false`, | ||||||
|  | 			`select foo from bar where foo.zam = 0`, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"sqlite3 bytea", | ||||||
|  | 			flavorSQLite3, | ||||||
|  | 			`"connector_data" bytea not null,`, | ||||||
|  | 			`"connector_data" blob not null,`, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			"sqlite3 now", | ||||||
|  | 			flavorSQLite3, | ||||||
|  | 			`now(),`, | ||||||
|  | 			`date('now'),`, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range tests { | ||||||
|  | 		if got := tc.flavor.translate(tc.query); got != tc.exp { | ||||||
|  | 			t.Errorf("%s: want=%q, got=%q", tc.testCase, tc.exp, got) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user