Merge pull request #1346 from concourse/pr/postgres-unix-sockets
Use pq connection parameters instead of URLs for postgres connections This enables the use of socket paths like /var/run/postgresql for the 'host' instead of requiring TCP. Also, we know allow using a non-default port.
This commit is contained in:
		| @@ -3,8 +3,9 @@ package sql | |||||||
| import ( | import ( | ||||||
| 	"database/sql" | 	"database/sql" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net/url" | 	"regexp" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/lib/pq" | 	"github.com/lib/pq" | ||||||
| 	sqlite3 "github.com/mattn/go-sqlite3" | 	sqlite3 "github.com/mattn/go-sqlite3" | ||||||
| @@ -81,6 +82,7 @@ type Postgres struct { | |||||||
| 	User     string | 	User     string | ||||||
| 	Password string | 	Password string | ||||||
| 	Host     string | 	Host     string | ||||||
|  | 	Port     uint16 | ||||||
|  |  | ||||||
| 	SSL PostgresSSL `json:"ssl" yaml:"ssl"` | 	SSL PostgresSSL `json:"ssl" yaml:"ssl"` | ||||||
|  |  | ||||||
| @@ -89,45 +91,75 @@ type Postgres struct { | |||||||
|  |  | ||||||
| // Open creates a new storage implementation backed by Postgres. | // Open creates a new storage implementation backed by Postgres. | ||||||
| func (p *Postgres) Open(logger logrus.FieldLogger) (storage.Storage, error) { | func (p *Postgres) Open(logger logrus.FieldLogger) (storage.Storage, error) { | ||||||
| 	conn, err := p.open(logger) | 	conn, err := p.open(logger, p.createDataSourceName()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 	return conn, nil | 	return conn, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| func (p *Postgres) open(logger logrus.FieldLogger) (*conn, error) { | var strEsc = regexp.MustCompile(`([\\'])`) | ||||||
| 	v := url.Values{} |  | ||||||
| 	set := func(key, val string) { |  | ||||||
| 		if val != "" { |  | ||||||
| 			v.Set(key, val) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| 	set("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) |  | ||||||
| 	set("sslkey", p.SSL.KeyFile) |  | ||||||
| 	set("sslcert", p.SSL.CertFile) |  | ||||||
| 	set("sslrootcert", p.SSL.CAFile) |  | ||||||
| 	if p.SSL.Mode == "" { |  | ||||||
| 		// Assume the strictest mode if unspecified. |  | ||||||
| 		p.SSL.Mode = sslVerifyFull |  | ||||||
| 	} |  | ||||||
| 	set("sslmode", p.SSL.Mode) |  | ||||||
|  |  | ||||||
| 	u := url.URL{ | func dataSourceStr(str string) string { | ||||||
| 		Scheme:   "postgres", | 	return "'" + strEsc.ReplaceAllString(str, `\$1`) + "'" | ||||||
| 		Host:     p.Host, | } | ||||||
| 		Path:     "/" + p.Database, |  | ||||||
| 		RawQuery: v.Encode(), | // createDataSourceName takes the configuration provided via the Postgres | ||||||
|  | // struct to create a data-source name that Go's database/sql package can | ||||||
|  | // make use of. | ||||||
|  | func (p *Postgres) createDataSourceName() string { | ||||||
|  | 	parameters := []string{} | ||||||
|  |  | ||||||
|  | 	addParam := func(key, val string) { | ||||||
|  | 		parameters = append(parameters, fmt.Sprintf("%s=%s", key, val)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	addParam("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) | ||||||
|  |  | ||||||
|  | 	if p.Host != "" { | ||||||
|  | 		addParam("host", dataSourceStr(p.Host)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if p.Port != 0 { | ||||||
|  | 		addParam("port", strconv.Itoa(int(p.Port))) | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if p.User != "" { | 	if p.User != "" { | ||||||
| 		if p.Password != "" { | 		addParam("user", dataSourceStr(p.User)) | ||||||
| 			u.User = url.UserPassword(p.User, p.Password) |  | ||||||
| 		} else { |  | ||||||
| 			u.User = url.User(p.User) |  | ||||||
| 		} |  | ||||||
| 	} | 	} | ||||||
| 	db, err := sql.Open("postgres", u.String()) |  | ||||||
|  | 	if p.Password != "" { | ||||||
|  | 		addParam("password", dataSourceStr(p.Password)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if p.Database != "" { | ||||||
|  | 		addParam("dbname", dataSourceStr(p.Database)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if p.SSL.Mode == "" { | ||||||
|  | 		// Assume the strictest mode if unspecified. | ||||||
|  | 		addParam("sslmode", dataSourceStr(sslVerifyFull)) | ||||||
|  | 	} else { | ||||||
|  | 		addParam("sslmode", dataSourceStr(p.SSL.Mode)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if p.SSL.CAFile != "" { | ||||||
|  | 		addParam("sslrootcert", dataSourceStr(p.SSL.CAFile)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if p.SSL.CertFile != "" { | ||||||
|  | 		addParam("sslcert", dataSourceStr(p.SSL.CertFile)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if p.SSL.KeyFile != "" { | ||||||
|  | 		addParam("sslkey", dataSourceStr(p.SSL.KeyFile)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return strings.Join(parameters, " ") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (p *Postgres) open(logger logrus.FieldLogger, dataSourceName string) (*conn, error) { | ||||||
|  | 	db, err := sql.Open("postgres", dataSourceName) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
|   | |||||||
| @@ -77,6 +77,103 @@ func getenv(key, defaultVal string) string { | |||||||
|  |  | ||||||
| const testPostgresEnv = "DEX_POSTGRES_HOST" | const testPostgresEnv = "DEX_POSTGRES_HOST" | ||||||
|  |  | ||||||
|  | func TestCreateDataSourceName(t *testing.T) { | ||||||
|  | 	var testCases = []struct { | ||||||
|  | 		description string | ||||||
|  | 		input       *Postgres | ||||||
|  | 		expected    string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			description: "with no configuration", | ||||||
|  | 			input:       &Postgres{}, | ||||||
|  | 			expected:    "connect_timeout=0 sslmode='verify-full'", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description: "with typical configuration", | ||||||
|  | 			input: &Postgres{ | ||||||
|  | 				Host:     "1.2.3.4", | ||||||
|  | 				Port:     6543, | ||||||
|  | 				User:     "some-user", | ||||||
|  | 				Password: "some-password", | ||||||
|  | 				Database: "some-db", | ||||||
|  | 			}, | ||||||
|  | 			expected: "connect_timeout=0 host='1.2.3.4' port=6543 user='some-user' password='some-password' dbname='some-db' sslmode='verify-full'", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description: "with unix socket host", | ||||||
|  | 			input: &Postgres{ | ||||||
|  | 				Host: "/var/run/postgres", | ||||||
|  | 				SSL: PostgresSSL{ | ||||||
|  | 					Mode: "disable", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expected: "connect_timeout=0 host='/var/run/postgres' sslmode='disable'", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description: "with tcp host", | ||||||
|  | 			input: &Postgres{ | ||||||
|  | 				Host: "coreos.com", | ||||||
|  | 				SSL: PostgresSSL{ | ||||||
|  | 					Mode: "disable", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expected: "connect_timeout=0 host='coreos.com' sslmode='disable'", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description: "with tcp host and port", | ||||||
|  | 			input: &Postgres{ | ||||||
|  | 				Host: "coreos.com", | ||||||
|  | 				Port: 6543, | ||||||
|  | 			}, | ||||||
|  | 			expected: "connect_timeout=0 host='coreos.com' port=6543 sslmode='verify-full'", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description: "with ssl ca cert", | ||||||
|  | 			input: &Postgres{ | ||||||
|  | 				Host: "coreos.com", | ||||||
|  | 				SSL: PostgresSSL{ | ||||||
|  | 					Mode:   "verify-ca", | ||||||
|  | 					CAFile: "/some/file/path", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expected: "connect_timeout=0 host='coreos.com' sslmode='verify-ca' sslrootcert='/some/file/path'", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description: "with ssl client cert", | ||||||
|  | 			input: &Postgres{ | ||||||
|  | 				Host: "coreos.com", | ||||||
|  | 				SSL: PostgresSSL{ | ||||||
|  | 					Mode:     "verify-ca", | ||||||
|  | 					CAFile:   "/some/ca/path", | ||||||
|  | 					CertFile: "/some/cert/path", | ||||||
|  | 					KeyFile:  "/some/key/path", | ||||||
|  | 				}, | ||||||
|  | 			}, | ||||||
|  | 			expected: "connect_timeout=0 host='coreos.com' sslmode='verify-ca' sslrootcert='/some/ca/path' sslcert='/some/cert/path' sslkey='/some/key/path'", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			description: "with funny characters in credentials", | ||||||
|  | 			input: &Postgres{ | ||||||
|  | 				Host:     "coreos.com", | ||||||
|  | 				User:     `some'user\slashed`, | ||||||
|  | 				Password: "some'password!", | ||||||
|  | 			}, | ||||||
|  | 			expected: `connect_timeout=0 host='coreos.com' user='some\'user\\slashed' password='some\'password!' sslmode='verify-full'`, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var actual string | ||||||
|  | 	for _, testCase := range testCases { | ||||||
|  | 		t.Run(testCase.description, func(t *testing.T) { | ||||||
|  | 			actual = testCase.input.createDataSourceName() | ||||||
|  |  | ||||||
|  | 			if actual != testCase.expected { | ||||||
|  | 				t.Fatalf("%s != %s", actual, testCase.expected) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestPostgres(t *testing.T) { | func TestPostgres(t *testing.T) { | ||||||
| 	host := os.Getenv(testPostgresEnv) | 	host := os.Getenv(testPostgresEnv) | ||||||
| 	if host == "" { | 	if host == "" { | ||||||
| @@ -100,7 +197,7 @@ func TestPostgres(t *testing.T) { | |||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	newStorage := func() storage.Storage { | 	newStorage := func() storage.Storage { | ||||||
| 		conn, err := p.open(logger) | 		conn, err := p.open(logger, p.createDataSourceName()) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			fatal(err) | 			fatal(err) | ||||||
| 		} | 		} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user