Skip to content

Commit

Permalink
Added protection against the new WindowsIdentity gadget:
Browse files Browse the repository at this point in the history
  • Loading branch information
yallie committed Apr 26, 2018
1 parent 7292308 commit 243fad1
Show file tree
Hide file tree
Showing 13 changed files with 203 additions and 27 deletions.
12 changes: 10 additions & 2 deletions SafeDeserializationHelpers.Tests/SafeSerializationBinderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Diagnostics;
using System.IO;
using System.Management.Automation;
using System.Security.Principal;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace SafeDeserializationHelpers.Tests
Expand Down Expand Up @@ -59,19 +60,23 @@ public void SafeSerializationBinderDoesntBreakNormalClasses()
row["Value"] = "Hello";
rdt.Rows.Add(row);

// data sets
// data set in XML format
var ds = new DataSet("XmlTestData");
ds.Tables.Add(dt);
ds.RemotingFormat = SerializationFormat.Xml;

// data set in binary format
var rds = new DataSet("BinaryTestData");
rds.RemotingFormat = SerializationFormat.Binary;
rds.Tables.Add(rdt);

// identity
var id = WindowsIdentity.GetAnonymous();

// scalar values, collections and dictionaries
var samples = new object[]
{
1, "Test", func, action, dt, ds, rds,
1, "Test", func, action, dt, ds, rds, id,
new List<string> { "abc", "def" },
new Dictionary<int, char> { { 1, 'a' }, { 2, 'b'} },
new Hashtable { { "Hello", "World" } },
Expand All @@ -80,6 +85,9 @@ public void SafeSerializationBinderDoesntBreakNormalClasses()
new PrivateSerializable { Y = "Hello" },
new DataSet[] { ds, null, null, rds, ds },
new List<DataSet> { null, rds, ds },
new DataTable[] { dt, rdt, null, rdt, dt },
new List<DataTable> { rdt, null, dt },
new IIdentity[] { null, id, null },
new Dictionary<string, DataSet> { { ds.DataSetName, ds }, { rds.DataSetName, rds } },
};

Expand Down
38 changes: 33 additions & 5 deletions SafeDeserializationHelpers.Tests/TestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Management.Automation;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using System.Security.Principal;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace SafeDeserializationHelpers.Tests
Expand All @@ -14,24 +15,43 @@ public class TestBase
{
protected void Roundtrip(object graph, bool safe)
{
var data = default(byte[]);
// make sure that safe mode doesn't affect the serialization
var originalData = Serialize(graph, false);
var safeData = Serialize(graph, true);
Assert_AreEqual(originalData, safeData);

// make sure that deserialized graph is the same as the original
var deserialized = Deserialize(originalData, safe);
var msg = $"Deserialized data doesn't match when {(safe ? string.Empty : "not ")}using binder.";
Assert_AreEqual(graph, deserialized, msg);
}

private byte[] Serialize(object graph, bool safe)
{
var fmt = new BinaryFormatter();
if (safe)
{
fmt = fmt.Safe();
}

using (var stream = new MemoryStream())
{
fmt.Serialize(stream, graph);
data = stream.ToArray();
return stream.ToArray();
}
}

private object Deserialize(byte[] data, bool safe)
{
var fmt = new BinaryFormatter();
if (safe)
{
fmt = fmt.Safe();
}

using (var stream = new MemoryStream(data))
{
var deserialized = fmt.Deserialize(stream);
var msg = $"Deserialized data doesn't match when {(safe ? string.Empty : "not ")}using binder.";
Assert_AreEqual(graph, deserialized, msg);
return fmt.Deserialize(stream);
}
}

Expand Down Expand Up @@ -118,6 +138,14 @@ protected void Assert_AreEqual(object expected, object actual, string msg = "Two
return;
}

if (expected is IIdentity id1 && actual is IIdentity id2)
{
Assert_AreEqual(id1.Name, id2.Name, msg);
Assert_AreEqual(id1.IsAuthenticated, id2.IsAuthenticated, msg);
Assert_AreEqual(id1.AuthenticationType, id2.AuthenticationType, msg);
return;
}

if (expected != null && actual != null &&
expected.GetType().IsValueType && !expected.GetType().IsPrimitive &&
actual.GetType().IsValueType && !actual.GetType().IsPrimitive)
Expand Down
9 changes: 8 additions & 1 deletion SafeDeserializationHelpers.Tests/YsoserialGadgetTests.cs

Large diffs are not rendered by default.

25 changes: 16 additions & 9 deletions SafeDeserializationHelpers/BinaryFormatterExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
namespace SafeDeserializationHelpers
{
using System.Data;
using System.Runtime.Serialization;
using System;
using System.Runtime.Serialization.Formatters.Binary;

/// <summary>
Expand All @@ -16,16 +15,24 @@ public static class BinaryFormatterExtensions
/// <returns>The safe version of the <see cref="BinaryFormatter"/>.</returns>
public static BinaryFormatter Safe(this BinaryFormatter fmt)
{
if (fmt == null)
{
throw new ArgumentNullException(nameof(fmt), "BinaryFormatter is not specified.");
}

// safe type binder prevents delegate deserialization attacks
var binder = new SafeSerializationBinder(fmt.Binder);
fmt.Binder = binder;
if (!(fmt.Binder is SafeSerializationBinder))
{
fmt.Binder = new SafeSerializationBinder(fmt.Binder);
}

// DataSet surrogate validates binary-serialized datasets
var ss = new SurrogateSelector();
ss.AddSurrogate(typeof(DataSet), new StreamingContext(StreamingContextStates.All), new DataSetSurrogate());
fmt.SurrogateSelector = ss;
// surrogates validate binary-serialized data before deserializing them
if (!(fmt.SurrogateSelector is SafeSurrogateSelector))
{
// create a new surrogate selector and chain to the existing one, if any
fmt.SurrogateSelector = new SafeSurrogateSelector(fmt.SurrogateSelector);
}

// TODO: do we need to chain surrogate selectors?
return fmt;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
/// Custom replacement for the DelegateSerializationHolder featuring delegate validation.
/// </summary>
[Serializable]
public sealed class CustomDelegateSerializationHolder : ISerializable, IObjectReference
internal sealed class CustomDelegateSerializationHolder : ISerializable, IObjectReference
{
/// <summary>
/// Initializes a new instance of the <see cref="CustomDelegateSerializationHolder"/> class.
Expand Down
2 changes: 1 addition & 1 deletion SafeDeserializationHelpers/DataSetSurrogate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
/// <summary>
/// Deserialization surrogate for the DataSet class.
/// </summary>
public class DataSetSurrogate : ISerializationSurrogate
internal class DataSetSurrogate : ISerializationSurrogate
{
private static ConstructorInfo Constructor { get; } = typeof(DataSet).GetConstructor(
BindingFlags.Instance | BindingFlags.NonPublic,
Expand Down
4 changes: 1 addition & 3 deletions SafeDeserializationHelpers/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
// Build Number
// Revision
//
// You can specify all the values or you can default the Build and Revision Numbers
// by using the '*' as shown below:
// [assembly: AssemblyVersion("1.0.*")]
[assembly: AssemblyVersion("1.0.0.0")]
[assembly: AssemblyFileVersion("1.0.0.0")]
[assembly: InternalsVisibleTo("SafeDeserializationHelpers.Tests")]
2 changes: 2 additions & 0 deletions SafeDeserializationHelpers/SafeDeserializationHelpers.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@
<Compile Include="IDelegateValidator.cs" />
<Compile Include="ITypeNameValidator.cs" />
<Compile Include="SafeSerializationBinder.cs" />
<Compile Include="SafeSurrogateSelector.cs" />
<Compile Include="TypeFullName.cs" />
<Compile Include="TypeNameValidator.cs" />
<Compile Include="UnsafeDeserializationException.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="WindowsIdentitySurrogate.cs" />
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
Expand Down
2 changes: 1 addition & 1 deletion SafeDeserializationHelpers/SafeSerializationBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using System.Runtime.Serialization;

/// <inheritdoc cref="SerializationBinder" />
public class SafeSerializationBinder : SerializationBinder
internal sealed class SafeSerializationBinder : SerializationBinder
{
/// <summary>
/// Core library assembly name.
Expand Down
58 changes: 58 additions & 0 deletions SafeDeserializationHelpers/SafeSurrogateSelector.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
namespace SafeDeserializationHelpers
{
using System;
using System.Data;
using System.Runtime.Serialization;
using System.Security.Permissions;
using System.Security.Principal;

/// <summary>
/// Safe surrogate selector provides surrogates for DataSet and WindowsIdentity classes.
/// </summary>
internal sealed class SafeSurrogateSelector : ISurrogateSelector
{
/// <summary>
/// Initializes a new instance of the <see cref="SafeSurrogateSelector"/> class.
/// </summary>
/// <param name="nextSelector">Next <see cref="ISurrogateSelector"/>, optional.</param>
[SecurityPermission(SecurityAction.LinkDemand, Flags = SecurityPermissionFlag.SerializationFormatter)]
public SafeSurrogateSelector(ISurrogateSelector nextSelector = null)
{
if (nextSelector != null)
{
SurrogateSelector.ChainSelector(nextSelector);
}

// register known surrogates for all streaming contexts
var ctx = new StreamingContext(StreamingContextStates.All);
SurrogateSelector.AddSurrogate(typeof(DataSet), ctx, new DataSetSurrogate());
SurrogateSelector.AddSurrogate(typeof(WindowsIdentity), ctx, new WindowsIdentitySurrogate());
}

private SurrogateSelector SurrogateSelector { get; } = new SurrogateSelector();

/// <inheritdoc cref="ISurrogateSelector" />
[SecurityPermission(SecurityAction.LinkDemand, Flags = SecurityPermissionFlag.SerializationFormatter)]
public void ChainSelector(ISurrogateSelector selector)
{
if (selector != null)
{
SurrogateSelector.ChainSelector(selector);
}
}

/// <inheritdoc cref="ISurrogateSelector" />
[SecurityPermission(SecurityAction.LinkDemand, Flags = SecurityPermissionFlag.SerializationFormatter)]
public ISurrogateSelector GetNextSelector()
{
return SurrogateSelector.GetNextSelector();
}

/// <inheritdoc cref="ISurrogateSelector" />
[SecurityPermission(SecurityAction.LinkDemand, Flags = SecurityPermissionFlag.SerializationFormatter)]
public ISerializationSurrogate GetSurrogate(Type type, StreamingContext context, out ISurrogateSelector selector)
{
return SurrogateSelector.GetSurrogate(type, context, out selector);
}
}
}
2 changes: 0 additions & 2 deletions SafeDeserializationHelpers/TypeFullName.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
namespace SafeDeserializationHelpers
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

/// <summary>
/// Represents the name of a .NET type.
Expand Down
2 changes: 0 additions & 2 deletions SafeDeserializationHelpers/TypeNameValidator.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
namespace SafeDeserializationHelpers
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

/// <summary>
/// Validates the type names before loading them for deserialization.
Expand Down
72 changes: 72 additions & 0 deletions SafeDeserializationHelpers/WindowsIdentitySurrogate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
namespace SafeDeserializationHelpers
{
using System;
using System.IO;
using System.Reflection;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using System.Security.Permissions;
using System.Security.Principal;

/// <summary>
/// Deserialization surrogate for the WindowsIdentity class.
/// </summary>
internal class WindowsIdentitySurrogate : ISerializationSurrogate
{
private static ConstructorInfo Constructor { get; } = typeof(WindowsIdentity).GetConstructor(
BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public,
null,
new[] { typeof(SerializationInfo), typeof(StreamingContext) },
null);

/// <inheritdoc cref="ISerializationSurrogate" />
[SecurityPermission(SecurityAction.LinkDemand, Flags = SecurityPermissionFlag.SerializationFormatter)]
public void GetObjectData(object obj, SerializationInfo info, StreamingContext context)
{
var ds = obj as ISerializable;
ds.GetObjectData(info, context);
}

/// <inheritdoc cref="ISerializationSurrogate" />
[SecurityPermission(SecurityAction.LinkDemand, Flags = SecurityPermissionFlag.SerializationFormatter)]
public object SetObjectData(object obj, SerializationInfo info, StreamingContext context, ISurrogateSelector selector)
{
Validate(info, context);

// discard obj
var result = Constructor.Invoke(new object[] { info, context });
return result;
}

private void Validate(SerializationInfo info, StreamingContext context)
{
// check the serialized data using a guarded BinaryFormatter
var fmt = new BinaryFormatter().Safe();

var e = info.GetEnumerator();
while (e.MoveNext())
{
switch (e.Name)
{
case "System.Security.ClaimsIdentity.actor":
case "System.Security.ClaimsIdentity.claims":
case "System.Security.ClaimsIdentity.bootstrapContext":
var base64 = info.GetString(e.Name);
if (string.IsNullOrEmpty(base64))
{
continue;
}

// safe BinaryFormatter will throw on malicious payload
var buffer = Convert.FromBase64String(base64);
using (var ms = new MemoryStream(buffer))
{
fmt.Deserialize(ms);
}

break;
}
}
}
}
}

0 comments on commit 243fad1

Please sign in to comment.