Skip to content

Commit

Permalink
Add a test for a simple UDF
Browse files Browse the repository at this point in the history
  • Loading branch information
JAicewizard committed May 4, 2024
1 parent 8e5d169 commit d419d43
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 3 deletions.
5 changes: 2 additions & 3 deletions examples/udf/udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,16 @@ func (d *tableUDF) BindArguments(args ...interface{}) []duckdb.ColumnName {
}
}

func (d *tableUDF) FillRow(row duckdb.Row) bool{
func (d *tableUDF) FillRow(row duckdb.Row) bool {
fmt.Println(d.count, d.n)
if d.count > d.n {
return false
}
d.count++
duckdb.SetRowValue[int64](row, 0, d.count)
duckdb.SetRowValue[int64](row, 0, d.count)
return true
}


func main() {
var err error
db, err = sql.Open("duckdb", "?access_mode=READ_WRITE")
Expand Down
126 changes: 126 additions & 0 deletions udf_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package duckdb

import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
)

type wrongValueError struct {
rowIdx int
colIdx int
colName string
expected any
got any
}

type tableUDF struct {
n int64
count int64
}

func (wve wrongValueError) Error() string {
return fmt.Sprintf("Wrong value at row %d, column %d(%s): Expected %v of type %[4]T, found %v of type %[5]T", wve.rowIdx, wve.colIdx, wve.colName, wve.expected, wve.got)
}

func (d *tableUDF) GetArguments() []interface{} {
return []interface{}{
int64(0),
}
}

func (d *tableUDF) BindArguments(args ...interface{}) []ColumnName {
d.count = 0
d.n = args[0].(int64)
return []ColumnName{
{Name: "result", V: int64(0)},
}
}

func (d *tableUDF) FillRow(row Row) bool {
if d.count > d.n {
return false
}
d.count++
SetRowValue[int64](row, 0, d.count)
return true
}

func (d *tableUDF) GetValue(r, c int) any {
return int64(r + 1)
}

func BenchmarkTableUDF(b *testing.B) {
b.StopTimer()
var err error
db, err := sql.Open("duckdb", "?access_mode=READ_WRITE")
if err != nil {
b.Fatal(err)
}
defer db.Close()
conn, _ := db.Conn(context.Background())
var fun tableUDF
RegisterTableUDF(conn, &fun)
b.StartTimer()
for n := 0; n < b.N; n++ {
rows, err := db.QueryContext(context.Background(), "SELECT * FROM whoo(2048)")
if err != nil {
b.Fatal(err)
}
defer rows.Close()
}
}

func TestTableUDF(t *testing.T) {
var err error
db, err := sql.Open("duckdb", "?access_mode=READ_WRITE")
if err != nil {
t.Fatal(err)
}
defer db.Close()
conn, _ := db.Conn(context.Background())
var fun tableUDF
RegisterTableUDF(conn, &fun)
rows, err := db.QueryContext(context.Background(), "SELECT * FROM whoo(2048)")
if err != nil {
t.Fatal(err)
}

//TODO: check column names
columns, err := rows.Columns()
if err != nil {
t.Fatal(err)
}

values := make([]interface{}, len(columns))
scanArgs := make([]interface{}, len(values))
for i := range values {
scanArgs[i] = &values[i]
}

// Fetch rows
var r int
for rows.Next() {
err = rows.Scan(scanArgs...)
if err != nil {
panic(err.Error())
}
for i, value := range values {
expected := fun.GetValue(r, i)
if !reflect.DeepEqual(expected, value) {
err := wrongValueError{
rowIdx: r,
colIdx: i,
colName: columns[i],
expected: expected,
got: value,
}
t.Log(err)
t.Fail()
}
}
r++
}
}

0 comments on commit d419d43

Please sign in to comment.