From dd1457ce30ddacbd06214cb8b0c46b4cb9cf54ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Sat, 20 Jul 2024 17:42:16 +0200 Subject: [PATCH] Implement #415 and fix a small mistake in #416. (#417) * Revert a change in daab8012c4c3fdd74ad427d020e7fee554e202cb. This was wrong; it could bubble an OCEx up in parts of the public API that are not expected to do so. * JsonRpcConnection: Remove some unnecessary async/await churn. Both channels used here are unbounded, so TryWrite() will never fail. Additionally, we are not interested in cancellation support for the act of adding a message to a channel. So, there is no reason for any of this to be done asynchronously. * Draco.JsonRpc: Rework connection shutdown to use a CancellationToken. Also reflect the change in Draco.Lsp and Draco.Dap, allowing the user to pass in a CT when calling LanguageServer.RunAsync() and DebugAdapter.RunAsync(), resp. Closes #415. --- src/Draco.Dap/Adapter/DebugAdapter.cs | 6 +- src/Draco.JsonRpc/IJsonRpcConnection.cs | 8 +-- src/Draco.JsonRpc/JsonRpcConnection.cs | 57 +++++++++---------- src/Draco.Lsp/Server/LanguageServer.cs | 11 +++- .../Server/LanguageServerLifecycle.cs | 7 ++- 5 files changed, 45 insertions(+), 44 deletions(-) diff --git a/src/Draco.Dap/Adapter/DebugAdapter.cs b/src/Draco.Dap/Adapter/DebugAdapter.cs index 925ba59fe..4e31bb648 100644 --- a/src/Draco.Dap/Adapter/DebugAdapter.cs +++ b/src/Draco.Dap/Adapter/DebugAdapter.cs @@ -2,6 +2,7 @@ using System.IO.Pipelines; using System.Linq; using System.Reflection; +using System.Threading; using System.Threading.Tasks; using Draco.Dap.Attributes; using Draco.JsonRpc; @@ -35,8 +36,9 @@ public static IDebugClient Connect(IDuplexPipe stream) /// /// The debug client. /// The debug adapter. + /// The cancellation token. /// The task that completes when the communication is over. - public static async Task RunAsync(this IDebugClient client, IDebugAdapter adapter) + public static async Task RunAsync(this IDebugClient client, IDebugAdapter adapter, CancellationToken cancellationToken = default) { var connection = ((DebugClientProxy)client).Connection; @@ -48,7 +50,7 @@ public static async Task RunAsync(this IDebugClient client, IDebugAdapter adapte RegisterAdapterRpcMethods(lifecycle, connection); // Done, now we can actually start - await connection.ListenAsync(); + await connection.ListenAsync(cancellationToken); } private static void RegisterAdapterRpcMethods(object target, IJsonRpcConnection connection) diff --git a/src/Draco.JsonRpc/IJsonRpcConnection.cs b/src/Draco.JsonRpc/IJsonRpcConnection.cs index d4406bc5f..0a76bb988 100644 --- a/src/Draco.JsonRpc/IJsonRpcConnection.cs +++ b/src/Draco.JsonRpc/IJsonRpcConnection.cs @@ -14,13 +14,9 @@ internal interface IJsonRpcConnection /// /// Starts listening on the connection. /// + /// The cancellation token. /// The task that completes when the connection closes. - public Task ListenAsync(); - - /// - /// Shuts down this connection. - /// - public void Shutdown(); + public Task ListenAsync(CancellationToken cancellationToken = default); /// /// Sends a request to the client. diff --git a/src/Draco.JsonRpc/JsonRpcConnection.cs b/src/Draco.JsonRpc/JsonRpcConnection.cs index f355a0412..d0a2e5c75 100644 --- a/src/Draco.JsonRpc/JsonRpcConnection.cs +++ b/src/Draco.JsonRpc/JsonRpcConnection.cs @@ -77,20 +77,15 @@ private sealed class OutgoingRequest : IOutgoingRequest private readonly ConcurrentDictionary pendingIncomingRequests = new(); private readonly ConcurrentDictionary pendingOutgoingRequests = new(); - // Shutdown - private readonly CancellationTokenSource shutdownTokenSource = new(); - // Communication state private int lastMessageId = 0; public void AddHandler(IJsonRpcMethodHandler handler) => this.methodHandlers.Add(handler.MethodName, handler); - public Task ListenAsync() => Task.WhenAll( - this.ReaderLoopAsync(), - this.WriterLoopAsync(), - this.ProcessorLoopAsync()); - - public void Shutdown() => this.shutdownTokenSource.Cancel(); + public Task ListenAsync(CancellationToken cancellationToken = default) => Task.WhenAll( + this.ReaderLoopAsync(cancellationToken), + this.WriterLoopAsync(cancellationToken), + this.ProcessorLoopAsync(cancellationToken)); /// /// Generates a new message ID. @@ -99,13 +94,13 @@ public Task ListenAsync() => Task.WhenAll( protected int NextMessageId() => Interlocked.Increment(ref this.lastMessageId); #region Message Loops - private async Task ReaderLoopAsync() + private async Task ReaderLoopAsync(CancellationToken cancellationToken) { while (true) { try { - var (message, foundMessage) = await this.ReadMessageAsync(); + var (message, foundMessage) = await this.ReadMessageAsync(cancellationToken); if (!foundMessage) break; if (this.IsResponseMessage(message!)) @@ -114,37 +109,37 @@ private async Task ReaderLoopAsync() } else { - await this.incomingMessages.Writer.WriteAsync(message!, this.shutdownTokenSource.Token); + _ = this.incomingMessages.Writer.TryWrite(message!); } } - catch (OperationCanceledException oce) when (oce.CancellationToken == this.shutdownTokenSource.Token) + catch (OperationCanceledException oce) when (oce.CancellationToken == cancellationToken) { break; } catch (JsonException ex) { var error = this.CreateJsonExceptionError(ex); - await this.SendMessageAsync(this.CreateErrorResponseMessage(default!, error)); + this.SendMessage(this.CreateErrorResponseMessage(default!, error)); continue; } } } - private async Task WriterLoopAsync() + private async Task WriterLoopAsync(CancellationToken cancellationToken) { try { - await foreach (var message in this.outgoingMessages.Reader.ReadAllAsync(this.shutdownTokenSource.Token)) + await foreach (var message in this.outgoingMessages.Reader.ReadAllAsync(cancellationToken)) { - await this.WriteMessageAsync(message); + await this.WriteMessageAsync(message, cancellationToken); } } - catch (OperationCanceledException oce) when (oce.CancellationToken == this.shutdownTokenSource.Token) + catch (OperationCanceledException oce) when (oce.CancellationToken == cancellationToken) { } } - private async Task ProcessorLoopAsync() + private async Task ProcessorLoopAsync(CancellationToken cancellationToken) { bool IsMutating(TMessage message) { @@ -156,7 +151,7 @@ bool IsMutating(TMessage message) try { var currentTasks = new List(); - await foreach (var message in this.incomingMessages.Reader.ReadAllAsync(this.shutdownTokenSource.Token)) + await foreach (var message in this.incomingMessages.Reader.ReadAllAsync(cancellationToken)) { if (IsMutating(message)) { @@ -174,7 +169,7 @@ bool IsMutating(TMessage message) await Task.WhenAll(currentTasks); } - catch (OperationCanceledException oce) when (oce.CancellationToken == this.shutdownTokenSource.Token) + catch (OperationCanceledException oce) when (oce.CancellationToken == cancellationToken) { } } @@ -188,7 +183,7 @@ private async Task ProcessMessageAsync(TMessage message) if (this.IsRequestMessage(message)) { var response = await this.ProcessIncomingRequestAsync(message); - await this.SendMessageAsync(response); + this.SendMessage(response); } else if (this.IsNotificationMessage(message)) { @@ -363,15 +358,15 @@ private async Task ProcessIncomingNotificationAsync(TMessage message) return (Task)pendingReq.Task; } - public async Task SendNotificationAsync(string method, object? @params) + public Task SendNotificationAsync(string method, object? @params) { var serializedParams = JsonSerializer.SerializeToElement(@params, this.JsonSerializerOptions); var notification = this.CreateNotificationMessage(method, serializedParams); - await this.SendMessageAsync(notification); - } - protected async Task SendMessageAsync(TMessage message) => - await this.outgoingMessages.Writer.WriteAsync(message, this.shutdownTokenSource.Token); + this.SendMessage(notification); + + return Task.CompletedTask; + } protected void SendMessage(TMessage message) => this.outgoingMessages.Writer.TryWrite(message); @@ -437,14 +432,14 @@ protected void FailOutgoingRequest(int id, JsonRpcResponseException error) #endregion #region Serialization - private async Task<(TMessage? Message, bool Found)> ReadMessageAsync() + private async Task<(TMessage? Message, bool Found)> ReadMessageAsync(CancellationToken cancellationToken) { var contentLength = -1; var reader = this.Transport.Input; while (true) { - var result = await reader.ReadAsync(this.shutdownTokenSource.Token); + var result = await reader.ReadAsync(cancellationToken); var buffer = result.Buffer; var foundJson = this.TryParseMessage(ref buffer, ref contentLength, out var message); @@ -519,7 +514,7 @@ private bool TryParseMessage( return false; } - private ValueTask WriteMessageAsync(TMessage message) + private ValueTask WriteMessageAsync(TMessage message, CancellationToken cancellationToken) { var writer = this.Transport.Output; @@ -550,7 +545,7 @@ void WriteData() } WriteData(); - return writer.FlushAsync(this.shutdownTokenSource.Token); + return writer.FlushAsync(cancellationToken); } #endregion diff --git a/src/Draco.Lsp/Server/LanguageServer.cs b/src/Draco.Lsp/Server/LanguageServer.cs index aee33fa5d..a747e4651 100644 --- a/src/Draco.Lsp/Server/LanguageServer.cs +++ b/src/Draco.Lsp/Server/LanguageServer.cs @@ -2,6 +2,7 @@ using System.IO.Pipelines; using System.Linq; using System.Reflection; +using System.Threading; using System.Threading.Tasks; using Draco.JsonRpc; using Draco.Lsp.Attributes; @@ -35,20 +36,24 @@ public static ILanguageClient Connect(IDuplexPipe stream) /// /// The language client. /// The language server. + /// The cancellation token. /// The task that completes when the communication is over. - public static async Task RunAsync(this ILanguageClient client, ILanguageServer server) + public static async Task RunAsync(this ILanguageClient client, ILanguageServer server, CancellationToken cancellationToken = default) { var connection = ((LanguageClientProxy)client).Connection; // Register server methods RegisterServerRpcMethods(server, connection); + // The lifecycle needs to be able to cancel the server too, so build a linked CTS on top of the given CT. + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + // Register builtin server methods. In the future, we should consider making this extensible in some way. - var lifecycle = new LanguageServerLifecycle(server, connection); + var lifecycle = new LanguageServerLifecycle(server, connection, cts); RegisterServerRpcMethods(lifecycle, connection); // Done, now we can actually start - await connection.ListenAsync(); + await connection.ListenAsync(cts.Token); } private static void RegisterServerRpcMethods(object target, IJsonRpcConnection connection) diff --git a/src/Draco.Lsp/Server/LanguageServerLifecycle.cs b/src/Draco.Lsp/Server/LanguageServerLifecycle.cs index b00c3d778..e4591700d 100644 --- a/src/Draco.Lsp/Server/LanguageServerLifecycle.cs +++ b/src/Draco.Lsp/Server/LanguageServerLifecycle.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Reflection; using System.Text.Json; +using System.Threading; using System.Threading.Tasks; using Draco.JsonRpc; using Draco.Lsp.Attributes; @@ -16,7 +17,8 @@ namespace Draco.Lsp.Server; /// internal sealed class LanguageServerLifecycle( ILanguageServer server, - IJsonRpcConnection connection) : ILanguageServerLifecycle + IJsonRpcConnection connection, + CancellationTokenSource cancellation) : ILanguageServerLifecycle { private ClientCapabilities clientCapabilities = null!; @@ -48,7 +50,8 @@ public async Task InitializedAsync(InitializedParams param) public Task ExitAsync() { - connection.Shutdown(); + cancellation.Cancel(); + return Task.CompletedTask; }