Skip to content

Commit

Permalink
Replacement scan (#208)
Browse files Browse the repository at this point in the history
* Initial commit for replacement scan

* lintfix

* use singe struct from replacement state

* consistent error check in test

* remove error return for RegisterReplacementScan

* update test for replacement scan on range

* remove lock and fix dealloc for duckdb_value

* consistent error reporting

* store replacement scan function inside duckdb instance instead of global variable

* fix gco
  • Loading branch information
ajzo90 authored Jun 10, 2024
1 parent 0b53b1e commit 9eca001
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
64 changes: 64 additions & 0 deletions replacement_scan.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package duckdb

/*
#include <stdlib.h>
#include <duckdb.h>
void replacement_scan_cb(duckdb_replacement_scan_info info, const char *table_name, void *data);
typedef const char cchar_t;
void replacement_scan_destroy_data(void *);
*/
import "C"
import (
"runtime/cgo"
"unsafe"
)

type ReplacementScanCallback func(tableName string) (string, []any, error)

func RegisterReplacementScan(connector *Connector, cb ReplacementScanCallback) {
handle := cgo.NewHandle(cb)
C.duckdb_add_replacement_scan(connector.db, C.duckdb_replacement_callback_t(C.replacement_scan_cb), unsafe.Pointer(&handle), C.duckdb_delete_callback_t(C.replacement_scan_destroy_data))
}

//export replacement_scan_destroy_data
func replacement_scan_destroy_data(data unsafe.Pointer) {
h := *(*cgo.Handle)(data)
h.Delete()
}

//export replacement_scan_cb
func replacement_scan_cb(info C.duckdb_replacement_scan_info, table_name *C.cchar_t, data *C.void) {
h := *(*cgo.Handle)(unsafe.Pointer(data))
scanner := h.Value().(ReplacementScanCallback)
tFunc, params, err := scanner(C.GoString(table_name))
if err != nil {
errstr := C.CString(err.Error())
C.duckdb_replacement_scan_set_error(info, errstr)
C.free(unsafe.Pointer(errstr))
return
}

fNameStr := C.CString(tFunc)
C.duckdb_replacement_scan_set_function_name(info, fNameStr)
defer C.free(unsafe.Pointer(fNameStr))

for _, v := range params {
switch x := v.(type) {
case string:
str := C.CString(x)
val := C.duckdb_create_varchar(str)
C.duckdb_replacement_scan_add_parameter(info, val)
C.free(unsafe.Pointer(str))
C.duckdb_destroy_value(&val)
case int64:
val := C.duckdb_create_int64(C.int64_t(x))
C.duckdb_replacement_scan_add_parameter(info, val)
C.duckdb_destroy_value(&val)
default:
errstr := C.CString("invalid type")
C.duckdb_replacement_scan_set_error(info, errstr)
C.free(unsafe.Pointer(errstr))
return
}
}
}
44 changes: 44 additions & 0 deletions replacement_scan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package duckdb

import (
"database/sql"
"database/sql/driver"
"github.com/stretchr/testify/require"
"testing"
)

func TestReplacementScan(t *testing.T) {

connector, err := NewConnector("", func(execer driver.ExecerContext) error {
return nil
})

require.NoError(t, err)
defer connector.Close()

var rangeRows = 100
RegisterReplacementScan(connector, func(tableName string) (string, []any, error) {
return "range", []any{int64(rangeRows)}, nil
})

db := sql.OpenDB(connector)
rows, err := db.Query("select * from any_table")
require.NoError(t, err)
defer rows.Close()

for i := 0; rows.Next(); i++ {
var val int
require.NoError(t, rows.Scan(&val))
if val != i {
require.Fail(t, "expected %d, got %d", i, val)
}
rangeRows--
}
if rows.Err() != nil {
require.NoError(t, rows.Err())
}
if rangeRows != 0 {
require.Fail(t, "expected 0, got %d", rangeRows)
}

}

0 comments on commit 9eca001

Please sign in to comment.