diff --git a/dialect_postgresql.go b/dialect_postgresql.go index 7f657206..6e2c61eb 100644 --- a/dialect_postgresql.go +++ b/dialect_postgresql.go @@ -5,13 +5,16 @@ import ( "fmt" "io" "os/exec" + "strings" "sync" + "unicode" "github.com/gobuffalo/fizz" "github.com/gobuffalo/fizz/translators" "github.com/gobuffalo/pop/columns" "github.com/gobuffalo/pop/logging" "github.com/jmoiron/sqlx" + pg "github.com/lib/pq" "github.com/markbates/going/defaults" "github.com/pkg/errors" ) @@ -23,6 +26,7 @@ func init() { AvailableDialects = append(AvailableDialects, namePostgreSQL) dialectSynonyms["postgresql"] = namePostgreSQL dialectSynonyms["pg"] = namePostgreSQL + urlParser[namePostgreSQL] = urlParserPostgreSQL finalizer[namePostgreSQL] = finalizerPostgreSQL newConnection[namePostgreSQL] = newPostgreSQL } @@ -208,6 +212,51 @@ func newPostgreSQL(deets *ConnectionDetails) (dialect, error) { return cd, nil } +// urlParserPostgreSQL parses the options the same way official lib/pg does: +// https://godoc.org/github.com/lib/pq#hdr-Connection_String_Parameters +// After parsed, they are set to ConnectionDetails instance +func urlParserPostgreSQL(cd *ConnectionDetails) error { + var err error + name := cd.URL + if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { + name, err = pg.ParseURL(name) + if err != nil { + return err + } + } + + o := make(values) + if err := parseOpts(name, o); err != nil { + return err + } + + if dbname, ok := o["dbname"]; ok { + cd.Database = dbname + } + if host, ok := o["host"]; ok { + cd.Host = host + } + if password, ok := o["password"]; ok { + cd.Password = password + } + if user, ok := o["user"]; ok { + cd.User = user + } + if port, ok := o["port"]; ok { + cd.Port = port + } + + options := []string{"sslmode", "fallback_application_name", "connect_timeout", "sslcert", "sslkey", "sslrootcert"} + + for i := range options { + if opt, ok := o[options[i]]; ok { + cd.Options[options[i]] = opt + } + } + + return nil +} + func finalizerPostgreSQL(cd *ConnectionDetails) { cd.Options["sslmode"] = defaults.String(cd.Options["sslmode"], "disable") cd.Port = defaults.String(cd.Port, portPostgreSQL) @@ -230,3 +279,117 @@ BEGIN END LOOP; END $func$;` + +// Code below is ported from: https://github.com/lib/pq/blob/master/conn.go +type values map[string]string + +// scanner implements a tokenizer for libpq-style option strings. +type scanner struct { + s []rune + i int +} + +// newScanner returns a new scanner initialized with the option string s. +func newScanner(s string) *scanner { + return &scanner{[]rune(s), 0} +} + +// Next returns the next rune. +// It returns 0, false if the end of the text has been reached. +func (s *scanner) Next() (rune, bool) { + if s.i >= len(s.s) { + return 0, false + } + r := s.s[s.i] + s.i++ + return r, true +} + +// SkipSpaces returns the next non-whitespace rune. +// It returns 0, false if the end of the text has been reached. +func (s *scanner) SkipSpaces() (rune, bool) { + r, ok := s.Next() + for unicode.IsSpace(r) && ok { + r, ok = s.Next() + } + return r, ok +} + +// parseOpts parses the options from name and adds them to the values. +// +// The parsing code is based on conninfo_parse from libpq's fe-connect.c +func parseOpts(name string, o values) error { + s := newScanner(name) + + for { + var ( + keyRunes, valRunes []rune + r rune + ok bool + ) + + if r, ok = s.SkipSpaces(); !ok { + break + } + + // Scan the key + for !unicode.IsSpace(r) && r != '=' { + keyRunes = append(keyRunes, r) + if r, ok = s.Next(); !ok { + break + } + } + + // Skip any whitespace if we're not at the = yet + if r != '=' { + r, ok = s.SkipSpaces() + } + + // The current character should be = + if r != '=' || !ok { + return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) + } + + // Skip any whitespace after the = + if r, ok = s.SkipSpaces(); !ok { + // If we reach the end here, the last value is just an empty string as per libpq. + o[string(keyRunes)] = "" + break + } + + if r != '\'' { + for !unicode.IsSpace(r) { + if r == '\\' { + if r, ok = s.Next(); !ok { + return fmt.Errorf(`missing character after backslash`) + } + } + valRunes = append(valRunes, r) + + if r, ok = s.Next(); !ok { + break + } + } + } else { + quote: + for { + if r, ok = s.Next(); !ok { + return fmt.Errorf(`unterminated quoted string literal in connection string`) + } + switch r { + case '\'': + break quote + case '\\': + r, _ = s.Next() + fallthrough + default: + valRunes = append(valRunes, r) + } + } + } + + o[string(keyRunes)] = string(valRunes) + } + + return nil +} diff --git a/dialect_postgresql_test.go b/dialect_postgresql_test.go new file mode 100644 index 00000000..643b0b21 --- /dev/null +++ b/dialect_postgresql_test.go @@ -0,0 +1,81 @@ +package pop + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_PostgreSQL_Connection_String(t *testing.T) { + r := require.New(t) + + url := "host=host port=port dbname=database user=user password=pass" + cd := &ConnectionDetails{ + Dialect: "postgres", + URL: url, + } + err := cd.Finalize() + r.NoError(err) + + r.Equal(url, cd.URL) + r.Equal("postgres", cd.Dialect) + r.Equal("host", cd.Host) + r.Equal("pass", cd.Password) + r.Equal("port", cd.Port) + r.Equal("user", cd.User) + r.Equal("database", cd.Database) +} + +func Test_PostgreSQL_Connection_String_Options(t *testing.T) { + r := require.New(t) + + url := "host=host port=port dbname=database user=user password=pass sslmode=disable fallback_application_name=test_app connect_timeout=10 sslcert=/some/location sslkey=/some/other/location sslrootcert=/root/location" + cd := &ConnectionDetails{ + Dialect: "postgres", + URL: url, + } + err := cd.Finalize() + r.NoError(err) + + r.Equal(url, cd.URL) + + r.Equal("disable", cd.Options["sslmode"]) + r.Equal("test_app", cd.Options["fallback_application_name"]) + r.Equal("10", cd.Options["connect_timeout"]) + r.Equal("/some/location", cd.Options["sslcert"]) + r.Equal("/some/other/location", cd.Options["sslkey"]) + r.Equal("/root/location", cd.Options["sslrootcert"]) +} + +func Test_PostgreSQL_Connection_String_Without_User(t *testing.T) { + r := require.New(t) + + url := "dbname=database" + cd := &ConnectionDetails{ + Dialect: "postgres", + URL: url, + } + err := cd.Finalize() + r.NoError(err) + + r.Equal(url, cd.URL) + r.Equal("postgres", cd.Dialect) + r.Equal("", cd.Host) + r.Equal("", cd.Password) + r.Equal(portPostgreSQL, cd.Port) // fallback + r.Equal("", cd.User) + r.Equal("database", cd.Database) +} + +func Test_PostgreSQL_Connection_String_Failure(t *testing.T) { + r := require.New(t) + + url := "abc" + cd := &ConnectionDetails{ + Dialect: "postgres", + URL: url, + } + err := cd.Finalize() + r.Error(err) + r.Equal("postgres", cd.Dialect) +}