Skip to content

Commit

Permalink
Fix IAsyncQueryProvider implementation (#2)
Browse files Browse the repository at this point in the history
* Fix IAsyncQueryProvider implementation
  • Loading branch information
ivo-stoilov authored Sep 17, 2024
1 parent 841781e commit 109e0ce
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 44 deletions.
17 changes: 17 additions & 0 deletions Telerik.JustMock.EntityFrameworkCore.Tests/DbSetTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,22 @@ public async Task Queries_DataIsThere()

Assert.AreEqual(2, xes.Count);
}

[TestMethod]
public async Task Queries_AsyncProviderIsSupported()
{
var ctx = Mock.Create<TheDbContext>().PrepareMock();
var list = new List<Person>
{
new Person { Id = 1, Name = "x" },
new Person { Id = 2, Name = "x" },
new Person { Id = 3, Name = "y" }
};
ctx.People.Bind(list);

var res = await ctx.People.SingleAsync(x => x.Id == 1);

Assert.IsNotNull(res);
}
}
}
2 changes: 1 addition & 1 deletion Telerik.JustMock.EntityFrameworkCore/MockDbSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public virtual Expression Expression

public virtual IQueryProvider Provider
{
get { return new TestDbAsyncQueryProvider<TEntity>(this.asQueryable.Provider); }
get { return new TestDbAsyncQueryProvider<TEntity>(this.asQueryable.Expression); }
}

public override IEntityType EntityType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToke

IQueryProvider IQueryable.Provider
{
get { return new TestDbAsyncQueryProvider<T>(this); }
get { return new TestDbAsyncQueryProvider<T>(this.AsEnumerable()); }
}
}

Expand Down
82 changes: 40 additions & 42 deletions Telerik.JustMock.EntityFrameworkCore/TestDbAsyncQueryProvider.cs
Original file line number Diff line number Diff line change
@@ -1,43 +1,41 @@
using Microsoft.EntityFrameworkCore.Query;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;

namespace Telerik.JustMock.EntityFrameworkCore
{
internal class TestDbAsyncQueryProvider<TEntity> : IAsyncQueryProvider
{
private readonly IQueryProvider _inner;

public TestDbAsyncQueryProvider(IQueryProvider inner)
{
_inner = inner;
}

public IQueryable CreateQuery(Expression expression)
{
return new TestDbAsyncEnumerable<TEntity>(expression);
}

public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
{
return new TestDbAsyncEnumerable<TElement>(expression);
}

public object Execute(Expression expression)
{
return _inner.Execute(expression);
}

public TResult Execute<TResult>(Expression expression)
{
return _inner.Execute<TResult>(expression);
}

public TResult ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken = default)
{
return Task.FromResult(Execute<TResult>(expression)).GetAwaiter().GetResult();
}
}
using Microsoft.EntityFrameworkCore.Query;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
using System.Threading.Tasks;

namespace Telerik.JustMock.EntityFrameworkCore
{
internal class TestDbAsyncQueryProvider<TEntity> : TestQueryProvider<TEntity>, IAsyncEnumerable<TEntity>, IAsyncQueryProvider
{
public TestDbAsyncQueryProvider(Expression expression)
: base(expression)
{
}

public TestDbAsyncQueryProvider(IEnumerable<TEntity> enumerable)
: base(enumerable)
{
}

public TResult ExecuteAsync<TResult>(Expression expression, CancellationToken cancellationToken)
{
var expectedResultType = typeof(TResult).GetGenericArguments()[0];
var executionResult = typeof(IQueryProvider)
.GetMethods()
.Single(method => method.Name == nameof(IQueryProvider.Execute) && method.IsGenericMethod)
.MakeGenericMethod(expectedResultType)
.Invoke(this, new object[] { expression });

return (TResult)typeof(Task).GetMethod(nameof(Task.FromResult))
.MakeGenericMethod(expectedResultType)
.Invoke(null, new[] { executionResult });
}

public IAsyncEnumerator<TEntity> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new TestDbAsyncEnumerator<TEntity>(this.AsEnumerable().GetEnumerator());
}
}
}
6 changes: 6 additions & 0 deletions Telerik.JustMock.EntityFrameworkCore/TestExpressionVisitor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using System.Linq.Expressions;

namespace Telerik.JustMock.EntityFrameworkCore
{
internal class TestExpressionVisitor : ExpressionVisitor { }
}
94 changes: 94 additions & 0 deletions Telerik.JustMock.EntityFrameworkCore/TestQueryProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;

namespace Telerik.JustMock.EntityFrameworkCore
{
internal abstract class TestQueryProvider<TEntity> : IOrderedQueryable<TEntity>, IQueryProvider
{
private IEnumerable<TEntity> _enumerable;

protected TestQueryProvider(Expression expression)
{
this.Expression = expression;
}

protected TestQueryProvider(IEnumerable<TEntity> enumerable)
{
_enumerable = enumerable;
this.Expression = enumerable.AsQueryable().Expression;
}

public IQueryable CreateQuery(Expression expression)
{
if (expression is MethodCallExpression m)
{
var resultType = m.Method.ReturnType;
var tElement = resultType.GetGenericArguments().First();
return CreateIQueryableInstance<IQueryable>(tElement, expression);
}

return CreateQuery<TEntity>(expression);
}

public IQueryable<TEntity> CreateQuery<TEntity>(Expression expression)
{
return CreateIQueryableInstance<IQueryable<TEntity>>(typeof(TEntity), expression);
}

private TQueryable CreateIQueryableInstance<TQueryable>(Type elementType, Expression expression) where TQueryable : IQueryable
{
var queryType = this.GetType().GetGenericTypeDefinition().MakeGenericType(elementType);
return (TQueryable)Activator.CreateInstance(queryType, expression);
}

public object Execute(Expression expression)
{
return CompileExpressionItem<object>(expression);
}

public TResult Execute<TResult>(Expression expression)
{
return CompileExpressionItem<TResult>(expression);
}

IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
{
if (_enumerable == null)
{
_enumerable = CompileExpressionItem<IEnumerable<TEntity>>(this.Expression);
}

return _enumerable.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
if (_enumerable == null)
{
_enumerable = CompileExpressionItem<IEnumerable<TEntity>>(this.Expression);
}

return _enumerable.GetEnumerator();
}

public Type ElementType => typeof(TEntity);

public Expression Expression { get; }

public IQueryProvider Provider
{
get { return this; }
}

private static TResult CompileExpressionItem<TResult>(Expression expression)
{
var visitor = new TestExpressionVisitor();
var body = visitor.Visit(expression);
var f = Expression.Lambda<Func<TResult>>(body ?? throw new InvalidOperationException($"{nameof(body)} is null"), (IEnumerable<ParameterExpression>)null);
return f.Compile()();
}
}
}

0 comments on commit 109e0ce

Please sign in to comment.