diff --git a/README.md b/README.md index d9f2bb0..7dac40c 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,18 @@ using Pose; Shim ctorShim = Shim.Replace(() => new MyClass()).With(() => new MyClass() { MyProperty = 10 }); ``` +```csharp +using Pose; + +Shim ctorShim = Shim.Replace(() => new MyClassAInheritingFromBase()).With(() => new MyClassBInheritingFromBase()); +``` + +```csharp +using Pose; + +Shim ctorShim = Shim.Replace(() => new MyClassAInheritingFromBase()).With(() => new MyClassBInheritingFromBase(42)); +``` + ### Shim instance method of a Reference Type ```csharp diff --git a/src/Pose/Helpers/ShimHelper.cs b/src/Pose/Helpers/ShimHelper.cs index 1b1df7d..f06cc0c 100644 --- a/src/Pose/Helpers/ShimHelper.cs +++ b/src/Pose/Helpers/ShimHelper.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -13,6 +12,9 @@ internal static class ShimHelper { public static MethodBase GetMethodFromExpression(Expression expression, bool setter, out Object instanceOrType) { + if (expression == null) + throw new ArgumentNullException(nameof(expression)); + switch (expression.NodeType) { case ExpressionType.MemberAccess: @@ -25,8 +27,8 @@ public static MethodBase GetMethodFromExpression(Expression expression, bool set instanceOrType = GetObjectInstanceOrType(memberExpression.Expression); return setter ? propertyInfo.GetSetMethod() : propertyInfo.GetGetMethod(); } - else - throw new NotImplementedException("Unsupported expression"); + + throw new NotImplementedException("Unsupported expression"); } case ExpressionType.Call: MethodCallExpression methodCallExpression = expression as MethodCallExpression; @@ -41,8 +43,14 @@ public static MethodBase GetMethodFromExpression(Expression expression, bool set } } - public static void ValidateReplacementMethodSignature(MethodBase original, MethodInfo replacement, Type type, bool setter) + public static void ValidateReplacementMethodSignature(MethodBase original, MethodInfo replacement, Type type, + bool setter, Type baseType) { + if (original == null) + throw new ArgumentNullException(nameof(original)); + if (replacement == null) + throw new ArgumentNullException(nameof(replacement)); + bool isValueType = original.IsForValueType(); bool isStatic = original.IsStatic; bool isConstructor = original.IsConstructor; @@ -56,14 +64,23 @@ public static void ValidateReplacementMethodSignature(MethodBase original, Metho Type shimOwningType = isStaticOrConstructor ? validOwningType : replacement.GetParameters().Select(p => p.ParameterType).FirstOrDefault(); - var validParameterTypes = original.GetParameters().Select(p => p.ParameterType); + var validParameterTypes = original.GetParameters().Select(p => p.ParameterType).ToList(); var shimParameterTypes = replacement.GetParameters() .Select(p => p.ParameterType) - .Skip(isStaticOrConstructor ? 0 : 1); + .Skip(isStaticOrConstructor ? 0 : 1).ToList(); - if (vaildReturnType != shimReturnType) + if (!isConstructor && vaildReturnType != shimReturnType) throw new InvalidShimSignatureException("Mismatched return types"); + if (isConstructor) + { + var isValidReturnType = CheckTypesForAssignability(baseType, shimReturnType); + if (!isValidReturnType) + { + throw new InvalidShimSignatureException("Mismatched construction types"); + } + } + if (!isStaticOrConstructor) { if (isValueType && !shimOwningType.IsByRef) @@ -73,10 +90,10 @@ public static void ValidateReplacementMethodSignature(MethodBase original, Metho if ((isValueType && !isStaticOrConstructor ? validOwningType.MakeByRefType() : validOwningType) != shimOwningType) throw new InvalidShimSignatureException("Mismatched instance types"); - if (validParameterTypes.Count() != shimParameterTypes.Count()) + if (validParameterTypes.Count != shimParameterTypes.Count) throw new InvalidShimSignatureException("Parameters count do not match"); - for (int i = 0; i < validParameterTypes.Count(); i++) + for (int i = 0; i < validParameterTypes.Count; i++) { if (validParameterTypes.ElementAt(i) != shimParameterTypes.ElementAt(i)) throw new InvalidShimSignatureException($"Parameter types at {i} do not match"); @@ -126,5 +143,32 @@ private static void EnsureInstanceNotValueType(object instance) if (instance.GetType().IsSubclassOf(typeof(ValueType))) throw new NotSupportedException("You cannot replace methods on specific value type instances"); } + + private static bool CheckTypesForAssignability(Type baseType, Type typeToCheck) + { + var stype = typeToCheck; + while (stype != typeof(object)) + { + if (baseType.IsAssignableFrom(stype)) + { + return true; + } + + stype = stype.BaseType; + } + + stype = baseType.BaseType; + while (stype != typeof(object)) + { + if (stype.IsAssignableFrom(typeToCheck)) + { + return true; + } + + stype = stype.BaseType; + } + + return false; + } } } \ No newline at end of file diff --git a/src/Pose/Shim.cs b/src/Pose/Shim.cs index 670b299..a944a17 100644 --- a/src/Pose/Shim.cs +++ b/src/Pose/Shim.cs @@ -1,5 +1,4 @@ using System; -using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -9,69 +8,42 @@ namespace Pose { public partial class Shim { - private MethodBase _original; - private Delegate _replacement; - private Object _instance; - private Type _type; + private Type _baseType; private bool _setter; - internal MethodBase Original - { - get - { - return _original; - } - } + internal MethodBase Original { get; } - internal Delegate Replacement - { - get - { - return _replacement; - } - } + internal Delegate Replacement { get; private set; } - internal Object Instance - { - get - { - return _instance; - } - } + internal Object Instance { get; } - internal Type Type - { - get - { - return _type; - } - } + internal Type Type { get; } private Shim(MethodBase original, object instanceOrType) { - _original = original; + Original = original ?? throw new ArgumentNullException(nameof(original)); if (instanceOrType is Type type) - _type = type; + Type = type; else - _instance = instanceOrType; + Instance = instanceOrType; } public static Shim Replace(Expression expression, bool setter = false) - => ReplaceImpl(expression, setter); + => ReplaceImpl(expression, setter, null); public static Shim Replace(Expression> expression, bool setter = false) - => ReplaceImpl(expression, setter); + => ReplaceImpl(expression, setter, typeof(T)); - private static Shim ReplaceImpl(Expression expression, bool setter) + private static Shim ReplaceImpl(Expression expression, bool setter, Type baseType) { MethodBase methodBase = ShimHelper.GetMethodFromExpression(expression.Body, setter, out object instance); - return new Shim(methodBase, instance) { _setter = setter }; + return new Shim(methodBase, instance) { _setter = setter, _baseType = baseType}; } private Shim WithImpl(Delegate replacement) { - _replacement = replacement; - ShimHelper.ValidateReplacementMethodSignature(this._original, this._replacement.Method, _instance?.GetType() ?? _type, _setter); + Replacement = replacement; + ShimHelper.ValidateReplacementMethodSignature(this.Original, this.Replacement.Method, Instance?.GetType() ?? Type, _setter, _baseType); return this; } } diff --git a/test/Pose.Tests/ShimTests.cs b/test/Pose.Tests/ShimTests.cs index 29c7641..1f14ad4 100644 --- a/test/Pose.Tests/ShimTests.cs +++ b/test/Pose.Tests/ShimTests.cs @@ -1,15 +1,11 @@ using System; +using System.Collections.Generic; using System.Globalization; -using System.Linq.Expressions; -using System.Reflection; using System.Threading; using Pose.Exceptions; -using Pose.Helpers; using Microsoft.VisualStudio.TestTools.UnitTesting; -using static System.Console; - namespace Pose.Tests { [TestClass] @@ -81,8 +77,7 @@ public void TestReplacePropertySetter() Assert.AreEqual(typeof(Thread).GetProperty(nameof(Thread.CurrentCulture), typeof(CultureInfo)).SetMethod, shim.Original); Assert.IsNull(shim.Replacement); } - - + [TestMethod] public void TestReplacePropertySetterAction() { @@ -114,5 +109,90 @@ public void TestReplacePropertySetterAction() Assert.IsTrue(getterExecuted, "Getter not executed"); Assert.IsTrue(setterExecuted, "Setter not executed"); } + + [TestMethod] + public void TestReplaceConstructor() + { + var dummyShim = Shim.Replace(() => new DummyForConstructor()).With(() => new DummyForConstructorReplacment()); + bool? wasCalled = false; + + PoseContext.Isolate(() => + { + wasCalled = new DummyHolder().Dummy.ConstructorCalled; + }, dummyShim); + + Assert.IsNotNull(wasCalled); + Assert.IsFalse(wasCalled.Value); + } + + [TestMethod] + public void TestReplaceConstructorWithoutGeneric() + { + var dummyShim = Shim.Replace(() => new DummyForConstructor()).With(() => new DummyForConstructorReplacment()); + bool? wasCalled = false; + + PoseContext.Isolate(() => + { + wasCalled = new DummyHolder().Dummy.ConstructorCalled; + }, dummyShim); + + Assert.IsNotNull(wasCalled); + Assert.IsFalse(wasCalled.Value); + } + + [TestMethod] + public void TestReplaceConstructorWithoutGeneric2() + { + var dummyShim = Shim.Replace(() => new DummyForConstructorReplacment()).With(() => new DummyForConstructorReplacment(true)); + + PoseContext.Isolate(() => + { + var unused = new DummyForConstructorReplacment(); + }, dummyShim); + } + + [TestMethod] + public void TestReplaceConstructorWithIncorrectTypes() + { + List Replacement() => new List(); + var exception = Assert.ThrowsException(() => Shim.Replace(() => new DummyForConstructor()).With(Replacement)); + Assert.AreEqual("Mismatched construction types", exception.Message); + } + + private class DummyHolder + { + public DummyHolder() + { + Dummy = new DummyForConstructor(); + } + + public DummyForConstructorBase Dummy { get; } + } + + private abstract class DummyForConstructorBase + { + public bool? ConstructorCalled { get; protected set; } + } + + private class DummyForConstructor : DummyForConstructorBase + { + public DummyForConstructor() + { + ConstructorCalled = true; + } + } + + private class DummyForConstructorReplacment : DummyForConstructorBase + { + public DummyForConstructorReplacment() + { + ConstructorCalled = false; + } + + public DummyForConstructorReplacment(bool dummyFlag) + { + ConstructorCalled = false; + } + } } }