Skip to content

Commit

Permalink
Added Support for Translating ByteArray.IndexOf for SQLite
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhil197 committed Aug 18, 2024
1 parent d81a857 commit ec0b63d
Show file tree
Hide file tree
Showing 11 changed files with 299 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,6 @@ public class SqlServerByteArrayMethodTranslator : IMethodCallTranslator
{
private readonly ISqlExpressionFactory _sqlExpressionFactory;

// NOTE: Might want to move these to a shared file, similar to EnumerableMethods
private static readonly MethodInfo IndexOfMethodInfo
= typeof(Array)
.GetGenericMethod(nameof(Array.IndexOf), 1, BindingFlags.Public | BindingFlags.Static, (_, t) =>
{
return [t[0].MakeArrayType(), t[0]];
})!;

private static readonly MethodInfo IndexOfWithStartingPositionMethodInfo
= typeof(Array)
.GetGenericMethod(nameof(Array.IndexOf), 1, BindingFlags.Public | BindingFlags.Static, (_, t) =>
{
return [t[0].MakeArrayType(), t[0], typeof(int)];
})!;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -61,7 +46,6 @@ public SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFac
var methodDefinition = method.GetGenericMethodDefinition();
if (methodDefinition.Equals(EnumerableMethods.Contains))
{
// NOTE: Should this be refactored to use the TranslateIndexOf method?? Everything is same expect one check
var source = arguments[0];
var sourceTypeMapping = source.TypeMapping;

Expand Down Expand Up @@ -91,12 +75,12 @@ public SqlServerByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFac
method.ReturnType);
}

if (methodDefinition.Equals(IndexOfMethodInfo))
if (methodDefinition.Equals(ArrayMethods.IndexOf))
{
return TranslateIndexOf(method, arguments[0], arguments[1], null);
}

if (methodDefinition.Equals(IndexOfWithStartingPositionMethodInfo))
if (methodDefinition.Equals(ArrayMethods.IndexOfWithStartingPosition))
{
return TranslateIndexOf(method, arguments[0], arguments[1], arguments[2]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,30 @@ public SqliteByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactor
IDiagnosticsLogger<DbLoggerCategory.Query> logger)
{
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition().Equals(EnumerableMethods.Contains)
&& arguments.Count >= 1
&& arguments[0].Type == typeof(byte[]))
{
var source = arguments[0];
var genericMethodDefinition = method.GetGenericMethodDefinition();
if (genericMethodDefinition.Equals(EnumerableMethods.Contains))
{
return _sqlExpressionFactory.GreaterThan(
GetInStrSqlFunctionExpression(arguments[0], arguments[1]),
_sqlExpressionFactory.Constant(0));

var value = arguments[1] is SqlConstantExpression constantValue
? (SqlExpression)_sqlExpressionFactory.Constant(new[] { (byte)constantValue.Value! }, source.TypeMapping)
: _sqlExpressionFactory.Function(
"char",
new[] { arguments[1] },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(string));
}

return _sqlExpressionFactory.GreaterThan(
_sqlExpressionFactory.Function(
"instr",
new[] { source, value },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
typeof(int)),
_sqlExpressionFactory.Constant(0));
if (genericMethodDefinition.Equals(ArrayMethods.IndexOf))
{
return _sqlExpressionFactory.Subtract(
GetInStrSqlFunctionExpression(arguments[0], arguments[1]),
_sqlExpressionFactory.Constant(1));
}

if (genericMethodDefinition.Equals(ArrayMethods.IndexOfWithStartingPosition))
{
// NOTE: IndexOf Method with a starting position is not supported by SQLite
return null;
}
}

// See issue#16428
Expand Down Expand Up @@ -92,4 +94,23 @@ public SqliteByteArrayMethodTranslator(ISqlExpressionFactory sqlExpressionFactor

return null;
}

private SqlExpression GetInStrSqlFunctionExpression(SqlExpression source, SqlExpression valueToSearch)
{
var value = valueToSearch is SqlConstantExpression { Value: byte constantValue }
? _sqlExpressionFactory.Constant(new byte[] { constantValue }, source.TypeMapping)
: _sqlExpressionFactory.Function(
"char",
[valueToSearch],
nullable: false,
argumentsPropagateNullability: [false],
typeof(string));

return _sqlExpressionFactory.Function(
"instr",
[source, value],
nullable: true,
argumentsPropagateNullability: [true, true],
typeof(int));
}
}
36 changes: 36 additions & 0 deletions src/Shared/ArrayMethods.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore;

internal static class ArrayMethods
{
public static MethodInfo IndexOf { get; }

public static MethodInfo IndexOfWithStartingPosition { get; }

static ArrayMethods()
{
var arrayGenericMethods = typeof(Array)
.GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Where(m => m.IsGenericMethod)
.GroupBy(m => m.Name)
.ToDictionary(m => m.Key, l => l.ToList());

IndexOf = GetMethod(nameof(Array.IndexOf), 1, (t) =>
{
return [t[0].MakeArrayType(), t[0]];
});

IndexOfWithStartingPosition = GetMethod(nameof(Array.IndexOf), 1, (t) =>
{
return [t[0].MakeArrayType(), t[0], typeof(int)];
});

MethodInfo GetMethod(string name, int genericParameterCount, Func<Type[], Type[]> parameterGenerator)
=> arrayGenericMethods[name].Single(
mi => mi.IsGenericMethod && mi.GetGenericArguments().Length == genericParameterCount
&& mi.GetParameters().Select(e => e.ParameterType).SequenceEqual(
parameterGenerator(mi.IsGenericMethod ? mi.GetGenericArguments() : [])));
}
}
36 changes: 14 additions & 22 deletions test/EFCore.Specification.Tests/Query/GearsOfWarQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6263,18 +6263,16 @@ public virtual Task Byte_array_filter_by_length_parameter(bool async)
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_max_filter_by_index_of_literal_casts_to_int(bool async)
{
return AssertQuery(
public virtual Task Byte_array_with_max_possible_length_filter_by_index_of_literal(bool async)
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, (byte)1) == 1)
);
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_max_filter_by_index_of_parameter_casts_to_int(bool async)
public virtual Task Byte_array_with_max_possible_length_filter_by_index_of_parameter(bool async)
{
byte b = 0;
return AssertQuery(
Expand All @@ -6286,18 +6284,16 @@ public virtual Task Byte_array_of_type_varbinary_max_filter_by_index_of_paramete
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_n_filter_by_index_of_literal_does_not_cast(bool async)
{
return AssertQuery(
public virtual Task Byte_array_with_length_n_filter_by_index_of_literal(bool async)
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, (byte)5) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, (byte)5) == 1)
);
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_n_filter_by_index_of_parameter_does_not_cast(bool async)
public virtual Task Byte_array_with_lenght_n_filter_by_index_of_parameter(bool async)
{
byte b = 4;
return AssertQuery(
Expand All @@ -6309,21 +6305,19 @@ public virtual Task Byte_array_of_type_varbinary_n_filter_by_index_of_parameter_
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_max_filter_by_index_of_with_starting_position_literal_casts_to_int(bool async)
{
return AssertQuery(
public virtual Task Byte_array_with_max_possible_length_filter_by_index_of_with_starting_position(bool async)
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, (byte)1, 1) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner, (byte)1, 1) == 1)
);
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_max_filter_by_index_of_with_starting_position_parameter_casts_to_int(bool async)
public virtual Task Byte_array_with_max_possible_length_filter_by_index_of_with_starting_position_parameter(bool async)
{
byte b = 0;
int startPos = 0;
var startPos = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner, b, startPos) == 0),
Expand All @@ -6333,21 +6327,19 @@ public virtual Task Byte_array_of_type_varbinary_max_filter_by_index_of_with_sta
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_n_filter_by_index_of_with_starting_position_literal_does_not_cast(bool async)
{
return AssertQuery(
public virtual Task Byte_array_with_length_n_filter_by_index_of_with_starting_position_literal(bool async)
=> AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, (byte)5, 1) == 1),
ss => ss.Set<Squad>().Where(w => w.Banner != null && Array.IndexOf(w.Banner5, (byte)5, 1) == 1)
);
}
[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Byte_array_of_type_varbinary_n_filter_by_index_of_with_starting_position_parameter_does_not_cast(bool async)
public virtual Task Byte_array_with_length_n_filter_by_index_of_with_starting_position_parameter(bool async)
{
byte b = 4;
int startPos = 0;
var startPos = 0;
return AssertQuery(
async,
ss => ss.Set<Squad>().Where(w => Array.IndexOf(w.Banner5, b, startPos) == 0),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7678,9 +7678,9 @@ WHERE DATALENGTH([s].[Banner5]) = 5

#region Byte Array IndexOf Translation

public override async Task Byte_array_of_type_varbinary_max_filter_by_index_of_literal_casts_to_int(bool async)
public override async Task Byte_array_with_max_possible_length_filter_by_index_of_literal(bool async)
{
await base.Byte_array_of_type_varbinary_max_filter_by_index_of_literal_casts_to_int(async);
await base.Byte_array_with_max_possible_length_filter_by_index_of_literal(async);

AssertSql(
"""
Expand All @@ -7690,9 +7690,9 @@ WHERE CAST(CHARINDEX(0x01, [s].[Banner]) AS int) - 1 = 1
""");
}

public override async Task Byte_array_of_type_varbinary_max_filter_by_index_of_parameter_casts_to_int(bool async)
public override async Task Byte_array_with_max_possible_length_filter_by_index_of_parameter(bool async)
{
await base.Byte_array_of_type_varbinary_max_filter_by_index_of_parameter_casts_to_int(async);
await base.Byte_array_with_max_possible_length_filter_by_index_of_parameter(async);

AssertSql(
"""
Expand All @@ -7704,9 +7704,9 @@ WHERE CAST(CHARINDEX(CAST(@__b_0 AS varbinary(max)), [s].[Banner]) AS int) - 1 =
""");
}

public override async Task Byte_array_of_type_varbinary_n_filter_by_index_of_literal_does_not_cast(bool async)
public override async Task Byte_array_with_length_n_filter_by_index_of_literal(bool async)
{
await base.Byte_array_of_type_varbinary_n_filter_by_index_of_literal_does_not_cast(async);
await base.Byte_array_with_length_n_filter_by_index_of_literal(async);

AssertSql(
"""
Expand All @@ -7716,9 +7716,9 @@ WHERE CHARINDEX(0x05, [s].[Banner5]) - 1 = 1
""");
}

public override async Task Byte_array_of_type_varbinary_n_filter_by_index_of_parameter_does_not_cast(bool async)
public override async Task Byte_array_with_lenght_n_filter_by_index_of_parameter(bool async)
{
await base.Byte_array_of_type_varbinary_n_filter_by_index_of_parameter_does_not_cast(async);
await base.Byte_array_with_lenght_n_filter_by_index_of_parameter(async);

AssertSql(
"""
Expand All @@ -7730,9 +7730,9 @@ WHERE CHARINDEX(CAST(@__b_0 AS varbinary(5)), [s].[Banner5]) - 1 = 0
""");
}

public override async Task Byte_array_of_type_varbinary_max_filter_by_index_of_with_starting_position_literal_casts_to_int(bool async)
public override async Task Byte_array_with_max_possible_length_filter_by_index_of_with_starting_position(bool async)
{
await base.Byte_array_of_type_varbinary_max_filter_by_index_of_with_starting_position_literal_casts_to_int(async);
await base.Byte_array_with_max_possible_length_filter_by_index_of_with_starting_position(async);

AssertSql(
"""
Expand All @@ -7742,9 +7742,9 @@ WHERE CAST(CHARINDEX(0x01, [s].[Banner], 2) AS int) - 1 = 1
""");
}

public override async Task Byte_array_of_type_varbinary_max_filter_by_index_of_with_starting_position_parameter_casts_to_int(bool async)
public override async Task Byte_array_with_max_possible_length_filter_by_index_of_with_starting_position_parameter(bool async)
{
await base.Byte_array_of_type_varbinary_max_filter_by_index_of_with_starting_position_parameter_casts_to_int(async);
await base.Byte_array_with_max_possible_length_filter_by_index_of_with_starting_position_parameter(async);

AssertSql(
"""
Expand All @@ -7757,9 +7757,9 @@ WHERE CAST(CHARINDEX(CAST(@__b_0 AS varbinary(max)), [s].[Banner], @__startPos_1
""");
}

public override async Task Byte_array_of_type_varbinary_n_filter_by_index_of_with_starting_position_literal_does_not_cast(bool async)
public override async Task Byte_array_with_length_n_filter_by_index_of_with_starting_position_literal(bool async)
{
await base.Byte_array_of_type_varbinary_n_filter_by_index_of_with_starting_position_literal_does_not_cast(async);
await base.Byte_array_with_length_n_filter_by_index_of_with_starting_position_literal(async);

AssertSql(
"""
Expand All @@ -7769,9 +7769,9 @@ WHERE CHARINDEX(0x05, [s].[Banner5], 2) - 1 = 1
""");
}

public override async Task Byte_array_of_type_varbinary_n_filter_by_index_of_with_starting_position_parameter_does_not_cast(bool async)
public override async Task Byte_array_with_length_n_filter_by_index_of_with_starting_position_parameter(bool async)
{
await base.Byte_array_of_type_varbinary_n_filter_by_index_of_with_starting_position_parameter_does_not_cast(async);
await base.Byte_array_with_length_n_filter_by_index_of_with_starting_position_parameter(async);

AssertSql(
"""
Expand Down
Loading

0 comments on commit ec0b63d

Please sign in to comment.