Skip to content

Commit

Permalink
Add support for named parameters (#227)
Browse files Browse the repository at this point in the history
* Add support for named parameters

* Re-build static libraries

* consistent error check

* fix pr comments

---------

Co-authored-by: ajzo90 <[email protected]>
  • Loading branch information
ajzo90 and ajzo90 authored Jun 6, 2024
1 parent 2ac253b commit 0d97dd3
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 3 deletions.
Binary file modified deps/darwin_amd64/libduckdb.a
Binary file not shown.
Binary file modified deps/darwin_arm64/libduckdb.a
Binary file not shown.
28 changes: 25 additions & 3 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,36 @@ func (s *stmt) NumInput() int {
}

func (s *stmt) bind(args []driver.NamedValue) error {
if s.NumInput() != len(args) {
if s.NumInput() > len(args) {
return fmt.Errorf("incorrect argument count for command: have %d want %d", len(args), s.NumInput())
}

// FIXME (feature): we can't pass nested types as parameters (bind_value) yet

for i, v := range args {
switch v := v.Value.(type) {
// relaxed length check allow for unused parameters.
for i := 0; i < s.NumInput(); i++ {
name := C.duckdb_parameter_name(*s.stmt, C.idx_t(i+1))
paramName := C.GoString(name)
C.duckdb_free(unsafe.Pointer(name))

// fallback on index position
var arg = args[i]

// override with ordinal if set
for _, v := range args {
if v.Ordinal == i+1 {
arg = v
}
}

// override with name if set
for _, v := range args {
if v.Name == paramName {
arg = v
}
}

switch v := arg.Value.(type) {
case bool:
if rv := C.duckdb_bind_boolean(*s.stmt, C.idx_t(i+1), C.bool(v)); rv == C.DuckDBError {
return errCouldNotBind
Expand Down
19 changes: 19 additions & 0 deletions statement_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package duckdb

import (
"context"
"database/sql"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -20,6 +22,23 @@ func TestPrepareQuery(t *testing.T) {
defer rows.Close()
}

func TestPrepareQueryNamed(t *testing.T) {
db := openDB(t)
defer db.Close()
createFooTable(db, t)

stmt, err := db.PrepareContext(context.Background(), "SELECT $foo, $bar, $baz, $foo")
require.NoError(t, err)
defer stmt.Close()
var foo, bar, foo2 int
var baz string
err = stmt.QueryRow(sql.Named("baz", "x"), sql.Named("foo", 1), sql.Named("bar", 2)).Scan(&foo, &bar, &baz, &foo2)
require.NoError(t, err)
if foo != 1 || bar != 2 || baz != "x" || foo2 != 1 {
require.Fail(t, "bad values: %d %d %s %d", foo, bar, baz, foo2)
}
}

func TestPrepareWithError(t *testing.T) {
db := openDB(t)
defer db.Close()
Expand Down

0 comments on commit 0d97dd3

Please sign in to comment.