Skip to content

Commit

Permalink
Merge pull request #235 from VitaliyMF/develop
Browse files Browse the repository at this point in the history
RegisterTableFunction errors handling #234
  • Loading branch information
Giorgi authored Nov 25, 2024
2 parents 1c96434 + 7ea612c commit 344fa96
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ public static class TableFunction
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_bind_set_bind_data")]
public static extern unsafe void DuckDBBindSetBindData(IntPtr info, IntPtr bindData, delegate* unmanaged[Cdecl]<IntPtr, void> destroy);

[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_bind_set_error")]
public static extern unsafe void DuckDBBindSetError(IntPtr info, SafeUnmanagedMemoryHandle error);

#endregion

#region TableFunction
Expand All @@ -61,6 +64,9 @@ public static class TableFunction
[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_function_get_bind_data")]
public static extern unsafe IntPtr DuckDBFunctionGetBindData(IntPtr info);

[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_function_set_error")]
public static extern unsafe void DuckDBFunctionSetError(IntPtr info, SafeUnmanagedMemoryHandle error);

#endregion
}
}
125 changes: 76 additions & 49 deletions DuckDB.NET.Data/DuckDBConnection.TableFunction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,37 +102,54 @@ private unsafe void RegisterTableFunctionInternal(string name, Func<IReadOnlyLis
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
public static unsafe void Bind(IntPtr info)
{
var handle = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBBindGetExtraInfo(info));

if (handle.Target is not TableFunctionInfo functionInfo)
IDuckDBValueReader[]? parameters = null;
try
{
throw new InvalidOperationException("User defined table function bind failed. Bind extra info is null");
}
var handle = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBBindGetExtraInfo(info));

var parameters = new IDuckDBValueReader[NativeMethods.TableFunction.DuckDBBindGetParameterCount(info)];
if (handle.Target is not TableFunctionInfo functionInfo)
{
throw new InvalidOperationException("User defined table function bind failed. Bind extra info is null");
}

for (var i = 0; i < parameters.Length; i++)
{
var value = NativeMethods.TableFunction.DuckDBBindGetParameter(info, (ulong)i);
parameters[i] = value;
}
parameters = new IDuckDBValueReader[NativeMethods.TableFunction.DuckDBBindGetParameterCount(info)];

for (var i = 0; i < parameters.Length; i++)
{
var value = NativeMethods.TableFunction.DuckDBBindGetParameter(info, (ulong)i);
parameters[i] = value;
}

var tableFunctionData = functionInfo.Bind(parameters);
var tableFunctionData = functionInfo.Bind(parameters);

foreach (var parameter in parameters)
foreach (var columnInfo in tableFunctionData.Columns)
{
using var logicalType = DuckDBTypeMap.GetLogicalType(columnInfo.Type);
NativeMethods.TableFunction.DuckDBBindAddResultColumn(info, columnInfo.Name.ToUnmanagedString(), logicalType);
}

var bindData = new TableFunctionBindData(tableFunctionData.Columns, tableFunctionData.Data.GetEnumerator());

NativeMethods.TableFunction.DuckDBBindSetBindData(info, bindData.ToHandle(), &DestroyExtraInfo);
}
catch (Exception ex)
{
((DuckDBValue)parameter).Dispose();
using (var errMsgHandle = ex.Message.ToUnmanagedString())
{
NativeMethods.TableFunction.DuckDBBindSetError(info, errMsgHandle);
}
return;
}

foreach (var columnInfo in tableFunctionData.Columns)
finally
{
using var logicalType = DuckDBTypeMap.GetLogicalType(columnInfo.Type);
NativeMethods.TableFunction.DuckDBBindAddResultColumn(info, columnInfo.Name.ToUnmanagedString(), logicalType);
if (parameters!=null)
foreach (var parameter in parameters)
{
if (parameter != null)
((DuckDBValue)parameter).Dispose();
}
}

var bindData = new TableFunctionBindData(tableFunctionData.Columns, tableFunctionData.Data.GetEnumerator());

NativeMethods.TableFunction.DuckDBBindSetBindData(info, bindData.ToHandle(), &DestroyExtraInfo);
}

[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
Expand All @@ -141,46 +158,56 @@ public static void Init(IntPtr info) { }
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
public static void TableFunction(IntPtr info, IntPtr chunk)
{
var bindData = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBFunctionGetBindData(info));
var extraInfo = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBFunctionGetExtraInfo(info));

if (bindData.Target is not TableFunctionBindData tableFunctionBindData)
try
{
throw new InvalidOperationException("User defined table function failed. Function bind data is null");
}
var bindData = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBFunctionGetBindData(info));
var extraInfo = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBFunctionGetExtraInfo(info));

if (extraInfo.Target is not TableFunctionInfo tableFunctionInfo)
{
throw new InvalidOperationException("User defined table function failed. Function extra info is null");
}
if (bindData.Target is not TableFunctionBindData tableFunctionBindData)
{
throw new InvalidOperationException("User defined table function failed. Function bind data is null");
}

var dataChunk = new DuckDBDataChunk(chunk);
if (extraInfo.Target is not TableFunctionInfo tableFunctionInfo)
{
throw new InvalidOperationException("User defined table function failed. Function extra info is null");
}

var writers = new VectorDataWriterBase[tableFunctionBindData.Columns.Count];
for (var columnIndex = 0; columnIndex < tableFunctionBindData.Columns.Count; columnIndex++)
{
var column = tableFunctionBindData.Columns[columnIndex];
var vector = NativeMethods.DataChunks.DuckDBDataChunkGetVector(dataChunk, columnIndex);
var dataChunk = new DuckDBDataChunk(chunk);

using var logicalType = DuckDBTypeMap.GetLogicalType(column.Type);
writers[columnIndex] = VectorDataWriterFactory.CreateWriter(vector, logicalType);
}
var writers = new VectorDataWriterBase[tableFunctionBindData.Columns.Count];
for (var columnIndex = 0; columnIndex < tableFunctionBindData.Columns.Count; columnIndex++)
{
var column = tableFunctionBindData.Columns[columnIndex];
var vector = NativeMethods.DataChunks.DuckDBDataChunkGetVector(dataChunk, columnIndex);

ulong size = 0;
using var logicalType = DuckDBTypeMap.GetLogicalType(column.Type);
writers[columnIndex] = VectorDataWriterFactory.CreateWriter(vector, logicalType);
}

for (; size < DuckDBGlobalData.VectorSize; size++)
{
if (tableFunctionBindData.DataEnumerator.MoveNext())
ulong size = 0;

for (; size < DuckDBGlobalData.VectorSize; size++)
{
tableFunctionInfo.Mapper(tableFunctionBindData.DataEnumerator.Current, writers, size);
if (tableFunctionBindData.DataEnumerator.MoveNext())
{
tableFunctionInfo.Mapper(tableFunctionBindData.DataEnumerator.Current, writers, size);
}
else
{
break;
}
}
else

NativeMethods.DataChunks.DuckDBDataChunkSetSize(dataChunk, size);
}
catch (Exception ex)
{
using (var errMsgHandle = ex.Message.ToUnmanagedString())
{
break;
NativeMethods.TableFunction.DuckDBFunctionSetError(info, errMsgHandle);
}
}

NativeMethods.DataChunks.DuckDBDataChunkSetSize(dataChunk, size);
}
#endif
}
27 changes: 27 additions & 0 deletions DuckDB.NET.Test/TableFunctionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,31 @@ public void RegisterTableFunctionWithBigInteger()

data.Should().BeEquivalentTo(BigInteger.Parse("123456789876543210").ToByteArray().Select(b => TimeSpan.FromDays(1 + b)));
}

[Fact]
public void RegisterTableFunctionWithErrors() {
Connection.RegisterTableFunction<string>("bind_err", parameters => {
throw new Exception("bind_err_msg");
}, (item, writer, rowIndex) => {
});

Assert.Contains("bind_err_msg",
Assert.Throws<DuckDBException>(() => {
var data = Connection.Query<int>($"SELECT * FROM bind_err('')").ToList();
}).Message);

Connection.RegisterTableFunction<string>("map_err", parameters => {
return new TableFunction(
new[] { new ColumnInfo("t1", typeof(string)) },
new[] { "a" }
);
}, (item, writer, rowIndex) => {
throw new NotSupportedException("map_err_msg");
});
Assert.Contains("map_err_msg",
Assert.Throws<DuckDBException>(() => {
var data = Connection.Query<int>($"SELECT * FROM map_err('')").ToList();
}).Message);
}

}

0 comments on commit 344fa96

Please sign in to comment.