Skip to content

Commit

Permalink
Fix append string methods with invalid encoding implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
afxres committed Jul 6, 2024
1 parent 4c11c43 commit 817f606
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 46 deletions.
160 changes: 160 additions & 0 deletions code/Binary.Tests/Contexts/AllocatorStringTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
namespace Mikodev.Binary.Tests.Contexts;

using System;
using System.Collections.Generic;
using System.Text;
using Xunit;

public class AllocatorStringTests
{
private delegate int GetByteCountDelegate(ReadOnlySpan<char> chars);

private delegate int GetBytesDelegate(ReadOnlySpan<char> chars, Span<byte> bytes);

private class FakeEncoding : Encoding
{
public required Func<int, int> GetMaxByteCountCallback { get; init; }

public required GetByteCountDelegate GetByteCountCallback { get; init; }

public required GetBytesDelegate GetBytesCallback { get; init; }

public override int GetByteCount(char[] chars, int index, int count) => throw new NotSupportedException();

public override int GetBytes(char[] chars, int charIndex, int charCount, byte[] bytes, int byteIndex) => throw new NotSupportedException();

public override int GetCharCount(byte[] bytes, int index, int count) => throw new NotSupportedException();

public override int GetChars(byte[] bytes, int byteIndex, int byteCount, char[] chars, int charIndex) => throw new NotSupportedException();

public override int GetMaxCharCount(int byteCount) => throw new NotSupportedException();

public override int GetMaxByteCount(int charCount) => GetMaxByteCountCallback.Invoke(charCount);

public override int GetByteCount(ReadOnlySpan<char> chars) => GetByteCountCallback.Invoke(chars);

public override int GetBytes(ReadOnlySpan<char> chars, Span<byte> bytes) => GetBytesCallback.Invoke(chars, bytes);
}

public static IEnumerable<object[]> StringData()
{
yield return new object[] { string.Empty };
yield return new object[] { "Alpha" };
yield return new object[] { "一二三四" };
}

[Theory(DisplayName = "Append String UTF8 Encoding")]
[MemberData(nameof(StringData))]
public void AppendStringUTF8Encoding(string text)
{
var encoding = Encoding.UTF8;
var allocator = new Allocator();
Assert.Equal(0, allocator.Length);
Assert.Equal(0, allocator.Capacity);
Allocator.Append(ref allocator, text, encoding);
var expected = encoding.GetBytes(text);
Assert.Equal(expected, allocator.ToArray());
}

[Theory(DisplayName = "Append String With Length Prefix UTF8 Encoding")]
[MemberData(nameof(StringData))]
public void AppendStringWithLengthPrefixUTF8Encoding(string text)
{
var encoding = Encoding.UTF8;
var allocator = new Allocator();
Assert.Equal(0, allocator.Length);
Assert.Equal(0, allocator.Capacity);
Allocator.AppendWithLengthPrefix(ref allocator, text, encoding);
var expected = encoding.GetBytes(text);
var buffer = allocator.AsSpan();
var result = Converter.DecodeWithLengthPrefix(ref buffer);
Assert.Equal(0, buffer.Length);
Assert.Equal(expected, result.ToArray());
}

[Theory(DisplayName = "Append String UTF8 Encoding Medium Length Test")]
[InlineData(72, 96, 96)]
[InlineData(72, 0, 256)]
[InlineData(72, 1024, 1024)]
public void AppendStringUTF8EncodingMediumLengthTest(int stringLength, int allocatorInitialCapacity, int allocatorFinalCapacity)
{
var encoding = Encoding.UTF8;
var text = new string('a', stringLength);
var allocator = new Allocator(new Span<byte>(new byte[allocatorInitialCapacity]));
Assert.True(encoding.GetByteCount(text) < 128);
Assert.True(encoding.GetMaxByteCount(text.Length) > 128);
Allocator.Append(ref allocator, text, encoding);
Assert.Equal(stringLength, allocator.Length);
Assert.Equal(allocatorFinalCapacity, allocator.Capacity);
}

[Theory(DisplayName = "Append String With Length Prefix UTF8 Encoding Medium Length Test")]
[InlineData(48, 80, 80, 1)]
[InlineData(48, 0, 256, 1)]
[InlineData(48, 1024, 1024, 4)]
public void AppendStringWithLengthPrefixUTF8EncodingMediumLengthTest(int stringLength, int allocatorInitialCapacity, int allocatorFinalCapacity, int prefixLength)
{
var encoding = Encoding.UTF8;
var text = new string('a', stringLength);
var allocator = new Allocator(new Span<byte>(new byte[allocatorInitialCapacity]));
Assert.True(encoding.GetByteCount(text) < 128);
Assert.True(encoding.GetMaxByteCount(text.Length) > 128);
Allocator.AppendWithLengthPrefix(ref allocator, text, encoding);
var buffer = allocator.AsSpan();
var actualIntentLength = Converter.Decode(buffer, out var actualPrefixLength);
Assert.Equal(stringLength, actualIntentLength);
Assert.Equal(prefixLength, actualPrefixLength);
Assert.Equal(stringLength + prefixLength, allocator.Length);
Assert.Equal(allocatorFinalCapacity, allocator.Capacity);
}

[Theory(DisplayName = "Append String Fake Encoding Invalid 'GetBytes()' Return Test")]
[InlineData(1024, 256, -1, -1)]
[InlineData(1024, 256, -1, 257)]
[InlineData(256, 256, -1, -1)]
[InlineData(256, 256, -1, 257)]
[InlineData(0, 128, 0, 1)]
[InlineData(0, 128, 0, -1)]
public void AppendStringFakeEncodingInvalidGetBytesReturnTest(int allocatorInitialCapacity, int getMaxByteCountReturn, int getByteCountReturn, int getBytesReturn)
{
var encoding = new FakeEncoding
{
GetMaxByteCountCallback = _ => getMaxByteCountReturn,
GetByteCountCallback = getByteCountReturn is -1 ? (_ => throw new NotSupportedException()) : (_ => getByteCountReturn),
GetBytesCallback = (_, _) => getBytesReturn,
};
var error = Assert.Throws<InvalidOperationException>(() =>
{
var allocator = new Allocator(new Span<byte>(new byte[allocatorInitialCapacity]));
Assert.Equal(0, allocator.Length);
Assert.Equal(allocatorInitialCapacity, allocator.Capacity);
Allocator.Append(ref allocator, string.Empty, encoding);
});
Assert.Equal("Invalid return value.", error.Message);
}

[Theory(DisplayName = "Append String With Length Prefix Fake Encoding Invalid 'GetBytes()' Return Test")]
[InlineData(196, 192, -1, -1)]
[InlineData(196, 192, -1, 193)]
[InlineData(192, 192, 192, -1)]
[InlineData(192, 192, 192, 193)]
[InlineData(0, 1, 0, 1)]
[InlineData(0, 1, 0, -1)]
public void AppendStringWithLengthPrefixInvalidGetBytesReturnTest(int allocatorInitialCapacity, int getMaxByteCountReturn, int getByteCountReturn, int getBytesReturn)
{
var encoding = new FakeEncoding
{
GetMaxByteCountCallback = _ => getMaxByteCountReturn,
GetByteCountCallback = getByteCountReturn is -1 ? (_ => throw new NotSupportedException()) : (_ => getByteCountReturn),
GetBytesCallback = (_, _) => getBytesReturn,
};
var error = Assert.Throws<InvalidOperationException>(() =>
{
var allocator = new Allocator(new Span<byte>(new byte[allocatorInitialCapacity]));
Assert.Equal(0, allocator.Length);
Assert.Equal(allocatorInitialCapacity, allocator.Capacity);
Allocator.AppendWithLengthPrefix(ref allocator, string.Empty, encoding);
});
Assert.Equal("Invalid return value.", error.Message);
}
}
18 changes: 18 additions & 0 deletions code/Binary/Allocator.Invoke.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ private static void Resize(ref Allocator allocator, int length)
Debug.Assert(offset <= allocator.bounds);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool TryEnsure(ref Allocator allocator, int length)
{
Debug.Assert(allocator.bounds >= 0);
Debug.Assert(allocator.offset >= 0);
Debug.Assert(allocator.bounds >= allocator.offset);
return (uint)allocator.bounds >= (ulong)(uint)allocator.offset + (uint)length;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ref byte TryCreate(ref Allocator allocator, int length)
{
if (TryEnsure(ref allocator, length) is false)
return ref Unsafe.NullRef<byte>();
var offset = allocator.offset;
return ref Unsafe.Add(ref allocator.target, offset);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ref byte Create(ref Allocator allocator, int length)
{
Expand Down
2 changes: 1 addition & 1 deletion code/Binary/Allocator.Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public static void AppendWithLengthPrefix<T>(ref Allocator allocator, T data, Al
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Ensure(ref Allocator allocator, int length)
{
if ((ulong)(uint)allocator.offset + (uint)length > (uint)allocator.bounds)
if (TryEnsure(ref allocator, length) is false)
Resize(ref allocator, length);
Debug.Assert(allocator.bounds <= allocator.MaxCapacity);
Debug.Assert(allocator.bounds >= allocator.offset + length);
Expand Down
40 changes: 26 additions & 14 deletions code/Binary/Allocator.String.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,36 @@ public ref partial struct Allocator
public static void Append(ref Allocator allocator, scoped ReadOnlySpan<char> span, Encoding encoding)
{
ArgumentNullException.ThrowIfNull(encoding);
var targetLimits = SharedModule.GetMaxByteCount(span, encoding);
Debug.Assert(targetLimits <= encoding.GetMaxByteCount(span.Length));
if (targetLimits is 0)
return;
ref var target = ref Create(ref allocator, targetLimits);
var targetLength = encoding.GetBytes(span, MemoryMarshal.CreateSpan(ref target, targetLimits));
FinishCreate(ref allocator, targetLength);
var limits = encoding.GetMaxByteCount(span.Length);
ref var target = ref TryCreate(ref allocator, limits);
if (Unsafe.IsNullRef(ref target))
target = ref Create(ref allocator, limits = encoding.GetByteCount(span));
var actual = encoding.GetBytes(span, MemoryMarshal.CreateSpan(ref target, limits));
if ((uint)actual > (uint)limits)
ThrowHelper.ThrowInvalidReturnValue();
Debug.Assert(actual >= 0);
Debug.Assert(actual <= limits);
FinishCreate(ref allocator, actual);
}

public static void AppendWithLengthPrefix(ref Allocator allocator, scoped ReadOnlySpan<char> span, Encoding encoding)
{
ArgumentNullException.ThrowIfNull(encoding);
var targetLimits = SharedModule.GetMaxByteCount(span, encoding);
Debug.Assert(targetLimits <= encoding.GetMaxByteCount(span.Length));
var prefixLength = NumberModule.EncodeLength((uint)targetLimits);
ref var target = ref Create(ref allocator, prefixLength + targetLimits);
var targetLength = targetLimits is 0 ? 0 : encoding.GetBytes(span, MemoryMarshal.CreateSpan(ref Unsafe.Add(ref target, prefixLength), targetLimits));
NumberModule.Encode(ref target, (uint)targetLength, prefixLength);
FinishCreate(ref allocator, targetLength + prefixLength);
var limits = encoding.GetMaxByteCount(span.Length);
var numberLength = NumberModule.EncodeLength((uint)limits);
ref var target = ref TryCreate(ref allocator, limits + numberLength);
if (Unsafe.IsNullRef(ref target))
{
limits = encoding.GetByteCount(span);
numberLength = NumberModule.EncodeLength((uint)limits);
target = ref Create(ref allocator, limits + numberLength);
}
var actual = encoding.GetBytes(span, MemoryMarshal.CreateSpan(ref Unsafe.Add(ref target, numberLength), limits));
if ((uint)actual > (uint)limits)
ThrowHelper.ThrowInvalidReturnValue();
Debug.Assert(actual >= 0);
Debug.Assert(actual <= limits);
NumberModule.Encode(ref target, (uint)actual, numberLength);
FinishCreate(ref allocator, actual + numberLength);
}
}
14 changes: 7 additions & 7 deletions code/Binary/Creators.Isolated.Primitive/StringConverter.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
namespace Mikodev.Binary.Creators.Isolated.Primitive;

using Mikodev.Binary.Internal;
using System;
using System.Text;

internal sealed class StringConverter : Converter<string>
{
public override void Encode(ref Allocator allocator, string? item) => Allocator.Append(ref allocator, item.AsSpan(), SharedModule.Encoding);
public override void Encode(ref Allocator allocator, string? item) => Allocator.Append(ref allocator, item.AsSpan(), Encoding.UTF8);

public override void EncodeAuto(ref Allocator allocator, string? item) => Allocator.AppendWithLengthPrefix(ref allocator, item.AsSpan(), SharedModule.Encoding);
public override void EncodeAuto(ref Allocator allocator, string? item) => Allocator.AppendWithLengthPrefix(ref allocator, item.AsSpan(), Encoding.UTF8);

public override void EncodeWithLengthPrefix(ref Allocator allocator, string? item) => Allocator.AppendWithLengthPrefix(ref allocator, item.AsSpan(), SharedModule.Encoding);
public override void EncodeWithLengthPrefix(ref Allocator allocator, string? item) => Allocator.AppendWithLengthPrefix(ref allocator, item.AsSpan(), Encoding.UTF8);

public override string Decode(in ReadOnlySpan<byte> span) => SharedModule.Encoding.GetString(span);
public override string Decode(in ReadOnlySpan<byte> span) => Encoding.UTF8.GetString(span);

public override string DecodeAuto(ref ReadOnlySpan<byte> span) => SharedModule.Encoding.GetString(Converter.DecodeWithLengthPrefix(ref span));
public override string DecodeAuto(ref ReadOnlySpan<byte> span) => Encoding.UTF8.GetString(Converter.DecodeWithLengthPrefix(ref span));

public override string DecodeWithLengthPrefix(ref ReadOnlySpan<byte> span) => SharedModule.Encoding.GetString(Converter.DecodeWithLengthPrefix(ref span));
public override string DecodeWithLengthPrefix(ref ReadOnlySpan<byte> span) => Encoding.UTF8.GetString(Converter.DecodeWithLengthPrefix(ref span));
}
24 changes: 0 additions & 24 deletions code/Binary/Internal/SharedModule.cs

This file was deleted.

0 comments on commit 817f606

Please sign in to comment.