From 109e0cea5f7b68db793aec4b7a8c355fa75ab2fb Mon Sep 17 00:00:00 2001 From: Ivo Stoilov Date: Tue, 17 Sep 2024 12:50:40 +0300 Subject: [PATCH] Fix IAsyncQueryProvider implementation (#2) * Fix IAsyncQueryProvider implementation --- .../DbSetTests.cs | 17 ++++ .../MockDbSet.cs | 2 +- .../TestDbAsyncEnumerable.cs | 2 +- .../TestDbAsyncQueryProvider.cs | 82 ++++++++-------- .../TestExpressionVisitor.cs | 6 ++ .../TestQueryProvider.cs | 94 +++++++++++++++++++ 6 files changed, 159 insertions(+), 44 deletions(-) create mode 100644 Telerik.JustMock.EntityFrameworkCore/TestExpressionVisitor.cs create mode 100644 Telerik.JustMock.EntityFrameworkCore/TestQueryProvider.cs diff --git a/Telerik.JustMock.EntityFrameworkCore.Tests/DbSetTests.cs b/Telerik.JustMock.EntityFrameworkCore.Tests/DbSetTests.cs index 2ab2cb6..8d25e93 100644 --- a/Telerik.JustMock.EntityFrameworkCore.Tests/DbSetTests.cs +++ b/Telerik.JustMock.EntityFrameworkCore.Tests/DbSetTests.cs @@ -304,5 +304,22 @@ public async Task Queries_DataIsThere() Assert.AreEqual(2, xes.Count); } + + [TestMethod] + public async Task Queries_AsyncProviderIsSupported() + { + var ctx = Mock.Create().PrepareMock(); + var list = new List + { + new Person { Id = 1, Name = "x" }, + new Person { Id = 2, Name = "x" }, + new Person { Id = 3, Name = "y" } + }; + ctx.People.Bind(list); + + var res = await ctx.People.SingleAsync(x => x.Id == 1); + + Assert.IsNotNull(res); + } } } diff --git a/Telerik.JustMock.EntityFrameworkCore/MockDbSet.cs b/Telerik.JustMock.EntityFrameworkCore/MockDbSet.cs index 5209182..f8a5a40 100644 --- a/Telerik.JustMock.EntityFrameworkCore/MockDbSet.cs +++ b/Telerik.JustMock.EntityFrameworkCore/MockDbSet.cs @@ -187,7 +187,7 @@ public virtual Expression Expression public virtual IQueryProvider Provider { - get { return new TestDbAsyncQueryProvider(this.asQueryable.Provider); } + get { return new TestDbAsyncQueryProvider(this.asQueryable.Expression); } } public override IEntityType EntityType diff --git a/Telerik.JustMock.EntityFrameworkCore/TestDbAsyncEnumerable.cs b/Telerik.JustMock.EntityFrameworkCore/TestDbAsyncEnumerable.cs index faa270f..0806687 100644 --- a/Telerik.JustMock.EntityFrameworkCore/TestDbAsyncEnumerable.cs +++ b/Telerik.JustMock.EntityFrameworkCore/TestDbAsyncEnumerable.cs @@ -22,7 +22,7 @@ public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToke IQueryProvider IQueryable.Provider { - get { return new TestDbAsyncQueryProvider(this); } + get { return new TestDbAsyncQueryProvider(this.AsEnumerable()); } } } diff --git a/Telerik.JustMock.EntityFrameworkCore/TestDbAsyncQueryProvider.cs b/Telerik.JustMock.EntityFrameworkCore/TestDbAsyncQueryProvider.cs index bfeabfd..cfc77ce 100644 --- a/Telerik.JustMock.EntityFrameworkCore/TestDbAsyncQueryProvider.cs +++ b/Telerik.JustMock.EntityFrameworkCore/TestDbAsyncQueryProvider.cs @@ -1,43 +1,41 @@ -using Microsoft.EntityFrameworkCore.Query; -using System.Linq; -using System.Linq.Expressions; -using System.Threading; -using System.Threading.Tasks; - -namespace Telerik.JustMock.EntityFrameworkCore -{ - internal class TestDbAsyncQueryProvider : IAsyncQueryProvider - { - private readonly IQueryProvider _inner; - - public TestDbAsyncQueryProvider(IQueryProvider inner) - { - _inner = inner; - } - - public IQueryable CreateQuery(Expression expression) - { - return new TestDbAsyncEnumerable(expression); - } - - public IQueryable CreateQuery(Expression expression) - { - return new TestDbAsyncEnumerable(expression); - } - - public object Execute(Expression expression) - { - return _inner.Execute(expression); - } - - public TResult Execute(Expression expression) - { - return _inner.Execute(expression); - } - - public TResult ExecuteAsync(Expression expression, CancellationToken cancellationToken = default) - { - return Task.FromResult(Execute(expression)).GetAwaiter().GetResult(); - } - } +using Microsoft.EntityFrameworkCore.Query; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Threading; +using System.Threading.Tasks; + +namespace Telerik.JustMock.EntityFrameworkCore +{ + internal class TestDbAsyncQueryProvider : TestQueryProvider, IAsyncEnumerable, IAsyncQueryProvider + { + public TestDbAsyncQueryProvider(Expression expression) + : base(expression) + { + } + + public TestDbAsyncQueryProvider(IEnumerable enumerable) + : base(enumerable) + { + } + + public TResult ExecuteAsync(Expression expression, CancellationToken cancellationToken) + { + var expectedResultType = typeof(TResult).GetGenericArguments()[0]; + var executionResult = typeof(IQueryProvider) + .GetMethods() + .Single(method => method.Name == nameof(IQueryProvider.Execute) && method.IsGenericMethod) + .MakeGenericMethod(expectedResultType) + .Invoke(this, new object[] { expression }); + + return (TResult)typeof(Task).GetMethod(nameof(Task.FromResult)) + .MakeGenericMethod(expectedResultType) + .Invoke(null, new[] { executionResult }); + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new TestDbAsyncEnumerator(this.AsEnumerable().GetEnumerator()); + } + } } \ No newline at end of file diff --git a/Telerik.JustMock.EntityFrameworkCore/TestExpressionVisitor.cs b/Telerik.JustMock.EntityFrameworkCore/TestExpressionVisitor.cs new file mode 100644 index 0000000..cb50262 --- /dev/null +++ b/Telerik.JustMock.EntityFrameworkCore/TestExpressionVisitor.cs @@ -0,0 +1,6 @@ +using System.Linq.Expressions; + +namespace Telerik.JustMock.EntityFrameworkCore +{ + internal class TestExpressionVisitor : ExpressionVisitor { } +} \ No newline at end of file diff --git a/Telerik.JustMock.EntityFrameworkCore/TestQueryProvider.cs b/Telerik.JustMock.EntityFrameworkCore/TestQueryProvider.cs new file mode 100644 index 0000000..acd68f7 --- /dev/null +++ b/Telerik.JustMock.EntityFrameworkCore/TestQueryProvider.cs @@ -0,0 +1,94 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; + +namespace Telerik.JustMock.EntityFrameworkCore +{ + internal abstract class TestQueryProvider : IOrderedQueryable, IQueryProvider + { + private IEnumerable _enumerable; + + protected TestQueryProvider(Expression expression) + { + this.Expression = expression; + } + + protected TestQueryProvider(IEnumerable enumerable) + { + _enumerable = enumerable; + this.Expression = enumerable.AsQueryable().Expression; + } + + public IQueryable CreateQuery(Expression expression) + { + if (expression is MethodCallExpression m) + { + var resultType = m.Method.ReturnType; + var tElement = resultType.GetGenericArguments().First(); + return CreateIQueryableInstance(tElement, expression); + } + + return CreateQuery(expression); + } + + public IQueryable CreateQuery(Expression expression) + { + return CreateIQueryableInstance>(typeof(TEntity), expression); + } + + private TQueryable CreateIQueryableInstance(Type elementType, Expression expression) where TQueryable : IQueryable + { + var queryType = this.GetType().GetGenericTypeDefinition().MakeGenericType(elementType); + return (TQueryable)Activator.CreateInstance(queryType, expression); + } + + public object Execute(Expression expression) + { + return CompileExpressionItem(expression); + } + + public TResult Execute(Expression expression) + { + return CompileExpressionItem(expression); + } + + IEnumerator IEnumerable.GetEnumerator() + { + if (_enumerable == null) + { + _enumerable = CompileExpressionItem>(this.Expression); + } + + return _enumerable.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + if (_enumerable == null) + { + _enumerable = CompileExpressionItem>(this.Expression); + } + + return _enumerable.GetEnumerator(); + } + + public Type ElementType => typeof(TEntity); + + public Expression Expression { get; } + + public IQueryProvider Provider + { + get { return this; } + } + + private static TResult CompileExpressionItem(Expression expression) + { + var visitor = new TestExpressionVisitor(); + var body = visitor.Visit(expression); + var f = Expression.Lambda>(body ?? throw new InvalidOperationException($"{nameof(body)} is null"), (IEnumerable)null); + return f.Compile()(); + } + } +} \ No newline at end of file