Skip to content

Commit

Permalink
Implement #415 and fix a small mistake in #416. (#417)
Browse files Browse the repository at this point in the history
* Revert a change in daab801.

This was wrong; it could bubble an OCEx up in parts of the public API that are
not expected to do so.

* JsonRpcConnection: Remove some unnecessary async/await churn.

Both channels used here are unbounded, so TryWrite() will never fail.
Additionally, we are not interested in cancellation support for the act of
adding a message to a channel. So, there is no reason for any of this to be
done asynchronously.

* Draco.JsonRpc: Rework connection shutdown to use a CancellationToken.

Also reflect the change in Draco.Lsp and Draco.Dap, allowing the user to pass in
a CT when calling LanguageServer.RunAsync() and DebugAdapter.RunAsync(), resp.

Closes #415.
  • Loading branch information
alexrp authored Jul 20, 2024
1 parent daab801 commit dd1457c
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 44 deletions.
6 changes: 4 additions & 2 deletions src/Draco.Dap/Adapter/DebugAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.IO.Pipelines;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Draco.Dap.Attributes;
using Draco.JsonRpc;
Expand Down Expand Up @@ -35,8 +36,9 @@ public static IDebugClient Connect(IDuplexPipe stream)
/// </summary>
/// <param name="client">The debug client.</param>
/// <param name="adapter">The debug adapter.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The task that completes when the communication is over.</returns>
public static async Task RunAsync(this IDebugClient client, IDebugAdapter adapter)
public static async Task RunAsync(this IDebugClient client, IDebugAdapter adapter, CancellationToken cancellationToken = default)
{
var connection = ((DebugClientProxy)client).Connection;

Expand All @@ -48,7 +50,7 @@ public static async Task RunAsync(this IDebugClient client, IDebugAdapter adapte
RegisterAdapterRpcMethods(lifecycle, connection);

// Done, now we can actually start
await connection.ListenAsync();
await connection.ListenAsync(cancellationToken);
}

private static void RegisterAdapterRpcMethods(object target, IJsonRpcConnection connection)
Expand Down
8 changes: 2 additions & 6 deletions src/Draco.JsonRpc/IJsonRpcConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@ internal interface IJsonRpcConnection
/// <summary>
/// Starts listening on the connection.
/// </summary>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The task that completes when the connection closes.</returns>
public Task ListenAsync();

/// <summary>
/// Shuts down this connection.
/// </summary>
public void Shutdown();
public Task ListenAsync(CancellationToken cancellationToken = default);

/// <summary>
/// Sends a request to the client.
Expand Down
57 changes: 26 additions & 31 deletions src/Draco.JsonRpc/JsonRpcConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,15 @@ private sealed class OutgoingRequest<TResponse> : IOutgoingRequest
private readonly ConcurrentDictionary<object, CancellationTokenSource> pendingIncomingRequests = new();
private readonly ConcurrentDictionary<int, IOutgoingRequest> pendingOutgoingRequests = new();

// Shutdown
private readonly CancellationTokenSource shutdownTokenSource = new();

// Communication state
private int lastMessageId = 0;

public void AddHandler(IJsonRpcMethodHandler handler) => this.methodHandlers.Add(handler.MethodName, handler);

public Task ListenAsync() => Task.WhenAll(
this.ReaderLoopAsync(),
this.WriterLoopAsync(),
this.ProcessorLoopAsync());

public void Shutdown() => this.shutdownTokenSource.Cancel();
public Task ListenAsync(CancellationToken cancellationToken = default) => Task.WhenAll(
this.ReaderLoopAsync(cancellationToken),
this.WriterLoopAsync(cancellationToken),
this.ProcessorLoopAsync(cancellationToken));

/// <summary>
/// Generates a new message ID.
Expand All @@ -99,13 +94,13 @@ public Task ListenAsync() => Task.WhenAll(
protected int NextMessageId() => Interlocked.Increment(ref this.lastMessageId);

#region Message Loops
private async Task ReaderLoopAsync()
private async Task ReaderLoopAsync(CancellationToken cancellationToken)
{
while (true)
{
try
{
var (message, foundMessage) = await this.ReadMessageAsync();
var (message, foundMessage) = await this.ReadMessageAsync(cancellationToken);
if (!foundMessage) break;

if (this.IsResponseMessage(message!))
Expand All @@ -114,37 +109,37 @@ private async Task ReaderLoopAsync()
}
else
{
await this.incomingMessages.Writer.WriteAsync(message!, this.shutdownTokenSource.Token);
_ = this.incomingMessages.Writer.TryWrite(message!);
}
}
catch (OperationCanceledException oce) when (oce.CancellationToken == this.shutdownTokenSource.Token)
catch (OperationCanceledException oce) when (oce.CancellationToken == cancellationToken)
{
break;
}
catch (JsonException ex)
{
var error = this.CreateJsonExceptionError(ex);
await this.SendMessageAsync(this.CreateErrorResponseMessage(default!, error));
this.SendMessage(this.CreateErrorResponseMessage(default!, error));
continue;
}
}
}

private async Task WriterLoopAsync()
private async Task WriterLoopAsync(CancellationToken cancellationToken)
{
try
{
await foreach (var message in this.outgoingMessages.Reader.ReadAllAsync(this.shutdownTokenSource.Token))
await foreach (var message in this.outgoingMessages.Reader.ReadAllAsync(cancellationToken))
{
await this.WriteMessageAsync(message);
await this.WriteMessageAsync(message, cancellationToken);
}
}
catch (OperationCanceledException oce) when (oce.CancellationToken == this.shutdownTokenSource.Token)
catch (OperationCanceledException oce) when (oce.CancellationToken == cancellationToken)
{
}
}

private async Task ProcessorLoopAsync()
private async Task ProcessorLoopAsync(CancellationToken cancellationToken)
{
bool IsMutating(TMessage message)
{
Expand All @@ -156,7 +151,7 @@ bool IsMutating(TMessage message)
try
{
var currentTasks = new List<Task>();
await foreach (var message in this.incomingMessages.Reader.ReadAllAsync(this.shutdownTokenSource.Token))
await foreach (var message in this.incomingMessages.Reader.ReadAllAsync(cancellationToken))
{
if (IsMutating(message))
{
Expand All @@ -174,7 +169,7 @@ bool IsMutating(TMessage message)

await Task.WhenAll(currentTasks);
}
catch (OperationCanceledException oce) when (oce.CancellationToken == this.shutdownTokenSource.Token)
catch (OperationCanceledException oce) when (oce.CancellationToken == cancellationToken)
{
}
}
Expand All @@ -188,7 +183,7 @@ private async Task ProcessMessageAsync(TMessage message)
if (this.IsRequestMessage(message))
{
var response = await this.ProcessIncomingRequestAsync(message);
await this.SendMessageAsync(response);
this.SendMessage(response);
}
else if (this.IsNotificationMessage(message))
{
Expand Down Expand Up @@ -363,15 +358,15 @@ private async Task ProcessIncomingNotificationAsync(TMessage message)
return (Task<TResponse?>)pendingReq.Task;
}

public async Task SendNotificationAsync(string method, object? @params)
public Task SendNotificationAsync(string method, object? @params)
{
var serializedParams = JsonSerializer.SerializeToElement(@params, this.JsonSerializerOptions);
var notification = this.CreateNotificationMessage(method, serializedParams);
await this.SendMessageAsync(notification);
}

protected async Task SendMessageAsync(TMessage message) =>
await this.outgoingMessages.Writer.WriteAsync(message, this.shutdownTokenSource.Token);
this.SendMessage(notification);

return Task.CompletedTask;
}

protected void SendMessage(TMessage message) =>
this.outgoingMessages.Writer.TryWrite(message);
Expand Down Expand Up @@ -437,14 +432,14 @@ protected void FailOutgoingRequest(int id, JsonRpcResponseException error)
#endregion

#region Serialization
private async Task<(TMessage? Message, bool Found)> ReadMessageAsync()
private async Task<(TMessage? Message, bool Found)> ReadMessageAsync(CancellationToken cancellationToken)
{
var contentLength = -1;
var reader = this.Transport.Input;

while (true)
{
var result = await reader.ReadAsync(this.shutdownTokenSource.Token);
var result = await reader.ReadAsync(cancellationToken);
var buffer = result.Buffer;

var foundJson = this.TryParseMessage(ref buffer, ref contentLength, out var message);
Expand Down Expand Up @@ -519,7 +514,7 @@ private bool TryParseMessage(
return false;
}

private ValueTask<FlushResult> WriteMessageAsync(TMessage message)
private ValueTask<FlushResult> WriteMessageAsync(TMessage message, CancellationToken cancellationToken)
{
var writer = this.Transport.Output;

Expand Down Expand Up @@ -550,7 +545,7 @@ void WriteData()
}

WriteData();
return writer.FlushAsync(this.shutdownTokenSource.Token);
return writer.FlushAsync(cancellationToken);
}
#endregion

Expand Down
11 changes: 8 additions & 3 deletions src/Draco.Lsp/Server/LanguageServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.IO.Pipelines;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Draco.JsonRpc;
using Draco.Lsp.Attributes;
Expand Down Expand Up @@ -35,20 +36,24 @@ public static ILanguageClient Connect(IDuplexPipe stream)
/// </summary>
/// <param name="client">The language client.</param>
/// <param name="server">The language server.</param>
/// <param name="cancellationToken">The cancellation token.</param>
/// <returns>The task that completes when the communication is over.</returns>
public static async Task RunAsync(this ILanguageClient client, ILanguageServer server)
public static async Task RunAsync(this ILanguageClient client, ILanguageServer server, CancellationToken cancellationToken = default)
{
var connection = ((LanguageClientProxy)client).Connection;

// Register server methods
RegisterServerRpcMethods(server, connection);

// The lifecycle needs to be able to cancel the server too, so build a linked CTS on top of the given CT.
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);

// Register builtin server methods. In the future, we should consider making this extensible in some way.
var lifecycle = new LanguageServerLifecycle(server, connection);
var lifecycle = new LanguageServerLifecycle(server, connection, cts);
RegisterServerRpcMethods(lifecycle, connection);

// Done, now we can actually start
await connection.ListenAsync();
await connection.ListenAsync(cts.Token);
}

private static void RegisterServerRpcMethods(object target, IJsonRpcConnection connection)
Expand Down
7 changes: 5 additions & 2 deletions src/Draco.Lsp/Server/LanguageServerLifecycle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Linq;
using System.Reflection;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Draco.JsonRpc;
using Draco.Lsp.Attributes;
Expand All @@ -16,7 +17,8 @@ namespace Draco.Lsp.Server;
/// </summary>
internal sealed class LanguageServerLifecycle(
ILanguageServer server,
IJsonRpcConnection connection) : ILanguageServerLifecycle
IJsonRpcConnection connection,
CancellationTokenSource cancellation) : ILanguageServerLifecycle
{
private ClientCapabilities clientCapabilities = null!;

Expand Down Expand Up @@ -48,7 +50,8 @@ public async Task InitializedAsync(InitializedParams param)

public Task ExitAsync()
{
connection.Shutdown();
cancellation.Cancel();

return Task.CompletedTask;
}

Expand Down

0 comments on commit dd1457c

Please sign in to comment.