-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #201 from JAicewizard/main
Table function UDFs
- Loading branch information
Showing
13 changed files
with
1,622 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"fmt" | ||
|
||
"github.com/marcboeker/go-duckdb" | ||
) | ||
|
||
type incrementTableUDF struct { | ||
tableSize int64 | ||
currentRow int64 | ||
} | ||
|
||
func bindTableUDF(namedArgs map[string]any, args ...interface{}) (duckdb.RowTableSource, error) { | ||
return &incrementTableUDF{ | ||
currentRow: 0, | ||
tableSize: args[0].(int64), | ||
}, nil | ||
} | ||
|
||
func (udf *incrementTableUDF) ColumnInfos() []duckdb.ColumnInfo { | ||
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT) | ||
check(err) | ||
return []duckdb.ColumnInfo{{Name: "result", T: t}} | ||
} | ||
|
||
func (udf *incrementTableUDF) Init() {} | ||
|
||
func (udf *incrementTableUDF) FillRow(row duckdb.Row) (bool, error) { | ||
if udf.currentRow+1 > udf.tableSize { | ||
return false, nil | ||
} | ||
udf.currentRow++ | ||
err := duckdb.SetRowValue(row, 0, udf.currentRow) | ||
return true, err | ||
} | ||
|
||
func (udf *incrementTableUDF) Cardinality() *duckdb.CardinalityInfo { | ||
return &duckdb.CardinalityInfo{ | ||
Cardinality: uint(udf.tableSize), | ||
Exact: true, | ||
} | ||
} | ||
|
||
func main() { | ||
db, err := sql.Open("duckdb", "?access_mode=READ_WRITE") | ||
check(err) | ||
|
||
conn, err := db.Conn(context.Background()) | ||
check(err) | ||
|
||
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT) | ||
check(err) | ||
udf := duckdb.RowTableFunction{ | ||
Config: duckdb.TableFunctionConfig{ | ||
Arguments: []duckdb.TypeInfo{t}, | ||
}, | ||
BindArguments: bindTableUDF, | ||
} | ||
|
||
err = duckdb.RegisterTableUDF(conn, "increment", udf) | ||
check(err) | ||
|
||
rows, err := db.QueryContext(context.Background(), `SELECT * FROM increment(100)`) | ||
check(err) | ||
|
||
// Get the column names. | ||
columns, err := rows.Columns() | ||
check(err) | ||
|
||
values := make([]interface{}, len(columns)) | ||
args := make([]interface{}, len(values)) | ||
for i := range values { | ||
args[i] = &values[i] | ||
} | ||
|
||
rowSum := int64(0) | ||
for rows.Next() { | ||
err = rows.Scan(args...) | ||
check(err) | ||
for _, value := range values { | ||
// Keep in mind that the value can be nil for NULL values. | ||
// This never happens in this example, so we don't check for it. | ||
rowSum += value.(int64) | ||
} | ||
} | ||
fmt.Printf("row sum: %d", rowSum) | ||
|
||
check(rows.Close()) | ||
check(conn.Close()) | ||
check(db.Close()) | ||
} | ||
|
||
func check(err interface{}) { | ||
if err != nil { | ||
panic(err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"database/sql" | ||
"fmt" | ||
"sync" | ||
|
||
"github.com/marcboeker/go-duckdb" | ||
) | ||
|
||
type ( | ||
parallelIncrementTableUDF struct { | ||
lock *sync.Mutex | ||
claimed int64 | ||
n int64 | ||
} | ||
|
||
localTableState struct { | ||
start int64 | ||
end int64 | ||
} | ||
) | ||
|
||
func bindParallelTableUDF(namedArgs map[string]any, args ...interface{}) (duckdb.ParallelRowTableSource, error) { | ||
return ¶llelIncrementTableUDF{ | ||
lock: &sync.Mutex{}, | ||
claimed: 0, | ||
n: args[0].(int64), | ||
}, nil | ||
} | ||
|
||
func (udf *parallelIncrementTableUDF) ColumnInfos() []duckdb.ColumnInfo { | ||
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT) | ||
check(err) | ||
return []duckdb.ColumnInfo{{Name: "result", T: t}} | ||
} | ||
|
||
func (udf *parallelIncrementTableUDF) Init() duckdb.ParallelTableSourceInfo { | ||
return duckdb.ParallelTableSourceInfo{ | ||
MaxThreads: 8, | ||
} | ||
} | ||
|
||
func (udf *parallelIncrementTableUDF) NewLocalState() any { | ||
return &localTableState{ | ||
start: 0, | ||
end: -1, | ||
} | ||
} | ||
|
||
func (udf *parallelIncrementTableUDF) FillRow(localState any, row duckdb.Row) (bool, error) { | ||
state := localState.(*localTableState) | ||
|
||
if state.start >= state.end { | ||
// Claim a new work unit. | ||
udf.lock.Lock() | ||
remaining := udf.n - udf.claimed | ||
|
||
if remaining <= 0 { | ||
// No more work. | ||
udf.lock.Unlock() | ||
return false, nil | ||
} else if remaining >= 2024 { | ||
remaining = 2024 | ||
} | ||
|
||
state.start = udf.claimed | ||
udf.claimed += remaining | ||
state.end = udf.claimed | ||
udf.lock.Unlock() | ||
} | ||
|
||
state.start++ | ||
err := duckdb.SetRowValue(row, 0, state.start) | ||
return true, err | ||
} | ||
|
||
func (udf *parallelIncrementTableUDF) GetValue(r, c int) any { | ||
return int64(r + 1) | ||
} | ||
|
||
func (udf *parallelIncrementTableUDF) GetTypes() []any { | ||
return []any{0} | ||
} | ||
|
||
func (udf *parallelIncrementTableUDF) Cardinality() *duckdb.CardinalityInfo { | ||
return nil | ||
} | ||
|
||
func (udf *parallelIncrementTableUDF) GetFunction() duckdb.ParallelRowTableFunction { | ||
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT) | ||
check(err) | ||
|
||
return duckdb.ParallelRowTableFunction{ | ||
Config: duckdb.TableFunctionConfig{ | ||
Arguments: []duckdb.TypeInfo{t}, | ||
}, | ||
BindArguments: bindParallelTableUDF, | ||
} | ||
} | ||
|
||
func main() { | ||
db, err := sql.Open("duckdb", "?access_mode=READ_WRITE") | ||
check(err) | ||
|
||
conn, err := db.Conn(context.Background()) | ||
check(err) | ||
|
||
t, err := duckdb.NewTypeInfo(duckdb.TYPE_BIGINT) | ||
check(err) | ||
udf := duckdb.ParallelRowTableFunction{ | ||
Config: duckdb.TableFunctionConfig{ | ||
Arguments: []duckdb.TypeInfo{t}, | ||
}, | ||
BindArguments: bindParallelTableUDF, | ||
} | ||
|
||
err = duckdb.RegisterTableUDF(conn, "increment", udf) | ||
check(err) | ||
|
||
rows, err := db.QueryContext(context.Background(), `SELECT * FROM increment(10000)`) | ||
check(err) | ||
|
||
// Get the column names. | ||
columns, err := rows.Columns() | ||
check(err) | ||
|
||
values := make([]interface{}, len(columns)) | ||
args := make([]interface{}, len(values)) | ||
for i := range values { | ||
args[i] = &values[i] | ||
} | ||
|
||
rowSum := int64(0) | ||
for rows.Next() { | ||
err = rows.Scan(args...) | ||
check(err) | ||
for _, value := range values { | ||
// Keep in mind that the value could be nil for NULL values. | ||
// This never happens in our case, so we don't check for it. | ||
rowSum += value.(int64) | ||
} | ||
} | ||
fmt.Printf("row sum: %d", rowSum) | ||
|
||
check(rows.Close()) | ||
check(conn.Close()) | ||
check(db.Close()) | ||
} | ||
|
||
func check(err interface{}) { | ||
if err != nil { | ||
panic(err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
package duckdb | ||
|
||
/* | ||
#include <duckdb.h> | ||
*/ | ||
import "C" | ||
|
||
// Row represents one row in duckdb. It references the internal vectors. | ||
type Row struct { | ||
chunk *DataChunk | ||
r C.idx_t | ||
projection []int | ||
} | ||
|
||
// IsProjected returns whether the column is projected. | ||
func (r Row) IsProjected(colIdx int) bool { | ||
return r.projection[colIdx] != -1 | ||
} | ||
|
||
// SetRowValue sets the value at colIdx to val. | ||
// Returns an error on failure, and nil for non-projected columns. | ||
func SetRowValue[T any](row Row, colIdx int, val T) error { | ||
projectedIdx := row.projection[colIdx] | ||
if projectedIdx < 0 || projectedIdx >= len(row.chunk.columns) { | ||
return nil | ||
} | ||
vec := row.chunk.columns[projectedIdx] | ||
return setVectorVal(&vec, row.r, val) | ||
} | ||
|
||
// SetRowValue sets the value at colIdx to val. Returns an error on failure. | ||
func (r Row) SetRowValue(colIdx int, val any) error { | ||
if !r.IsProjected(colIdx) { | ||
return nil | ||
} | ||
return r.chunk.SetValue(colIdx, int(r.r), val) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.