Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protocol: limit max size hint. #275

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions src/Tmds.DBus.Protocol/MessageWriter.Basic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ namespace Tmds.DBus.Protocol;

public ref partial struct MessageWriter
{
private const int MaxSizeHint = 4096;

public void WriteBool(bool value) => WriteUInt32(value ? 1u : 0u);

public void WriteByte(byte value) => WritePrimitiveCore<byte>(value, DBusType.Byte);
Expand Down Expand Up @@ -170,33 +172,63 @@ private void WritePrimitiveCore<T>(T value, DBusType type)

private int WriteRaw(ReadOnlySpan<byte> data)
{
int length = data.Length;
var dst = GetSpan(length);
data.CopyTo(dst);
Advance(length);
return length;
int totalLength = data.Length;
if (totalLength <= MaxSizeHint)
{
var dst = GetSpan(totalLength);
data.CopyTo(dst);
Advance(totalLength);
return totalLength;
}
else
{
while (!data.IsEmpty)
{
var dst = GetSpan(1);
int length = Math.Min(data.Length, dst.Length);
data.Slice(0, length).CopyTo(dst);
Advance(length);
data = data.Slice(length);
}
return totalLength;
}
}

private int WriteRaw(string data)
{
#if NETSTANDARD2_1_OR_GREATER || NET
// To use the IBufferWriter we need to flush the Span.
// Avoid it when we're writing small strings.
if (data.Length <= 2048)
const int MaxUtf8BytesPerChar = 3;

if (data.Length <= MaxSizeHint / MaxUtf8BytesPerChar)
{
ReadOnlySpan<char> chars = data.AsSpan();
int byteCount = Encoding.UTF8.GetByteCount(chars);
var dst = GetSpan(byteCount);
byteCount = Encoding.UTF8.GetBytes(data, dst);
byteCount = Encoding.UTF8.GetBytes(data.AsSpan(), dst);
Advance(byteCount);
return byteCount;
}
else
#endif
{
int length = (int)Encoding.UTF8.GetBytes(data.AsSpan(), Writer);
_offset += length;
return length;
ReadOnlySpan<char> chars = data.AsSpan();
Encoder encoder = Encoding.UTF8.GetEncoder();
int totalLength = 0;
do
{
Debug.Assert(!chars.IsEmpty);

var dst = GetSpan(MaxUtf8BytesPerChar);
encoder.Convert(chars, dst, flush: true, out int charsUsed, out int bytesUsed, out bool completed);

Advance(bytesUsed);
totalLength += bytesUsed;

if (completed)
{
return totalLength;
}

chars = chars.Slice(charsUsed);
} while (true);
}
}
}
9 changes: 0 additions & 9 deletions src/Tmds.DBus.Protocol/MessageWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,6 @@ public MessageBuffer CreateMessage()
return message;
}

private IBufferWriter<byte> Writer
{
get
{
Flush();
return _data;
}
}

internal MessageWriter(MessageBufferPool messagePool, uint serial)
{
_message = messagePool.Rent();
Expand Down
40 changes: 0 additions & 40 deletions src/Tmds.DBus.Protocol/Netstandard2_1Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,46 +25,6 @@ public static SafeHandle GetSafeHandle(this Socket socket)
return null!;
}

public static long GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, IBufferWriter<byte> writer)
{
if (chars.Length <= MaxInputElementsPerIteration)
{
int byteCount = encoding.GetByteCount(chars);
Span<byte> scratchBuffer = writer.GetSpan(byteCount);

int actualBytesWritten = encoding.GetBytes(chars, scratchBuffer);

writer.Advance(actualBytesWritten);
return actualBytesWritten;
}
else
{
Convert(encoding.GetEncoder(), chars, writer, flush: true, out long totalBytesWritten, out bool completed);
return totalBytesWritten;
}
}

public static void Convert(this Encoder encoder, ReadOnlySpan<char> chars, IBufferWriter<byte> writer, bool flush, out long bytesUsed, out bool completed)
{
long totalBytesWritten = 0;
do
{
int byteCountForThisSlice = (chars.Length <= MaxInputElementsPerIteration)
? encoder.GetByteCount(chars, flush)
: encoder.GetByteCount(chars.Slice(0, MaxInputElementsPerIteration), flush: false);

Span<byte> scratchBuffer = writer.GetSpan(byteCountForThisSlice);

encoder.Convert(chars, scratchBuffer, flush, out int charsUsedJustNow, out int bytesWrittenJustNow, out completed);

chars = chars.Slice(charsUsedJustNow);
writer.Advance(bytesWrittenJustNow);
totalBytesWritten += bytesWrittenJustNow;
} while (!chars.IsEmpty);

bytesUsed = totalBytesWritten;
}

public static async Task ConnectAsync(this Socket socket, EndPoint remoteEP, CancellationToken cancellationToken)
{
using var ctr = cancellationToken.Register(state => ((Socket)state!).Dispose(), socket, useSynchronizationContext: false);
Expand Down
Loading