From acf61086da546680d60d408a4356b7c8b394854c Mon Sep 17 00:00:00 2001 From: Matt Whitfield Date: Fri, 6 Oct 2023 18:16:28 +0100 Subject: [PATCH] Update to fix exception in InvocationExtractor (#235) --- .../Helpers/InvocationExtractorTests.cs | 23 ++++++------- .../Helpers/InvocationExtractor.cs | 32 +++++++------------ src/Unitverse.Core/Helpers/MockHelper.cs | 2 +- 3 files changed, 24 insertions(+), 33 deletions(-) diff --git a/src/Unitverse.Core.Tests/Helpers/InvocationExtractorTests.cs b/src/Unitverse.Core.Tests/Helpers/InvocationExtractorTests.cs index 3aaf64d4..6decf544 100644 --- a/src/Unitverse.Core.Tests/Helpers/InvocationExtractorTests.cs +++ b/src/Unitverse.Core.Tests/Helpers/InvocationExtractorTests.cs @@ -9,6 +9,7 @@ namespace Unitverse.Core.Tests.Helpers using Microsoft.CodeAnalysis.CSharp; using System.Linq; using NSubstitute; + using Unitverse.Core.Models; [TestFixture] public class InvocationExtractorTests @@ -19,7 +20,7 @@ public void CanCallExtractFrom() var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration); var targetFields = new[] { "_dummyService", "_dummyService2" }; - var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleNoReturn").Node, classModel.SemanticModel, targetFields); + var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleNoReturn").Node, targetFields); result.GetAccessedPropertySymbolsFor("_dummyService2").Single().Name.Should().Be("SomeProp"); result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("NoReturnMethod", "GenericMethod"); result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("ReturnMethod"); @@ -31,7 +32,7 @@ public void ExtractFrom_DependencyCalledInsidePrivateMethod_ReturnsCalledMethods var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration); var targetFields = new[] { "_dummyService", "_dummyService2" }; - var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleAsyncMethod").Node, classModel.SemanticModel, targetFields); + var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleAsyncMethod").Node, targetFields); result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); } @@ -42,7 +43,7 @@ public void ExtractFrom_DependencyCalledInsidePublicMethod_ReturnsCalledMethods( var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration); var targetFields = new[] { "_dummyService", "_dummyService2" }; - var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleDependencyCalledInsidePublicMethod").Node, classModel.SemanticModel, targetFields); + var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleDependencyCalledInsidePublicMethod").Node, targetFields); result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); } @@ -53,7 +54,7 @@ public void ExtractFrom_DeeperNestedDependencyCall_ReturnsCalledMethods() var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration); var targetFields = new[] { "_dummyService", "_dummyService2" }; - var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleDeeperNestedDependencyCall").Node, classModel.SemanticModel, targetFields); + var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleDeeperNestedDependencyCall").Node, targetFields); result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); } @@ -64,7 +65,7 @@ public void ExtractFrom_DependencyCalledWithDelegate_ReturnsCalledMethods() var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration); var targetFields = new[] { "_dummyService", "_dummyService2" }; - var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsADelegateMethod").Node, classModel.SemanticModel, targetFields); + var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsADelegateMethod").Node, targetFields); result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); } @@ -75,7 +76,7 @@ public void ExtractFrom_DependencyCalledWithLambda_ReturnsCalledMethods() var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration); var targetFields = new[] { "_dummyService", "_dummyService2" }; - var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsALambdaMethod").Node, classModel.SemanticModel, targetFields); + var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsALambdaMethod").Node, targetFields); result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); } @@ -86,7 +87,7 @@ public void ExtractFrom_DependencyCalledWithAction_ReturnsCalledMethods() var classModel = ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration); var targetFields = new[] { "_dummyService", "_dummyService2" }; - var result = InvocationExtractor.ExtractFrom(classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsAActionMethod").Node, classModel.SemanticModel, targetFields); + var result = InvocationExtractor.ExtractFrom(classModel, classModel.Methods.Single(x => x.Name == "SampleDependencyCalledAsAActionMethod").Node, targetFields); result.GetAccessedMethodSymbolsFor("_dummyService").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); result.GetAccessedMethodSymbolsFor("_dummyService2").Select(x => x.Name).Should().BeEquivalentTo("AsyncMethod"); } @@ -94,19 +95,19 @@ public void ExtractFrom_DependencyCalledWithAction_ReturnsCalledMethods() [Test] public void CannotCallExtractFromWithNullNode() { - FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(default(CSharpSyntaxNode), Substitute.For(), new[] { "TestValue1478414786", "TestValue1253389239", "TestValue1543172025" })).Should().Throw(); + FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration), default(CSharpSyntaxNode), new[] { "TestValue1478414786", "TestValue1253389239", "TestValue1543172025" })).Should().Throw(); } [Test] - public void CannotCallExtractFromWithNullSemanticModel() + public void CannotCallExtractFromWithNullModel() { - FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(SyntaxFactory.IdentifierName("hello"), default(SemanticModel), new[] { "TestValue1562618265", "TestValue1888707362", "TestValue2031161598" })).Should().Throw(); + FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(default(ClassModel), SyntaxFactory.IdentifierName("hello"), new[] { "TestValue1562618265", "TestValue1888707362", "TestValue2031161598" })).Should().Throw(); } [Test] public void CannotCallExtractFromWithNullTargetFields() { - FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(SyntaxFactory.IdentifierName("hello"), Substitute.For(), default(IEnumerable))).Should().Throw(); + FluentActions.Invoking(() => InvocationExtractor.ExtractFrom(ClassModelProvider.CreateModel(TestClasses.AutomaticMockGeneration), SyntaxFactory.IdentifierName("hello"), default(IEnumerable))).Should().Throw(); } } } \ No newline at end of file diff --git a/src/Unitverse.Core/Helpers/InvocationExtractor.cs b/src/Unitverse.Core/Helpers/InvocationExtractor.cs index f83aefbb..88c3a77e 100644 --- a/src/Unitverse.Core/Helpers/InvocationExtractor.cs +++ b/src/Unitverse.Core/Helpers/InvocationExtractor.cs @@ -9,28 +9,29 @@ public class InvocationExtractor : CSharpSyntaxWalker { - private InvocationExtractor(SemanticModel semanticModel, IEnumerable targetFields) - : this(semanticModel, targetFields, new HashSet()) + private InvocationExtractor(INamedTypeSymbol containingTypeSymbol, SemanticModel semanticModel, IEnumerable targetFields) + : this(containingTypeSymbol, semanticModel, targetFields, new HashSet()) { } - private InvocationExtractor(SemanticModel semanticModel, IEnumerable targetFields, ISet visitedMethods) + private InvocationExtractor(INamedTypeSymbol containingTypeSymbol, SemanticModel semanticModel, IEnumerable targetFields, ISet visitedMethods) { _semanticModel = semanticModel; _targetFields = new HashSet(targetFields); _visitedMethods = visitedMethods; + _containingTypeSymbol = containingTypeSymbol; } - public static DependencyAccessMap ExtractFrom(CSharpSyntaxNode node, SemanticModel semanticModel, IEnumerable targetFields) + public static DependencyAccessMap ExtractFrom(ClassModel model, CSharpSyntaxNode node, IEnumerable targetFields) { if (node is null) { throw new ArgumentNullException(nameof(node)); } - if (semanticModel is null) + if (model is null) { - throw new ArgumentNullException(nameof(semanticModel)); + throw new ArgumentNullException(nameof(model)); } if (targetFields is null) @@ -38,7 +39,7 @@ public static DependencyAccessMap ExtractFrom(CSharpSyntaxNode node, SemanticMod throw new ArgumentNullException(nameof(targetFields)); } - var extractor = new InvocationExtractor(semanticModel, targetFields); + var extractor = new InvocationExtractor(model.TypeSymbol, model.SemanticModel, targetFields); node.Accept(extractor); return new DependencyAccessMap(extractor._methodCalls, extractor._propertyCalls, extractor._invocationCount, extractor._memberAccessCount); @@ -49,6 +50,7 @@ public static DependencyAccessMap ExtractFrom(CSharpSyntaxNode node, SemanticMod private readonly SemanticModel _semanticModel; private readonly HashSet _targetFields; private readonly ISet _visitedMethods; + private readonly INamedTypeSymbol _containingTypeSymbol; private int _invocationCount; private int _memberAccessCount; @@ -120,7 +122,7 @@ private void Descend(ExpressionSyntax node) } _visitedMethods.Add(methodDeclaration); - var extractor = new InvocationExtractor(_semanticModel, _targetFields, _visitedMethods); + var extractor = new InvocationExtractor(_containingTypeSymbol, _semanticModel, _targetFields, _visitedMethods); methodDeclaration?.Accept(extractor); _methodCalls.AddRange(extractor._methodCalls); } @@ -150,24 +152,12 @@ private bool GetFieldTarget(ExpressionSyntax expressionSyntax, out string fieldT if (symbol == null || symbol.Kind != SymbolKind.Method - || !IsMethodInTheSameClass(symbol, invocationExpression)) + || symbol.ContainingType != _containingTypeSymbol) { return null; } return symbol.DeclaringSyntaxReferences[0].GetSyntax() as MethodDeclarationSyntax; } - - private bool IsMethodInTheSameClass(ISymbol methodSymbol, ExpressionSyntax invocationExpression) - { - // private methods can only be invoked from inside the class - if (methodSymbol.DeclaredAccessibility == Accessibility.Private) - { - return true; - } - - var invocationContainingType = _semanticModel.GetEnclosingSymbol(invocationExpression.SpanStart)?.ContainingType; - return methodSymbol.ContainingType == invocationContainingType; - } } } diff --git a/src/Unitverse.Core/Helpers/MockHelper.cs b/src/Unitverse.Core/Helpers/MockHelper.cs index 2c6623fd..5e1f46a9 100644 --- a/src/Unitverse.Core/Helpers/MockHelper.cs +++ b/src/Unitverse.Core/Helpers/MockHelper.cs @@ -23,7 +23,7 @@ public static bool PrepareMockCalls(ClassModel model, CSharpSyntaxNode targetBod } var mappedInterfaceFields = model.DependencyMap.MappedInterfaceFields.ToList(); - var dependencyMap = InvocationExtractor.ExtractFrom(targetBody, model.SemanticModel, mappedInterfaceFields); + var dependencyMap = InvocationExtractor.ExtractFrom(model, targetBody, mappedInterfaceFields); foreach (var field in mappedInterfaceFields) {