Skip to content

Commit

Permalink
[5.2.2] | Fix ArgumentNullException on SqlDataRecord.GetValue when us…
Browse files Browse the repository at this point in the history
…ing Udt data type (#2448) (#2816)
  • Loading branch information
dauinsight authored Aug 26, 2024
1 parent b5edb42 commit c7de60b
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,7 @@ internal static SmiExtendedMetaData SqlMetaDataToSmiExtendedMetaData(SqlMetaData
source.Scale,
source.LocaleId,
source.CompareOptions,
#if NETFRAMEWORK
source.Type,
#else
null,
#endif
source.Name,
typeSpecificNamePart1,
typeSpecificNamePart2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@
<PackageReference Condition="$(ReferenceType.Contains('NetStandard'))" Include="System.Diagnostics.DiagnosticSource" Version="$(SystemDiagnosticsDiagnosticSourceVersion)" />
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonVersion)" />
<Reference Condition="'$(TargetGroup)'=='netfx'" Include="System.Transactions" />
<PackageReference Condition="'$(TargetGroup)'=='netfx'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersion)" />
<PackageReference Condition="'$(TargetGroup)'=='netcoreapp'" Include="Microsoft.SqlServer.Types" Version="$(MicrosoftSqlServerTypesVersionNet)" />
</ItemGroup>
<ItemGroup Condition="'$(TargetGroup)' == 'netcoreapp'">
<PackageReference Include="System.Data.Odbc" Version="$(SystemDataOdbcVersion)" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Data;
using System.Data.SqlTypes;
using Microsoft.Data.SqlClient.Server;
using Microsoft.SqlServer.Types;
using Xunit;

namespace Microsoft.Data.SqlClient.Tests
Expand Down Expand Up @@ -318,6 +319,19 @@ public void GetChar_ThrowsNotSupported()
Assert.Throws<NotSupportedException>(() => record.GetChar(0));
}

[Theory]
[ClassData(typeof(GetUdtTypeTestData))]
public void GetUdt_ReturnsValue(Type udtType, object value, string serverTypeName)
{
SqlMetaData[] metadata = new SqlMetaData[] { new SqlMetaData(nameof(udtType.Name), SqlDbType.Udt, udtType, serverTypeName) };

SqlDataRecord record = new SqlDataRecord(metadata);

record.SetValue(0, value);

Assert.Equal(value.ToString(), record.GetValue(0).ToString());
}

[Theory]
[ClassData(typeof(GetXXXBadTypeTestData))]
public void GetXXX_ThrowsIfBadType(Func<SqlDataRecord, object> getXXX)
Expand All @@ -342,8 +356,8 @@ public void GetXXX_ReturnValue(SqlDbType dbType, object value, Func<SqlDataRecor
};
SqlDataRecord record = new SqlDataRecord(metaData);
record.SetValue(0, value);
Assert.Equal(value, record.GetValue(0));
Assert.Equal(value, getXXX(record));

}
}

Expand All @@ -369,6 +383,21 @@ IEnumerator IEnumerable.GetEnumerator()
}
}

public class GetUdtTypeTestData : IEnumerable<object[]>
{
public IEnumerator<object[]> GetEnumerator()
{
yield return new object[] { typeof(SqlGeography), SqlGeography.Point(43, -81, 4326), "Geography" };
yield return new object[] { typeof(SqlGeometry), SqlGeometry.Point(43, -81, 4326), "Geometry" };
yield return new object[] { typeof(SqlHierarchyId), SqlHierarchyId.Parse("/"), "HierarchyId" };
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}

public class GetXXXCheckValueTestData : IEnumerable<object[]>
{
public IEnumerator<object[]> GetEnumerator()
Expand All @@ -383,6 +412,10 @@ public IEnumerator<object[]> GetEnumerator()
yield return new object[] { SqlDbType.DateTime, DateTime.Now, new Func<SqlDataRecord, object>(r => r.GetDateTime(0)) };
yield return new object[] { SqlDbType.DateTimeOffset, new DateTimeOffset(DateTime.Now), new Func<SqlDataRecord, object>(r => r.GetDateTimeOffset(0)) };
yield return new object[] { SqlDbType.Time, TimeSpan.FromHours(1), new Func<SqlDataRecord, object>(r => r.GetTimeSpan(0)) };
yield return new object[] { SqlDbType.Date, DateTime.Now.Date, new Func<SqlDataRecord, object>(r => r.GetDateTime(0)) };
yield return new object[] { SqlDbType.Bit, bool.Parse(bool.TrueString), new Func<SqlDataRecord, object>(r => r.GetBoolean(0)) };
yield return new object[] { SqlDbType.SmallDateTime, DateTime.Now, new Func<SqlDataRecord, object>(r => r.GetDateTime(0)) };
yield return new object[] { SqlDbType.TinyInt, (byte)1, new Func<SqlDataRecord, object>(r => r.GetByte(0)) };
}

IEnumerator IEnumerable.GetEnumerator()
Expand Down

0 comments on commit c7de60b

Please sign in to comment.