Skip to content

Commit

Permalink
Added CustomDelegateSerializationHolder and SafeSerializationBinder.
Browse files Browse the repository at this point in the history
  • Loading branch information
yallie committed Apr 25, 2018
1 parent 288224c commit 73bfede
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 10 deletions.
12 changes: 6 additions & 6 deletions SafeDeserializationHelpers.Tests/DelegateValidatorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,35 +15,35 @@ public class DelegateValidatorTests
public void NullDelegateIsValid()
{
// Assert.DoesNotThrow
new DelegateValidator().ValidateDelegate(null);
DelegateValidator.Default.ValidateDelegate(null);
}

[TestMethod]
public void DelegateIsValidUnlessBlacklisted()
{
new DelegateValidator().ValidateDelegate(new Action<int>(x => { }));
DelegateValidator.Default.ValidateDelegate(new Action<int>(x => { }));
}

[TestMethod, ExpectedException(typeof(UnsafeDeserializationException))]
public void SystemDiagnosticsDelegatesAreNotValid()
{
var del = new Func<string, string, Process>(Process.Start);
new DelegateValidator().ValidateDelegate(del);
DelegateValidator.Default.ValidateDelegate(del);
}

[TestMethod, ExpectedException(typeof(UnsafeDeserializationException))]
public void SystemIODelegatesAreNotValid()
{
var del = new Action<string>(File.Delete);
new DelegateValidator().ValidateDelegate(del);
DelegateValidator.Default.ValidateDelegate(del);
}

[TestMethod]
public void MulticastDelegatesAreValidated()
{
var del = new Func<string, string, Process>((a, b) => null);
del = Delegate.Combine(del, del, del) as Func<string, string, Process>;
new DelegateValidator().ValidateDelegate(del);
DelegateValidator.Default.ValidateDelegate(del);
}

[TestMethod, ExpectedException(typeof(UnsafeDeserializationException))]
Expand All @@ -52,7 +52,7 @@ public void MulticastDelegatesWithSystemDiagnosticsMethodsAreNotValid()
var del = new Func<string, string, Process>((a, b) => null);
var start = new Func<string, string, Process>(Process.Start);
del = Delegate.Combine(del, del, start, del, del) as Func<string, string, Process>;
new DelegateValidator().ValidateDelegate(del);
DelegateValidator.Default.ValidateDelegate(del);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<AppDesignerFolder>Properties</AppDesignerFolder>
<RootNamespace>SafeDeserializationHelpers.Tests</RootNamespace>
<AssemblyName>SafeDeserializationHelpers.Tests</AssemblyName>
<TargetFrameworkVersion>v4.5</TargetFrameworkVersion>
<TargetFrameworkVersion>v4.7</TargetFrameworkVersion>
<FileAlignment>512</FileAlignment>
<TargetFrameworkProfile />
</PropertyGroup>
Expand All @@ -22,6 +22,7 @@
<ErrorReport>prompt</ErrorReport>
<WarningLevel>4</WarningLevel>
<Prefer32Bit>false</Prefer32Bit>
<LangVersion>latest</LangVersion>
</PropertyGroup>
<PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
<DebugType>pdbonly</DebugType>
Expand All @@ -44,6 +45,7 @@
<ItemGroup>
<Compile Include="DelegateValidatorTests.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="SafeSerializationBinderTests.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\SafeDeserializationHelpers\SafeDeserializationHelpers.csproj">
Expand Down
176 changes: 176 additions & 0 deletions SafeDeserializationHelpers.Tests/SafeSerializationBinderTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.Serialization.Formatters.Binary;
using System.Text;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace SafeDeserializationHelpers.Tests
{
[TestClass]
public class SafeSerializationBinderTests
{
private void Roundtrip(object graph, bool useBinder)
{
var data = default(byte[]);
var fmt = new BinaryFormatter();
using (var stream = new MemoryStream())
{
fmt.Serialize(stream, graph);
data = stream.ToArray();
}

if (useBinder)
{
fmt.Binder = new SafeSerializationBinder(fmt.Binder);
}

using (var stream = new MemoryStream(data))
{
var deserialized = fmt.Deserialize(stream);
var msg = $"Deserialized data doesn't match when {(useBinder ? string.Empty : "not ")}using binder.";
AssertAreEqual(graph, deserialized, msg);
}
}

private void AssertAreEqual(object expected, object actual, string msg)
{
if (expected is Delegate del1 && actual is Delegate del2)
{
AssertAreEqual(del1.Target, del2.Target, msg);
AssertAreEqual(del1.Method, del2.Method, msg);
return;
}

if (expected is string s1 && actual is string s2)
{
// avoid comparing strings as IEnumerables
Assert.AreEqual(s1, s2, msg);
return;
}

if (expected is IEnumerable enum1 && actual is IEnumerable enum2)
{
Assert.AreEqual(enum1.OfType<object>().Count(), enum2.OfType<object>().Count(), msg);
foreach (var item in enum1.OfType<object>().Zip(enum2.OfType<object>(), (e1, e2) => (e1, e2)))
{
AssertAreEqual(item.e1, item.e2, msg);
}

return;
}

if (expected is IDictionary dic1 && actual is IDictionary dic2)
{
Assert.AreEqual(dic1.Count, dic2.Count, msg);
foreach (var item in dic1.OfType<object>().Zip(dic2.OfType<object>(), (e1, e2) => (e1, e2)))
{
AssertAreEqual(item.e1, item.e2, msg);
}

return;
}

if (expected is DataTable dt1 && actual is DataTable dt2)
{
AssertAreEqual(dt1.TableName, dt2.TableName, msg);
AssertAreEqual(dt1.Columns, dt2.Columns, msg);
AssertAreEqual(dt1.Rows, dt2.Rows, msg);
return;
}

if (expected is DataColumn dc1 && actual is DataColumn dc2)
{
AssertAreEqual(dc1.ColumnName, dc2.ColumnName, msg);
AssertAreEqual(dc1.DataType, dc2.DataType, msg);
return;
}

if (expected is DataRow dr1 && actual is DataRow dr2)
{
AssertAreEqual(dr1.ItemArray, dr2.ItemArray, msg);
return;
}

Assert.AreEqual(expected, actual, msg);
}

[Serializable]
public class PublicSerializable
{
public int X { get; set; }
public override int GetHashCode() => X.GetHashCode();
public override bool Equals(object obj)
{
if (obj is PublicSerializable p) return p.X == X;
return false;
}
}

[Serializable]
private class PrivateSerializable
{
public string Y { get; set; }
public static void SampleMethod(int a, string b, DateTime c) { }
public override int GetHashCode() => Y.GetHashCode();
public override bool Equals(object obj)
{
if (obj is PrivateSerializable p) return p.Y == Y;
return false;
}
}

[TestMethod]
public void SafeSerializationBinderDoesntBreakNormalClasses()
{
// prepare sample data: delegates
var func = new Func<object, bool>(new PublicSerializable().Equals);
var action = new Action<int, string, DateTime>(PrivateSerializable.SampleMethod);

// data tables
var dt = new DataTable("TestTable");
dt.Columns.Add("ID", typeof(long));
dt.Columns.Add("Name", typeof(string));
var row = dt.NewRow();
row["ID"] = 123;
row["Name"] = "432";
dt.Rows.Add(row);

// scalar values, collections and dictionaries
var samples = new object[]
{
1, "Test", func, action, dt,
new List<string> { "abc", "def" },
new Dictionary<int, char> { { 1, 'a' }, { 2, 'b'} },
new Hashtable { { "Hello", "World" } },
new List<Delegate> { func, action, action, func },
new PublicSerializable { X = 123 },
new PrivateSerializable { Y = "Hello" }
};

// make sure that the round-trip doesn't damage any of them
foreach (var sample in samples)
{
Roundtrip(sample, false);
Roundtrip(sample, true);
}
}

[TestMethod]
public void OrdinaryBinaryFormatterDoesntBreakOnProcessStartDelegate()
{
Roundtrip(new Func<string, string, Process>(Process.Start), false);
}

[TestMethod, ExpectedException(typeof(UnsafeDeserializationException))]
public void SafeSerializationBinderBreaksOnProcessStartDelegate()
{
Roundtrip(new Func<string, string, Process>(Process.Start), true);
}
}
}
54 changes: 54 additions & 0 deletions SafeDeserializationHelpers/CustomDelegateSerializationHolder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
namespace SafeDeserializationHelpers
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Runtime.Serialization;
using System.Text;

/// <summary>
/// Custom replacement for the DelegateSerializationHolder featuring delegate validation.
/// </summary>
[Serializable]
public class CustomDelegateSerializationHolder : ISerializable, IObjectReference
{
/// <summary>
/// Initializes a new instance of the <see cref="CustomDelegateSerializationHolder"/> class.
/// </summary>
/// <param name="info">Serialization info.</param>
/// <param name="context">Streaming context</param>
protected CustomDelegateSerializationHolder(SerializationInfo info, StreamingContext context)
{
Holder = (IObjectReference)Constructor.Invoke(new object[] { info, context });
}

private static Type DelegateSerializationHolderType { get; } = Type.GetType(SafeSerializationBinder.DelegateSerializationHolderTypeName);

private static ConstructorInfo Constructor { get; } = DelegateSerializationHolderType.GetConstructor(
BindingFlags.Instance | BindingFlags.NonPublic,
null,
new[] { typeof(SerializationInfo), typeof(StreamingContext) },
null);

private IObjectReference Holder { get; set; }

/// <inheritdoc cref="ISerializable" />
public void GetObjectData(SerializationInfo info, StreamingContext context)
{
throw new NotSupportedException();
}

/// <inheritdoc cref="IObjectReference" />
public object GetRealObject(StreamingContext context)
{
var result = Holder.GetRealObject(context);
if (result is Delegate del)
{
DelegateValidator.Default.ValidateDelegate(del);
}

return result;
}
}
}
5 changes: 5 additions & 0 deletions SafeDeserializationHelpers/DelegateValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ public DelegateValidator(params string[] blacklistedNamespaces)
BlacklistedNamespaces = new HashSet<string>(blacklistedNamespaces, StringComparer.OrdinalIgnoreCase);
}

/// <summary>
/// Gets or sets the default <see cref="DelegateValidator"/> instance.
/// </summary>
public static DelegateValidator Default { get; set; } = new DelegateValidator();

private HashSet<string> BlacklistedNamespaces { get; }

/// <summary>
Expand Down
4 changes: 1 addition & 3 deletions SafeDeserializationHelpers/GlobalSuppressions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,5 @@
// Project-level suppressions either have no target or are given
// a specific target and scoped to a namespace, type, member, etc.

[assembly: System.Diagnostics.CodeAnalysis.SuppressMessage("StyleCop.CSharp.ReadabilityRules", "SA1101:Prefix local calls with this", Justification = "This is a visual garbage", Scope = "member", Target = "~M:SafeDeserializationHelpers.DelegateValidator.#ctor(System.String[])")]
[assembly: System.Diagnostics.CodeAnalysis.SuppressMessage("StyleCop.CSharp.ReadabilityRules", "SA1101:Prefix local calls with this", Justification = "This is a visual garbage")]
[assembly: System.Diagnostics.CodeAnalysis.SuppressMessage("StyleCop.CSharp.DocumentationRules", "SA1633:File should have header", Justification = "Not necessary for this project")]
[assembly: System.Diagnostics.CodeAnalysis.SuppressMessage("StyleCop.CSharp.ReadabilityRules", "SA1101:Prefix local calls with this", Justification = "<Pending>", Scope = "member", Target = "~M:SafeDeserializationHelpers.DelegateValidator.ValidateDelegate(System.Delegate)")]

2 changes: 2 additions & 0 deletions SafeDeserializationHelpers/SafeDeserializationHelpers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<Compile Include="CustomDelegateSerializationHolder.cs" />
<Compile Include="DelegateValidator.cs" />
<Compile Include="GlobalSuppressions.cs" />
<Compile Include="SafeSerializationBinder.cs" />
<Compile Include="UnsafeDeserializationException.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>
Expand Down
45 changes: 45 additions & 0 deletions SafeDeserializationHelpers/SafeSerializationBinder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
namespace SafeDeserializationHelpers
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.Serialization;
using System.Text;

/// <inheritdoc cref="SerializationBinder" />
public class SafeSerializationBinder : SerializationBinder
{
/// <summary>
/// Core library assembly name.
/// </summary>
public const string CoreLibraryAssemblyName = "mscorlib";

/// <summary>
/// System.DelegateSerializationHolder type name.
/// </summary>
public const string DelegateSerializationHolderTypeName = "System.DelegateSerializationHolder";

/// <summary>
/// Initializes a new instance of the <see cref="SafeSerializationBinder"/> class.
/// </summary>
/// <param name="nextBinder">Next serialization binder in chain.</param>
public SafeSerializationBinder(SerializationBinder nextBinder = null)
{
NextBinder = nextBinder;
}

private SerializationBinder NextBinder { get; }

/// <inheritdoc cref="SerializationBinder" />
public override Type BindToType(string assemblyName, string typeName)
{
if (typeName == DelegateSerializationHolderTypeName &&
assemblyName.StartsWith(CoreLibraryAssemblyName, StringComparison.InvariantCultureIgnoreCase))
{
return typeof(CustomDelegateSerializationHolder);
}

return NextBinder?.BindToType(assemblyName, typeName);
}
}
}

0 comments on commit 73bfede

Please sign in to comment.