Skip to content

Commit

Permalink
PostgreSQL dialog URL parser (#346)
Browse files Browse the repository at this point in the history
* Fixes a parsing issue when PG connection string provided (works with Google Cloud SQL)

* Makes PostgresSQL dialog compatible with the lib/pg Connection String Parameters

* Renamed tests to follow naming convention
  • Loading branch information
tombiscan authored and stanislas-m committed Feb 17, 2019
1 parent 9aeb9ba commit 3e6cf4d
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 0 deletions.
163 changes: 163 additions & 0 deletions dialect_postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ import (
"fmt"
"io"
"os/exec"
"strings"
"sync"
"unicode"

"github.com/gobuffalo/fizz"
"github.com/gobuffalo/fizz/translators"
"github.com/gobuffalo/pop/columns"
"github.com/gobuffalo/pop/logging"
"github.com/jmoiron/sqlx"
pg "github.com/lib/pq"
"github.com/markbates/going/defaults"
"github.com/pkg/errors"
)
Expand All @@ -23,6 +26,7 @@ func init() {
AvailableDialects = append(AvailableDialects, namePostgreSQL)
dialectSynonyms["postgresql"] = namePostgreSQL
dialectSynonyms["pg"] = namePostgreSQL
urlParser[namePostgreSQL] = urlParserPostgreSQL
finalizer[namePostgreSQL] = finalizerPostgreSQL
newConnection[namePostgreSQL] = newPostgreSQL
}
Expand Down Expand Up @@ -208,6 +212,51 @@ func newPostgreSQL(deets *ConnectionDetails) (dialect, error) {
return cd, nil
}

// urlParserPostgreSQL parses the options the same way official lib/pg does:
// https://godoc.org/github.com/lib/pq#hdr-Connection_String_Parameters
// After parsed, they are set to ConnectionDetails instance
func urlParserPostgreSQL(cd *ConnectionDetails) error {
var err error
name := cd.URL
if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
name, err = pg.ParseURL(name)
if err != nil {
return err
}
}

o := make(values)
if err := parseOpts(name, o); err != nil {
return err
}

if dbname, ok := o["dbname"]; ok {
cd.Database = dbname
}
if host, ok := o["host"]; ok {
cd.Host = host
}
if password, ok := o["password"]; ok {
cd.Password = password
}
if user, ok := o["user"]; ok {
cd.User = user
}
if port, ok := o["port"]; ok {
cd.Port = port
}

options := []string{"sslmode", "fallback_application_name", "connect_timeout", "sslcert", "sslkey", "sslrootcert"}

for i := range options {
if opt, ok := o[options[i]]; ok {
cd.Options[options[i]] = opt
}
}

return nil
}

func finalizerPostgreSQL(cd *ConnectionDetails) {
cd.Options["sslmode"] = defaults.String(cd.Options["sslmode"], "disable")
cd.Port = defaults.String(cd.Port, portPostgreSQL)
Expand All @@ -230,3 +279,117 @@ BEGIN
END LOOP;
END
$func$;`

// Code below is ported from: https://github.com/lib/pq/blob/master/conn.go
type values map[string]string

// scanner implements a tokenizer for libpq-style option strings.
type scanner struct {
s []rune
i int
}

// newScanner returns a new scanner initialized with the option string s.
func newScanner(s string) *scanner {
return &scanner{[]rune(s), 0}
}

// Next returns the next rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) Next() (rune, bool) {
if s.i >= len(s.s) {
return 0, false
}
r := s.s[s.i]
s.i++
return r, true
}

// SkipSpaces returns the next non-whitespace rune.
// It returns 0, false if the end of the text has been reached.
func (s *scanner) SkipSpaces() (rune, bool) {
r, ok := s.Next()
for unicode.IsSpace(r) && ok {
r, ok = s.Next()
}
return r, ok
}

// parseOpts parses the options from name and adds them to the values.
//
// The parsing code is based on conninfo_parse from libpq's fe-connect.c
func parseOpts(name string, o values) error {
s := newScanner(name)

for {
var (
keyRunes, valRunes []rune
r rune
ok bool
)

if r, ok = s.SkipSpaces(); !ok {
break
}

// Scan the key
for !unicode.IsSpace(r) && r != '=' {
keyRunes = append(keyRunes, r)
if r, ok = s.Next(); !ok {
break
}
}

// Skip any whitespace if we're not at the = yet
if r != '=' {
r, ok = s.SkipSpaces()
}

// The current character should be =
if r != '=' || !ok {
return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
}

// Skip any whitespace after the =
if r, ok = s.SkipSpaces(); !ok {
// If we reach the end here, the last value is just an empty string as per libpq.
o[string(keyRunes)] = ""
break
}

if r != '\'' {
for !unicode.IsSpace(r) {
if r == '\\' {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`missing character after backslash`)
}
}
valRunes = append(valRunes, r)

if r, ok = s.Next(); !ok {
break
}
}
} else {
quote:
for {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`unterminated quoted string literal in connection string`)
}
switch r {
case '\'':
break quote
case '\\':
r, _ = s.Next()
fallthrough
default:
valRunes = append(valRunes, r)
}
}
}

o[string(keyRunes)] = string(valRunes)
}

return nil
}
81 changes: 81 additions & 0 deletions dialect_postgresql_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package pop

import (
"testing"

"github.com/stretchr/testify/require"
)

func Test_PostgreSQL_Connection_String(t *testing.T) {
r := require.New(t)

url := "host=host port=port dbname=database user=user password=pass"
cd := &ConnectionDetails{
Dialect: "postgres",
URL: url,
}
err := cd.Finalize()
r.NoError(err)

r.Equal(url, cd.URL)
r.Equal("postgres", cd.Dialect)
r.Equal("host", cd.Host)
r.Equal("pass", cd.Password)
r.Equal("port", cd.Port)
r.Equal("user", cd.User)
r.Equal("database", cd.Database)
}

func Test_PostgreSQL_Connection_String_Options(t *testing.T) {
r := require.New(t)

url := "host=host port=port dbname=database user=user password=pass sslmode=disable fallback_application_name=test_app connect_timeout=10 sslcert=/some/location sslkey=/some/other/location sslrootcert=/root/location"
cd := &ConnectionDetails{
Dialect: "postgres",
URL: url,
}
err := cd.Finalize()
r.NoError(err)

r.Equal(url, cd.URL)

r.Equal("disable", cd.Options["sslmode"])
r.Equal("test_app", cd.Options["fallback_application_name"])
r.Equal("10", cd.Options["connect_timeout"])
r.Equal("/some/location", cd.Options["sslcert"])
r.Equal("/some/other/location", cd.Options["sslkey"])
r.Equal("/root/location", cd.Options["sslrootcert"])
}

func Test_PostgreSQL_Connection_String_Without_User(t *testing.T) {
r := require.New(t)

url := "dbname=database"
cd := &ConnectionDetails{
Dialect: "postgres",
URL: url,
}
err := cd.Finalize()
r.NoError(err)

r.Equal(url, cd.URL)
r.Equal("postgres", cd.Dialect)
r.Equal("", cd.Host)
r.Equal("", cd.Password)
r.Equal(portPostgreSQL, cd.Port) // fallback
r.Equal("", cd.User)
r.Equal("database", cd.Database)
}

func Test_PostgreSQL_Connection_String_Failure(t *testing.T) {
r := require.New(t)

url := "abc"
cd := &ConnectionDetails{
Dialect: "postgres",
URL: url,
}
err := cd.Finalize()
r.Error(err)
r.Equal("postgres", cd.Dialect)
}

0 comments on commit 3e6cf4d

Please sign in to comment.