Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constructor replacement #16

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ using Pose;
Shim ctorShim = Shim.Replace(() => new MyClass()).With(() => new MyClass() { MyProperty = 10 });
```

```csharp
using Pose;

Shim ctorShim = Shim.Replace<MyClassBase>(() => new MyClassAInheritingFromBase()).With(() => new MyClassBInheritingFromBase());
```

```csharp
using Pose;

Shim ctorShim = Shim.Replace(() => new MyClassAInheritingFromBase()).With(() => new MyClassBInheritingFromBase(42));
```

### Shim instance method of a Reference Type

```csharp
Expand Down
62 changes: 53 additions & 9 deletions src/Pose/Helpers/ShimHelper.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand All @@ -13,6 +12,9 @@ internal static class ShimHelper
{
public static MethodBase GetMethodFromExpression(Expression expression, bool setter, out Object instanceOrType)
{
if (expression == null)
throw new ArgumentNullException(nameof(expression));

switch (expression.NodeType)
{
case ExpressionType.MemberAccess:
Expand All @@ -25,8 +27,8 @@ public static MethodBase GetMethodFromExpression(Expression expression, bool set
instanceOrType = GetObjectInstanceOrType(memberExpression.Expression);
return setter ? propertyInfo.GetSetMethod() : propertyInfo.GetGetMethod();
}
else
throw new NotImplementedException("Unsupported expression");

throw new NotImplementedException("Unsupported expression");
}
case ExpressionType.Call:
MethodCallExpression methodCallExpression = expression as MethodCallExpression;
Expand All @@ -41,8 +43,14 @@ public static MethodBase GetMethodFromExpression(Expression expression, bool set
}
}

public static void ValidateReplacementMethodSignature(MethodBase original, MethodInfo replacement, Type type, bool setter)
public static void ValidateReplacementMethodSignature(MethodBase original, MethodInfo replacement, Type type,
bool setter, Type baseType)
{
if (original == null)
throw new ArgumentNullException(nameof(original));
if (replacement == null)
throw new ArgumentNullException(nameof(replacement));

bool isValueType = original.IsForValueType();
bool isStatic = original.IsStatic;
bool isConstructor = original.IsConstructor;
Expand All @@ -56,14 +64,23 @@ public static void ValidateReplacementMethodSignature(MethodBase original, Metho
Type shimOwningType = isStaticOrConstructor
? validOwningType : replacement.GetParameters().Select(p => p.ParameterType).FirstOrDefault();

var validParameterTypes = original.GetParameters().Select(p => p.ParameterType);
var validParameterTypes = original.GetParameters().Select(p => p.ParameterType).ToList();
var shimParameterTypes = replacement.GetParameters()
.Select(p => p.ParameterType)
.Skip(isStaticOrConstructor ? 0 : 1);
.Skip(isStaticOrConstructor ? 0 : 1).ToList();

if (vaildReturnType != shimReturnType)
if (!isConstructor && vaildReturnType != shimReturnType)
throw new InvalidShimSignatureException("Mismatched return types");

if (isConstructor)
{
var isValidReturnType = CheckTypesForAssignability(baseType, shimReturnType);
if (!isValidReturnType)
{
throw new InvalidShimSignatureException("Mismatched construction types");
}
}

if (!isStaticOrConstructor)
{
if (isValueType && !shimOwningType.IsByRef)
Expand All @@ -73,10 +90,10 @@ public static void ValidateReplacementMethodSignature(MethodBase original, Metho
if ((isValueType && !isStaticOrConstructor ? validOwningType.MakeByRefType() : validOwningType) != shimOwningType)
throw new InvalidShimSignatureException("Mismatched instance types");

if (validParameterTypes.Count() != shimParameterTypes.Count())
if (validParameterTypes.Count != shimParameterTypes.Count)
throw new InvalidShimSignatureException("Parameters count do not match");

for (int i = 0; i < validParameterTypes.Count(); i++)
for (int i = 0; i < validParameterTypes.Count; i++)
{
if (validParameterTypes.ElementAt(i) != shimParameterTypes.ElementAt(i))
throw new InvalidShimSignatureException($"Parameter types at {i} do not match");
Expand Down Expand Up @@ -126,5 +143,32 @@ private static void EnsureInstanceNotValueType(object instance)
if (instance.GetType().IsSubclassOf(typeof(ValueType)))
throw new NotSupportedException("You cannot replace methods on specific value type instances");
}

private static bool CheckTypesForAssignability(Type baseType, Type typeToCheck)
{
var stype = typeToCheck;
while (stype != typeof(object))
{
if (baseType.IsAssignableFrom(stype))
{
return true;
}

stype = stype.BaseType;
}

stype = baseType.BaseType;
while (stype != typeof(object))
{
if (stype.IsAssignableFrom(typeToCheck))
{
return true;
}

stype = stype.BaseType;
}

return false;
}
}
}
56 changes: 14 additions & 42 deletions src/Pose/Shim.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

Expand All @@ -9,69 +8,42 @@ namespace Pose
{
public partial class Shim
{
private MethodBase _original;
private Delegate _replacement;
private Object _instance;
private Type _type;
private Type _baseType;
private bool _setter;

internal MethodBase Original
{
get
{
return _original;
}
}
internal MethodBase Original { get; }

internal Delegate Replacement
{
get
{
return _replacement;
}
}
internal Delegate Replacement { get; private set; }

internal Object Instance
{
get
{
return _instance;
}
}
internal Object Instance { get; }

internal Type Type
{
get
{
return _type;
}
}
internal Type Type { get; }

private Shim(MethodBase original, object instanceOrType)
{
_original = original;
Original = original ?? throw new ArgumentNullException(nameof(original));
if (instanceOrType is Type type)
_type = type;
Type = type;
else
_instance = instanceOrType;
Instance = instanceOrType;
}

public static Shim Replace(Expression<Action> expression, bool setter = false)
=> ReplaceImpl(expression, setter);
=> ReplaceImpl(expression, setter, null);

public static Shim Replace<T>(Expression<Func<T>> expression, bool setter = false)
=> ReplaceImpl(expression, setter);
=> ReplaceImpl(expression, setter, typeof(T));

private static Shim ReplaceImpl<T>(Expression<T> expression, bool setter)
private static Shim ReplaceImpl<T>(Expression<T> expression, bool setter, Type baseType)
{
MethodBase methodBase = ShimHelper.GetMethodFromExpression(expression.Body, setter, out object instance);
return new Shim(methodBase, instance) { _setter = setter };
return new Shim(methodBase, instance) { _setter = setter, _baseType = baseType};
}

private Shim WithImpl(Delegate replacement)
{
_replacement = replacement;
ShimHelper.ValidateReplacementMethodSignature(this._original, this._replacement.Method, _instance?.GetType() ?? _type, _setter);
Replacement = replacement;
ShimHelper.ValidateReplacementMethodSignature(this.Original, this.Replacement.Method, Instance?.GetType() ?? Type, _setter, _baseType);
return this;
}
}
Expand Down
94 changes: 87 additions & 7 deletions test/Pose.Tests/ShimTests.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;

using Pose.Exceptions;
using Pose.Helpers;
using Microsoft.VisualStudio.TestTools.UnitTesting;

using static System.Console;

namespace Pose.Tests
{
[TestClass]
Expand Down Expand Up @@ -81,8 +77,7 @@ public void TestReplacePropertySetter()
Assert.AreEqual(typeof(Thread).GetProperty(nameof(Thread.CurrentCulture), typeof(CultureInfo)).SetMethod, shim.Original);
Assert.IsNull(shim.Replacement);
}



[TestMethod]
public void TestReplacePropertySetterAction()
{
Expand Down Expand Up @@ -114,5 +109,90 @@ public void TestReplacePropertySetterAction()
Assert.IsTrue(getterExecuted, "Getter not executed");
Assert.IsTrue(setterExecuted, "Setter not executed");
}

[TestMethod]
public void TestReplaceConstructor()
{
var dummyShim = Shim.Replace<DummyForConstructorBase>(() => new DummyForConstructor()).With(() => new DummyForConstructorReplacment());
bool? wasCalled = false;

PoseContext.Isolate(() =>
{
wasCalled = new DummyHolder().Dummy.ConstructorCalled;
}, dummyShim);

Assert.IsNotNull(wasCalled);
Assert.IsFalse(wasCalled.Value);
}

[TestMethod]
public void TestReplaceConstructorWithoutGeneric()
{
var dummyShim = Shim.Replace(() => new DummyForConstructor()).With(() => new DummyForConstructorReplacment());
bool? wasCalled = false;

PoseContext.Isolate(() =>
{
wasCalled = new DummyHolder().Dummy.ConstructorCalled;
}, dummyShim);

Assert.IsNotNull(wasCalled);
Assert.IsFalse(wasCalled.Value);
}

[TestMethod]
public void TestReplaceConstructorWithoutGeneric2()
{
var dummyShim = Shim.Replace(() => new DummyForConstructorReplacment()).With(() => new DummyForConstructorReplacment(true));

PoseContext.Isolate(() =>
{
var unused = new DummyForConstructorReplacment();
}, dummyShim);
}

[TestMethod]
public void TestReplaceConstructorWithIncorrectTypes()
{
List<string> Replacement() => new List<string>();
var exception = Assert.ThrowsException<InvalidShimSignatureException>(() => Shim.Replace<DummyForConstructorBase>(() => new DummyForConstructor()).With(Replacement));
Assert.AreEqual("Mismatched construction types", exception.Message);
}

private class DummyHolder
{
public DummyHolder()
{
Dummy = new DummyForConstructor();
}

public DummyForConstructorBase Dummy { get; }
}

private abstract class DummyForConstructorBase
{
public bool? ConstructorCalled { get; protected set; }
}

private class DummyForConstructor : DummyForConstructorBase
{
public DummyForConstructor()
{
ConstructorCalled = true;
}
}

private class DummyForConstructorReplacment : DummyForConstructorBase
{
public DummyForConstructorReplacment()
{
ConstructorCalled = false;
}

public DummyForConstructorReplacment(bool dummyFlag)
{
ConstructorCalled = false;
}
}
}
}