Skip to content

Commit

Permalink
Add support for nested Lists and Arrays in prepared statements
Browse files Browse the repository at this point in the history
  • Loading branch information
Giorgi committed Nov 29, 2024
1 parent c8e9944 commit 1f1f024
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 152 deletions.
12 changes: 12 additions & 0 deletions DuckDB.NET.Bindings/NativeMethods/NativeMethods.Value.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ public static class Value

[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_create_list_value")]
public static extern DuckDBValue DuckDBCreateListValue(DuckDBLogicalType logicalType, IntPtr[] values, long count);

[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_create_array_value")]
public static extern DuckDBValue DuckDBCreateArrayValue(DuckDBLogicalType logicalType, IntPtr[] values, long count);

[DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_create_null_value")]
public static extern DuckDBValue DuckDBCreateNullValue();
Expand All @@ -148,5 +151,14 @@ public static DuckDBValue DuckDBCreateListValue(DuckDBLogicalType logicalType, D

return duckDBValue;
}

public static DuckDBValue DuckDBCreateArrayValue(DuckDBLogicalType logicalType, DuckDBValue[] values, int count)
{
var duckDBValue = DuckDBCreateArrayValue(logicalType, values.Select(item => item.DangerousGetHandle()).ToArray(), count);

duckDBValue.SetChildValues(values);

return duckDBValue;
}
}
}
200 changes: 50 additions & 150 deletions DuckDB.NET.Data/Internal/ClrToDuckDBConverter.cs
Original file line number Diff line number Diff line change
@@ -1,160 +1,14 @@
using DuckDB.NET.Data.Extensions;
using DuckDB.NET.Native;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Globalization;
using System.Numerics;
using DuckDB.NET.Data.Extensions;
using DuckDB.NET.Native;

namespace DuckDB.NET.Data.Internal;

internal static class ClrToDuckDBConverter
{
public static DuckDBValue ToDuckDBValue(this object? value)
{
if (value.IsNull())
{
return NativeMethods.Value.DuckDBCreateNullValue();
}

return value switch
{
bool val => NativeMethods.Value.DuckDBCreateBool(val),

sbyte val => NativeMethods.Value.DuckDBCreateInt8(val),
short val => NativeMethods.Value.DuckDBCreateInt16(val),
int val => NativeMethods.Value.DuckDBCreateInt32(val),
long val => NativeMethods.Value.DuckDBCreateInt64(val),

byte val => NativeMethods.Value.DuckDBCreateUInt8(val),
ushort val => NativeMethods.Value.DuckDBCreateUInt16(val),
uint val => NativeMethods.Value.DuckDBCreateUInt32(val),
ulong val => NativeMethods.Value.DuckDBCreateUInt64(val),

float val => NativeMethods.Value.DuckDBCreateFloat(val),
double val => NativeMethods.Value.DuckDBCreateDouble(val),

decimal val => DecimalToDuckDBValue(val),
BigInteger val => NativeMethods.Value.DuckDBCreateHugeInt(new DuckDBHugeInt(val)),

string val => StringToDuckDBValue(val),
Guid val => GuidToDuckDBValue(val),
DateTime val => NativeMethods.Value.DuckDBCreateTimestamp(NativeMethods.DateTimeHelpers.DuckDBToTimestamp(DuckDBTimestamp.FromDateTime(val))),
TimeSpan val => NativeMethods.Value.DuckDBCreateInterval(val),
DuckDBDateOnly val => NativeMethods.Value.DuckDBCreateDate(NativeMethods.DateTimeHelpers.DuckDBToDate(val)),
DuckDBTimeOnly val => NativeMethods.Value.DuckDBCreateTime(NativeMethods.DateTimeHelpers.DuckDBToTime(val)),
#if NET6_0_OR_GREATER
DateOnly val => NativeMethods.Value.DuckDBCreateDate(NativeMethods.DateTimeHelpers.DuckDBToDate(val)),
TimeOnly val => NativeMethods.Value.DuckDBCreateTime(NativeMethods.DateTimeHelpers.DuckDBToTime(val)),
#endif
DateTimeOffset val => DateTimeOffsetToDuckDBValue(val),
byte[] val => NativeMethods.Value.DuckDBCreateBlob(val, val.Length),

ICollection val => CreateCollectionValue(val),
_ => throw new InvalidCastException($"Cannot convert value of type {value.GetType().FullName} to DuckDBValue.")
};
}

private static DuckDBValue DateTimeOffsetToDuckDBValue(DateTimeOffset val)
{
var duckDBToTime = NativeMethods.DateTimeHelpers.DuckDBToTime((DuckDBTimeOnly)val.DateTime);
var duckDBCreateTimeTz = NativeMethods.DateTimeHelpers.DuckDBCreateTimeTz(duckDBToTime.Micros, (int)val.Offset.TotalSeconds);
return NativeMethods.Value.DuckDBCreateTimeTz(duckDBCreateTimeTz);
}

private static DuckDBValue GuidToDuckDBValue(Guid value)
{
using var handle = value.ToString().ToUnmanagedString();
return NativeMethods.Value.DuckDBCreateVarchar(handle);
}

private static DuckDBValue DecimalToDuckDBValue(decimal value)
{
using var handle = value.ToString(CultureInfo.InvariantCulture).ToUnmanagedString();
return NativeMethods.Value.DuckDBCreateVarchar(handle);
}

private static DuckDBValue StringToDuckDBValue(string value)
{
using var handle = value.ToUnmanagedString();
return NativeMethods.Value.DuckDBCreateVarchar(handle);
}

private static DuckDBValue CreateCollectionValue(ICollection collection)
{
return collection switch
{
ICollection<bool> items => CreateCollectionValue(DuckDBType.Boolean, items),
ICollection<bool?> items => CreateCollectionValue(DuckDBType.Boolean, items),

ICollection<sbyte> items => CreateCollectionValue(DuckDBType.TinyInt, items),
ICollection<sbyte?> items => CreateCollectionValue(DuckDBType.TinyInt, items),
ICollection<short> items => CreateCollectionValue(DuckDBType.SmallInt, items),
ICollection<short?> items => CreateCollectionValue(DuckDBType.SmallInt, items),
ICollection<int> items => CreateCollectionValue(DuckDBType.Integer, items),
ICollection<int?> items => CreateCollectionValue(DuckDBType.Integer, items),
ICollection<long> items => CreateCollectionValue(DuckDBType.BigInt, items),
ICollection<long?> items => CreateCollectionValue(DuckDBType.BigInt, items),

ICollection<byte> items => CreateCollectionValue(DuckDBType.UnsignedTinyInt, items),
ICollection<byte?> items => CreateCollectionValue(DuckDBType.UnsignedTinyInt, items),
ICollection<ushort> items => CreateCollectionValue(DuckDBType.UnsignedSmallInt, items),
ICollection<ushort?> items => CreateCollectionValue(DuckDBType.UnsignedSmallInt, items),
ICollection<uint> items => CreateCollectionValue(DuckDBType.UnsignedInteger, items),
ICollection<uint?> items => CreateCollectionValue(DuckDBType.UnsignedInteger, items),
ICollection<ulong> items => CreateCollectionValue(DuckDBType.UnsignedBigInt, items),
ICollection<ulong?> items => CreateCollectionValue(DuckDBType.UnsignedBigInt, items),

ICollection<float> items => CreateCollectionValue(DuckDBType.Float, items),
ICollection<float?> items => CreateCollectionValue(DuckDBType.Float, items),
ICollection<double> items => CreateCollectionValue(DuckDBType.Double, items),
ICollection<double?> items => CreateCollectionValue(DuckDBType.Double, items),

ICollection<decimal> items => CreateCollectionValue(DuckDBType.Varchar, items),
ICollection<decimal?> items => CreateCollectionValue(DuckDBType.Varchar, items),
ICollection<BigInteger> items => CreateCollectionValue(DuckDBType.HugeInt, items),
ICollection<BigInteger?> items => CreateCollectionValue(DuckDBType.HugeInt, items),

ICollection<string> items => CreateCollectionValue(DuckDBType.Varchar, items),
ICollection<Guid> items => CreateCollectionValue(DuckDBType.Varchar, items),
ICollection<Guid?> items => CreateCollectionValue(DuckDBType.Varchar, items),
ICollection<DateTime> items => CreateCollectionValue(DuckDBType.Date, items),
ICollection<DateTime?> items => CreateCollectionValue(DuckDBType.Date, items),
ICollection<TimeSpan> items => CreateCollectionValue(DuckDBType.Interval, items),
ICollection<TimeSpan?> items => CreateCollectionValue(DuckDBType.Interval, items),
ICollection<DuckDBDateOnly> items => CreateCollectionValue(DuckDBType.Date, items),
ICollection<DuckDBDateOnly?> items => CreateCollectionValue(DuckDBType.Date, items),
ICollection<DuckDBTimeOnly> items => CreateCollectionValue(DuckDBType.Time, items),
ICollection<DuckDBTimeOnly?> items => CreateCollectionValue(DuckDBType.Time, items),
#if NET6_0_OR_GREATER
ICollection<DateOnly> items => CreateCollectionValue(DuckDBType.Date, items),
ICollection<DateOnly?> items => CreateCollectionValue(DuckDBType.Date, items),
ICollection<TimeOnly> items => CreateCollectionValue(DuckDBType.Time, items),
ICollection<TimeOnly?> items => CreateCollectionValue(DuckDBType.Time, items),
#endif
ICollection<DateTimeOffset> items => CreateCollectionValue(DuckDBType.TimeTz, items),
ICollection<DateTimeOffset?> items => CreateCollectionValue(DuckDBType.TimeTz, items),
_ => throw new InvalidOperationException($"Cannot convert collection type {collection.GetType().FullName} to DuckDBValue.")
};
}

private static DuckDBValue CreateCollectionValue<T>(DuckDBType duckDBType, ICollection<T> collection)
{
using var listItemType = NativeMethods.LogicalType.DuckDBCreateLogicalType(duckDBType);

var values = new DuckDBValue[collection.Count];

var index = 0;
foreach (var item in collection)
{
var duckDBValue = item.ToDuckDBValue();
values[index] = duckDBValue;
index++;
}

return NativeMethods.Value.DuckDBCreateListValue(listItemType, values, collection.Count);
}

public static DuckDBValue ToDuckDBValue(this object? item, DuckDBLogicalType logicalType)
{
if (item.IsNull())
Expand Down Expand Up @@ -189,16 +43,18 @@ public static DuckDBValue ToDuckDBValue(this object? item, DuckDBLogicalType log

(DuckDBType.Timestamp, DateTime value) => NativeMethods.Value.DuckDBCreateTimestamp(NativeMethods.DateTimeHelpers.DuckDBToTimestamp(DuckDBTimestamp.FromDateTime(value))),
(DuckDBType.Interval, TimeSpan value) => NativeMethods.Value.DuckDBCreateInterval(value),
(DuckDBType.Date, DateTime value) => NativeMethods.Value.DuckDBCreateDate(NativeMethods.DateTimeHelpers.DuckDBToDate((DuckDBDateOnly)value)),
(DuckDBType.Date, DuckDBDateOnly value) => NativeMethods.Value.DuckDBCreateDate(NativeMethods.DateTimeHelpers.DuckDBToDate(value)),
(DuckDBType.Time, DateTime value) => NativeMethods.Value.DuckDBCreateTime(NativeMethods.DateTimeHelpers.DuckDBToTime((DuckDBTimeOnly)value)),
(DuckDBType.Time, DuckDBTimeOnly value) => NativeMethods.Value.DuckDBCreateTime(NativeMethods.DateTimeHelpers.DuckDBToTime(value)),
#if NET6_0_OR_GREATER
(DuckDBType.Date, DateOnly value) => NativeMethods.Value.DuckDBCreateDate(NativeMethods.DateTimeHelpers.DuckDBToDate(value)),
(DuckDBType.Time, TimeOnly value) => NativeMethods.Value.DuckDBCreateTime(NativeMethods.DateTimeHelpers.DuckDBToTime(value)),
#endif
(DuckDBType.TimeTz, DateTimeOffset value) => DateTimeOffsetToDuckDBValue(value),
(DuckDBType.Blob, byte[] value) => NativeMethods.Value.DuckDBCreateBlob(value, value.Length),
(DuckDBType.List, ICollection value) => CreateCollectionValue(value),
(DuckDBType.Array, ICollection value) => CreateCollectionValue(value),
(DuckDBType.List, ICollection value) => CreateCollectionValue(logicalType, value, true),
(DuckDBType.Array, ICollection value) => CreateCollectionValue(logicalType, value, false),
_ => throw new InvalidOperationException($"Cannot bind parameter type {item.GetType().FullName} to column of type {duckDBType}")
};

Expand All @@ -214,4 +70,48 @@ T ConvertTo<T>()
}
}
}

private static DuckDBValue CreateCollectionValue(DuckDBLogicalType logicalType, ICollection collection, bool isList)
{
using var collectionItemType = isList ? NativeMethods.LogicalType.DuckDBListTypeChildType(logicalType) :
NativeMethods.LogicalType.DuckDBArrayTypeChildType(logicalType);

var values = new DuckDBValue[collection.Count];

var index = 0;
foreach (var item in collection)
{
var duckDBValue = item.ToDuckDBValue(collectionItemType);
values[index] = duckDBValue;
index++;
}

return isList ? NativeMethods.Value.DuckDBCreateListValue(collectionItemType, values, collection.Count)
: NativeMethods.Value.DuckDBCreateArrayValue(collectionItemType, values, collection.Count);
}

private static DuckDBValue GuidToDuckDBValue(Guid value)
{
using var handle = value.ToString().ToUnmanagedString();
return NativeMethods.Value.DuckDBCreateVarchar(handle);
}

private static DuckDBValue StringToDuckDBValue(string value)
{
using var handle = value.ToUnmanagedString();
return NativeMethods.Value.DuckDBCreateVarchar(handle);
}

private static DuckDBValue DecimalToDuckDBValue(decimal value)
{
using var handle = value.ToString(CultureInfo.InvariantCulture).ToUnmanagedString();
return NativeMethods.Value.DuckDBCreateVarchar(handle);
}

private static DuckDBValue DateTimeOffsetToDuckDBValue(DateTimeOffset val)
{
var duckDBToTime = NativeMethods.DateTimeHelpers.DuckDBToTime((DuckDBTimeOnly)val.DateTime);
var duckDBCreateTimeTz = NativeMethods.DateTimeHelpers.DuckDBCreateTimeTz(duckDBToTime.Micros, (int)val.Offset.TotalSeconds);
return NativeMethods.Value.DuckDBCreateTimeTz(duckDBCreateTimeTz);
}
}
9 changes: 7 additions & 2 deletions DuckDB.NET.Test/Parameters/ListParameterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ public class ListParameterTests(DuckDBDatabaseFixture db) : DuckDBTestBase(db)
private void TestInsertSelect<T>(string duckDbType, Func<Faker, T> generator, int? length = null)
{
var list = GetRandomList(generator, length ?? Random.Shared.Next(10, 200));
var nestedList = new List<List<T>> { GetRandomList(generator, 5), GetRandomList(generator, 10), GetRandomList(generator, 20) };

Command.CommandText = $"CREATE OR REPLACE TABLE ParameterListTest (a {duckDbType}[], b {duckDbType}[10]);";
Command.CommandText = $"CREATE OR REPLACE TABLE ParameterListTest (a {duckDbType}[], b {duckDbType}[10], c {duckDbType}[][]);";
Command.ExecuteNonQuery();

Command.CommandText = "INSERT INTO ParameterListTest (a, b) VALUES ($list, $array);";
Command.CommandText = "INSERT INTO ParameterListTest (a, b, c) VALUES ($list, $array, $nestedList);";
Command.Parameters.Add(new DuckDBParameter(list));
Command.Parameters.Add(new DuckDBParameter(list.Take(10).ToList()));
Command.Parameters.Add(new DuckDBParameter(nestedList));
Command.ExecuteNonQuery();

Command.CommandText = "SELECT * FROM ParameterListTest;";
Expand All @@ -36,6 +38,9 @@ private void TestInsertSelect<T>(string duckDbType, Func<Faker, T> generator, in
var arrayValue = reader.GetFieldValue<List<T>>(1);
arrayValue.Should().BeEquivalentTo(list.Take(10));

var nestedListValue = reader.GetFieldValue<List<List<T>>>(2);
nestedListValue.Should().BeEquivalentTo(nestedList);

Command.CommandText = "DROP TABLE ParameterListTest";
Command.ExecuteNonQuery();
}
Expand Down

0 comments on commit 1f1f024

Please sign in to comment.