From 7e960214287ad86acf84c31f7b5952d4bf5fd1b9 Mon Sep 17 00:00:00 2001 From: Alex Suraci Date: Thu, 15 Nov 2018 13:17:42 -0500 Subject: [PATCH 1/6] retry on serialization errors --- storage/sql/sql.go | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/storage/sql/sql.go b/storage/sql/sql.go index dc6be4a1..ddcfae9e 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -6,10 +6,10 @@ import ( "regexp" "time" + "github.com/lib/pq" "github.com/sirupsen/logrus" // import third party drivers - _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) @@ -51,19 +51,34 @@ var ( // NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a // session level didn't work for some edge cases. Might be something worth exploring. executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error { - tx, err := db.Begin() - if err != nil { - return err - } - defer tx.Rollback() + for { + tx, err := db.Begin() + if err != nil { + return err + } - if _, err := tx.Exec(`SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;`); err != nil { - return err + defer tx.Rollback() + + if _, err := tx.Exec(`SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;`); err != nil { + return err + } + + if err := fn(tx); err != nil { + return err + } + + err = tx.Commit() + if err != nil { + if pqErr, ok := err.(*pq.Error); ok && pqErr.Code == "40001" { + // serialization error; retry + continue + } + + return err + } + + return nil } - if err := fn(tx); err != nil { - return err - } - return tx.Commit() }, supportsTimezones: true, From 9b9013a560a098fe0597fb337bbbb2b4f225fd44 Mon Sep 17 00:00:00 2001 From: Alex Suraci Date: Fri, 16 Nov 2018 19:35:47 +0000 Subject: [PATCH 2/6] postgres: use stdlib to set serializable tx level also use a context for the rollback, which is a bit cleaner since it only results in one 'defer', rather than N from the loop --- storage/sql/sql.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/storage/sql/sql.go b/storage/sql/sql.go index ddcfae9e..2ddccab7 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -2,6 +2,7 @@ package sql import ( + "context" "database/sql" "regexp" "time" @@ -51,18 +52,19 @@ var ( // NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a // session level didn't work for some edge cases. Might be something worth exploring. executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error { + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + opts := &sql.TxOptions{ + Isolation: sql.LevelSerializable, + } + for { - tx, err := db.Begin() + tx, err := db.BeginTx(ctx, opts) if err != nil { return err } - defer tx.Rollback() - - if _, err := tx.Exec(`SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;`); err != nil { - return err - } - if err := fn(tx); err != nil { return err } From aa068b667aff461f87d49cc0227e162ad8cbadfb Mon Sep 17 00:00:00 2001 From: Alex Suraci Date: Fri, 16 Nov 2018 19:36:56 +0000 Subject: [PATCH 3/6] postgres: improve readability of error check --- storage/sql/sql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/storage/sql/sql.go b/storage/sql/sql.go index 2ddccab7..7f20cf9d 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -71,7 +71,7 @@ var ( err = tx.Commit() if err != nil { - if pqErr, ok := err.(*pq.Error); ok && pqErr.Code == "40001" { + if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "serialization_failure" { // serialization error; retry continue } From 5d67da147298737e1273567dea15ac3478f5ea8d Mon Sep 17 00:00:00 2001 From: Alex Suraci Date: Fri, 16 Nov 2018 19:49:39 +0000 Subject: [PATCH 4/6] bump lib/pq --- glide.lock | 6 +- glide.yaml | 2 +- vendor/github.com/lib/pq/array.go | 51 +- vendor/github.com/lib/pq/conn.go | 556 ++++++++++---------- vendor/github.com/lib/pq/conn_go18.go | 129 +++++ vendor/github.com/lib/pq/connector.go | 43 ++ vendor/github.com/lib/pq/copy.go | 29 +- vendor/github.com/lib/pq/doc.go | 53 +- vendor/github.com/lib/pq/encode.go | 20 +- vendor/github.com/lib/pq/error.go | 9 +- vendor/github.com/lib/pq/notify.go | 65 ++- vendor/github.com/lib/pq/oid/gen.go | 59 ++- vendor/github.com/lib/pq/oid/types.go | 184 ++++++- vendor/github.com/lib/pq/rows.go | 93 ++++ vendor/github.com/lib/pq/ssl.go | 175 ++++++ vendor/github.com/lib/pq/ssl_permissions.go | 20 + vendor/github.com/lib/pq/ssl_windows.go | 9 + vendor/github.com/lib/pq/uuid.go | 23 + 18 files changed, 1162 insertions(+), 364 deletions(-) create mode 100644 vendor/github.com/lib/pq/conn_go18.go create mode 100644 vendor/github.com/lib/pq/connector.go create mode 100644 vendor/github.com/lib/pq/rows.go create mode 100644 vendor/github.com/lib/pq/ssl.go create mode 100644 vendor/github.com/lib/pq/ssl_permissions.go create mode 100644 vendor/github.com/lib/pq/ssl_windows.go create mode 100644 vendor/github.com/lib/pq/uuid.go diff --git a/glide.lock b/glide.lock index f33cfe75..e4b00be5 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: e5972bbdf15ad612d99ce8cd34e19537b9eacb5ff53688f339e0da285eb8ec22 -updated: 2018-11-12T19:38:56.235070564+01:00 +hash: 70e399f3424964c1535cefb66bce0e47af25ea6bb0f32a254e83e91bd774b5f2 +updated: 2018-11-20T09:49:19.83565589-05:00 imports: - name: github.com/beevik/etree version: 4cd0dd976db869f817248477718071a28e978df0 @@ -54,7 +54,7 @@ imports: - diff - pretty - name: github.com/lib/pq - version: 50761b0867bd1d9d069276790bcd4a3bccf2324a + version: 9eb73efc1fcc404148b56765b0d3f61d9a5ef8ee subpackages: - oid - name: github.com/mattn/go-sqlite3 diff --git a/glide.yaml b/glide.yaml index b8f459be..e4909ff4 100644 --- a/glide.yaml +++ b/glide.yaml @@ -114,7 +114,7 @@ import: - package: github.com/mattn/go-sqlite3 version: 3fb7a0e792edd47bf0cf1e919dfc14e2be412e15 - package: github.com/lib/pq - version: 50761b0867bd1d9d069276790bcd4a3bccf2324a + version: 9eb73efc1fcc404148b56765b0d3f61d9a5ef8ee # etcd driver - package: github.com/coreos/etcd diff --git a/vendor/github.com/lib/pq/array.go b/vendor/github.com/lib/pq/array.go index 27eb07a9..e4933e22 100644 --- a/vendor/github.com/lib/pq/array.go +++ b/vendor/github.com/lib/pq/array.go @@ -13,7 +13,7 @@ import ( var typeByteSlice = reflect.TypeOf([]byte{}) var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() -var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() +var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() // Array returns the optimal driver.Valuer and sql.Scanner for an array or // slice of any dimension. @@ -70,6 +70,9 @@ func (a *BoolArray) Scan(src interface{}) error { return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil } return fmt.Errorf("pq: cannot convert %T to BoolArray", src) @@ -80,7 +83,7 @@ func (a *BoolArray) scanBytes(src []byte) error { if err != nil { return err } - if len(elems) == 0 { + if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(BoolArray, len(elems)) @@ -141,6 +144,9 @@ func (a *ByteaArray) Scan(src interface{}) error { return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil } return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) @@ -151,7 +157,7 @@ func (a *ByteaArray) scanBytes(src []byte) error { if err != nil { return err } - if len(elems) == 0 { + if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(ByteaArray, len(elems)) @@ -210,6 +216,9 @@ func (a *Float64Array) Scan(src interface{}) error { return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil } return fmt.Errorf("pq: cannot convert %T to Float64Array", src) @@ -220,7 +229,7 @@ func (a *Float64Array) scanBytes(src []byte) error { if err != nil { return err } - if len(elems) == 0 { + if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(Float64Array, len(elems)) @@ -269,7 +278,7 @@ func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]b // TODO calculate the assign function for other types // TODO repeat this section on the element type of arrays or slices (multidimensional) { - if reflect.PtrTo(rt).Implements(typeSqlScanner) { + if reflect.PtrTo(rt).Implements(typeSQLScanner) { // dest is always addressable because it is an element of a slice. assign = func(src []byte, dest reflect.Value) (err error) { ss := dest.Addr().Interface().(sql.Scanner) @@ -320,6 +329,11 @@ func (a GenericArray) Scan(src interface{}) error { return a.scanBytes(src, dv) case string: return a.scanBytes([]byte(src), dv) + case nil: + if dv.Kind() == reflect.Slice { + dv.Set(reflect.Zero(dv.Type())) + return nil + } } return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) @@ -386,7 +400,13 @@ func (a GenericArray) Value() (driver.Value, error) { rv := reflect.ValueOf(a.A) - if k := rv.Kind(); k != reflect.Array && k != reflect.Slice { + switch rv.Kind() { + case reflect.Slice: + if rv.IsNil() { + return nil, nil + } + case reflect.Array: + default: return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) } @@ -412,6 +432,9 @@ func (a *Int64Array) Scan(src interface{}) error { return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil } return fmt.Errorf("pq: cannot convert %T to Int64Array", src) @@ -422,7 +445,7 @@ func (a *Int64Array) scanBytes(src []byte) error { if err != nil { return err } - if len(elems) == 0 { + if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(Int64Array, len(elems)) @@ -470,6 +493,9 @@ func (a *StringArray) Scan(src interface{}) error { return a.scanBytes(src) case string: return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil } return fmt.Errorf("pq: cannot convert %T to StringArray", src) @@ -480,7 +506,7 @@ func (a *StringArray) scanBytes(src []byte) error { if err != nil { return err } - if len(elems) == 0 { + if *a != nil && len(elems) == 0 { *a = (*a)[:0] } else { b := make(StringArray, len(elems)) @@ -561,7 +587,7 @@ func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { } } - var del string = "," + var del = "," var err error var iv interface{} = rv.Interface() @@ -639,6 +665,9 @@ Element: for i < len(src) { switch src[i] { case '{': + if depth == len(dims) { + break Element + } depth++ dims[depth-1] = 0 i++ @@ -680,11 +709,11 @@ Element: } for i < len(src) { - if bytes.HasPrefix(src[i:], del) { + if bytes.HasPrefix(src[i:], del) && depth > 0 { dims[depth-1]++ i += len(del) goto Element - } else if src[i] == '}' { + } else if src[i] == '}' && depth > 0 { dims[depth-1]++ depth-- i++ diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index 8e1aee9f..43c8df29 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -3,15 +3,12 @@ package pq import ( "bufio" "crypto/md5" - "crypto/tls" - "crypto/x509" "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" - "io/ioutil" "net" "os" "os/user" @@ -30,18 +27,26 @@ var ( ErrNotSupported = errors.New("pq: Unsupported command") ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") + ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") + ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") + + errUnexpectedReady = errors.New("unexpected ReadyForQuery") + errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") + errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) -type drv struct{} +// Driver is the Postgres database driver. +type Driver struct{} -func (d *drv) Open(name string) (driver.Conn, error) { +// Open opens a new connection to the database. name is a connection string. +// Most users should only use it through database/sql package from the standard +// library. +func (d *Driver) Open(name string) (driver.Conn, error) { return Open(name) } func init() { - sql.Register("postgres", &drv{}) + sql.Register("postgres", &Driver{}) } type parameterStatus struct { @@ -77,6 +82,8 @@ func (s transactionStatus) String() string { panic("not reached") } +// Dialer is the dialer interface. It can be used to obtain more control over +// how pq creates network connections. type Dialer interface { Dial(network, address string) (net.Conn, error) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) @@ -97,6 +104,15 @@ type conn struct { namei int scratch [512]byte txnStatus transactionStatus + txnFinish func() + + // Save connection arguments to use during CancelRequest. + dialer Dialer + opts values + + // Cancellation key data for use with CancelRequest messages. + processID int + secretKey int parameterStatus parameterStatus @@ -115,12 +131,15 @@ type conn struct { // Whether to always send []byte parameters over as binary. Enables single // round-trip mode for non-prepared Query calls. binaryParameters bool + + // If true this connection is in the middle of a COPY + inCopy bool } // Handle driver-side settings in parsed connection string. -func (c *conn) handleDriverSettings(o values) (err error) { +func (cn *conn) handleDriverSettings(o values) (err error) { boolSetting := func(key string, val *bool) error { - if value := o.Get(key); value != "" { + if value, ok := o[key]; ok { if value == "yes" { *val = true } else if value == "no" { @@ -132,32 +151,32 @@ func (c *conn) handleDriverSettings(o values) (err error) { return nil } - err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) + err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) if err != nil { return err } - err = boolSetting("binary_parameters", &c.binaryParameters) - if err != nil { - return err - } - return nil + return boolSetting("binary_parameters", &cn.binaryParameters) } -func (c *conn) handlePgpass(o values) { +func (cn *conn) handlePgpass(o values) { // if a password was supplied, do not process .pgpass - _, ok := o["password"] - if ok { + if _, ok := o["password"]; ok { return } filename := os.Getenv("PGPASSFILE") if filename == "" { // XXX this code doesn't work on Windows where the default filename is // XXX %APPDATA%\postgresql\pgpass.conf - user, err := user.Current() - if err != nil { - return + // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 + userHome := os.Getenv("HOME") + if userHome == "" { + user, err := user.Current() + if err != nil { + return + } + userHome = user.HomeDir } - filename = filepath.Join(user.HomeDir, ".pgpass") + filename = filepath.Join(userHome, ".pgpass") } fileinfo, err := os.Stat(filename) if err != nil { @@ -174,11 +193,11 @@ func (c *conn) handlePgpass(o values) { } defer file.Close() scanner := bufio.NewScanner(io.Reader(file)) - hostname := o.Get("host") + hostname := o["host"] ntw, _ := network(o) - port := o.Get("port") - db := o.Get("dbname") - username := o.Get("user") + port := o["port"] + db := o["dbname"] + username := o["user"] // From: https://github.com/tg/pgpass/blob/master/reader.go getFields := func(s string) []string { fs := make([]string, 0, 5) @@ -217,18 +236,22 @@ func (c *conn) handlePgpass(o values) { } } -func (c *conn) writeBuf(b byte) *writeBuf { - c.scratch[0] = b +func (cn *conn) writeBuf(b byte) *writeBuf { + cn.scratch[0] = b return &writeBuf{ - buf: c.scratch[:5], + buf: cn.scratch[:5], pos: 1, } } +// Open opens a new connection to the database. name is a connection string. +// Most users should only use it through database/sql package from the standard +// library. func Open(name string) (_ driver.Conn, err error) { return DialOpen(defaultDialer{}, name) } +// DialOpen opens a new connection to the database using a dialer. func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // Handle any panics during connection initialization. Note that we // specifically do *not* want to use errRecover(), as that would turn any @@ -243,13 +266,13 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // * Very low precedence defaults applied in every situation // * Environment variables // * Explicitly passed connection information - o.Set("host", "localhost") - o.Set("port", "5432") + o["host"] = "localhost" + o["port"] = "5432" // N.B.: Extra float digits should be set to 3, but that breaks // Postgres 8.4 and older, where the max is 2. - o.Set("extra_float_digits", "2") + o["extra_float_digits"] = "2" for k, v := range parseEnviron(os.Environ()) { - o.Set(k, v) + o[k] = v } if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { @@ -264,9 +287,9 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { } // Use the "fallback" application name if necessary - if fallback := o.Get("fallback_application_name"); fallback != "" { - if !o.Isset("application_name") { - o.Set("application_name", fallback) + if fallback, ok := o["fallback_application_name"]; ok { + if _, ok := o["application_name"]; !ok { + o["application_name"] = fallback } } @@ -277,33 +300,35 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { // parsing its value is not worth it. Instead, we always explicitly send // client_encoding as a separate run-time parameter, which should override // anything set in options. - if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) { + if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { return nil, errors.New("client_encoding must be absent or 'UTF8'") } - o.Set("client_encoding", "UTF8") + o["client_encoding"] = "UTF8" // DateStyle needs a similar treatment. - if datestyle := o.Get("datestyle"); datestyle != "" { + if datestyle, ok := o["datestyle"]; ok { if datestyle != "ISO, MDY" { panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)) } } else { - o.Set("datestyle", "ISO, MDY") + o["datestyle"] = "ISO, MDY" } // If a user is not provided by any other means, the last // resort is to use the current operating system provided user // name. - if o.Get("user") == "" { + if _, ok := o["user"]; !ok { u, err := userCurrent() if err != nil { return nil, err - } else { - o.Set("user", u) } + o["user"] = u } - cn := &conn{} + cn := &conn{ + opts: o, + dialer: d, + } err = cn.handleDriverSettings(o) if err != nil { return nil, err @@ -314,14 +339,28 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { if err != nil { return nil, err } - cn.ssl(o) + + err = cn.ssl(o) + if err != nil { + return nil, err + } + + // cn.startup panics on error. Make sure we don't leak cn.c. + panicking := true + defer func() { + if panicking { + cn.c.Close() + } + }() + cn.buf = bufio.NewReader(cn.c) cn.startup(o) // reset the deadline, in case one was set (see dial) - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { err = cn.c.SetDeadline(time.Time{}) } + panicking = false return cn, err } @@ -333,7 +372,7 @@ func dial(d Dialer, o values) (net.Conn, error) { } // Zero or not specified means wait indefinitely. - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { seconds, err := strconv.ParseInt(timeout, 10, 0) if err != nil { return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) @@ -355,31 +394,18 @@ func dial(d Dialer, o values) (net.Conn, error) { } func network(o values) (string, string) { - host := o.Get("host") + host := o["host"] if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o.Get("port")) + sockPath := path.Join(host, ".s.PGSQL."+o["port"]) return "unix", sockPath } - return "tcp", net.JoinHostPort(host, o.Get("port")) + return "tcp", net.JoinHostPort(host, o["port"]) } type values map[string]string -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - -func (vs values) Isset(k string) bool { - _, ok := vs[k] - return ok -} - // scanner implements a tokenizer for libpq-style option strings. type scanner struct { s []rune @@ -450,7 +476,7 @@ func parseOpts(name string, o values) error { // Skip any whitespace after the = if r, ok = s.SkipSpaces(); !ok { // If we reach the end here, the last value is just an empty string as per libpq. - o.Set(string(keyRunes), "") + o[string(keyRunes)] = "" break } @@ -485,7 +511,7 @@ func parseOpts(name string, o values) error { } } - o.Set(string(keyRunes), string(valRunes)) + o[string(keyRunes)] = string(valRunes) } return nil @@ -504,13 +530,17 @@ func (cn *conn) checkIsInTransaction(intxn bool) { } func (cn *conn) Begin() (_ driver.Tx, err error) { + return cn.begin("") +} + +func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if cn.bad { return nil, driver.ErrBadConn } defer cn.errRecover(&err) cn.checkIsInTransaction(false) - _, commandTag, err := cn.simpleExec("BEGIN") + _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { return nil, err } @@ -525,7 +555,14 @@ func (cn *conn) Begin() (_ driver.Tx, err error) { return cn, nil } +func (cn *conn) closeTxn() { + if finish := cn.txnFinish; finish != nil { + finish() + } +} + func (cn *conn) Commit() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -561,6 +598,7 @@ func (cn *conn) Commit() (err error) { } func (cn *conn) Rollback() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -598,11 +636,16 @@ func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } // done return case 'E': err = parseError(r) - case 'T', 'D', 'I': + case 'I': + res = emptyRows + case 'T', 'D': // ignore any results default: cn.bad = true @@ -635,6 +678,12 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { cn: cn, } } + // Set the result and tag to the last command complete if there wasn't a + // query already run. Although queries usually return from here and cede + // control to Next, a query with zero results does not. + if t == 'C' && res.colNames == nil { + res.result, res.tag = cn.parseComplete(r.string()) + } res.done = true case 'Z': cn.processReadyForQuery(r) @@ -666,9 +715,23 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { } } +type noRows struct{} + +var emptyRows noRows + +var _ driver.Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { + return 0, errNoLastInsertID +} + +func (noRows) RowsAffected() (int64, error) { + return 0, errNoRowsAffected +} + // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. -func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) { +func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { if len(colTyps) == 0 { return nil, colFmtDataAllText } @@ -680,8 +743,8 @@ func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, c allBinary := true allText := true - for i, o := range colTyps { - switch o { + for i, t := range colTyps { + switch t.OID { // This is the list of types to use binary mode for when receiving them // through a prepared statement. If a type appears in this list, it // must also be implemented in binaryDecode in encode.go. @@ -692,6 +755,8 @@ func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, c case oid.T_int4: fallthrough case oid.T_int2: + fallthrough + case oid.T_uuid: colFmts[i] = formatBinary allText = false @@ -743,32 +808,45 @@ func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { defer cn.errRecover(&err) if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { - return cn.prepareCopyIn(q) + s, err := cn.prepareCopyIn(q) + if err == nil { + cn.inCopy = true + } + return s, err } return cn.prepareTo(q, cn.gname()), nil } func (cn *conn) Close() (err error) { - if cn.bad { - return driver.ErrBadConn - } + // Skip cn.bad return here because we always want to close a connection. defer cn.errRecover(&err) + // Ensure that cn.c.Close is always run. Since error handling is done with + // panics and cn.errRecover, the Close must be in a defer. + defer func() { + cerr := cn.c.Close() + if err == nil { + err = cerr + } + }() + // Don't go through send(); ListenerConn relies on us not scribbling on the // scratch buffer of this connection. - err = cn.sendSimpleMessage('X') - if err != nil { - return err - } - - return cn.c.Close() + return cn.sendSimpleMessage('X') } // Implement the "Queryer" interface -func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) { +func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { + return cn.query(query, args) +} + +func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { if cn.bad { return nil, driver.ErrBadConn } + if cn.inCopy { + return nil, errCopyInProgress + } defer cn.errRecover(&err) // Check to see if we can use the "simpleQuery" interface, which is @@ -786,16 +864,15 @@ func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err err rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() cn.postExecuteWorkaround() return rows, nil - } else { - st := cn.prepareTo(query, "") - st.exec(args) - return &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, - }, nil } + st := cn.prepareTo(query, "") + st.exec(args) + return &rows{ + cn: cn, + colNames: st.colNames, + colTyps: st.colTyps, + colFmts: st.colFmts, + }, nil } // Implement the optional "Execer" interface for one-shot queries @@ -822,17 +899,16 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err cn.postExecuteWorkaround() res, _, err = cn.readExecuteResponse("Execute") return res, err - } else { - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st := cn.prepareTo(query, "") - r, err := st.Exec(args) - if err != nil { - panic(err) - } - return r, err } + // Use the unnamed statement to defer planning until bind + // time, or else value-based selectivity estimates cannot be + // used. + st := cn.prepareTo(query, "") + r, err := st.Exec(args) + if err != nil { + panic(err) + } + return r, err } func (cn *conn) send(m *writeBuf) { @@ -842,16 +918,9 @@ func (cn *conn) send(m *writeBuf) { } } -func (cn *conn) sendStartupPacket(m *writeBuf) { - // sanity check - if m.buf[0] != 0 { - panic("oops") - } - +func (cn *conn) sendStartupPacket(m *writeBuf) error { _, err := cn.c.Write((m.wrap())[1:]) - if err != nil { - panic(err) - } + return err } // Send a message of type typ to the server on the other end of cn. The @@ -964,165 +1033,35 @@ func (cn *conn) recv1() (t byte, r *readBuf) { return t, r } -func (cn *conn) ssl(o values) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o.Get("sslmode"); mode { - // "require" is the default. - case "", "require": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - - // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: - // Note: For backwards compatibility with earlier versions of PostgreSQL, if a - // root CA file exists, the behavior of sslmode=require will be the same as - // that of verify-ca, meaning the server certificate is validated against the - // CA. Relying on this behavior is discouraged, and applications that need - // certificate validation should always use verify-ca or verify-full. - if _, err := os.Stat(o.Get("sslrootcert")); err == nil { - verifyCaOnly = true - } else { - o.Set("sslrootcert", "") - } - case "verify-ca": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - verifyCaOnly = true - case "verify-full": - tlsConf.ServerName = o.Get("host") - case "disable": - return - default: - errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) +func (cn *conn) ssl(o values) error { + upgrade, err := ssl(o) + if err != nil { + return err } - cn.setupSSLClientCertificates(&tlsConf, o) - cn.setupSSLCA(&tlsConf, o) + if upgrade == nil { + // Nothing to do + return nil + } w := cn.writeBuf(0) w.int32(80877103) - cn.sendStartupPacket(w) + if err = cn.sendStartupPacket(w); err != nil { + return err + } b := cn.scratch[:1] - _, err := io.ReadFull(cn.c, b) + _, err = io.ReadFull(cn.c, b) if err != nil { - panic(err) + return err } if b[0] != 'S' { - panic(ErrSSLNotSupported) + return ErrSSLNotSupported } - client := tls.Client(cn.c, &tlsConf) - if verifyCaOnly { - cn.verifyCA(client, &tlsConf) - } - cn.c = client -} - -// verifyCA carries out a TLS handshake to the server and verifies the -// presented certificate against the effective CA, i.e. the one specified in -// sslrootcert or the system CA if sslrootcert was not specified. -func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) { - err := client.Handshake() - if err != nil { - panic(err) - } - certs := client.ConnectionState().PeerCertificates - opts := x509.VerifyOptions{ - DNSName: client.ConnectionState().ServerName, - Intermediates: x509.NewCertPool(), - Roots: tlsConf.RootCAs, - } - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - _, err = certs[0].Verify(opts) - if err != nil { - panic(err) - } -} - -// This function sets up SSL client certificates based on either the "sslkey" -// and "sslcert" settings (possibly set via the environment variables PGSSLKEY -// and PGSSLCERT, respectively), or if they aren't set, from the .postgresql -// directory in the user's home directory. If the file paths are set -// explicitly, the files must exist. The key file must also not be -// world-readable, or this function will panic with -// ErrSSLKeyHasWorldPermissions. -func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) { - var missingOk bool - - sslkey := o.Get("sslkey") - sslcert := o.Get("sslcert") - if sslkey != "" && sslcert != "" { - // If the user has set an sslkey and sslcert, they *must* exist. - missingOk = false - } else { - // Automatically load certificates from ~/.postgresql. - user, err := user.Current() - if err != nil { - // user.Current() might fail when cross-compiling. We have to - // ignore the error and continue without client certificates, since - // we wouldn't know where to load them from. - return - } - - sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") - sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") - missingOk = true - } - - // Check that both files exist, and report the error or stop, depending on - // which behaviour we want. Note that we don't do any more extensive - // checks than this (such as checking that the paths aren't directories); - // LoadX509KeyPair() will take care of the rest. - keyfinfo, err := os.Stat(sslkey) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - _, err = os.Stat(sslcert) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - - // If we got this far, the key file must also have the correct permissions - kmode := keyfinfo.Mode() - if kmode != kmode&0600 { - panic(ErrSSLKeyHasWorldPermissions) - } - - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - panic(err) - } - tlsConf.Certificates = []tls.Certificate{cert} -} - -// Sets up RootCAs in the TLS configuration if sslrootcert is set. -func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) { - if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" { - tlsConf.RootCAs = x509.NewCertPool() - - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - panic(err) - } - - ok := tlsConf.RootCAs.AppendCertsFromPEM(cert) - if !ok { - errorf("couldn't parse pem in sslrootcert") - } - } + cn.c, err = upgrade(cn.c) + return err } // isDriverSetting returns true iff a setting is purely for configuring the @@ -1171,12 +1110,15 @@ func (cn *conn) startup(o values) { w.string(v) } w.string("") - cn.sendStartupPacket(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } for { t, r := cn.recv() switch t { case 'K': + cn.processBackendKeyData(r) case 'S': cn.processParameterStatus(r) case 'R': @@ -1196,7 +1138,7 @@ func (cn *conn) auth(r *readBuf, o values) { // OK case 3: w := cn.writeBuf('p') - w.string(o.Get("password")) + w.string(o["password"]) cn.send(w) t, r := cn.recv() @@ -1210,7 +1152,7 @@ func (cn *conn) auth(r *readBuf, o values) { case 5: s := string(r.next(4)) w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) + w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) cn.send(w) t, r := cn.recv() @@ -1232,10 +1174,10 @@ const formatText format = 0 const formatBinary format = 1 // One result-column format code with the value 1 (i.e. all binary). -var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1} +var colFmtDataAllBinary = []byte{0, 1, 0, 1} // No result-column format codes (i.e. all text). -var colFmtDataAllText []byte = []byte{0, 0} +var colFmtDataAllText = []byte{0, 0} type stmt struct { cn *conn @@ -1243,7 +1185,7 @@ type stmt struct { colNames []string colFmts []format colFmtData []byte - colTyps []oid.Oid + colTyps []fieldDesc paramTyps []oid.Oid closed bool } @@ -1404,21 +1346,32 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { type rows struct { cn *conn + finish func() colNames []string - colTyps []oid.Oid + colTyps []fieldDesc colFmts []format done bool rb readBuf + result driver.Result + tag string } func (rs *rows) Close() error { + if finish := rs.finish; finish != nil { + defer finish() + } // no need to look at cn.bad as Next() will for { err := rs.Next(nil) switch err { case nil: case io.EOF: - return nil + // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row + // description, used with HasNextResultSet). We need to fetch messages until + // we hit a 'Z', which is done by waiting for done to be set. + if rs.done { + return nil + } default: return err } @@ -1429,6 +1382,17 @@ func (rs *rows) Columns() []string { return rs.colNames } +func (rs *rows) Result() driver.Result { + if rs.result == nil { + return emptyRows + } + return rs.result +} + +func (rs *rows) Tag() string { + return rs.tag +} + func (rs *rows) Next(dest []driver.Value) (err error) { if rs.done { return io.EOF @@ -1446,6 +1410,9 @@ func (rs *rows) Next(dest []driver.Value) (err error) { case 'E': err = parseError(&rs.rb) case 'C', 'I': + if t == 'C' { + rs.result, rs.tag = conn.parseComplete(rs.rb.string()) + } continue case 'Z': conn.processReadyForQuery(&rs.rb) @@ -1469,21 +1436,33 @@ func (rs *rows) Next(dest []driver.Value) (err error) { dest[i] = nil continue } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i]) + dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) } return + case 'T': + rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb) + return io.EOF default: errorf("unexpected message after execute: %q", t) } } } +func (rs *rows) HasNextResultSet() bool { + return !rs.done +} + +func (rs *rows) NextResultSet() error { + return nil +} + // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be // used as part of an SQL statement. For example: // // tblname := "my_table" // data := "my_data" -// err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data) +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) // // Any double quotes in name will be escaped. The quoted identifier will be // case sensitive when used in a query. If the input string contains a zero @@ -1564,7 +1543,7 @@ func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { cn.send(b) } -func (c *conn) processParameterStatus(r *readBuf) { +func (cn *conn) processParameterStatus(r *readBuf) { var err error param := r.string() @@ -1575,13 +1554,13 @@ func (c *conn) processParameterStatus(r *readBuf) { var minor int _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) if err == nil { - c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor + cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor } case "TimeZone": - c.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) + cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) if err != nil { - c.parameterStatus.currentLocation = nil + cn.parameterStatus.currentLocation = nil } default: @@ -1589,8 +1568,8 @@ func (c *conn) processParameterStatus(r *readBuf) { } } -func (c *conn) processReadyForQuery(r *readBuf) { - c.txnStatus = transactionStatus(r.byte()) +func (cn *conn) processReadyForQuery(r *readBuf) { + cn.txnStatus = transactionStatus(r.byte()) } func (cn *conn) readReadyForQuery() { @@ -1605,6 +1584,11 @@ func (cn *conn) readReadyForQuery() { } } +func (cn *conn) processBackendKeyData(r *readBuf) { + cn.processID = r.int32() + cn.secretKey = r.int32() +} + func (cn *conn) readParseResponse() { t, r := cn.recv1() switch t { @@ -1620,7 +1604,7 @@ func (cn *conn) readParseResponse() { } } -func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) { +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { for { t, r := cn.recv1() switch t { @@ -1646,7 +1630,7 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ } } -func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) { +func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) { t, r := cn.recv1() switch t { case 'T': @@ -1720,6 +1704,9 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } return res, commandTag, err case 'E': err = parseError(r) @@ -1728,6 +1715,9 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co cn.bad = true errorf("unexpected %q after error %s", t, err) } + if t == 'I' { + res = emptyRows + } // ignore any results default: cn.bad = true @@ -1736,31 +1726,33 @@ func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, co } } -func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) { +func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) - colTyps = make([]oid.Oid, n) + colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) - colTyps[i] = r.oid() - r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() // format code not known when describing a statement; always 0 r.next(2) } return } -func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) { +func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) colFmts = make([]format, n) - colTyps = make([]oid.Oid, n) + colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) - colTyps[i] = r.oid() - r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() colFmts[i] = format(r.int16()) } return diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go new file mode 100644 index 00000000..81c9ee47 --- /dev/null +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -0,0 +1,129 @@ +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "io/ioutil" +) + +// Implement the "QueryerContext" interface +func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + finish := cn.watchCancel(ctx) + r, err := cn.query(query, list) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil +} + +// Implement the "ExecerContext" interface +func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + + return cn.Exec(query, list) +} + +// Implement the "ConnBeginTx" interface +func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var mode string + + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + // Don't touch mode: use the server's default + case sql.LevelReadUncommitted: + mode = " ISOLATION LEVEL READ UNCOMMITTED" + case sql.LevelReadCommitted: + mode = " ISOLATION LEVEL READ COMMITTED" + case sql.LevelRepeatableRead: + mode = " ISOLATION LEVEL REPEATABLE READ" + case sql.LevelSerializable: + mode = " ISOLATION LEVEL SERIALIZABLE" + default: + return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) + } + + if opts.ReadOnly { + mode += " READ ONLY" + } else { + mode += " READ WRITE" + } + + tx, err := cn.begin(mode) + if err != nil { + return nil, err + } + cn.txnFinish = cn.watchCancel(ctx) + return tx, nil +} + +func (cn *conn) watchCancel(ctx context.Context) func() { + if done := ctx.Done(); done != nil { + finished := make(chan struct{}) + go func() { + select { + case <-done: + _ = cn.cancel() + finished <- struct{}{} + case <-finished: + } + }() + return func() { + select { + case <-finished: + case finished <- struct{}{}: + } + } + } + return nil +} + +func (cn *conn) cancel() error { + c, err := dial(cn.dialer, cn.opts) + if err != nil { + return err + } + defer c.Close() + + { + can := conn{ + c: c, + } + err = can.ssl(cn.opts) + if err != nil { + return err + } + + w := can.writeBuf(0) + w.int32(80877102) // cancel request code + w.int32(cn.processID) + w.int32(cn.secretKey) + + if err := can.sendStartupPacket(w); err != nil { + return err + } + } + + // Read until EOF to ensure that the server received the cancel. + { + _, err := io.Copy(ioutil.Discard, c) + return err + } +} diff --git a/vendor/github.com/lib/pq/connector.go b/vendor/github.com/lib/pq/connector.go new file mode 100644 index 00000000..9e66eb5d --- /dev/null +++ b/vendor/github.com/lib/pq/connector.go @@ -0,0 +1,43 @@ +// +build go1.10 + +package pq + +import ( + "context" + "database/sql/driver" +) + +// Connector represents a fixed configuration for the pq driver with a given +// name. Connector satisfies the database/sql/driver Connector interface and +// can be used to create any number of DB Conn's via the database/sql OpenDB +// function. +// +// See https://golang.org/pkg/database/sql/driver/#Connector. +// See https://golang.org/pkg/database/sql/#OpenDB. +type connector struct { + name string +} + +// Connect returns a connection to the database using the fixed configuration +// of this Connector. Context is not used. +func (c *connector) Connect(_ context.Context) (driver.Conn, error) { + return (&Driver{}).Open(c.name) +} + +// Driver returnst the underlying driver of this Connector. +func (c *connector) Driver() driver.Driver { + return &Driver{} +} + +var _ driver.Connector = &connector{} + +// NewConnector returns a connector for the pq driver in a fixed configuration +// with the given name. The returned connector can be used to create any number +// of equivalent Conn's. The returned connector is intended to be used with +// database/sql.OpenDB. +// +// See https://golang.org/pkg/database/sql/driver/#Connector. +// See https://golang.org/pkg/database/sql/#OpenDB. +func NewConnector(name string) (driver.Connector, error) { + return &connector{name: name}, nil +} diff --git a/vendor/github.com/lib/pq/copy.go b/vendor/github.com/lib/pq/copy.go index 101f1113..345c2398 100644 --- a/vendor/github.com/lib/pq/copy.go +++ b/vendor/github.com/lib/pq/copy.go @@ -13,6 +13,7 @@ var ( errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") errCopyToNotSupported = errors.New("pq: COPY TO is not supported") errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") + errCopyInProgress = errors.New("pq: COPY in progress") ) // CopyIn creates a COPY FROM statement which can be prepared with @@ -96,13 +97,13 @@ awaitCopyInResponse: err = parseError(r) case 'Z': if err == nil { - cn.bad = true + ci.setBad() errorf("unexpected ReadyForQuery in response to COPY") } cn.processReadyForQuery(r) return nil, err default: - cn.bad = true + ci.setBad() errorf("unknown response for copy query: %q", t) } } @@ -121,7 +122,7 @@ awaitCopyInResponse: cn.processReadyForQuery(r) return nil, err default: - cn.bad = true + ci.setBad() errorf("unknown response for CopyFail: %q", t) } } @@ -142,7 +143,7 @@ func (ci *copyin) resploop() { var r readBuf t, err := ci.cn.recvMessage(&r) if err != nil { - ci.cn.bad = true + ci.setBad() ci.setError(err) ci.done <- true return @@ -160,7 +161,7 @@ func (ci *copyin) resploop() { err := parseError(&r) ci.setError(err) default: - ci.cn.bad = true + ci.setBad() ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) ci.done <- true return @@ -168,6 +169,19 @@ func (ci *copyin) resploop() { } } +func (ci *copyin) setBad() { + ci.Lock() + ci.cn.bad = true + ci.Unlock() +} + +func (ci *copyin) isBad() bool { + ci.Lock() + b := ci.cn.bad + ci.Unlock() + return b +} + func (ci *copyin) isErrorSet() bool { ci.Lock() isSet := (ci.err != nil) @@ -205,7 +219,7 @@ func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) { return nil, errCopyInClosed } - if ci.cn.bad { + if ci.isBad() { return nil, driver.ErrBadConn } defer ci.cn.errRecover(&err) @@ -243,7 +257,7 @@ func (ci *copyin) Close() (err error) { } ci.closed = true - if ci.cn.bad { + if ci.isBad() { return driver.ErrBadConn } defer ci.cn.errRecover(&err) @@ -258,6 +272,7 @@ func (ci *copyin) Close() (err error) { } <-ci.done + ci.cn.inCopy = false if ci.isErrorSet() { err = ci.err diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go index 19798dfc..2a60054e 100644 --- a/vendor/github.com/lib/pq/doc.go +++ b/vendor/github.com/lib/pq/doc.go @@ -11,7 +11,8 @@ using this package directly. For example: ) func main() { - db, err := sql.Open("postgres", "user=pqgotest dbname=pqgotest sslmode=verify-full") + connStr := "user=pqgotest dbname=pqgotest sslmode=verify-full" + db, err := sql.Open("postgres", connStr) if err != nil { log.Fatal(err) } @@ -23,7 +24,8 @@ using this package directly. For example: You can also connect to a database using a URL. For example: - db, err := sql.Open("postgres", "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full") + connStr := "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full" + db, err := sql.Open("postgres", connStr) Connection String Parameters @@ -43,21 +45,28 @@ supported: * dbname - The name of the database to connect to * user - The user to sign in as * password - The user's password - * host - The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) + * host - The host to connect to. Values that start with / are for unix + domain sockets. (default is localhost) * port - The port to bind to. (default is 5432) - * sslmode - Whether or not to use SSL (default is require, this is not the default for libpq) + * sslmode - Whether or not to use SSL (default is require, this is not + the default for libpq) * fallback_application_name - An application_name to fall back to if one isn't provided. - * connect_timeout - Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. + * connect_timeout - Maximum wait for connection, in seconds. Zero or + not specified means wait indefinitely. * sslcert - Cert file location. The file must contain PEM encoded data. * sslkey - Key file location. The file must contain PEM encoded data. - * sslrootcert - The location of the root certificate file. The file must contain PEM encoded data. + * sslrootcert - The location of the root certificate file. The file + must contain PEM encoded data. Valid values for sslmode are: * disable - No SSL * require - Always SSL (skip verification) - * verify-ca - Always SSL (verify that the certificate presented by the server was signed by a trusted CA) - * verify-full - Always SSL (verify that the certification presented by the server was signed by a trusted CA and the server host name matches the one in the certificate) + * verify-ca - Always SSL (verify that the certificate presented by the + server was signed by a trusted CA) + * verify-full - Always SSL (verify that the certification presented by + the server was signed by a trusted CA and the server host name + matches the one in the certificate) See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING for more information about connection string parameters. @@ -68,7 +77,7 @@ Use single quotes for values that contain whitespace: A backslash will escape the next character in values: - "user=space\ man password='it\'s valid' + "user=space\ man password='it\'s valid'" Note that the connection parameter client_encoding (which sets the text encoding for the connection) may be set but must be "UTF8", @@ -89,8 +98,10 @@ provided connection parameters. The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html is supported, but on Windows PGPASSFILE must be specified explicitly. + Queries + database/sql does not dictate any specific format for parameter markers in query strings, and pq uses the Postgres-native ordinal markers, as shown above. The same marker can be reused for the same parameter: @@ -114,8 +125,30 @@ For more details on RETURNING, see the Postgres documentation: For additional instructions on querying see the documentation for the database/sql package. + +Data Types + + +Parameters pass through driver.DefaultParameterConverter before they are handled +by this package. When the binary_parameters connection option is enabled, +[]byte values are sent directly to the backend as data in binary format. + +This package returns the following types for values from the PostgreSQL backend: + + - integer types smallint, integer, and bigint are returned as int64 + - floating-point types real and double precision are returned as float64 + - character types char, varchar, and text are returned as string + - temporal types date, time, timetz, timestamp, and timestamptz are + returned as time.Time + - the boolean type is returned as bool + - the bytea type is returned as []byte + +All other types are returned directly from the backend as []byte values in text format. + + Errors + pq may return errors of type *pq.Error which can be interrogated for error details: if err, ok := err.(*pq.Error); ok { @@ -206,7 +239,7 @@ for more information). Note that the channel name will be truncated to 63 bytes by the PostgreSQL server. You can find a complete, working example of Listener usage at -http://godoc.org/github.com/lib/pq/listen_example. +https://godoc.org/github.com/lib/pq/example/listen. */ package pq diff --git a/vendor/github.com/lib/pq/encode.go b/vendor/github.com/lib/pq/encode.go index 29e8f6ff..3b0d365f 100644 --- a/vendor/github.com/lib/pq/encode.go +++ b/vendor/github.com/lib/pq/encode.go @@ -76,6 +76,12 @@ func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) inter return int64(int32(binary.BigEndian.Uint32(s))) case oid.T_int2: return int64(int16(binary.BigEndian.Uint16(s))) + case oid.T_uuid: + b, err := decodeUUIDBinary(s) + if err != nil { + panic(err) + } + return b default: errorf("don't know how to decode binary parameter of type %d", uint32(typ)) @@ -361,8 +367,15 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro timeSep := daySep + 3 day := p.mustAtoi(str, daySep+1, timeSep) + minLen := monSep + len("01-01") + 1 + + isBC := strings.HasSuffix(str, " BC") + if isBC { + minLen += 3 + } + var hour, minute, second int - if len(str) > monSep+len("01-01")+1 { + if len(str) > minLen { p.expect(str, ' ', timeSep) minSep := timeSep + 3 p.expect(str, ':', minSep) @@ -418,7 +431,8 @@ func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, erro tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) } var isoYear int - if remainderIdx+3 <= len(str) && str[remainderIdx:remainderIdx+3] == " BC" { + + if isBC { isoYear = 1 - year remainderIdx += 3 } else { @@ -471,7 +485,7 @@ func FormatTimestamp(t time.Time) []byte { t = t.AddDate((-t.Year())*2+1, 0, 0) bc = true } - b := []byte(t.Format(time.RFC3339Nano)) + b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) _, offset := t.Zone() offset = offset % 60 diff --git a/vendor/github.com/lib/pq/error.go b/vendor/github.com/lib/pq/error.go index b4bb44ce..96aae29c 100644 --- a/vendor/github.com/lib/pq/error.go +++ b/vendor/github.com/lib/pq/error.go @@ -153,6 +153,7 @@ var errorCodeNames = map[ErrorCode]string{ "22004": "null_value_not_allowed", "22002": "null_value_no_indicator_parameter", "22003": "numeric_value_out_of_range", + "2200H": "sequence_generator_limit_exceeded", "22026": "string_data_length_mismatch", "22001": "string_data_right_truncation", "22011": "substring_error", @@ -459,6 +460,11 @@ func errorf(s string, args ...interface{}) { panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) } +// TODO(ainar-g) Rename to errorf after removing panics. +func fmterrorf(s string, args ...interface{}) error { + return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)) +} + func errRecoverNoErrBadConn(err *error) { e := recover() if e == nil { @@ -487,7 +493,8 @@ func (c *conn) errRecover(err *error) { *err = v } case *net.OpError: - *err = driver.ErrBadConn + c.bad = true + *err = v case error: if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { *err = driver.ErrBadConn diff --git a/vendor/github.com/lib/pq/notify.go b/vendor/github.com/lib/pq/notify.go index 09f94244..850bb904 100644 --- a/vendor/github.com/lib/pq/notify.go +++ b/vendor/github.com/lib/pq/notify.go @@ -60,7 +60,7 @@ type ListenerConn struct { replyChan chan message } -// Creates a new ListenerConn. Use NewListener instead. +// NewListenerConn creates a new ListenerConn. Use NewListener instead. func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) { return newDialListenerConn(defaultDialer{}, name, notificationChan) } @@ -214,17 +214,17 @@ func (l *ListenerConn) listenerConnMain() { // this ListenerConn is done } -// Send a LISTEN query to the server. See ExecSimpleQuery. +// Listen sends a LISTEN query to the server. See ExecSimpleQuery. func (l *ListenerConn) Listen(channel string) (bool, error) { return l.ExecSimpleQuery("LISTEN " + QuoteIdentifier(channel)) } -// Send an UNLISTEN query to the server. See ExecSimpleQuery. +// Unlisten sends an UNLISTEN query to the server. See ExecSimpleQuery. func (l *ListenerConn) Unlisten(channel string) (bool, error) { return l.ExecSimpleQuery("UNLISTEN " + QuoteIdentifier(channel)) } -// Send `UNLISTEN *` to the server. See ExecSimpleQuery. +// UnlistenAll sends an `UNLISTEN *` query to the server. See ExecSimpleQuery. func (l *ListenerConn) UnlistenAll() (bool, error) { return l.ExecSimpleQuery("UNLISTEN *") } @@ -267,8 +267,8 @@ func (l *ListenerConn) sendSimpleQuery(q string) (err error) { return nil } -// Execute a "simple query" (i.e. one with no bindable parameters) on the -// connection. The possible return values are: +// ExecSimpleQuery executes a "simple query" (i.e. one with no bindable +// parameters) on the connection. The possible return values are: // 1) "executed" is true; the query was executed to completion on the // database server. If the query failed, err will be set to the error // returned by the database, otherwise err will be nil. @@ -333,6 +333,7 @@ func (l *ListenerConn) ExecSimpleQuery(q string) (executed bool, err error) { } } +// Close closes the connection. func (l *ListenerConn) Close() error { l.connectionLock.Lock() if l.err != nil { @@ -346,7 +347,7 @@ func (l *ListenerConn) Close() error { return l.cn.c.Close() } -// Err() returns the reason the connection was closed. It is not safe to call +// Err returns the reason the connection was closed. It is not safe to call // this function until l.Notify has been closed. func (l *ListenerConn) Err() error { return l.err @@ -354,32 +355,43 @@ func (l *ListenerConn) Err() error { var errListenerClosed = errors.New("pq: Listener has been closed") +// ErrChannelAlreadyOpen is returned from Listen when a channel is already +// open. var ErrChannelAlreadyOpen = errors.New("pq: channel is already open") + +// ErrChannelNotOpen is returned from Unlisten when a channel is not open. var ErrChannelNotOpen = errors.New("pq: channel is not open") +// ListenerEventType is an enumeration of listener event types. type ListenerEventType int const ( - // Emitted only when the database connection has been initially - // initialized. err will always be nil. + // ListenerEventConnected is emitted only when the database connection + // has been initially initialized. The err argument of the callback + // will always be nil. ListenerEventConnected ListenerEventType = iota - // Emitted after a database connection has been lost, either because of an - // error or because Close has been called. err will be set to the reason - // the database connection was lost. + // ListenerEventDisconnected is emitted after a database connection has + // been lost, either because of an error or because Close has been + // called. The err argument will be set to the reason the database + // connection was lost. ListenerEventDisconnected - // Emitted after a database connection has been re-established after - // connection loss. err will always be nil. After this event has been - // emitted, a nil pq.Notification is sent on the Listener.Notify channel. + // ListenerEventReconnected is emitted after a database connection has + // been re-established after connection loss. The err argument of the + // callback will always be nil. After this event has been emitted, a + // nil pq.Notification is sent on the Listener.Notify channel. ListenerEventReconnected - // Emitted after a connection to the database was attempted, but failed. - // err will be set to an error describing why the connection attempt did - // not succeed. + // ListenerEventConnectionAttemptFailed is emitted after a connection + // to the database was attempted, but failed. The err argument will be + // set to an error describing why the connection attempt did not + // succeed. ListenerEventConnectionAttemptFailed ) +// EventCallbackType is the event callback type. See also ListenerEventType +// constants' documentation. type EventCallbackType func(event ListenerEventType, err error) // Listener provides an interface for listening to notifications from a @@ -454,9 +466,9 @@ func NewDialListener(d Dialer, return l } -// Returns the notification channel for this listener. This is the same -// channel as Notify, and will not be recreated during the life time of the -// Listener. +// NotificationChannel returns the notification channel for this listener. +// This is the same channel as Notify, and will not be recreated during the +// life time of the Listener. func (l *Listener) NotificationChannel() <-chan *Notification { return l.Notify } @@ -625,7 +637,7 @@ func (l *Listener) disconnectCleanup() error { // after the connection has been established. func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notification) error { doneChan := make(chan error) - go func() { + go func(notificationChan <-chan *Notification) { for channel := range l.channels { // If we got a response, return that error to our caller as it's // going to be more descriptive than cn.Err(). @@ -639,14 +651,14 @@ func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notificatio // close and then return the error message from the connection, as // per ListenerConn's interface. if err != nil { - for _ = range notificationChan { + for range notificationChan { } doneChan <- cn.Err() return } } doneChan <- nil - }() + }(notificationChan) // Ignore notifications while synchronization is going on to avoid // deadlocks. We have to send a nil notification over Notify anyway as @@ -713,6 +725,9 @@ func (l *Listener) Close() error { } l.isClosed = true + // Unblock calls to Listen() + l.reconnectCond.Broadcast() + return nil } @@ -772,7 +787,7 @@ func (l *Listener) listenerConnLoop() { } l.emitEvent(ListenerEventDisconnected, err) - time.Sleep(nextReconnect.Sub(time.Now())) + time.Sleep(time.Until(nextReconnect)) } } diff --git a/vendor/github.com/lib/pq/oid/gen.go b/vendor/github.com/lib/pq/oid/gen.go index cd4aea80..7c634cdc 100644 --- a/vendor/github.com/lib/pq/oid/gen.go +++ b/vendor/github.com/lib/pq/oid/gen.go @@ -10,10 +10,22 @@ import ( "log" "os" "os/exec" + "strings" _ "github.com/lib/pq" ) +// OID represent a postgres Object Identifier Type. +type OID struct { + ID int + Type string +} + +// Name returns an upper case version of the oid type. +func (o OID) Name() string { + return strings.ToUpper(o.Type) +} + func main() { datname := os.Getenv("PGDATABASE") sslmode := os.Getenv("PGSSLMODE") @@ -30,6 +42,25 @@ func main() { if err != nil { log.Fatal(err) } + rows, err := db.Query(` + SELECT typname, oid + FROM pg_type WHERE oid < 10000 + ORDER BY oid; + `) + if err != nil { + log.Fatal(err) + } + oids := make([]*OID, 0) + for rows.Next() { + var oid OID + if err = rows.Scan(&oid.Type, &oid.ID); err != nil { + log.Fatal(err) + } + oids = append(oids, &oid) + } + if err = rows.Err(); err != nil { + log.Fatal(err) + } cmd := exec.Command("gofmt") cmd.Stderr = os.Stderr w, err := cmd.StdinPipe() @@ -45,30 +76,18 @@ func main() { if err != nil { log.Fatal(err) } - fmt.Fprintln(w, "// generated by 'go run gen.go'; do not edit") + fmt.Fprintln(w, "// Code generated by gen.go. DO NOT EDIT.") fmt.Fprintln(w, "\npackage oid") fmt.Fprintln(w, "const (") - rows, err := db.Query(` - SELECT typname, oid - FROM pg_type WHERE oid < 10000 - ORDER BY oid; - `) - if err != nil { - log.Fatal(err) - } - var name string - var oid int - for rows.Next() { - err = rows.Scan(&name, &oid) - if err != nil { - log.Fatal(err) - } - fmt.Fprintf(w, "T_%s Oid = %d\n", name, oid) - } - if err = rows.Err(); err != nil { - log.Fatal(err) + for _, oid := range oids { + fmt.Fprintf(w, "T_%s Oid = %d\n", oid.Type, oid.ID) } fmt.Fprintln(w, ")") + fmt.Fprintln(w, "var TypeName = map[Oid]string{") + for _, oid := range oids { + fmt.Fprintf(w, "T_%s: \"%s\",\n", oid.Type, oid.Name()) + } + fmt.Fprintln(w, "}") w.Close() cmd.Wait() } diff --git a/vendor/github.com/lib/pq/oid/types.go b/vendor/github.com/lib/pq/oid/types.go index 03df05a6..ecc84c2c 100644 --- a/vendor/github.com/lib/pq/oid/types.go +++ b/vendor/github.com/lib/pq/oid/types.go @@ -1,4 +1,4 @@ -// generated by 'go run gen.go'; do not edit +// Code generated by gen.go. DO NOT EDIT. package oid @@ -18,6 +18,7 @@ const ( T_xid Oid = 28 T_cid Oid = 29 T_oidvector Oid = 30 + T_pg_ddl_command Oid = 32 T_pg_type Oid = 71 T_pg_attribute Oid = 75 T_pg_proc Oid = 81 @@ -28,6 +29,7 @@ const ( T_pg_node_tree Oid = 194 T__json Oid = 199 T_smgr Oid = 210 + T_index_am_handler Oid = 325 T_point Oid = 600 T_lseg Oid = 601 T_path Oid = 602 @@ -133,6 +135,9 @@ const ( T__uuid Oid = 2951 T_txid_snapshot Oid = 2970 T_fdw_handler Oid = 3115 + T_pg_lsn Oid = 3220 + T__pg_lsn Oid = 3221 + T_tsm_handler Oid = 3310 T_anyenum Oid = 3500 T_tsvector Oid = 3614 T_tsquery Oid = 3615 @@ -144,6 +149,8 @@ const ( T__regconfig Oid = 3735 T_regdictionary Oid = 3769 T__regdictionary Oid = 3770 + T_jsonb Oid = 3802 + T__jsonb Oid = 3807 T_anyrange Oid = 3831 T_event_trigger Oid = 3838 T_int4range Oid = 3904 @@ -158,4 +165,179 @@ const ( T__daterange Oid = 3913 T_int8range Oid = 3926 T__int8range Oid = 3927 + T_pg_shseclabel Oid = 4066 + T_regnamespace Oid = 4089 + T__regnamespace Oid = 4090 + T_regrole Oid = 4096 + T__regrole Oid = 4097 ) + +var TypeName = map[Oid]string{ + T_bool: "BOOL", + T_bytea: "BYTEA", + T_char: "CHAR", + T_name: "NAME", + T_int8: "INT8", + T_int2: "INT2", + T_int2vector: "INT2VECTOR", + T_int4: "INT4", + T_regproc: "REGPROC", + T_text: "TEXT", + T_oid: "OID", + T_tid: "TID", + T_xid: "XID", + T_cid: "CID", + T_oidvector: "OIDVECTOR", + T_pg_ddl_command: "PG_DDL_COMMAND", + T_pg_type: "PG_TYPE", + T_pg_attribute: "PG_ATTRIBUTE", + T_pg_proc: "PG_PROC", + T_pg_class: "PG_CLASS", + T_json: "JSON", + T_xml: "XML", + T__xml: "_XML", + T_pg_node_tree: "PG_NODE_TREE", + T__json: "_JSON", + T_smgr: "SMGR", + T_index_am_handler: "INDEX_AM_HANDLER", + T_point: "POINT", + T_lseg: "LSEG", + T_path: "PATH", + T_box: "BOX", + T_polygon: "POLYGON", + T_line: "LINE", + T__line: "_LINE", + T_cidr: "CIDR", + T__cidr: "_CIDR", + T_float4: "FLOAT4", + T_float8: "FLOAT8", + T_abstime: "ABSTIME", + T_reltime: "RELTIME", + T_tinterval: "TINTERVAL", + T_unknown: "UNKNOWN", + T_circle: "CIRCLE", + T__circle: "_CIRCLE", + T_money: "MONEY", + T__money: "_MONEY", + T_macaddr: "MACADDR", + T_inet: "INET", + T__bool: "_BOOL", + T__bytea: "_BYTEA", + T__char: "_CHAR", + T__name: "_NAME", + T__int2: "_INT2", + T__int2vector: "_INT2VECTOR", + T__int4: "_INT4", + T__regproc: "_REGPROC", + T__text: "_TEXT", + T__tid: "_TID", + T__xid: "_XID", + T__cid: "_CID", + T__oidvector: "_OIDVECTOR", + T__bpchar: "_BPCHAR", + T__varchar: "_VARCHAR", + T__int8: "_INT8", + T__point: "_POINT", + T__lseg: "_LSEG", + T__path: "_PATH", + T__box: "_BOX", + T__float4: "_FLOAT4", + T__float8: "_FLOAT8", + T__abstime: "_ABSTIME", + T__reltime: "_RELTIME", + T__tinterval: "_TINTERVAL", + T__polygon: "_POLYGON", + T__oid: "_OID", + T_aclitem: "ACLITEM", + T__aclitem: "_ACLITEM", + T__macaddr: "_MACADDR", + T__inet: "_INET", + T_bpchar: "BPCHAR", + T_varchar: "VARCHAR", + T_date: "DATE", + T_time: "TIME", + T_timestamp: "TIMESTAMP", + T__timestamp: "_TIMESTAMP", + T__date: "_DATE", + T__time: "_TIME", + T_timestamptz: "TIMESTAMPTZ", + T__timestamptz: "_TIMESTAMPTZ", + T_interval: "INTERVAL", + T__interval: "_INTERVAL", + T__numeric: "_NUMERIC", + T_pg_database: "PG_DATABASE", + T__cstring: "_CSTRING", + T_timetz: "TIMETZ", + T__timetz: "_TIMETZ", + T_bit: "BIT", + T__bit: "_BIT", + T_varbit: "VARBIT", + T__varbit: "_VARBIT", + T_numeric: "NUMERIC", + T_refcursor: "REFCURSOR", + T__refcursor: "_REFCURSOR", + T_regprocedure: "REGPROCEDURE", + T_regoper: "REGOPER", + T_regoperator: "REGOPERATOR", + T_regclass: "REGCLASS", + T_regtype: "REGTYPE", + T__regprocedure: "_REGPROCEDURE", + T__regoper: "_REGOPER", + T__regoperator: "_REGOPERATOR", + T__regclass: "_REGCLASS", + T__regtype: "_REGTYPE", + T_record: "RECORD", + T_cstring: "CSTRING", + T_any: "ANY", + T_anyarray: "ANYARRAY", + T_void: "VOID", + T_trigger: "TRIGGER", + T_language_handler: "LANGUAGE_HANDLER", + T_internal: "INTERNAL", + T_opaque: "OPAQUE", + T_anyelement: "ANYELEMENT", + T__record: "_RECORD", + T_anynonarray: "ANYNONARRAY", + T_pg_authid: "PG_AUTHID", + T_pg_auth_members: "PG_AUTH_MEMBERS", + T__txid_snapshot: "_TXID_SNAPSHOT", + T_uuid: "UUID", + T__uuid: "_UUID", + T_txid_snapshot: "TXID_SNAPSHOT", + T_fdw_handler: "FDW_HANDLER", + T_pg_lsn: "PG_LSN", + T__pg_lsn: "_PG_LSN", + T_tsm_handler: "TSM_HANDLER", + T_anyenum: "ANYENUM", + T_tsvector: "TSVECTOR", + T_tsquery: "TSQUERY", + T_gtsvector: "GTSVECTOR", + T__tsvector: "_TSVECTOR", + T__gtsvector: "_GTSVECTOR", + T__tsquery: "_TSQUERY", + T_regconfig: "REGCONFIG", + T__regconfig: "_REGCONFIG", + T_regdictionary: "REGDICTIONARY", + T__regdictionary: "_REGDICTIONARY", + T_jsonb: "JSONB", + T__jsonb: "_JSONB", + T_anyrange: "ANYRANGE", + T_event_trigger: "EVENT_TRIGGER", + T_int4range: "INT4RANGE", + T__int4range: "_INT4RANGE", + T_numrange: "NUMRANGE", + T__numrange: "_NUMRANGE", + T_tsrange: "TSRANGE", + T__tsrange: "_TSRANGE", + T_tstzrange: "TSTZRANGE", + T__tstzrange: "_TSTZRANGE", + T_daterange: "DATERANGE", + T__daterange: "_DATERANGE", + T_int8range: "INT8RANGE", + T__int8range: "_INT8RANGE", + T_pg_shseclabel: "PG_SHSECLABEL", + T_regnamespace: "REGNAMESPACE", + T__regnamespace: "_REGNAMESPACE", + T_regrole: "REGROLE", + T__regrole: "_REGROLE", +} diff --git a/vendor/github.com/lib/pq/rows.go b/vendor/github.com/lib/pq/rows.go new file mode 100644 index 00000000..c6aa5b9a --- /dev/null +++ b/vendor/github.com/lib/pq/rows.go @@ -0,0 +1,93 @@ +package pq + +import ( + "math" + "reflect" + "time" + + "github.com/lib/pq/oid" +) + +const headerSize = 4 + +type fieldDesc struct { + // The object ID of the data type. + OID oid.Oid + // The data type size (see pg_type.typlen). + // Note that negative values denote variable-width types. + Len int + // The type modifier (see pg_attribute.atttypmod). + // The meaning of the modifier is type-specific. + Mod int +} + +func (fd fieldDesc) Type() reflect.Type { + switch fd.OID { + case oid.T_int8: + return reflect.TypeOf(int64(0)) + case oid.T_int4: + return reflect.TypeOf(int32(0)) + case oid.T_int2: + return reflect.TypeOf(int16(0)) + case oid.T_varchar, oid.T_text: + return reflect.TypeOf("") + case oid.T_bool: + return reflect.TypeOf(false) + case oid.T_date, oid.T_time, oid.T_timetz, oid.T_timestamp, oid.T_timestamptz: + return reflect.TypeOf(time.Time{}) + case oid.T_bytea: + return reflect.TypeOf([]byte(nil)) + default: + return reflect.TypeOf(new(interface{})).Elem() + } +} + +func (fd fieldDesc) Name() string { + return oid.TypeName[fd.OID] +} + +func (fd fieldDesc) Length() (length int64, ok bool) { + switch fd.OID { + case oid.T_text, oid.T_bytea: + return math.MaxInt64, true + case oid.T_varchar, oid.T_bpchar: + return int64(fd.Mod - headerSize), true + default: + return 0, false + } +} + +func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { + switch fd.OID { + case oid.T_numeric, oid.T__numeric: + mod := fd.Mod - headerSize + precision = int64((mod >> 16) & 0xffff) + scale = int64(mod & 0xffff) + return precision, scale, true + default: + return 0, 0, false + } +} + +// ColumnTypeScanType returns the value type that can be used to scan types into. +func (rs *rows) ColumnTypeScanType(index int) reflect.Type { + return rs.colTyps[index].Type() +} + +// ColumnTypeDatabaseTypeName return the database system type name. +func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { + return rs.colTyps[index].Name() +} + +// ColumnTypeLength returns the length of the column type if the column is a +// variable length type. If the column is not a variable length type ok +// should return false. +func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { + return rs.colTyps[index].Length() +} + +// ColumnTypePrecisionScale should return the precision and scale for decimal +// types. If not applicable, ok should be false. +func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return rs.colTyps[index].PrecisionScale() +} diff --git a/vendor/github.com/lib/pq/ssl.go b/vendor/github.com/lib/pq/ssl.go new file mode 100644 index 00000000..d9020845 --- /dev/null +++ b/vendor/github.com/lib/pq/ssl.go @@ -0,0 +1,175 @@ +package pq + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "os" + "os/user" + "path/filepath" +) + +// ssl generates a function to upgrade a net.Conn based on the "sslmode" and +// related settings. The function is nil when no upgrade should take place. +func ssl(o values) (func(net.Conn) (net.Conn, error), error) { + verifyCaOnly := false + tlsConf := tls.Config{} + switch mode := o["sslmode"]; mode { + // "require" is the default. + case "", "require": + // We must skip TLS's own verification since it requires full + // verification since Go 1.3. + tlsConf.InsecureSkipVerify = true + + // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: + // + // Note: For backwards compatibility with earlier versions of + // PostgreSQL, if a root CA file exists, the behavior of + // sslmode=require will be the same as that of verify-ca, meaning the + // server certificate is validated against the CA. Relying on this + // behavior is discouraged, and applications that need certificate + // validation should always use verify-ca or verify-full. + if sslrootcert, ok := o["sslrootcert"]; ok { + if _, err := os.Stat(sslrootcert); err == nil { + verifyCaOnly = true + } else { + delete(o, "sslrootcert") + } + } + case "verify-ca": + // We must skip TLS's own verification since it requires full + // verification since Go 1.3. + tlsConf.InsecureSkipVerify = true + verifyCaOnly = true + case "verify-full": + tlsConf.ServerName = o["host"] + case "disable": + return nil, nil + default: + return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) + } + + err := sslClientCertificates(&tlsConf, o) + if err != nil { + return nil, err + } + err = sslCertificateAuthority(&tlsConf, o) + if err != nil { + return nil, err + } + + // Accept renegotiation requests initiated by the backend. + // + // Renegotiation was deprecated then removed from PostgreSQL 9.5, but + // the default configuration of older versions has it enabled. Redshift + // also initiates renegotiations and cannot be reconfigured. + tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient + + return func(conn net.Conn) (net.Conn, error) { + client := tls.Client(conn, &tlsConf) + if verifyCaOnly { + err := sslVerifyCertificateAuthority(client, &tlsConf) + if err != nil { + return nil, err + } + } + return client, nil + }, nil +} + +// sslClientCertificates adds the certificate specified in the "sslcert" and +// "sslkey" settings, or if they aren't set, from the .postgresql directory +// in the user's home directory. The configured files must exist and have +// the correct permissions. +func sslClientCertificates(tlsConf *tls.Config, o values) error { + // user.Current() might fail when cross-compiling. We have to ignore the + // error and continue without home directory defaults, since we wouldn't + // know from where to load them. + user, _ := user.Current() + + // In libpq, the client certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 + sslcert := o["sslcert"] + if len(sslcert) == 0 && user != nil { + sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045 + if len(sslcert) == 0 { + return nil + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054 + if _, err := os.Stat(sslcert); os.IsNotExist(err) { + return nil + } else if err != nil { + return err + } + + // In libpq, the ssl key is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222 + sslkey := o["sslkey"] + if len(sslkey) == 0 && user != nil { + sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + } + + if len(sslkey) > 0 { + if err := sslKeyPermissions(sslkey); err != nil { + return err + } + } + + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return err + } + + tlsConf.Certificates = []tls.Certificate{cert} + return nil +} + +// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. +func sslCertificateAuthority(tlsConf *tls.Config, o values) error { + // In libpq, the root certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951 + if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { + tlsConf.RootCAs = x509.NewCertPool() + + cert, err := ioutil.ReadFile(sslrootcert) + if err != nil { + return err + } + + if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { + return fmterrorf("couldn't parse pem in sslrootcert") + } + } + + return nil +} + +// sslVerifyCertificateAuthority carries out a TLS handshake to the server and +// verifies the presented certificate against the CA, i.e. the one specified in +// sslrootcert or the system CA if sslrootcert was not specified. +func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error { + err := client.Handshake() + if err != nil { + return err + } + certs := client.ConnectionState().PeerCertificates + opts := x509.VerifyOptions{ + DNSName: client.ConnectionState().ServerName, + Intermediates: x509.NewCertPool(), + Roots: tlsConf.RootCAs, + } + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + _, err = certs[0].Verify(opts) + return err +} diff --git a/vendor/github.com/lib/pq/ssl_permissions.go b/vendor/github.com/lib/pq/ssl_permissions.go new file mode 100644 index 00000000..3b7c3a2a --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_permissions.go @@ -0,0 +1,20 @@ +// +build !windows + +package pq + +import "os" + +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(sslkey string) error { + info, err := os.Stat(sslkey) + if err != nil { + return err + } + if info.Mode().Perm()&0077 != 0 { + return ErrSSLKeyHasWorldPermissions + } + return nil +} diff --git a/vendor/github.com/lib/pq/ssl_windows.go b/vendor/github.com/lib/pq/ssl_windows.go new file mode 100644 index 00000000..5d2c763c --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_windows.go @@ -0,0 +1,9 @@ +// +build windows + +package pq + +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(string) error { return nil } diff --git a/vendor/github.com/lib/pq/uuid.go b/vendor/github.com/lib/pq/uuid.go new file mode 100644 index 00000000..9a1b9e07 --- /dev/null +++ b/vendor/github.com/lib/pq/uuid.go @@ -0,0 +1,23 @@ +package pq + +import ( + "encoding/hex" + "fmt" +) + +// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. +func decodeUUIDBinary(src []byte) ([]byte, error) { + if len(src) != 16 { + return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) + } + + dst := make([]byte, 36) + dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' + hex.Encode(dst[0:], src[0:4]) + hex.Encode(dst[9:], src[4:6]) + hex.Encode(dst[14:], src[6:8]) + hex.Encode(dst[19:], src[8:10]) + hex.Encode(dst[24:], src[10:16]) + + return dst, nil +} From 587081a643af2e9e6011750eccd469c0ef0f16c8 Mon Sep 17 00:00:00 2001 From: Alex Suraci Date: Mon, 19 Nov 2018 11:34:45 -0500 Subject: [PATCH 5/6] postgres: refactor error handling to fix retrying prior to this change, many of the functions in the ExecTx callback would wrap the error before returning it. this made it impossible to check for the error code. instead, the error wrapping has been moved to be external to the `ExecTx` callback, so that the error code can be checked and serialization failures can be retried. --- storage/sql/crud.go | 226 ++++++++++++++++++++++++++++---------------- storage/sql/sql.go | 14 ++- 2 files changed, 152 insertions(+), 88 deletions(-) diff --git a/storage/sql/crud.go b/storage/sql/crud.go index d7c055ab..a1406e20 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -134,7 +134,7 @@ func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { } func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { r, err := getAuthRequest(tx, id) if err != nil { return err @@ -144,6 +144,7 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) if err != nil { return err } + _, err = tx.Exec(` update auth_request set @@ -163,21 +164,31 @@ func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) a.ConnectorID, a.ConnectorData, a.Expiry, r.ID, ) - if err != nil { - return fmt.Errorf("update auth request: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update auth request: %v", err) + } + return nil } func (c *conn) GetAuthRequest(id string) (storage.AuthRequest, error) { - return getAuthRequest(c, id) + req, err := getAuthRequest(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.AuthRequest{}, storage.ErrNotFound + } + + return storage.AuthRequest{}, fmt.Errorf("select auth request: %v", err) + } + + return req, nil } func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { err = q.QueryRow(` - select + select id, client_id, response_types, scopes, redirect_uri, nonce, state, force_approval_prompt, logged_in, claims_user_id, claims_username, claims_email, claims_email_verified, @@ -192,10 +203,7 @@ func getAuthRequest(q querier, id string) (a storage.AuthRequest, err error) { &a.ConnectorID, &a.ConnectorData, &a.Expiry, ) if err != nil { - if err == sql.ErrNoRows { - return a, storage.ErrNotFound - } - return a, fmt.Errorf("select auth request: %v", err) + return a, err } return a, nil } @@ -269,20 +277,22 @@ func (c *conn) CreateRefresh(r storage.RefreshToken) error { if c.alreadyExistsCheck(err) { return storage.ErrAlreadyExists } - return fmt.Errorf("insert refresh_token: %v", err) + return fmt.Errorf("insert refresh token: %v", err) } return nil } func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { r, err := getRefresh(tx, id) if err != nil { return err } + if r, err = updater(r); err != nil { return err } + _, err = tx.Exec(` update refresh_token set @@ -308,15 +318,25 @@ func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshTok r.ConnectorID, r.ConnectorData, r.Token, r.CreatedAt, r.LastUsed, id, ) - if err != nil { - return fmt.Errorf("update refresh token: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update refresh token: %v", err) + } + return nil } func (c *conn) GetRefresh(id string) (storage.RefreshToken, error) { - return getRefresh(c, id) + req, err := getRefresh(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.RefreshToken{}, storage.ErrNotFound + } + + return storage.RefreshToken{}, fmt.Errorf("get refresh token: %v", err) + } + + return req, nil } func getRefresh(q querier, id string) (storage.RefreshToken, error) { @@ -342,14 +362,15 @@ func (c *conn) ListRefreshTokens() ([]storage.RefreshToken, error) { from refresh_token; `) if err != nil { - return nil, fmt.Errorf("query: %v", err) + return nil, fmt.Errorf("select refresh tokens: %v", err) } var tokens []storage.RefreshToken for rows.Next() { r, err := scanRefresh(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan refresh token: %s", err) } + tokens = append(tokens, r) } if err := rows.Err(); err != nil { @@ -367,10 +388,7 @@ func scanRefresh(s scanner) (r storage.RefreshToken, err error) { &r.Token, &r.CreatedAt, &r.LastUsed, ) if err != nil { - if err == sql.ErrNoRows { - return r, storage.ErrNotFound - } - return r, fmt.Errorf("scan refresh_token: %v", err) + return r, err } return r, nil } @@ -381,12 +399,11 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) // TODO(ericchiang): errors may cause a transaction be rolled back by the SQL // server. Test this, and consider adding a COUNT() command beforehand. old, err := getKeys(tx) - if err != nil { - if err != storage.ErrNotFound { - return fmt.Errorf("get keys: %v", err) - } + if err == sql.ErrNoRows { firstUpdate = true old = storage.Keys{} + } else if err != nil { + return err } nk, err := updater(old) @@ -405,12 +422,12 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) encoder(nk.SigningKeyPub), nk.NextRotation, ) if err != nil { - return fmt.Errorf("insert: %v", err) + return err } } else { _, err = tx.Exec(` update keys - set + set verification_keys = $1, signing_key = $2, signing_key_pub = $3, @@ -421,15 +438,24 @@ func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) encoder(nk.SigningKeyPub), nk.NextRotation, keysRowID, ) if err != nil { - return fmt.Errorf("update: %v", err) + return err } } return nil }) } -func (c *conn) GetKeys() (keys storage.Keys, err error) { - return getKeys(c) +func (c *conn) GetKeys() (storage.Keys, error) { + keys, err := getKeys(c) + if err != nil { + if err == sql.ErrNoRows { + return storage.Keys{}, storage.ErrNotFound + } + + return storage.Keys{}, fmt.Errorf("select keys: %s", err) + } + + return keys, nil } func getKeys(q querier) (keys storage.Keys, err error) { @@ -443,20 +469,18 @@ func getKeys(q querier) (keys storage.Keys, err error) { decoder(&keys.SigningKeyPub), &keys.NextRotation, ) if err != nil { - if err == sql.ErrNoRows { - return keys, storage.ErrNotFound - } - return keys, fmt.Errorf("query keys: %v", err) + return keys, err } return keys, nil } func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { cli, err := getClient(tx, id) if err != nil { return err } + nc, err := updater(cli) if err != nil { return err @@ -474,11 +498,13 @@ func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage where id = $7; `, nc.Secret, encoder(nc.RedirectURIs), encoder(nc.TrustedPeers), nc.Public, nc.Name, nc.LogoURL, id, ) - if err != nil { - return fmt.Errorf("update client: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update client: %v", err) + } + + return nil } func (c *conn) CreateClient(cli storage.Client) error { @@ -509,7 +535,16 @@ func getClient(q querier, id string) (storage.Client, error) { } func (c *conn) GetClient(id string) (storage.Client, error) { - return getClient(c, id) + client, err := getClient(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.Client{}, storage.ErrNotFound + } + + return storage.Client{}, fmt.Errorf("select client: %v", err) + } + + return client, nil } func (c *conn) ListClients() ([]storage.Client, error) { @@ -525,12 +560,12 @@ func (c *conn) ListClients() ([]storage.Client, error) { for rows.Next() { cli, err := scanClient(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan client: %s", err) } clients = append(clients, cli) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return clients, nil } @@ -541,10 +576,7 @@ func scanClient(s scanner) (cli storage.Client, err error) { &cli.Public, &cli.Name, &cli.LogoURL, ) if err != nil { - if err == sql.ErrNoRows { - return cli, storage.ErrNotFound - } - return cli, fmt.Errorf("get client: %v", err) + return cli, err } return cli, nil } @@ -571,7 +603,7 @@ func (c *conn) CreatePassword(p storage.Password) error { } func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { p, err := getPassword(tx, email) if err != nil { return err @@ -581,6 +613,7 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st if err != nil { return err } + _, err = tx.Exec(` update password set @@ -589,15 +622,25 @@ func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (st `, np.Hash, np.Username, np.UserID, p.Email, ) - if err != nil { - return fmt.Errorf("update password: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update password: %v", err) + } + return nil } func (c *conn) GetPassword(email string) (storage.Password, error) { - return getPassword(c, email) + pass, err := getPassword(c, email) + if err != nil { + if err == sql.ErrNoRows { + return storage.Password{}, storage.ErrNotFound + } + + return storage.Password{}, fmt.Errorf("get password: %s", err) + } + + return pass, nil } func getPassword(q querier, email string) (p storage.Password, err error) { @@ -622,12 +665,12 @@ func (c *conn) ListPasswords() ([]storage.Password, error) { for rows.Next() { p, err := scanPassword(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan password: %s", err) } passwords = append(passwords, p) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return passwords, nil } @@ -637,10 +680,7 @@ func scanPassword(s scanner) (p storage.Password, err error) { &p.Email, &p.Hash, &p.Username, &p.UserID, ) if err != nil { - if err == sql.ErrNoRows { - return p, storage.ErrNotFound - } - return p, fmt.Errorf("select password: %v", err) + return p, err } return p, nil } @@ -666,7 +706,7 @@ func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { } func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { s, err := getOfflineSessions(tx, userID, connID) if err != nil { return err @@ -676,6 +716,7 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( if err != nil { return err } + _, err = tx.Exec(` update offline_session set @@ -684,15 +725,26 @@ func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func( `, encoder(newSession.Refresh), s.UserID, s.ConnID, ) - if err != nil { - return fmt.Errorf("update offline session: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update offline session: %v", err) + } + + return nil } func (c *conn) GetOfflineSessions(userID string, connID string) (storage.OfflineSessions, error) { - return getOfflineSessions(c, userID, connID) + sessions, err := getOfflineSessions(c, userID, connID) + if err != nil { + if err == sql.ErrNoRows { + return storage.OfflineSessions{}, storage.ErrNotFound + } + + return storage.OfflineSessions{}, fmt.Errorf("get offline sessions: %s", err) + } + + return sessions, nil } func getOfflineSessions(q querier, userID string, connID string) (storage.OfflineSessions, error) { @@ -709,10 +761,7 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { &o.UserID, &o.ConnID, decoder(&o.Refresh), ) if err != nil { - if err == sql.ErrNoRows { - return o, storage.ErrNotFound - } - return o, fmt.Errorf("select offline session: %v", err) + return o, err } return o, nil } @@ -738,7 +787,7 @@ func (c *conn) CreateConnector(connector storage.Connector) error { } func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error { - return c.ExecTx(func(tx *trans) error { + err := c.ExecTx(func(tx *trans) error { connector, err := getConnector(tx, id) if err != nil { return err @@ -748,9 +797,10 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto if err != nil { return err } + _, err = tx.Exec(` update connector - set + set type = $1, name = $2, resource_version = $3, @@ -759,15 +809,26 @@ func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (sto `, newConn.Type, newConn.Name, newConn.ResourceVersion, newConn.Config, connector.ID, ) - if err != nil { - return fmt.Errorf("update connector: %v", err) - } - return nil + return err }) + if err != nil { + return fmt.Errorf("update connector: %v", err) + } + + return nil } func (c *conn) GetConnector(id string) (storage.Connector, error) { - return getConnector(c, id) + connector, err := getConnector(c, id) + if err != nil { + if err == sql.ErrNoRows { + return storage.Connector{}, storage.ErrNotFound + } + + return storage.Connector{}, fmt.Errorf("get connector: %s", err) + } + + return connector, nil } func getConnector(q querier, id string) (storage.Connector, error) { @@ -784,10 +845,7 @@ func scanConnector(s scanner) (c storage.Connector, err error) { &c.ID, &c.Type, &c.Name, &c.ResourceVersion, &c.Config, ) if err != nil { - if err == sql.ErrNoRows { - return c, storage.ErrNotFound - } - return c, fmt.Errorf("select connector: %v", err) + return c, err } return c, nil } @@ -805,12 +863,12 @@ func (c *conn) ListConnectors() ([]storage.Connector, error) { for rows.Next() { conn, err := scanConnector(rows) if err != nil { - return nil, err + return nil, fmt.Errorf("scan connector: %s", err) } connectors = append(connectors, conn) } if err := rows.Err(); err != nil { - return nil, err + return nil, fmt.Errorf("scan: %s", err) } return connectors, nil } diff --git a/storage/sql/sql.go b/storage/sql/sql.go index 7f20cf9d..b51f6fcc 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -44,13 +44,14 @@ var ( // The "github.com/lib/pq" driver is the default flavor. All others are // translations of this. flavorPostgres = flavor{ - // The default behavior for Postgres transactions is consistent reads, not consistent writes. - // For each transaction opened, ensure it has the correct isolation level. + // The default behavior for Postgres transactions is consistent reads, not + // consistent writes. For each transaction opened, ensure it has the + // correct isolation level. // // See: https://www.postgresql.org/docs/9.3/static/sql-set-transaction.html // - // NOTE(ericchiang): For some reason using `SET SESSION CHARACTERISTICS AS TRANSACTION` at a - // session level didn't work for some edge cases. Might be something worth exploring. + // Be careful not to wrap sql errors in the callback 'fn', otherwise + // serialization failures will not be detected and retried. executeTx: func(db *sql.DB, fn func(sqlTx *sql.Tx) error) error { ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -66,6 +67,11 @@ var ( } if err := fn(tx); err != nil { + if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "serialization_failure" { + // serialization error; retry + continue + } + return err } From 85dd0684babf0d718dee796cf6176c59a0661ab3 Mon Sep 17 00:00:00 2001 From: Alex Suraci Date: Tue, 20 Nov 2018 09:42:30 -0500 Subject: [PATCH 6/6] extract and document serialization failure check --- storage/sql/sql.go | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/storage/sql/sql.go b/storage/sql/sql.go index b51f6fcc..69b03cbd 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -40,6 +40,21 @@ func matchLiteral(s string) *regexp.Regexp { return regexp.MustCompile(`\b` + regexp.QuoteMeta(s) + `\b`) } +// Detect a serialization failure, which should trigger retrying the +// transaction according to PostgreSQL docs: +// +// https://www.postgresql.org/docs/current/transaction-iso.html#XACT-SERIALIZABLE +// +// "applications using this level must be prepared to retry transactions due to +// serialization failures" +func isRetryableSerializationFailure(err error) bool { + if pqErr, ok := err.(*pq.Error); ok { + return pqErr.Code.Name() == "serialization_failure" + } + + return false +} + var ( // The "github.com/lib/pq" driver is the default flavor. All others are // translations of this. @@ -67,8 +82,7 @@ var ( } if err := fn(tx); err != nil { - if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "serialization_failure" { - // serialization error; retry + if isRetryableSerializationFailure(err) { continue } @@ -77,8 +91,7 @@ var ( err = tx.Commit() if err != nil { - if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == "serialization_failure" { - // serialization error; retry + if isRetryableSerializationFailure(err) { continue }