From e1f1aa6f729f1ac721120cf2a33f041dbb641f4e Mon Sep 17 00:00:00 2001 From: LPeter1997 Date: Fri, 20 Oct 2023 19:09:11 +0200 Subject: [PATCH] Update JsonRpcConnection.cs --- src/Draco.JsonRpc/JsonRpcConnection.cs | 70 +++++++++++++------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/src/Draco.JsonRpc/JsonRpcConnection.cs b/src/Draco.JsonRpc/JsonRpcConnection.cs index 5ed27b228..f41082b6e 100644 --- a/src/Draco.JsonRpc/JsonRpcConnection.cs +++ b/src/Draco.JsonRpc/JsonRpcConnection.cs @@ -16,7 +16,7 @@ namespace Draco.JsonRpc; /// The error descriptor. public abstract class JsonRpcConnection : IJsonRpcConnection { - internal sealed class JsonRpcResponseException : Exception + protected sealed class JsonRpcResponseException : Exception { public TError ResponseError { get; } @@ -27,7 +27,7 @@ public JsonRpcResponseException(TError error, string message) } } - private interface IOutgoingRequest + protected interface IOutgoingRequest { public Task Task { get; } public Type ResponseType { get; } @@ -133,7 +133,7 @@ private async Task ReaderLoopAsync() catch (JsonException ex) { var error = this.CreateJsonExceptionError(ex); - await this.outgoingMessages.Writer.WriteAsync(this.CreateErrorResponseMessage(null!, error)); + await this.SendMessageAsync(this.CreateErrorResponseMessage(null!, error)); continue; } } @@ -184,7 +184,7 @@ private async Task ProcessMessageAsync(TMessage message) if (this.IsRequestMessage(message)) { var response = await this.ProcessIncomingRequestAsync(message); - await this.outgoingMessages.Writer.WriteAsync(response); + await this.SendMessageAsync(response); } else if (this.IsNotificationMessage(message)) { @@ -204,12 +204,9 @@ private async Task ProcessIncomingRequestAsync(TMessage message) TMessage Error(TError error) => this.CreateErrorResponseMessage(messageId!, error); - // Cancellation handling - if (this.IsCancellationMessage(message)) - { - this.CancelIncomingRequest(messageId!); - return this.CreateOkResponseMessage(messageId!, default); - } + // Custom handling + var (customResponse, customHandled) = await this.TryProcessCustomRequest(message); + if (customHandled) return customResponse!; // Error handling block if (!this.methodHandlers.TryGetValue(methodName, out var handler)) @@ -292,16 +289,12 @@ private void ProcessIncomingResponse(TMessage message) private async Task ProcessIncomingNotificationAsync(TMessage message) { - var messageId = this.GetMessageId(message); var methodName = this.GetMessageMethodName(message); var @params = this.GetMessageParams(message); - // Cancellation handling - if (this.IsCancellationMessage(message)) - { - this.CancelIncomingRequest(messageId!); - return; - } + // Custom handling + var customHandled = await this.TryProcessCustomNotification(message); + if (customHandled) return; // Error handling block // Note that we can't respond to notifications @@ -342,6 +335,9 @@ private async Task ProcessIncomingNotificationAsync(TMessage message) { } } + + protected abstract Task<(TMessage? Message, bool Handled)> TryProcessCustomRequest(TMessage message); + protected abstract Task TryProcessCustomNotification(TMessage message); #endregion #region Sending Message @@ -361,7 +357,7 @@ private async Task ProcessIncomingNotificationAsync(TMessage message) cancellationToken.Register(() => this.CancelOutgoingRequest(id)); // Actually send message - this.outgoingMessages.Writer.TryWrite(request); + this.SendMessage(request); return (Task)pendingReq.Task; } @@ -369,19 +365,25 @@ public async Task SendNotificationAsync(string method, object? @params) { var serializedParams = JsonSerializer.SerializeToElement(@params, this.JsonSerializerOptions); var notification = this.CreateNotificationMessage(method, serializedParams); - await this.outgoingMessages.Writer.WriteAsync(notification); + await this.SendMessageAsync(notification); } + + protected Task SendMessageAsync(TMessage message) => + this.outgoingMessages.Writer.WriteAsync(message).AsTask(); + + protected void SendMessage(TMessage message) => + this.outgoingMessages.Writer.TryWrite(message); #endregion #region Request Response - private CancellationToken AddIncomingRequest(object id) + protected CancellationToken AddIncomingRequest(object id) { var cts = new CancellationTokenSource(); this.pendingIncomingRequests.TryAdd(id, cts); return cts.Token; } - private void CancelIncomingRequest(object id) + protected void CancelIncomingRequest(object id) { if (this.pendingIncomingRequests.TryRemove(id, out var cts)) { @@ -390,29 +392,31 @@ private void CancelIncomingRequest(object id) } } - private void CompleteIncomingRequest(object id) + protected void CompleteIncomingRequest(object id) { - if (this.pendingIncomingRequests.TryRemove(id, out var cts)) cts.Dispose(); + if (this.pendingIncomingRequests.TryRemove(id, out var cts)) + { + cts.Dispose(); + } } - private IOutgoingRequest AddOutgoingRequest(int id) + protected IOutgoingRequest AddOutgoingRequest(int id) { var req = new OutgoingRequest(); this.pendingOutgoingRequests.TryAdd(id, req); return req; } - private void CancelOutgoingRequest(int id) + protected virtual void CancelOutgoingRequest(int id) { // Cancel the task - if (this.pendingOutgoingRequests.TryRemove(id, out var req)) req.Cancel(); - - // Send message - var cancelMessage = this.CreateCancelRequestMessage(id); - this.outgoingMessages.Writer.TryWrite(cancelMessage); + if (this.pendingOutgoingRequests.TryRemove(id, out var req)) + { + req.Cancel(); + } } - private void CompleteOutgoingRequest(int id, JsonElement? resultJson) + protected void CompleteOutgoingRequest(int id, JsonElement? resultJson) { if (this.pendingOutgoingRequests.TryRemove(id, out var req)) { @@ -421,7 +425,7 @@ private void CompleteOutgoingRequest(int id, JsonElement? resultJson) } } - private void FailOutgoingRequest(int id, JsonRpcResponseException error) + protected void FailOutgoingRequest(int id, JsonRpcResponseException error) { if (this.pendingOutgoingRequests.TryRemove(id, out var req)) { @@ -550,7 +554,6 @@ void WriteData() #region Factory Methods protected abstract TMessage CreateRequestMessage(int id, string method, JsonElement @params); - protected abstract TMessage CreateCancelRequestMessage(int id); protected abstract TMessage CreateOkResponseMessage(object id, JsonElement okResult); protected abstract TMessage CreateErrorResponseMessage(object id, TError errorResult); protected abstract TMessage CreateNotificationMessage(string method, JsonElement @params); @@ -566,7 +569,6 @@ void WriteData() protected abstract bool IsRequestMessage(TMessage message); protected abstract bool IsResponseMessage(TMessage message); protected abstract bool IsNotificationMessage(TMessage message); - protected abstract bool IsCancellationMessage(TMessage message); protected abstract object? GetMessageId(TMessage message); protected abstract string GetMessageMethodName(TMessage message); protected abstract JsonElement? GetMessageParams(TMessage message);