Skip to content

Commit

Permalink
Update JsonRpcConnection.cs
Browse files Browse the repository at this point in the history
  • Loading branch information
LPeter1997 committed Oct 20, 2023
1 parent 6454d67 commit e1f1aa6
Showing 1 changed file with 36 additions and 34 deletions.
70 changes: 36 additions & 34 deletions src/Draco.JsonRpc/JsonRpcConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace Draco.JsonRpc;
/// <typeparam name="TError">The error descriptor.</typeparam>
public abstract class JsonRpcConnection<TMessage, TError> : IJsonRpcConnection
{
internal sealed class JsonRpcResponseException : Exception
protected sealed class JsonRpcResponseException : Exception
{
public TError ResponseError { get; }

Expand All @@ -27,7 +27,7 @@ public JsonRpcResponseException(TError error, string message)
}
}

private interface IOutgoingRequest
protected interface IOutgoingRequest
{
public Task Task { get; }
public Type ResponseType { get; }
Expand Down Expand Up @@ -133,7 +133,7 @@ private async Task ReaderLoopAsync()
catch (JsonException ex)
{
var error = this.CreateJsonExceptionError(ex);
await this.outgoingMessages.Writer.WriteAsync(this.CreateErrorResponseMessage(null!, error));
await this.SendMessageAsync(this.CreateErrorResponseMessage(null!, error));
continue;
}
}
Expand Down Expand Up @@ -184,7 +184,7 @@ private async Task ProcessMessageAsync(TMessage message)
if (this.IsRequestMessage(message))
{
var response = await this.ProcessIncomingRequestAsync(message);
await this.outgoingMessages.Writer.WriteAsync(response);
await this.SendMessageAsync(response);
}
else if (this.IsNotificationMessage(message))
{
Expand All @@ -204,12 +204,9 @@ private async Task<TMessage> ProcessIncomingRequestAsync(TMessage message)

TMessage Error(TError error) => this.CreateErrorResponseMessage(messageId!, error);

// Cancellation handling
if (this.IsCancellationMessage(message))
{
this.CancelIncomingRequest(messageId!);
return this.CreateOkResponseMessage(messageId!, default);
}
// Custom handling
var (customResponse, customHandled) = await this.TryProcessCustomRequest(message);
if (customHandled) return customResponse!;

// Error handling block
if (!this.methodHandlers.TryGetValue(methodName, out var handler))
Expand Down Expand Up @@ -292,16 +289,12 @@ private void ProcessIncomingResponse(TMessage message)

private async Task ProcessIncomingNotificationAsync(TMessage message)
{
var messageId = this.GetMessageId(message);
var methodName = this.GetMessageMethodName(message);
var @params = this.GetMessageParams(message);

// Cancellation handling
if (this.IsCancellationMessage(message))
{
this.CancelIncomingRequest(messageId!);
return;
}
// Custom handling
var customHandled = await this.TryProcessCustomNotification(message);
if (customHandled) return;

// Error handling block
// Note that we can't respond to notifications
Expand Down Expand Up @@ -342,6 +335,9 @@ private async Task ProcessIncomingNotificationAsync(TMessage message)
{
}
}

protected abstract Task<(TMessage? Message, bool Handled)> TryProcessCustomRequest(TMessage message);
protected abstract Task<bool> TryProcessCustomNotification(TMessage message);
#endregion

#region Sending Message
Expand All @@ -361,27 +357,33 @@ private async Task ProcessIncomingNotificationAsync(TMessage message)
cancellationToken.Register(() => this.CancelOutgoingRequest(id));

// Actually send message
this.outgoingMessages.Writer.TryWrite(request);
this.SendMessage(request);
return (Task<TResponse?>)pendingReq.Task;
}

public async Task SendNotificationAsync(string method, object? @params)
{
var serializedParams = JsonSerializer.SerializeToElement(@params, this.JsonSerializerOptions);
var notification = this.CreateNotificationMessage(method, serializedParams);
await this.outgoingMessages.Writer.WriteAsync(notification);
await this.SendMessageAsync(notification);
}

protected Task SendMessageAsync(TMessage message) =>
this.outgoingMessages.Writer.WriteAsync(message).AsTask();

protected void SendMessage(TMessage message) =>
this.outgoingMessages.Writer.TryWrite(message);
#endregion

#region Request Response
private CancellationToken AddIncomingRequest(object id)
protected CancellationToken AddIncomingRequest(object id)
{
var cts = new CancellationTokenSource();
this.pendingIncomingRequests.TryAdd(id, cts);
return cts.Token;
}

private void CancelIncomingRequest(object id)
protected void CancelIncomingRequest(object id)
{
if (this.pendingIncomingRequests.TryRemove(id, out var cts))
{
Expand All @@ -390,29 +392,31 @@ private void CancelIncomingRequest(object id)
}
}

private void CompleteIncomingRequest(object id)
protected void CompleteIncomingRequest(object id)
{
if (this.pendingIncomingRequests.TryRemove(id, out var cts)) cts.Dispose();
if (this.pendingIncomingRequests.TryRemove(id, out var cts))
{
cts.Dispose();
}
}

private IOutgoingRequest AddOutgoingRequest<TResponse>(int id)
protected IOutgoingRequest AddOutgoingRequest<TResponse>(int id)
{
var req = new OutgoingRequest<TResponse>();
this.pendingOutgoingRequests.TryAdd(id, req);
return req;
}

private void CancelOutgoingRequest(int id)
protected virtual void CancelOutgoingRequest(int id)
{
// Cancel the task
if (this.pendingOutgoingRequests.TryRemove(id, out var req)) req.Cancel();

// Send message
var cancelMessage = this.CreateCancelRequestMessage(id);
this.outgoingMessages.Writer.TryWrite(cancelMessage);
if (this.pendingOutgoingRequests.TryRemove(id, out var req))
{
req.Cancel();
}
}

private void CompleteOutgoingRequest(int id, JsonElement? resultJson)
protected void CompleteOutgoingRequest(int id, JsonElement? resultJson)
{
if (this.pendingOutgoingRequests.TryRemove(id, out var req))
{
Expand All @@ -421,7 +425,7 @@ private void CompleteOutgoingRequest(int id, JsonElement? resultJson)
}
}

private void FailOutgoingRequest(int id, JsonRpcResponseException error)
protected void FailOutgoingRequest(int id, JsonRpcResponseException error)
{
if (this.pendingOutgoingRequests.TryRemove(id, out var req))
{
Expand Down Expand Up @@ -550,7 +554,6 @@ void WriteData()

#region Factory Methods
protected abstract TMessage CreateRequestMessage(int id, string method, JsonElement @params);
protected abstract TMessage CreateCancelRequestMessage(int id);
protected abstract TMessage CreateOkResponseMessage(object id, JsonElement okResult);
protected abstract TMessage CreateErrorResponseMessage(object id, TError errorResult);
protected abstract TMessage CreateNotificationMessage(string method, JsonElement @params);
Expand All @@ -566,7 +569,6 @@ void WriteData()
protected abstract bool IsRequestMessage(TMessage message);
protected abstract bool IsResponseMessage(TMessage message);
protected abstract bool IsNotificationMessage(TMessage message);
protected abstract bool IsCancellationMessage(TMessage message);
protected abstract object? GetMessageId(TMessage message);
protected abstract string GetMessageMethodName(TMessage message);
protected abstract JsonElement? GetMessageParams(TMessage message);
Expand Down

0 comments on commit e1f1aa6

Please sign in to comment.