postgres: use connection string instead of url
otherwise it's impossible to use a Unix socket, as the path gets escaped awkwardly. Signed-off-by: Ciro S. Costa <cscosta@pivotal.io> Signed-off-by: Alex Suraci <suraci.alex@gmail.com>
This commit is contained in:
		
				
					committed by
					
						 Alex Suraci
						Alex Suraci
					
				
			
			
				
	
			
			
			
						parent
						
							2425c6ea63
						
					
				
				
					commit
					f82b904d05
				
			| @@ -3,8 +3,9 @@ package sql | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/lib/pq" | ||||
| 	sqlite3 "github.com/mattn/go-sqlite3" | ||||
| @@ -81,6 +82,7 @@ type Postgres struct { | ||||
| 	User     string | ||||
| 	Password string | ||||
| 	Host     string | ||||
| 	Port     uint16 | ||||
|  | ||||
| 	SSL PostgresSSL `json:"ssl" yaml:"ssl"` | ||||
|  | ||||
| @@ -89,45 +91,75 @@ type Postgres struct { | ||||
|  | ||||
| // Open creates a new storage implementation backed by Postgres. | ||||
| func (p *Postgres) Open(logger logrus.FieldLogger) (storage.Storage, error) { | ||||
| 	conn, err := p.open(logger) | ||||
| 	conn, err := p.open(logger, p.createDataSourceName()) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return conn, nil | ||||
| } | ||||
|  | ||||
| func (p *Postgres) open(logger logrus.FieldLogger) (*conn, error) { | ||||
| 	v := url.Values{} | ||||
| 	set := func(key, val string) { | ||||
| 		if val != "" { | ||||
| 			v.Set(key, val) | ||||
| 		} | ||||
| 	} | ||||
| 	set("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) | ||||
| 	set("sslkey", p.SSL.KeyFile) | ||||
| 	set("sslcert", p.SSL.CertFile) | ||||
| 	set("sslrootcert", p.SSL.CAFile) | ||||
| 	if p.SSL.Mode == "" { | ||||
| 		// Assume the strictest mode if unspecified. | ||||
| 		p.SSL.Mode = sslVerifyFull | ||||
| 	} | ||||
| 	set("sslmode", p.SSL.Mode) | ||||
| var strEsc = regexp.MustCompile(`([\\'])`) | ||||
|  | ||||
| 	u := url.URL{ | ||||
| 		Scheme:   "postgres", | ||||
| 		Host:     p.Host, | ||||
| 		Path:     "/" + p.Database, | ||||
| 		RawQuery: v.Encode(), | ||||
| func dataSourceStr(str string) string { | ||||
| 	return "'" + strEsc.ReplaceAllString(str, `\$1`) + "'" | ||||
| } | ||||
|  | ||||
| // createDataSourceName takes the configuration provided via the Postgres | ||||
| // struct to create a data-source name that Go's database/sql package can | ||||
| // make use of. | ||||
| func (p *Postgres) createDataSourceName() string { | ||||
| 	parameters := []string{} | ||||
|  | ||||
| 	addParam := func(key, val string) { | ||||
| 		parameters = append(parameters, fmt.Sprintf("%s=%s", key, val)) | ||||
| 	} | ||||
|  | ||||
| 	addParam("connect_timeout", strconv.Itoa(p.ConnectionTimeout)) | ||||
|  | ||||
| 	if p.Host != "" { | ||||
| 		addParam("host", dataSourceStr(p.Host)) | ||||
| 	} | ||||
|  | ||||
| 	if p.Port != 0 { | ||||
| 		addParam("port", strconv.Itoa(int(p.Port))) | ||||
| 	} | ||||
|  | ||||
| 	if p.User != "" { | ||||
| 		if p.Password != "" { | ||||
| 			u.User = url.UserPassword(p.User, p.Password) | ||||
| 		} else { | ||||
| 			u.User = url.User(p.User) | ||||
| 		} | ||||
| 		addParam("user", dataSourceStr(p.User)) | ||||
| 	} | ||||
| 	db, err := sql.Open("postgres", u.String()) | ||||
|  | ||||
| 	if p.Password != "" { | ||||
| 		addParam("password", dataSourceStr(p.Password)) | ||||
| 	} | ||||
|  | ||||
| 	if p.Database != "" { | ||||
| 		addParam("dbname", dataSourceStr(p.Database)) | ||||
| 	} | ||||
|  | ||||
| 	if p.SSL.Mode == "" { | ||||
| 		// Assume the strictest mode if unspecified. | ||||
| 		addParam("sslmode", dataSourceStr(sslVerifyFull)) | ||||
| 	} else { | ||||
| 		addParam("sslmode", dataSourceStr(p.SSL.Mode)) | ||||
| 	} | ||||
|  | ||||
| 	if p.SSL.CAFile != "" { | ||||
| 		addParam("sslrootcert", dataSourceStr(p.SSL.CAFile)) | ||||
| 	} | ||||
|  | ||||
| 	if p.SSL.CertFile != "" { | ||||
| 		addParam("sslcert", dataSourceStr(p.SSL.CertFile)) | ||||
| 	} | ||||
|  | ||||
| 	if p.SSL.KeyFile != "" { | ||||
| 		addParam("sslkey", dataSourceStr(p.SSL.KeyFile)) | ||||
| 	} | ||||
|  | ||||
| 	return strings.Join(parameters, " ") | ||||
| } | ||||
|  | ||||
| func (p *Postgres) open(logger logrus.FieldLogger, dataSourceName string) (*conn, error) { | ||||
| 	db, err := sql.Open("postgres", dataSourceName) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|   | ||||
| @@ -77,6 +77,103 @@ func getenv(key, defaultVal string) string { | ||||
|  | ||||
| const testPostgresEnv = "DEX_POSTGRES_HOST" | ||||
|  | ||||
| func TestCreateDataSourceName(t *testing.T) { | ||||
| 	var testCases = []struct { | ||||
| 		description string | ||||
| 		input       *Postgres | ||||
| 		expected    string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			description: "with no configuration", | ||||
| 			input:       &Postgres{}, | ||||
| 			expected:    "connect_timeout=0 sslmode='verify-full'", | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "with typical configuration", | ||||
| 			input: &Postgres{ | ||||
| 				Host:     "1.2.3.4", | ||||
| 				Port:     6543, | ||||
| 				User:     "some-user", | ||||
| 				Password: "some-password", | ||||
| 				Database: "some-db", | ||||
| 			}, | ||||
| 			expected: "connect_timeout=0 host='1.2.3.4' port=6543 user='some-user' password='some-password' dbname='some-db' sslmode='verify-full'", | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "with unix socket host", | ||||
| 			input: &Postgres{ | ||||
| 				Host: "/var/run/postgres", | ||||
| 				SSL: PostgresSSL{ | ||||
| 					Mode: "disable", | ||||
| 				}, | ||||
| 			}, | ||||
| 			expected: "connect_timeout=0 host='/var/run/postgres' sslmode='disable'", | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "with tcp host", | ||||
| 			input: &Postgres{ | ||||
| 				Host: "coreos.com", | ||||
| 				SSL: PostgresSSL{ | ||||
| 					Mode: "disable", | ||||
| 				}, | ||||
| 			}, | ||||
| 			expected: "connect_timeout=0 host='coreos.com' sslmode='disable'", | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "with tcp host and port", | ||||
| 			input: &Postgres{ | ||||
| 				Host: "coreos.com", | ||||
| 				Port: 6543, | ||||
| 			}, | ||||
| 			expected: "connect_timeout=0 host='coreos.com' port=6543 sslmode='verify-full'", | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "with ssl ca cert", | ||||
| 			input: &Postgres{ | ||||
| 				Host: "coreos.com", | ||||
| 				SSL: PostgresSSL{ | ||||
| 					Mode:   "verify-ca", | ||||
| 					CAFile: "/some/file/path", | ||||
| 				}, | ||||
| 			}, | ||||
| 			expected: "connect_timeout=0 host='coreos.com' sslmode='verify-ca' sslrootcert='/some/file/path'", | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "with ssl client cert", | ||||
| 			input: &Postgres{ | ||||
| 				Host: "coreos.com", | ||||
| 				SSL: PostgresSSL{ | ||||
| 					Mode:     "verify-ca", | ||||
| 					CAFile:   "/some/ca/path", | ||||
| 					CertFile: "/some/cert/path", | ||||
| 					KeyFile:  "/some/key/path", | ||||
| 				}, | ||||
| 			}, | ||||
| 			expected: "connect_timeout=0 host='coreos.com' sslmode='verify-ca' sslrootcert='/some/ca/path' sslcert='/some/cert/path' sslkey='/some/key/path'", | ||||
| 		}, | ||||
| 		{ | ||||
| 			description: "with funny characters in credentials", | ||||
| 			input: &Postgres{ | ||||
| 				Host:     "coreos.com", | ||||
| 				User:     `some'user\slashed`, | ||||
| 				Password: "some'password!", | ||||
| 			}, | ||||
| 			expected: `connect_timeout=0 host='coreos.com' user='some\'user\\slashed' password='some\'password!' sslmode='verify-full'`, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	var actual string | ||||
| 	for _, testCase := range testCases { | ||||
| 		t.Run(testCase.description, func(t *testing.T) { | ||||
| 			actual = testCase.input.createDataSourceName() | ||||
|  | ||||
| 			if actual != testCase.expected { | ||||
| 				t.Fatalf("%s != %s", actual, testCase.expected) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestPostgres(t *testing.T) { | ||||
| 	host := os.Getenv(testPostgresEnv) | ||||
| 	if host == "" { | ||||
| @@ -100,7 +197,7 @@ func TestPostgres(t *testing.T) { | ||||
| 	} | ||||
|  | ||||
| 	newStorage := func() storage.Storage { | ||||
| 		conn, err := p.open(logger) | ||||
| 		conn, err := p.open(logger, p.createDataSourceName()) | ||||
| 		if err != nil { | ||||
| 			fatal(err) | ||||
| 		} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user