diff --git a/DuckDB.NET.Bindings/NativeMethods/NativeMethods.LogicalType.cs b/DuckDB.NET.Bindings/NativeMethods/NativeMethods.LogicalType.cs index 1c101577..84afe0c8 100644 --- a/DuckDB.NET.Bindings/NativeMethods/NativeMethods.LogicalType.cs +++ b/DuckDB.NET.Bindings/NativeMethods/NativeMethods.LogicalType.cs @@ -26,6 +26,9 @@ public static class LogicalType [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_enum_internal_type")] public static extern DuckDBType DuckDBEnumInternalType(DuckDBLogicalType type); + [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_enum_dictionary_size")] + public static extern uint DuckDBEnumDictionarySize(DuckDBLogicalType type); + [DllImport(DuckDbLibrary, CallingConvention = CallingConvention.Cdecl, EntryPoint = "duckdb_enum_dictionary_value")] public static extern IntPtr DuckDBEnumDictionaryValue(DuckDBLogicalType type, long index); diff --git a/DuckDB.NET.Data/DuckDBAppenderRow.cs b/DuckDB.NET.Data/DuckDBAppenderRow.cs index 66db1994..869bc31b 100644 --- a/DuckDB.NET.Data/DuckDBAppenderRow.cs +++ b/DuckDB.NET.Data/DuckDBAppenderRow.cs @@ -68,6 +68,12 @@ public void EndRow() #endregion + #region Append Enum + + public DuckDBAppenderRow AppendValue(TEnum value) where TEnum : Enum => AppendValueInternal(value); + + #endregion + #region Append Float public DuckDBAppenderRow AppendValue(float? value) => AppendValueInternal(value); diff --git a/DuckDB.NET.Data/Internal/Writer/EnumVectorDataWriter.cs b/DuckDB.NET.Data/Internal/Writer/EnumVectorDataWriter.cs new file mode 100644 index 00000000..5dc6183a --- /dev/null +++ b/DuckDB.NET.Data/Internal/Writer/EnumVectorDataWriter.cs @@ -0,0 +1,93 @@ +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Threading; +using DuckDB.NET.Native; + +namespace DuckDB.NET.Data.Internal.Writer; + +internal sealed unsafe class EnumVectorDataWriter : VectorDataWriterBase +{ + private readonly DuckDBType enumType; + + private readonly uint enumDictionarySize; + + private readonly Dictionary enumValues; + + public EnumVectorDataWriter(IntPtr vector, void* vectorData, DuckDBLogicalType logicalType, DuckDBType columnType) : base(vector, vectorData, columnType) + { + enumType = NativeMethods.LogicalType.DuckDBEnumInternalType(logicalType); + enumDictionarySize = NativeMethods.LogicalType.DuckDBEnumDictionarySize(logicalType); + + uint maxEnumDictionarySize = enumType switch + { + DuckDBType.UnsignedTinyInt => byte.MaxValue, + DuckDBType.UnsignedSmallInt => ushort.MaxValue, + DuckDBType.UnsignedInteger => uint.MaxValue, + _ => throw new NotSupportedException($"The internal enum type must be utinyint, usmallint, or uinteger."), + }; + if (enumDictionarySize > maxEnumDictionarySize) + { + // This exception should only be thrown if the DuckDB library has a bug. + throw new InvalidOperationException($"The internal enum type is \"{enumType}\" but the enum dictionary size is greater than {maxEnumDictionarySize}."); + } + + enumValues = []; + for (uint index = 0; index < enumDictionarySize; index++) + { + string enumValueName = NativeMethods.LogicalType.DuckDBEnumDictionaryValue(logicalType, index).ToManagedString(); + enumValues.Add(enumValueName, index); + } + } + + internal override bool AppendString(string value, int rowIndex) + { + if (enumValues.TryGetValue(value, out uint enumValue)) + { + // The following casts to byte and ushort are safe because we ensure in the constructor that the value enumDictionarySize is not too high. + return enumType switch + { + DuckDBType.UnsignedTinyInt => AppendValueInternal((byte)enumValue, rowIndex), + DuckDBType.UnsignedSmallInt => AppendValueInternal((ushort)enumValue, rowIndex), + DuckDBType.UnsignedInteger => AppendValueInternal(enumValue, rowIndex), + _ => throw new InvalidOperationException($"Failed to write Enum column because the internal enum type must be utinyint, usmallint, or uinteger."), + }; + } + + throw new InvalidOperationException($"Failed to write Enum column because the value \"{value}\" is not valid."); + } + + internal override bool AppendEnum(TEnum value, int rowIndex) + { + ulong enumValue = ConvertEnumValueToUInt64(value); + if (enumValue < enumDictionarySize) + { + // The following casts to byte, ushort and uint are safe because we ensure in the constructor that the value enumDictionarySize is not too high. + return enumType switch + { + DuckDBType.UnsignedTinyInt => AppendValueInternal((byte)enumValue, rowIndex), + DuckDBType.UnsignedSmallInt => AppendValueInternal((ushort)enumValue, rowIndex), + DuckDBType.UnsignedInteger => AppendValueInternal((uint)enumValue, rowIndex), + _ => throw new InvalidOperationException($"Failed to write Enum column because the internal enum type must be utinyint, usmallint, or uinteger."), + }; + } + + throw new InvalidOperationException($"Failed to write Enum column because the value is outside the range (0-{enumDictionarySize-1})."); + } + + private static ulong ConvertEnumValueToUInt64(TEnum value) where TEnum : Enum + { + return Convert.GetTypeCode(value) switch + { + TypeCode.SByte => (ulong)Convert.ToSByte(value), + TypeCode.Byte => Convert.ToByte(value), + TypeCode.Int16 => (ulong)Convert.ToInt16(value), + TypeCode.UInt16 => Convert.ToUInt16(value), + TypeCode.Int32 => (ulong)Convert.ToInt32(value), + TypeCode.UInt32 => Convert.ToUInt32(value), + TypeCode.Int64 => (ulong)Convert.ToInt64(value), + TypeCode.UInt64 => Convert.ToUInt64(value), + _ => throw new InvalidOperationException($"Failed to convert the enum value {value} to ulong."), + }; + } +} diff --git a/DuckDB.NET.Data/Internal/Writer/ListVectorDataWriter.cs b/DuckDB.NET.Data/Internal/Writer/ListVectorDataWriter.cs index ad57c155..f4d6dadb 100644 --- a/DuckDB.NET.Data/Internal/Writer/ListVectorDataWriter.cs +++ b/DuckDB.NET.Data/Internal/Writer/ListVectorDataWriter.cs @@ -82,7 +82,7 @@ internal override bool AppendCollection(ICollection value, int rowIndex) IEnumerable items => WriteItems(items), IEnumerable items => WriteItems(items), - _ => WriteItems((IEnumerable)value) + _ => WriteItemsFallback(value), }; var duckDBListEntry = new DuckDBListEntry(offset, count); @@ -108,6 +108,23 @@ int WriteItems(IEnumerable items) return 0; } + + int WriteItemsFallback(IEnumerable items) + { + if (IsList == false && count != arraySize) + { + throw new InvalidOperationException($"Column has Array size of {arraySize} but the specified value has size of {count}"); + } + + var index = 0; + + foreach (var item in items) + { + listItemWriter.AppendValue(item, (int)offset + (index++)); + } + + return 0; + } } private void ResizeVector(int rowIndex, ulong count) diff --git a/DuckDB.NET.Data/Internal/Writer/VectorDataWriterBase.cs b/DuckDB.NET.Data/Internal/Writer/VectorDataWriterBase.cs index 67c3250c..3ee95bf8 100644 --- a/DuckDB.NET.Data/Internal/Writer/VectorDataWriterBase.cs +++ b/DuckDB.NET.Data/Internal/Writer/VectorDataWriterBase.cs @@ -50,6 +50,8 @@ public void AppendValue(T value, int rowIndex) decimal val => AppendDecimal(val, rowIndex), BigInteger val => AppendBigInteger(val, rowIndex), + Enum val => AppendEnum(val, rowIndex), + string val => AppendString(val, rowIndex), Guid val => AppendGuid(val, rowIndex), DateTime val => AppendDateTime(val, rowIndex), @@ -96,6 +98,8 @@ public void AppendValue(T value, int rowIndex) internal virtual bool AppendBigInteger(BigInteger value, int rowIndex) => ThrowException(); + internal virtual bool AppendEnum(TEnum value, int rowIndex) where TEnum : Enum => ThrowException(); + internal virtual bool AppendCollection(ICollection value, int rowIndex) => ThrowException(); private bool ThrowException() diff --git a/DuckDB.NET.Data/Internal/Writer/VectorDataWriterFactory.cs b/DuckDB.NET.Data/Internal/Writer/VectorDataWriterFactory.cs index 98f828a4..d96d8299 100644 --- a/DuckDB.NET.Data/Internal/Writer/VectorDataWriterFactory.cs +++ b/DuckDB.NET.Data/Internal/Writer/VectorDataWriterFactory.cs @@ -27,7 +27,7 @@ public static unsafe VectorDataWriterBase CreateWriter(IntPtr vector, DuckDBLogi DuckDBType.Blob => new StringVectorDataWriter(vector, dataPointer, columnType), DuckDBType.Varchar => new StringVectorDataWriter(vector, dataPointer, columnType), DuckDBType.Bit => throw new NotImplementedException($"Writing {columnType} to data chunk is not yet supported"), - DuckDBType.Enum => throw new NotImplementedException($"Writing {columnType} to data chunk is not yet supported"), + DuckDBType.Enum => new EnumVectorDataWriter(vector, dataPointer, logicalType, columnType), DuckDBType.Struct => throw new NotImplementedException($"Writing {columnType} to data chunk is not yet supported"), DuckDBType.Decimal => new DecimalVectorDataWriter(vector, dataPointer, logicalType, columnType), DuckDBType.TimestampS => new DateTimeVectorDataWriter(vector, dataPointer, columnType), diff --git a/DuckDB.NET.Test/DuckDBManagedAppenderListTests.cs b/DuckDB.NET.Test/DuckDBManagedAppenderListTests.cs index 1adf5eb1..a9a78dff 100644 --- a/DuckDB.NET.Test/DuckDBManagedAppenderListTests.cs +++ b/DuckDB.NET.Test/DuckDBManagedAppenderListTests.cs @@ -208,6 +208,16 @@ public void ArrayValuesInt() ListValuesInternal("Integer", faker => faker.Random.Int(), 5); } + [Fact] + public void ListValuesEnum() + { + Command.CommandText = "CREATE TYPE test_enum AS ENUM('test1','test2','test3');"; + Command.ExecuteNonQuery(); + + ListValuesInternal("test_enum", faker => faker.Random.CollectionItem([null, "test1", "test2", "test3"])); + ListValuesInternal("test_enum", faker => faker.Random.CollectionItem([null, TestEnum.Test1, TestEnum.Test2, TestEnum.Test3])); + } + private void ListValuesInternal(string typeName, Func generator, int? length = null) { var rows = 2000; @@ -268,4 +278,11 @@ private void ListValuesInternal(string typeName, Func generator, in .Should().Throw().Where(exception => exception.Message.Contains(length.ToString())); } } + + private enum TestEnum + { + Test1 = 0, + Test2 = 1, + Test3 = 2, + } } \ No newline at end of file diff --git a/DuckDB.NET.Test/DuckDBManagedAppenderTests.cs b/DuckDB.NET.Test/DuckDBManagedAppenderTests.cs index c7a8828e..b6f937a3 100644 --- a/DuckDB.NET.Test/DuckDBManagedAppenderTests.cs +++ b/DuckDB.NET.Test/DuckDBManagedAppenderTests.cs @@ -11,6 +11,7 @@ using System.Numerics; using Bogus; using Xunit; +using System.Text; namespace DuckDB.NET.Test; @@ -275,6 +276,50 @@ public void TemporalValues() result.Select(tuple => tuple.Item8).Should().BeEquivalentTo(dates.Select(TimeOnly.FromDateTime)); } + [Fact] + public void EnumValues() + { + Command.CommandText = GetCreateEnumTypeSql("test_enum1", "test", 3); + Command.ExecuteNonQuery(); + + Command.CommandText = GetCreateEnumTypeSql("test_enum2", "test", 1000); + Command.ExecuteNonQuery(); + + Command.CommandText = GetCreateEnumTypeSql("test_enum3", "test", 100000); + Command.ExecuteNonQuery(); + + Command.CommandText = "CREATE TABLE managedAppenderEnum(a test_enum1, b test_enum1, c test_enum1, d test_enum1, e test_enum1, f test_enum2, g test_enum2, h test_enum3, i test_enum3);"; + Command.ExecuteNonQuery(); + + using (var appender = Connection.CreateAppender("managedAppenderEnum")) + { + appender + .CreateRow() + .AppendNullValue() + .AppendNullValue() + .AppendValue("test1") + .AppendValue(TestEnum1.Test2) + .AppendValue(TestEnum1.Test3) + .AppendValue("test327") + .AppendValue(TestEnum2.Test1000) + .AppendValue("test100000") + .AppendValue(TestEnum3.Test6699) + .EndRow(); + } + + var queryResult = Connection.Query<(string, TestEnum1?, TestEnum1, string, TestEnum1, TestEnum2, string, string, TestEnum3)>("SELECT a, b, c, d, e, f, g, h, i FROM managedAppenderEnum").ToList(); + var result = queryResult[0]; + result.Item1.Should().BeNull(); + result.Item2.Should().BeNull(); + result.Item3.Should().Be(TestEnum1.Test1); + result.Item4.Should().Be("test2"); + result.Item5.Should().Be(TestEnum1.Test3); + result.Item6.Should().Be(TestEnum2.Test327); + result.Item7.Should().Be("test1000"); + result.Item8.Should().Be("test100000"); + result.Item9.Should().Be(TestEnum3.Test6699); + } + [Fact] public void IncompleteRowThrowsException() { @@ -366,6 +411,35 @@ public void ClosedAdapterThrowException() }).Should().Throw(); } + [Fact] + public void EnumNotValidValueThrowException() + { + Command.CommandText = GetCreateEnumTypeSql("enum_not_valid_value_test_enum", "test", 100); + Command.ExecuteNonQuery(); + + var table = "CREATE TABLE managedAppenderEnumNotValidValueTest(a enum_not_valid_value_test_enum);"; + Command.CommandText = table; + Command.ExecuteNonQuery(); + + Connection.Invoking(dbConnection => + { + using var appender = dbConnection.CreateAppender("managedAppenderEnumNotValidValueTest"); + appender + .CreateRow() + .AppendValue("test12345") + .EndRow(); + }).Should().Throw(); + + Connection.Invoking(dbConnection => + { + using var appender = dbConnection.CreateAppender("managedAppenderEnumNotValidValueTest"); + appender + .CreateRow() + .AppendValue(EnumNotValidValueTestEnum.NotValid) + .EndRow(); + }).Should().Throw(); + } + [Fact] public void TableWithSchema() { @@ -504,9 +578,55 @@ public void ManagedAppenderOnTableAndColumnsWithSpecialCharacters(string schemaN } } + private static string GetCreateEnumTypeSql(string enumName, string enumValueNamePrefix, int count) + { + var stringBuilder = new StringBuilder(); + stringBuilder.AppendFormat(CultureInfo.InvariantCulture, "CREATE TYPE {0} AS ENUM(", enumName); + + for (int i = 1; i <= count; i++) + { + if (i > 1) + { + stringBuilder.Append(','); + } + + stringBuilder.Append('\''); + stringBuilder.Append(enumValueNamePrefix); + stringBuilder.Append(i); + stringBuilder.Append('\''); + } + + stringBuilder.Append(");"); + return stringBuilder.ToString(); + } + private static string GetQualifiedObjectName(params string[] parts) => string.Join('.', parts. Where(p => !string.IsNullOrWhiteSpace(p)). Select(p => '"' + p + '"') ); + + private enum TestEnum1 + { + Test1 = 0, + Test2 = 1, + Test3 = 2, + } + + private enum TestEnum2 : short + { + Test327 = 326, + Test1000 = 999, + } + + private enum TestEnum3 : ulong + { + Test6699 = 6698, + Test100000 = 99999, + } + + private enum EnumNotValidValueTestEnum + { + NotValid = 12345, + } } \ No newline at end of file