diff --git a/src/Tmds.DBus.Protocol/DBusConnection.cs b/src/Tmds.DBus.Protocol/DBusConnection.cs index 87b901f6..b86924ca 100644 --- a/src/Tmds.DBus.Protocol/DBusConnection.cs +++ b/src/Tmds.DBus.Protocol/DBusConnection.cs @@ -118,6 +118,7 @@ public void Invoke(Exception? exception, Message message) private Observer? _currentObserver; private SynchronizationContext? _currentSynchronizationContext; private TaskCompletionSource? _disconnectedTcs; + private CancellationTokenSource _abortedCts; private bool _isMonitor; private Action? _monitorHandler; @@ -140,6 +141,7 @@ public DBusConnection(Connection parent, string machineId) _matchedObservers = new(); _pathNodes = new(); _machineId = machineId; + _abortedCts = new(); } // For tests. @@ -319,7 +321,7 @@ private async void HandleMessages(Exception? exception, Message message) if (isMethodCall) { - methodContext = new MethodContext(_parentConnection, message); // TODO: pool. + methodContext = new MethodContext(_parentConnection, message, _abortedCts.Token); // TODO: pool. if (message.PathIsSet) { @@ -519,6 +521,8 @@ public void Dispose() _messageStream?.Close(disconnectReason); + _abortedCts.Cancel(); + if (_pendingCalls is not null) { foreach (var pendingCall in _pendingCalls.Values) diff --git a/src/Tmds.DBus.Protocol/MethodContext.cs b/src/Tmds.DBus.Protocol/MethodContext.cs index 33250f06..db8380b0 100644 --- a/src/Tmds.DBus.Protocol/MethodContext.cs +++ b/src/Tmds.DBus.Protocol/MethodContext.cs @@ -2,14 +2,16 @@ namespace Tmds.DBus.Protocol; public class MethodContext { - internal MethodContext(Connection connection, Message request) + internal MethodContext(Connection connection, Message request, CancellationToken requestAborted) { Connection = connection; Request = request; + RequestAborted = requestAborted; } public Message Request { get; } public Connection Connection { get; } + public CancellationToken RequestAborted { get; } public bool ReplySent { get; private set; } diff --git a/test/Tmds.DBus.Protocol.Tests/ConnectionTests.cs b/test/Tmds.DBus.Protocol.Tests/ConnectionTests.cs index 8d259075..37d05900 100644 --- a/test/Tmds.DBus.Protocol.Tests/ConnectionTests.cs +++ b/test/Tmds.DBus.Protocol.Tests/ConnectionTests.cs @@ -40,6 +40,35 @@ public async Task DisconnectedException() await Assert.ThrowsAsync(() => proxy.ConcatAsync("hello ", "world")); } + [Fact] + public async Task DisposeTriggersRequestAborted() + { + var connections = PairedConnection.CreatePair(); + using var conn1 = connections.Item1; + using var conn2 = connections.Item2; + + var handler = new WaitForCancellationHandler(); + conn2.AddMethodHandler(handler); + + Task pendingCall = conn1.CallMethodAsync(CreateMessage()); + + conn2.Dispose(); + + await Assert.ThrowsAsync(() => pendingCall); + + await handler.WaitForCancelledAsync().WaitAsync(new CancellationTokenSource(TimeSpan.FromSeconds(30)).Token); + + MessageBuffer CreateMessage() + { + using var writer = conn1.GetMessageWriter(); + writer.WriteMethodCallHeader( + path: handler.Path, + @interface: "org.any", + member: "Any"); + return writer.CreateMessage(); + } + } + [Theory] [InlineData(true)] [InlineData(false)] @@ -386,6 +415,35 @@ MessageBuffer CreateAddMessage() } } + class WaitForCancellationHandler : IMethodHandler + { + public string Path => "/"; + + private readonly TaskCompletionSource _cancelled = new(); + + public async ValueTask HandleMethodAsync(MethodContext context) + { + try + { + while (true) + { + await Task.Delay(int.MaxValue, context.RequestAborted); + } + } + catch (OperationCanceledException) + { + _cancelled.SetResult(); + + throw; + } + } + + public Task WaitForCancelledAsync() => _cancelled.Task; + + public bool RunMethodHandlerSynchronously(Message message) + => true; + } + class StringOperations : IMethodHandler { public string Path => "/tmds/dbus/tests/stringoperations"; diff --git a/test/Tmds.DBus.Protocol.Tests/PathNodeDictionaryTests.cs b/test/Tmds.DBus.Protocol.Tests/PathNodeDictionaryTests.cs index 2772dad0..d4f4a11a 100644 --- a/test/Tmds.DBus.Protocol.Tests/PathNodeDictionaryTests.cs +++ b/test/Tmds.DBus.Protocol.Tests/PathNodeDictionaryTests.cs @@ -275,7 +275,7 @@ public void RemoveHandlersDoesntRemovePreExistingParentNodes() private void AssertChildNames(string[] expectedChildNames, PathNode node) { - var methodContext = new MethodContext(null!, null!); + var methodContext = new MethodContext(null!, null!, default); node.CopyChildNamesTo(methodContext); if (methodContext.IntrospectChildNameList == null) {