Skip to content

Commit

Permalink
Added unit tests for a few known ysoserial.net exploits.
Browse files Browse the repository at this point in the history
  • Loading branch information
yallie committed Apr 25, 2018
1 parent a322e83 commit fa02939
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="SafeSerializationBinderTests.cs" />
<Compile Include="TestBase.cs" />
<Compile Include="TestBaseTests.cs" />
<Compile Include="TypeFullNameTests.cs" />
<Compile Include="TypeNameValidatorTests.cs" />
<Compile Include="YsoserialGadgetTests.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\SafeDeserializationHelpers\SafeDeserializationHelpers.csproj">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ public void SafeSerializationBinderDoesntBreakNormalClasses()
new Hashtable { { "Hello", "World" } },
new List<Delegate> { func, action, action, func },
new PublicSerializable { X = 123 },
new PrivateSerializable { Y = "Hello" }
new PrivateSerializable { Y = "Hello" },
new DataSet[] { ds, null, null, ds, ds },
new List<DataSet> { null, ds, null },
new Dictionary<string, DataSet> { { ds.DataSetName, ds } }
};

// make sure that the round-trip doesn't damage any of them
Expand Down
74 changes: 21 additions & 53 deletions SafeDeserializationHelpers.Tests/TestBase.cs
Original file line number Diff line number Diff line change
@@ -1,66 +1,16 @@
using System;
using System.Collections;
using System.Data;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Management.Automation;
using System.Runtime.Serialization.Formatters.Binary;
using System.Security;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace SafeDeserializationHelpers.Tests
{
[TestClass]
public class TestBase
{
[TestMethod, ExpectedException(typeof(AssertFailedException))]
public void AssertDoesntThrowFailsIfExceptionWasThrown()
{
Assert_DoesNotThrow(() => { throw new SecurityException(); });
}

[TestMethod, ExpectedException(typeof(AssertFailedException))]
public void AssertThrowsFailsIfNoExceptionWasThrown()
{
Assert_Throws<SecurityException>(() => { });
}

[TestMethod, ExpectedException(typeof(AssertFailedException))]
public void AssertThrowsFailsIfUnexpectedExceptionWasThrown()
{
Assert_Throws<SecurityException>(() => { new ArgumentNullException(); });
}

[TestMethod]
public void AssertAreEqualWorksOnLotsOfDataTypes()
{
Assert_DoesNotThrow(() =>
{
Assert_AreEqual(1, 1);
Assert_AreEqual("Hello", "Hello");
Assert_AreEqual(new DataTable("Test"), new DataTable("Test"));
Assert_AreEqual(new Func<string, string, Process>(Process.Start), new Func<string, string, Process>(Process.Start));
});

Assert_Throws<AssertFailedException>(() =>
{
Assert_AreEqual(1, 2);
});
}

[TestMethod]
public void RoundtripIsExpectedToSucceedOnPrimitives()
{
Assert_DoesNotThrow(() =>
{
Roundtrip(1, false);
Roundtrip(1, true);
Roundtrip("Hello", false);
Roundtrip("Hello", true);
});
}

protected void Roundtrip(object graph, bool useBinder)
{
var data = default(byte[]);
Expand Down Expand Up @@ -167,6 +117,24 @@ protected void Assert_AreEqual(object expected, object actual, string msg = "Two
return;
}

if (expected != null && actual != null &&
expected.GetType().IsValueType && !expected.GetType().IsPrimitive &&
actual.GetType().IsValueType && !actual.GetType().IsPrimitive)
{
var expectedType = expected.GetType();
var actualType = actual.GetType();
Assert_AreEqual(expectedType, actualType, msg);

foreach (var prop in expectedType.GetProperties())
{
var exp = prop.GetValue(expected);
var act = prop.GetValue(actual);
Assert_AreEqual(exp, act, msg);
}

return;
}

Assert.AreEqual(expected, actual, msg);
}

Expand All @@ -182,7 +150,7 @@ protected void Assert_DoesNotThrow(Action action)
}
}

protected void Assert_Throws<T>(Action action) where T : Exception
protected void Assert_Throws<T>(Action action, string msg = null) where T : Exception
{
try
{
Expand All @@ -194,10 +162,10 @@ protected void Assert_Throws<T>(Action action) where T : Exception
}
catch (Exception ex)
{
Assert.Fail($"Expected to catch exception of type {typeof(T).Name}, but here is what we caught: {ex.ToString()}");
Assert.Fail(msg ?? $"Expected to catch exception of type {typeof(T).Name}, but here is what we caught: {ex.ToString()}");
}

Assert.Fail($"Expected to catch exception of type {typeof(T).Name}, but no exception was thrown.");
Assert.Fail(msg ?? $"Expected to catch exception of type {typeof(T).Name}, but no exception was thrown.");
}
}
}
63 changes: 63 additions & 0 deletions SafeDeserializationHelpers.Tests/TestBaseTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using System;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.Linq;
using System.Security;
using System.Text;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace SafeDeserializationHelpers.Tests
{
[TestClass]
public class TestBaseTests : TestBase
{
[TestMethod, ExpectedException(typeof(AssertFailedException))]
public void AssertDoesntThrowFailsIfExceptionWasThrown()
{
Assert_DoesNotThrow(() => { throw new SecurityException(); });
}

[TestMethod, ExpectedException(typeof(AssertFailedException))]
public void AssertThrowsFailsIfNoExceptionWasThrown()
{
Assert_Throws<SecurityException>(() => { });
}

[TestMethod, ExpectedException(typeof(AssertFailedException))]
public void AssertThrowsFailsIfUnexpectedExceptionWasThrown()
{
Assert_Throws<SecurityException>(() => { new ArgumentNullException(); });
}

[TestMethod]
public void AssertAreEqualWorksOnLotsOfDataTypes()
{
Assert_DoesNotThrow(() =>
{
Assert_AreEqual(1, 1);
Assert_AreEqual("Hello", "Hello");
Assert_AreEqual(new DataTable("Test"), new DataTable("Test"));
Assert_AreEqual(new Func<string, string, Process>(Process.Start), new Func<string, string, Process>(Process.Start));
});

Assert_Throws<AssertFailedException>(() =>
{
Assert_AreEqual(1, 2);
});
}

[TestMethod]
public void RoundtripIsExpectedToSucceedOnPrimitives()
{
Assert_DoesNotThrow(() =>
{
Roundtrip(1, false);
Roundtrip(1, true);
Roundtrip("Hello", false);
Roundtrip("Hello", true);
});
}
}
}
45 changes: 45 additions & 0 deletions SafeDeserializationHelpers.Tests/YsoserialGadgetTests.cs

Large diffs are not rendered by default.

76 changes: 76 additions & 0 deletions SafeDeserializationHelpers/CustomDataSetDeserializer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
namespace SafeDeserializationHelpers
{
using System;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Reflection;
using System.Runtime.Serialization;
using System.Security.Permissions;
using System.Text;

/// <summary>
/// Custom <see cref="DataSet"/> descendant controlling the deserialization.
/// </summary>
[Serializable]
public sealed class CustomDataSetDeserializer : DataSet, IObjectReference
{
/// <summary>
/// Initializes a new instance of the <see cref="CustomDataSetDeserializer"/> class.
/// </summary>
/// <param name="info">Serialization info.</param>
/// <param name="context">Streaming context</param>
private CustomDataSetDeserializer(SerializationInfo info, StreamingContext context)
{
Info = info;
Context = context;
}

private static ConstructorInfo Constructor { get; } = typeof(DataSet).GetConstructor(
BindingFlags.Instance | BindingFlags.NonPublic,
null,
new[] { typeof(SerializationInfo), typeof(StreamingContext) },
null);

private SerializationInfo Info { get; }

private StreamingContext Context { get; }

/// <inheritdoc cref="IObjectReference" />
[SecurityPermission(SecurityAction.LinkDemand, Flags = SecurityPermissionFlag.SerializationFormatter)]
public object GetRealObject(StreamingContext context)
{
Validate(Info);
return Constructor.Invoke(new object[] { Info, context });
}

private void Validate(SerializationInfo info)
{
var remotingFormat = SerializationFormat.Xml;
var schemaSerializationMode = SchemaSerializationMode.IncludeSchema;

var e = info.GetEnumerator();
while (e.MoveNext())
{
switch (e.Name)
{
case "DataSet.RemotingFormat": // DataSet.RemotingFormat does not exist in V1/V1.1 versions
remotingFormat = (SerializationFormat)e.Value;
break;

case "SchemaSerializationMode.DataSet": // SchemaSerializationMode.DataSet does not exist in V1/V1.1 versions
schemaSerializationMode = (SchemaSerializationMode)e.Value;
break;
}
}

if (remotingFormat == SerializationFormat.Xml)
{
// XML dataset serialization isn't vulnerable
return;
}

throw new UnsafeDeserializationException("Serialized DataSet probably includes malicious data.");
}
}
}
3 changes: 3 additions & 0 deletions SafeDeserializationHelpers/SafeDeserializationHelpers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<Compile Include="CustomDataSetDeserializer.cs">
<SubType>Component</SubType>
</Compile>
<Compile Include="CustomDelegateSerializationHolder.cs" />
<Compile Include="DelegateValidator.cs" />
<Compile Include="GlobalSuppressions.cs" />
Expand Down
19 changes: 18 additions & 1 deletion SafeDeserializationHelpers/SafeSerializationBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,21 @@ public class SafeSerializationBinder : SerializationBinder
/// </summary>
public const string CoreLibraryAssemblyName = "mscorlib";

/// <summary>
/// System.Data assembly name.
/// </summary>
public const string SystemDataAssemblyName = "System.Data";

/// <summary>
/// System.DelegateSerializationHolder type name.
/// </summary>
public const string DelegateSerializationHolderTypeName = "System.DelegateSerializationHolder";

/// <summary>
/// System.Data.DataSet type name.
/// </summary>
public const string DataSetTypeName = "System.Data.DataSet";

/// <summary>
/// Initializes a new instance of the <see cref="SafeSerializationBinder"/> class.
/// </summary>
Expand All @@ -30,13 +40,20 @@ public SafeSerializationBinder(SerializationBinder nextBinder = null)
/// <inheritdoc cref="SerializationBinder" />
public override Type BindToType(string assemblyName, string typeName)
{
// prevent delegate serialization attack
// prevent delegate deserialization attack
if (typeName == DelegateSerializationHolderTypeName &&
assemblyName.StartsWith(CoreLibraryAssemblyName, StringComparison.InvariantCultureIgnoreCase))
{
return typeof(CustomDelegateSerializationHolder);
}

////// prevent DataSet-based deserialization attack
////if (typeName == DataSetTypeName &&
//// assemblyName.StartsWith(SystemDataAssemblyName, StringComparison.InvariantCultureIgnoreCase))
////{
//// return typeof(CustomDataSetDeserializer);
////}

// suppress known blacklisted types
TypeNameValidator.Default.ValidateTypeName(assemblyName, typeName);

Expand Down

0 comments on commit fa02939

Please sign in to comment.