diff --git a/drivers/cassandra/cassandra.go b/drivers/cassandra/cassandra.go index 124cbebb5bc..aa80c04cb65 100644 --- a/drivers/cassandra/cassandra.go +++ b/drivers/cassandra/cassandra.go @@ -20,41 +20,6 @@ import ( "github.com/xo/usql/drivers" ) -// logger is a null logger that satisfies the gocql.StdLogger and the io.Writer -// interfaces in order to capture the last error issued by the cql/gocql -// packages, since the cql package does not (at this time) return any error -// other than sql.ErrBadConn. -type logger struct { - debug bool - last string -} - -func (l *logger) Print(v ...interface{}) { - if l.debug { - log.Print(v...) - } -} - -func (l *logger) Printf(s string, v ...interface{}) { - if l.debug { - log.Printf(s, v...) - } -} - -func (l *logger) Println(v ...interface{}) { - if l.debug { - log.Println(v...) - } -} - -func (l *logger) Write(buf []byte) (int, error) { - if l.debug { - log.Printf("WRITE: %s", string(buf)) - } - l.last = string(buf) - return len(buf), nil -} - func init() { var debug bool if s := os.Getenv("CQL_DEBUG"); s != "" { @@ -76,7 +41,7 @@ func init() { u.RawQuery = q.Encode() } }, - Open: func(u *dburl.URL, stdout, stderr func() io.Writer) (func(string, string) (*sql.DB, error), error) { + Open: func(_ context.Context, u *dburl.URL, stdout, stderr func() io.Writer) (func(string, string) (*sql.DB, error), error) { // override cql and gocql loggers l = &logger{debug: debug} gocql.Logger, cql.CqlDriver.Logger = l, log.New(l, "", 0) @@ -124,3 +89,38 @@ func init() { }, }) } + +// logger is a null logger that satisfies the gocql.StdLogger and the io.Writer +// interfaces in order to capture the last error issued by the cql/gocql +// packages, since the cql package does not (at this time) return any error +// other than sql.ErrBadConn. +type logger struct { + debug bool + last string +} + +func (l *logger) Print(v ...interface{}) { + if l.debug { + log.Print(v...) + } +} + +func (l *logger) Printf(s string, v ...interface{}) { + if l.debug { + log.Printf(s, v...) + } +} + +func (l *logger) Println(v ...interface{}) { + if l.debug { + log.Println(v...) + } +} + +func (l *logger) Write(buf []byte) (int, error) { + if l.debug { + log.Printf("WRITE: %s", string(buf)) + } + l.last = string(buf) + return len(buf), nil +} diff --git a/drivers/drivers.go b/drivers/drivers.go index 351891ebae0..d45f6dbb413 100644 --- a/drivers/drivers.go +++ b/drivers/drivers.go @@ -64,7 +64,7 @@ type Driver struct { // ForceParams will be used to force parameters if defined. ForceParams func(*dburl.URL) // Open will be used by Open if defined. - Open func(*dburl.URL, func() io.Writer, func() io.Writer) (func(string, string) (*sql.DB, error), error) + Open func(context.Context, *dburl.URL, func() io.Writer, func() io.Writer) (func(string, string) (*sql.DB, error), error) // Version will be used by Version if defined. Version func(context.Context, DB) (string, error) // User will be used by User if defined. @@ -160,7 +160,7 @@ func ForceParams(u *dburl.URL) { } // Open opens a sql.DB connection for a driver. -func Open(u *dburl.URL, stdout, stderr func() io.Writer) (*sql.DB, error) { +func Open(ctx context.Context, u *dburl.URL, stdout, stderr func() io.Writer) (*sql.DB, error) { d, ok := drivers[u.Driver] if !ok { return nil, WrapErr(u.Driver, text.ErrDriverNotAvailable) @@ -168,7 +168,7 @@ func Open(u *dburl.URL, stdout, stderr func() io.Writer) (*sql.DB, error) { f := sql.Open if d.Open != nil { var err error - if f, err = d.Open(u, stdout, stderr); err != nil { + if f, err = d.Open(ctx, u, stdout, stderr); err != nil { return nil, WrapErr(u.Driver, err) } } @@ -506,7 +506,7 @@ func Copy(ctx context.Context, u *dburl.URL, stdout, stderr func() io.Writer, ro if d.Copy == nil { return 0, fmt.Errorf(text.NotSupportedByDriver, "copy", u.Driver) } - db, err := Open(u, stdout, stderr) + db, err := Open(ctx, u, stdout, stderr) if err != nil { return 0, err } diff --git a/drivers/moderncsqlite/moderncsqlite.go b/drivers/moderncsqlite/moderncsqlite.go index 484b3a1e3c1..1a010e62c6e 100644 --- a/drivers/moderncsqlite/moderncsqlite.go +++ b/drivers/moderncsqlite/moderncsqlite.go @@ -19,7 +19,7 @@ import ( func init() { drivers.Register("moderncsqlite", drivers.Driver{ AllowMultilineComments: true, - Open: func(u *dburl.URL, stdout, stderr func() io.Writer) (func(string, string) (*sql.DB, error), error) { + Open: func(_ context.Context, u *dburl.URL, stdout, stderr func() io.Writer) (func(string, string) (*sql.DB, error), error) { return func(_ string, params string) (*sql.DB, error) { return sql.Open("sqlite", params) }, nil diff --git a/drivers/postgres/postgres.go b/drivers/postgres/postgres.go index 9d28890a9a3..31118b23b53 100644 --- a/drivers/postgres/postgres.go +++ b/drivers/postgres/postgres.go @@ -10,6 +10,7 @@ package postgres import ( "context" "database/sql" + "errors" "fmt" "io" "strings" @@ -19,10 +20,32 @@ import ( "github.com/xo/usql/drivers" "github.com/xo/usql/drivers/metadata" pgmeta "github.com/xo/usql/drivers/metadata/postgres" + "github.com/xo/usql/env" "github.com/xo/usql/text" ) func init() { + openConn := func(stdout, stderr func() io.Writer, dsn string) (*sql.DB, error) { + conn, err := pq.NewConnector(dsn) + if err != nil { + return nil, err + } + noticeConn := pq.ConnectorWithNoticeHandler(conn, func(notice *pq.Error) { + out := stderr() + fmt.Fprintln(out, notice.Severity+": ", notice.Message) + if notice.Hint != "" { + fmt.Fprintln(out, "HINT: ", notice.Hint) + } + }) + notificationConn := pq.ConnectorWithNotificationHandler(noticeConn, func(notification *pq.Notification) { + var payload string + if notification.Extra != "" { + payload = fmt.Sprintf(text.NotificationPayload, notification.Extra) + } + fmt.Fprintln(stdout(), fmt.Sprintf(text.NotificationReceived, notification.Channel, payload, notification.BePid)) + }) + return sql.OpenDB(notificationConn), nil + } drivers.Register("postgres", drivers.Driver{ Name: "pq", AllowDollar: true, @@ -33,27 +56,27 @@ func init() { drivers.ForceQueryParameters([]string{"sslmode", "disable"})(u) } }, - Open: func(u *dburl.URL, stdout, stderr func() io.Writer) (func(string, string) (*sql.DB, error), error) { - return func(typ, dsn string) (*sql.DB, error) { - conn, err := pq.NewConnector(dsn) + Open: func(ctx context.Context, u *dburl.URL, stdout, stderr func() io.Writer) (func(string, string) (*sql.DB, error), error) { + return func(_, dsn string) (*sql.DB, error) { + conn, err := openConn(stdout, stderr, dsn) if err != nil { return nil, err } - noticeConn := pq.ConnectorWithNoticeHandler(conn, func(notice *pq.Error) { - out := stderr() - fmt.Fprintln(out, notice.Severity+": ", notice.Message) - if notice.Hint != "" { - fmt.Fprintln(out, "HINT: ", notice.Hint) + // special retry handling case, since there's no lib/pq retry mode + if env.Get("SSLMODE") == "retry" && !u.Query().Has("sslmode") { + switch err = conn.PingContext(ctx); { + case errors.Is(err, pq.ErrSSLNotSupported): + s := "sslmode=disable " + dsn + conn, err = openConn(stdout, stderr, s) + if err != nil { + return nil, err + } + u.DSN = s + case err != nil: + return nil, err } - }) - notificationConn := pq.ConnectorWithNotificationHandler(noticeConn, func(notification *pq.Notification) { - var payload string - if notification.Extra != "" { - payload = fmt.Sprintf(text.NotificationPayload, notification.Extra) - } - fmt.Fprintln(stdout(), fmt.Sprintf(text.NotificationReceived, notification.Channel, payload, notification.BePid)) - }) - return sql.OpenDB(notificationConn), nil + } + return conn, nil }, nil }, Version: func(ctx context.Context, db drivers.DB) (string, error) { diff --git a/env/types.go b/env/types.go index 11ec04603bd..938d58a0cfe 100644 --- a/env/types.go +++ b/env/types.go @@ -164,6 +164,10 @@ var envVarNames = []varName{ text.CommandUpper() + "RC", "alternative location for the user's .usqlrc file", }, + { + text.CommandUpper() + "_SSLMODE, SSLMODE", + "when set to 'retry', allows postgres connections to attempt to reconnect when no ?sslmode= was specified on the url", + }, { "SYNTAX_HL", "enable syntax highlighting", @@ -207,9 +211,10 @@ func (v Vars) All() map[string]string { var vars, pvars Vars func init() { + cmdNameUpper := strings.ToUpper(text.CommandName) // get USQL_* variables enableHostInformation := "true" - if v, _ := Getenv(strings.ToUpper(text.CommandName) + "_SHOW_HOST_INFORMATION"); v != "" { + if v, _ := Getenv(cmdNameUpper + "_SHOW_HOST_INFORMATION"); v != "" { enableHostInformation = v } // get color level @@ -219,7 +224,7 @@ func init() { enableSyntaxHL = "false" } // pager - pagerCmd, ok := Getenv(strings.ToUpper(text.CommandName)+"_PAGER", "PAGER") + pagerCmd, ok := Getenv(cmdNameUpper+"_PAGER", "PAGER") pager := "off" if !ok { for _, s := range []string{"less", "more"} { @@ -233,7 +238,12 @@ func init() { pager = "on" } // editor - editorCmd, _ := Getenv(strings.ToUpper(text.CommandName)+"_EDITOR", "EDITOR", "VISUAL") + editorCmd, _ := Getenv(cmdNameUpper+"_EDITOR", "EDITOR", "VISUAL") + // sslmode + sslmode, ok := Getenv(cmdNameUpper+"_SSLMODE", "SSLMODE") + if !ok { + sslmode = "retry" + } vars = Vars{ // usql related logic "SHOW_HOST_INFORMATION": enableHostInformation, @@ -247,6 +257,7 @@ func init() { "SYNTAX_HL_FORMAT": colorLevel.ChromaFormatterName(), "SYNTAX_HL_STYLE": "monokai", "SYNTAX_HL_OVERRIDE_BG": "true", + "SSLMODE": sslmode, } // determine locale locale := "en-US" diff --git a/handler/handler.go b/handler/handler.go index 6f173f43253..24a86a388ce 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -744,7 +744,7 @@ func (h *Handler) Open(ctx context.Context, params ...string) error { } // open connection var err error - h.db, err = drivers.Open(h.u, h.GetOutput, h.IO().Stderr) + h.db, err = drivers.Open(ctx, h.u, h.GetOutput, h.IO().Stderr) if err != nil && !drivers.IsPasswordErr(h.u, err) { defer h.Close() return err diff --git a/metacmd/cmds.go b/metacmd/cmds.go index 57cc381b7e8..67ac99319c2 100644 --- a/metacmd/cmds.go +++ b/metacmd/cmds.go @@ -833,6 +833,7 @@ func init() { "copy": {"copy query from source url to columns of table on destination url", "SRC DST QUERY TABLE(A,...)"}, }, Process: func(p *Params) error { + ctx := context.Background() stdout, stderr := p.Handler.IO().Stdout, p.Handler.IO().Stderr srcDsn, err := p.Get(true) if err != nil { @@ -858,17 +859,17 @@ func init() { if err != nil { return err } - src, err := drivers.Open(srcURL, stdout, stderr) + src, err := drivers.Open(ctx, srcURL, stdout, stderr) if err != nil { return err } defer src.Close() - dest, err := drivers.Open(destURL, stdout, stderr) + dest, err := drivers.Open(ctx, destURL, stdout, stderr) if err != nil { return err } defer dest.Close() - ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) defer cancel() // get the result set r, err := src.QueryContext(ctx, query)