Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for translating Array.IndexOf methods for byte arrays for SqlServer & SQLite #34457

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ namespace Microsoft.EntityFrameworkCore.SqlServer.Query.Internal;
/// </summary>
public class SqlServerByteArrayMethodTranslator : IMethodCallTranslator
{
private static readonly MethodInfo ArrayIndexOf
= typeof(Array).GetMethod(nameof(Array.IndexOf), 1, BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly, null, CallingConventions.Any, [Type.MakeGenericMethodParameter(0).MakeArrayType(), Type.MakeGenericMethodParameter(0)], null)!;

private static readonly MethodInfo ArrayIndexOfWithStartIndex
= typeof(Array).GetMethod(nameof(Array.IndexOf), 1, BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly, null, CallingConventions.Any, [Type.MakeGenericMethodParameter(0).MakeArrayType(), Type.MakeGenericMethodParameter(0), typeof(int)], null)!;

private readonly ISqlExpressionFactory _sqlExpressionFactory;

/// <summary>
Expand All @@ -38,40 +44,102 @@ public SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFac
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
&& arguments.Count >= 1
&& arguments[0].Type == typeof(byte[]))
{
var source = arguments[0];
var sourceTypeMapping = source.TypeMapping;
var methodDefinition = method.GetGenericMethodDefinition();
if (methodDefinition.Equals(EnumerableMethods.Contains))
{
var source = arguments[0];
var sourceTypeMapping = source.TypeMapping;

var value = arguments[1] is SqlConstantExpression constantValue
? _sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, sourceTypeMapping)
: _sqlExpressionFactory.Convert(arguments[1], typeof(byte[]), sourceTypeMapping);
var value = arguments[1] is SqlConstantExpression constantValue
? _sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, sourceTypeMapping)
: _sqlExpressionFactory.Convert(arguments[1], typeof(byte[]), sourceTypeMapping);

return _sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"CHARINDEX",
[value, source],
nullable: true,
argumentsPropagateNullability: [true, true],
typeof(int)),
_sqlExpressionFactory.Constant(0));
}
return _sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"CHARINDEX",
[value, source],
nullable: true,
argumentsPropagateNullability: [true, true],
typeof(int)),
_sqlExpressionFactory.Constant(0));
}

if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.FirstWithoutPredicate)
&& arguments[0].Type == typeof(byte[]))
{
return _sqlExpressionFactory.Convert(
if (methodDefinition.Equals(EnumerableMethods.FirstWithoutPredicate))
{
return _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"SUBSTRING",
[arguments[0], _sqlExpressionFactory.Constant(1), _sqlExpressionFactory.Constant(1)],
nullable: true,
argumentsPropagateNullability: [true, true, true],
typeof(byte[])),
method.ReturnType);
}

if (methodDefinition.Equals(ArrayIndexOf))
{
return TranslateByteArrayIndexOf(method, arguments[0], arguments[1], null);
}

if (methodDefinition.Equals(ArrayIndexOfWithStartIndex))
{
return TranslateByteArrayIndexOf(method, arguments[0], arguments[1], arguments[2]);
}
}

return null;
}

private SqlExpression TranslateByteArrayIndexOf(
MethodInfo method,
SqlExpression source,
SqlExpression valueToSearch,
SqlExpression? startIndex)
{
var sourceTypeMapping = source.TypeMapping;
var sqlArguments = new List<SqlExpression>
{
valueToSearch is SqlConstantExpression { Value: byte constantValue }
? _sqlExpressionFactory.Constant(new byte[] { constantValue }, sourceTypeMapping)
: _sqlExpressionFactory.Convert(valueToSearch, typeof(byte[]), sourceTypeMapping),
source
};

if (startIndex is not null)
{
sqlArguments.Add(
startIndex is SqlConstantExpression { Value : int index }
? _sqlExpressionFactory.Constant(index + 1, typeof(int))
: _sqlExpressionFactory.Add(startIndex, _sqlExpressionFactory.Constant(1)));
}

var argumentsPropagateNullability = Enumerable.Repeat(true, sqlArguments.Count);

SqlExpression charIndexExpr;
var storeType = sourceTypeMapping?.StoreType;
if (storeType == "varbinary(max)")
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
{
charIndexExpr = GetCharIndexSqlFunctionExpression(sqlArguments, argumentsPropagateNullability, typeof(long));
charIndexExpr = _sqlExpressionFactory.Convert(charIndexExpr, typeof(int));
}
else
{
charIndexExpr = GetCharIndexSqlFunctionExpression(sqlArguments, argumentsPropagateNullability, method.ReturnType);
}

return _sqlExpressionFactory.Subtract(charIndexExpr, _sqlExpressionFactory.Constant(1));

SqlExpression GetCharIndexSqlFunctionExpression(List<SqlExpression> sqlArguments, IEnumerable<bool> argumentsPropagateNullability, Type returnType)
{
return _sqlExpressionFactory.Function(
"CHARINDEX",
sqlArguments,
nullable: true,
argumentsPropagateNullability: argumentsPropagateNullability,
returnType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal;
/// </summary>
public class SqliteByteArrayMethodTranslator : IMethodCallTranslator
{
private static readonly MethodInfo ArrayIndexOf
= typeof(Array).GetMethod(nameof(Array.IndexOf), 1, BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly, null, CallingConventions.Any, [Type.MakeGenericMethodParameter(0).MakeArrayType(), Type.MakeGenericMethodParameter(0)], null)!;

private readonly ISqlExpressionFactory _sqlExpressionFactory;

/// <summary>
Expand All @@ -38,28 +41,26 @@ public SqliteByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactor
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
&& arguments.Count >= 1
&& arguments[0].Type == typeof(byte[]))
{
var source = arguments[0];
var genericMethodDefinition = method.GetGenericMethodDefinition();
if (genericMethodDefinition.Equals(EnumerableMethods.Contains))
{
return _sqlExpressionFactory.GreaterThan(
GetInStrSqlFunctionExpression(arguments[0], arguments[1]),
_sqlExpressionFactory.Constant(0));

var value = arguments[1] is SqlConstantExpression constantValue
? (SqlExpression)_sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, source.TypeMapping)
: _sqlExpressionFactory.Function(
"char",
new[] { arguments[1] },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(string));
}

if (genericMethodDefinition.Equals(ArrayIndexOf))
{
return _sqlExpressionFactory.Subtract(
nikhil197 marked this conversation as resolved.
Show resolved Hide resolved
GetInStrSqlFunctionExpression(arguments[0], arguments[1]),
_sqlExpressionFactory.Constant(1));
}

return _sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"instr",
new[] { source, value },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
typeof(int)),
_sqlExpressionFactory.Constant(0));
// NOTE: IndexOf Method with a starting position is not supported by SQLite
}

// See issue#16428
Expand Down Expand Up @@ -89,5 +90,24 @@ public SqliteByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactor
//}

return null;

SqlExpression GetInStrSqlFunctionExpression(SqlExpression source, SqlExpression valueToSearch)
{
var value = valueToSearch is SqlConstantExpression { Value: byte constantValue }
? _sqlExpressionFactory.Constant(new byte[] { constantValue }, source.TypeMapping)
: _sqlExpressionFactory.Function(
"char",
[valueToSearch],
nullable: false,
argumentsPropagateNullability: [false],
typeof(string));

return _sqlExpressionFactory.Function(
"instr",
[source, value],
nullable: true,
argumentsPropagateNullability: [true, true],
typeof(int));
}
}
}
82 changes: 82 additions & 0 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6258,6 +6258,88 @@ public virtual Task Byte_array_filter_by_length_parameter(bool async)
ss => ss.Set<Squad>().Where(w => w.Banner != null && w.Banner.Length == someByteArr.Length));
}

#region Byte Array IndexOf

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_IndexOf_with_literal(bool async)
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1) == 1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast here shouldn't be needed, no?

Suggested change
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1) == 1),
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, 1) == 1),

Copy link
Author

@nikhil197 nikhil197 Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No actually, it is needed. Without this it's picking non-generic version of the method IndexOf(Array arr, object value) (because 1 is an Int32 by default and I haven't specified the type argument on the IndexOf).

Do we want to support that too?

ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, (byte)1) == 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_IndexOf_with_parameter(bool async)
{
byte b = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, b) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, b) == 0));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_length_IndexOf_with_literal(bool async)
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, (byte)5) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, (byte)5) == 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_length_IndexOf_with_parameter(bool async)
{
byte b = 4;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, b) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, b) == 0));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_IndexOf_with_startIndex_with_literals(bool async)
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1, 1) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, (byte)1, 1) == 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_IndexOf_with_startIndex_with_parameters(bool async)
{
byte b = 0;
var startPos = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, b, startPos) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, b, startPos) == 0));
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_length_IndexOf_with_startIndex_with_literals(bool async)
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, (byte)5, 1) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, (byte)5, 1) == 1));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_with_length_IndexOf_with_startIndex_with_parameters(bool async)
{
byte b = 4;
var startPos = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, b, startPos) == 0),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, b, startPos) == 0));
}

#endregion

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task OrderBy_bool_coming_from_optional_navigation(bool async)
Expand Down
Loading