Skip to content

Commit

Permalink
Add support for cardinality and projection pushdown
Browse files Browse the repository at this point in the history
  • Loading branch information
JAicewizard committed May 12, 2024
1 parent 1436c7c commit 7454a0f
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 8 deletions.
8 changes: 7 additions & 1 deletion examples/udf/udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ func (d *tableUDF) Init() duckdb.TableFunctionInitData {
}

func (d *tableUDF) FillRow(row duckdb.Row) bool {
fmt.Println(d.count, d.n)
if d.count > d.n {
return false
}
Expand All @@ -56,6 +55,13 @@ func (d *tableUDF) FillRow(row duckdb.Row) bool {
return true
}

func (d * tableUDF) Cardinality() *duckdb.CardinalityData {
return &duckdb.CardinalityData{
Cardinality: uint(d.n),
IsExact: true,
}
}

func main() {
var err error
db, err = sql.Open("duckdb", "?access_mode=READ_WRITE")
Expand Down
65 changes: 58 additions & 7 deletions udf.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,56 @@ type (
vectors []vector
r C.idx_t
info C.duckdb_function_info
projection []int
}

ColumnName struct {
Name string
V any
}

CardinalityData struct {
Cardinality uint
IsExact bool
}

TableFunctionInitData struct {
MaxThreads int
}

//TODO: tableFunctionData
tableFunctionMetaInstance struct {
fun TableFunctionInstance
projection []int
}

//TODO: TableFunction
TableFunctionInstance interface {
Init() TableFunctionInitData
FillRow(Row) bool
Cardinality() *CardinalityData
}

//TODO: TableFunctionProvider
TableFunction interface {
GetArguments() []any
BindArguments(args ...interface{}) (TableFunctionInstance, []ColumnName)
}
)

// Returns whether or now the column is projected
func (r Row) IsProjected(c int) bool {
return r.projection[c] != -1
}

func SetRowValue[T any](row Row, c int, val T) error {
vec := row.vectors[c]
return setVectorVal[T](&vec, row.r, val)
if !row.IsProjected(c) {
// we want to allow setting to columns that are not projected,
// it should just be a nop.
return nil
}
vec := row.vectors[row.projection[c]]
return setVectorVal(&vec, row.r, val)
}

func (row Row) SetRowValue(c int, val any) {
Expand Down Expand Up @@ -109,28 +134,53 @@ func udf_bind(info C.duckdb_bind_info) {
C.duckdb_bind_add_result_column(info, colName, t)
}

handle := cgo.NewHandle(instance)


cardinality := instance.Cardinality()
if cardinality != nil {
C.duckdb_bind_set_cardinality(info, C.idx_t(cardinality.Cardinality), C.bool(cardinality.IsExact))
}

instanceData := tableFunctionMetaInstance{
fun: instance,
projection: make([]int, len(returnvalues)),
}

for i := range returnvalues {
instanceData.projection[i] = -1
}

handle := cgo.NewHandle(instanceData)
C.duckdb_bind_set_bind_data(info, unsafe.Pointer(&handle), C.duckdb_delete_callback_t(C.udf_destroy_data))
}

//export udf_init
func udf_init(info C.duckdb_init_info) {
instanceRef := C.duckdb_init_get_bind_data(info)
h := *(*cgo.Handle)(instanceRef)
instance := h.Value().(TableFunctionInstance)
initData := instance.Init()
instance := h.Value().(tableFunctionMetaInstance)
initData := instance.fun.Init()

columnCount := C.duckdb_init_get_column_count(info)
for i := C.idx_t(0); i < columnCount; i++ {
srcPos := int(C.duckdb_init_get_column_index(info, i))
instance.projection[srcPos] = int(i)
}

C.duckdb_init_set_max_threads(info, C.idx_t(initData.MaxThreads))
}

//export udf_callback
func udf_callback(info C.duckdb_function_info, output C.duckdb_data_chunk) {
instanceRef := C.duckdb_function_get_bind_data(info)
h := *(*cgo.Handle)(instanceRef)
instance := h.Value().(TableFunctionInstance)
instance := h.Value().(tableFunctionMetaInstance)
fun := instance.fun

columnCount := C.duckdb_data_chunk_get_column_count(output)
var row Row
row.vectors = make([]vector, columnCount)
row.projection = instance.projection
var err error
for i := C.idx_t(0); i < columnCount; i++ {
duckdbVector := C.duckdb_data_chunk_get_vector(output, i)
Expand All @@ -151,7 +201,7 @@ func udf_callback(info C.duckdb_function_info, output C.duckdb_data_chunk) {
maxSize := C.duckdb_vector_size()
// At the end of the loop row.r must be the index one past the last added row
for row.r = 0; row.r < maxSize; row.r++ {
nextResults := instance.FillRow(row)
nextResults := fun.FillRow(row)
if !nextResults {
break
}
Expand All @@ -174,6 +224,7 @@ func RegisterTableUDF(c *sql.Conn, name string, function TableFunction) error {
C.duckdb_table_function_set_init(tableFunction, C.init(C.udf_init))
C.duckdb_table_function_set_function(tableFunction, C.callback(C.udf_callback))
C.duckdb_table_function_set_extra_info(tableFunction, unsafe.Pointer(&handle), C.duckdb_delete_callback_t(C.udf_destroy_data))
C.duckdb_table_function_supports_projection_pushdown(tableFunction, C.bool(true))

argumentvalues := function.GetArguments()

Expand Down
9 changes: 9 additions & 0 deletions udf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ func (d *incTableUDF) GetTypes() []any {
}
}

func (d *incTableUDF) Cardinality() *CardinalityData {
return nil
}

func (d *structTableUDF) GetArguments() []interface{} {
return []interface{}{
int64(0),
Expand Down Expand Up @@ -134,6 +138,11 @@ func (d *structTableUDF) GetValue(r, c int) any {
}
}

func (d *structTableUDF) Cardinality() *CardinalityData {
return nil
}


func TestTableUDF(t *testing.T) {
for _, fun := range tudfs {
_fun := fun
Expand Down

0 comments on commit 7454a0f

Please sign in to comment.