Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

all: switch internal API's to use driver.NamedValue instead of driver.Value #1068

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,12 +867,20 @@ func (cn *conn) Close() (err error) {
return cn.sendSimpleMessage('X')
}

func toNamedValue(v []driver.Value) []driver.NamedValue {
v2 := make([]driver.NamedValue, len(v))
for i := range v {
v2[i] = driver.NamedValue{Ordinal: i + 1, Value: v[i]}
}
return v2
}

// Implement the "Queryer" interface
func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return cn.query(query, args)
return cn.query(query, toNamedValue(args))
}

func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
func (cn *conn) query(query string, args []driver.NamedValue) (_ *rows, err error) {
if err := cn.err.get(); err != nil {
return nil, err
}
Expand Down Expand Up @@ -921,7 +929,7 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err
}

if cn.binaryParameters {
cn.sendBinaryModeQuery(query, args)
cn.sendBinaryModeQuery(query, toNamedValue(args))

cn.readParseResponse()
cn.readBindResponse()
Expand Down Expand Up @@ -1379,10 +1387,10 @@ func (st *stmt) Close() (err error) {
}

func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
return st.query(v)
return st.query(toNamedValue(v))
}

func (st *stmt) query(v []driver.Value) (r *rows, err error) {
func (st *stmt) query(v []driver.NamedValue) (r *rows, err error) {
if err := st.cn.err.get(); err != nil {
return nil, err
}
Expand All @@ -1395,18 +1403,11 @@ func (st *stmt) query(v []driver.Value) (r *rows, err error) {
}, nil
}

func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
if err := st.cn.err.get(); err != nil {
return nil, err
}
defer st.cn.errRecover(&err)

st.exec(v)
res, _, err = st.cn.readExecuteResponse("simple query")
return res, err
func (st *stmt) Exec(v []driver.Value) (driver.Result, error) {
return st.ExecContext(context.Background(), toNamedValue(v))
}

func (st *stmt) exec(v []driver.Value) {
func (st *stmt) exec(v []driver.NamedValue) {
if len(v) >= 65536 {
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
}
Expand All @@ -1425,10 +1426,10 @@ func (st *stmt) exec(v []driver.Value) {
w.int16(0)
w.int16(len(v))
for i, x := range v {
if x == nil {
if x.Value == nil {
w.int32(-1)
} else {
b := encode(&cn.parameterStatus, x, st.paramTyps[i])
b := encode(&cn.parameterStatus, x.Value, st.paramTyps[i])
w.int32(len(b))
w.bytes(b)
}
Expand Down Expand Up @@ -1684,13 +1685,13 @@ func md5s(s string) string {
return fmt.Sprintf("%x", h.Sum(nil))
}

func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.NamedValue) {
// Do one pass over the parameters to see if we're going to send any of
// them over in binary. If we are, create a paramFormats array at the
// same time.
var paramFormats []int
for i, x := range args {
_, ok := x.([]byte)
_, ok := x.Value.([]byte)
if ok {
if paramFormats == nil {
paramFormats = make([]int, len(args))
Expand All @@ -1709,17 +1710,17 @@ func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {

b.int16(len(args))
for _, x := range args {
if x == nil {
if x.Value == nil {
b.int32(-1)
} else {
datum := binaryEncode(&cn.parameterStatus, x)
datum := binaryEncode(&cn.parameterStatus, x.Value)
b.int32(len(datum))
b.bytes(datum)
}
}
}

func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
func (cn *conn) sendBinaryModeQuery(query string, args []driver.NamedValue) {
if len(args) >= 65536 {
errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
}
Expand Down
28 changes: 11 additions & 17 deletions conn_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@ const (

// Implement the "QueryerContext" interface
func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}
finish := cn.watchCancel(ctx)
r, err := cn.query(query, list)
r, err := cn.query(query, args)
if err != nil {
if finish != nil {
finish()
Expand Down Expand Up @@ -183,12 +179,8 @@ func (cn *conn) cancel(ctx context.Context) error {

// Implement the "StmtQueryContext" interface
func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}
finish := st.watchCancel(ctx)
r, err := st.query(list)
r, err := st.query(args)
if err != nil {
if finish != nil {
finish()
Expand All @@ -200,17 +192,19 @@ func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (dri
}

// Implement the "StmtExecContext" interface
func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
list := make([]driver.Value, len(args))
for i, nv := range args {
list[i] = nv.Value
}

func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, err error) {
if finish := st.watchCancel(ctx); finish != nil {
defer finish()
}

return st.Exec(list)
if err := st.cn.err.get(); err != nil {
return nil, err
}
defer st.cn.errRecover(&err)

st.exec(args)
res, _, err = st.cn.readExecuteResponse("simple query")
return res, err
}

// watchCancel is implemented on stmt in order to not mark the parent conn as bad
Expand Down