Skip to content

Commit

Permalink
Update to fix exception in InvocationExtractor (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattwhitfield authored Oct 6, 2023
1 parent 1c00070 commit acf6108
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 33 deletions.
23 changes: 12 additions & 11 deletions src/Unitverse.Core.Tests/Helpers/InvocationExtractorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace Unitverse.Core.Tests.Helpers
using Microsoft.CodeAnalysis.CSharp;
using System.Linq;
using NSubstitute;
using Unitverse.Core.Models;

[TestFixture]
public class InvocationExtractorTests
Expand All @@ -19,7 +20,7 @@ public void CanCallExtractFrom()
var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration);

var targetFields = new[] { "_dummyService", "_dummyService2" };
var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleNoReturn").Node, classModel.SemanticModel, targetFields);
var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleNoReturn").Node, targetFields);
result.GetAccessedPropertySymbolsFor("_dummyService2").Single().Name.Should().Be("SomeProp");
result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("NoReturnMethod", "GenericMethod");
result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("ReturnMethod");
Expand All @@ -31,7 +32,7 @@ public void ExtractFrom_DependencyCalledInsidePrivateMethod_ReturnsCalledMethods
var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration);

var targetFields = new[] { "_dummyService", "_dummyService2" };
var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleAsyncMethod").Node, classModel.SemanticModel, targetFields);
var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleAsyncMethod").Node, targetFields);
result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
}
Expand All @@ -42,7 +43,7 @@ public void ExtractFrom_DependencyCalledInsidePublicMethod_ReturnsCalledMethods(
var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration);

var targetFields = new[] { "_dummyService", "_dummyService2" };
var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleDependencyCalledInsidePublicMethod").Node, classModel.SemanticModel, targetFields);
var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleDependencyCalledInsidePublicMethod").Node, targetFields);
result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
}
Expand All @@ -53,7 +54,7 @@ public void ExtractFrom_DeeperNestedDependencyCall_ReturnsCalledMethods()
var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration);

var targetFields = new[] { "_dummyService", "_dummyService2" };
var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleDeeperNestedDependencyCall").Node, classModel.SemanticModel, targetFields);
var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleDeeperNestedDependencyCall").Node, targetFields);
result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
}
Expand All @@ -64,7 +65,7 @@ public void ExtractFrom_DependencyCalledWithDelegate_ReturnsCalledMethods()
var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration);

var targetFields = new[] { "_dummyService", "_dummyService2" };
var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsADelegateMethod").Node, classModel.SemanticModel, targetFields);
var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsADelegateMethod").Node, targetFields);
result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
}
Expand All @@ -75,7 +76,7 @@ public void ExtractFrom_DependencyCalledWithLambda_ReturnsCalledMethods()
var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration);

var targetFields = new[] { "_dummyService", "_dummyService2" };
var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsALambdaMethod").Node, classModel.SemanticModel, targetFields);
var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsALambdaMethod").Node, targetFields);
result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
}
Expand All @@ -86,27 +87,27 @@ public void ExtractFrom_DependencyCalledWithAction_ReturnsCalledMethods()
var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration);

var targetFields = new[] { "_dummyService", "_dummyService2" };
var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsAActionMethod").Node, classModel.SemanticModel, targetFields);
var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsAActionMethod").Node, targetFields);
result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod");
}

[Test]
public void CannotCallExtractFromWithNullNode()
{
FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(default(CSharpSyntaxNode), Substitute.For<SemanticModel>(), new[] { "TestValue1478414786", "TestValue1253389239", "TestValue1543172025" })).Should().Throw<ArgumentNullException>();
FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration), default(CSharpSyntaxNode), new[] { "TestValue1478414786", "TestValue1253389239", "TestValue1543172025" })).Should().Throw<ArgumentNullException>();
}

[Test]
public void CannotCallExtractFromWithNullSemanticModel()
public void CannotCallExtractFromWithNullModel()
{
FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(SyntaxFactory.IdentifierName("hello"), default(SemanticModel), new[] { "TestValue1562618265", "TestValue1888707362", "TestValue2031161598" })).Should().Throw<ArgumentNullException>();
FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(default(ClassModel), SyntaxFactory.IdentifierName("hello"), new[] { "TestValue1562618265", "TestValue1888707362", "TestValue2031161598" })).Should().Throw<ArgumentNullException>();
}

[Test]
public void CannotCallExtractFromWithNullTargetFields()
{
FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(SyntaxFactory.IdentifierName("hello"), Substitute.For<SemanticModel>(), default(IEnumerable<string>))).Should().Throw<ArgumentNullException>();
FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration), SyntaxFactory.IdentifierName("hello"), default(IEnumerable<string>))).Should().Throw<ArgumentNullException>();
}
}
}
32 changes: 11 additions & 21 deletions src/Unitverse.Core/Helpers/InvocationExtractor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,37 @@

public class InvocationExtractor : CSharpSyntaxWalker
{
private InvocationExtractor(SemanticModel semanticModel, IEnumerable<string> targetFields)
: this(semanticModel, targetFields, new HashSet<MethodDeclarationSyntax>())
private InvocationExtractor(INamedTypeSymbol containingTypeSymbol, SemanticModel semanticModel, IEnumerable<string> targetFields)
: this(containingTypeSymbol, semanticModel, targetFields, new HashSet<MethodDeclarationSyntax>())
{
}

private InvocationExtractor(SemanticModel semanticModel, IEnumerable<string> targetFields, ISet<MethodDeclarationSyntax> visitedMethods)
private InvocationExtractor(INamedTypeSymbol containingTypeSymbol, SemanticModel semanticModel, IEnumerable<string> targetFields, ISet<MethodDeclarationSyntax> visitedMethods)
{
_semanticModel = semanticModel;
_targetFields = new HashSet<string>(targetFields);
_visitedMethods = visitedMethods;
_containingTypeSymbol = containingTypeSymbol;
}

public static DependencyAccessMap ExtractFrom(CSharpSyntaxNode node, SemanticModel semanticModel, IEnumerable<string> targetFields)
public static DependencyAccessMap ExtractFrom(ClassModel model, CSharpSyntaxNode node, IEnumerable<string> targetFields)
{
if (node is null)
{
throw new ArgumentNullException(nameof(node));
}

if (semanticModel is null)
if (model is null)
{
throw new ArgumentNullException(nameof(semanticModel));
throw new ArgumentNullException(nameof(model));
}

if (targetFields is null)
{
throw new ArgumentNullException(nameof(targetFields));
}

var extractor = new InvocationExtractor(semanticModel, targetFields);
var extractor = new InvocationExtractor(model.TypeSymbol, model.SemanticModel, targetFields);
node.Accept(extractor);

return new DependencyAccessMap(extractor._methodCalls, extractor._propertyCalls, extractor._invocationCount, extractor._memberAccessCount);
Expand All @@ -49,6 +50,7 @@ public static DependencyAccessMap ExtractFrom(CSharpSyntaxNode node, SemanticMod
private readonly SemanticModel _semanticModel;
private readonly HashSet<string> _targetFields;
private readonly ISet<MethodDeclarationSyntax> _visitedMethods;
private readonly INamedTypeSymbol _containingTypeSymbol;
private int _invocationCount;
private int _memberAccessCount;

Expand Down Expand Up @@ -120,7 +122,7 @@ private void Descend(ExpressionSyntax node)
}

_visitedMethods.Add(methodDeclaration);
var extractor = new InvocationExtractor(_semanticModel, _targetFields, _visitedMethods);
var extractor = new InvocationExtractor(_containingTypeSymbol, _semanticModel, _targetFields, _visitedMethods);
methodDeclaration?.Accept(extractor);
_methodCalls.AddRange(extractor._methodCalls);
}
Expand Down Expand Up @@ -150,24 +152,12 @@ private bool GetFieldTarget(ExpressionSyntax expressionSyntax, out string fieldT

if (symbol == null
|| symbol.Kind != SymbolKind.Method
|| !IsMethodInTheSameClass(symbol, invocationExpression))
|| symbol.ContainingType != _containingTypeSymbol)
{
return null;
}

return symbol.DeclaringSyntaxReferences[0].GetSyntax() as MethodDeclarationSyntax;
}

private bool IsMethodInTheSameClass(ISymbol methodSymbol, ExpressionSyntax invocationExpression)
{
// private methods can only be invoked from inside the class
if (methodSymbol.DeclaredAccessibility == Accessibility.Private)
{
return true;
}

var invocationContainingType = _semanticModel.GetEnclosingSymbol(invocationExpression.SpanStart)?.ContainingType;
return methodSymbol.ContainingType == invocationContainingType;
}
}
}
2 changes: 1 addition & 1 deletion src/Unitverse.Core/Helpers/MockHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public static bool PrepareMockCalls(ClassModel model, CSharpSyntaxNode targetBod
}

var mappedInterfaceFields = model.DependencyMap.MappedInterfaceFields.ToList();
var dependencyMap = InvocationExtractor.ExtractFrom(targetBody, model.SemanticModel, mappedInterfaceFields);
var dependencyMap = InvocationExtractor.ExtractFrom(model, targetBody, mappedInterfaceFields);

foreach (var field in mappedInterfaceFields)
{
Expand Down

0 comments on commit acf6108

Please sign in to comment.