Skip to content

Commit

Permalink
Honor user-provided Json IContractResolver while maintaining marshali…
Browse files Browse the repository at this point in the history
…ng capabilities (#783)

Addressing failures similar to
```
StreamJsonRpc.RemoteMethodNotFoundException : Unable to find method 'xxx' on {no object} for the following reasons: Deserializing JSON-RPC argument with name "observer" and position 1 to type "xxx" failed: Could not create an instance of type xxx. Type is an interface or abstract class and cannot be instantiated. Path 'params[1].__jsonrpc_marshaled'.
```
when calling APIs under the following conditions:
- using Newtonsoft.Json
- the API has marshalable (`IObserver<>`, `IDisposable` or interfaces with `RpcMarshalableAttribute`) objects as parameters or return values
- the user replaces `JsonMessageFormatter`'s `ContractResolver`.

This issue has worsened since #777 due to more behavior being delegated to the `ContractResolver`

This PR integrates the `ContractResolver` provided by the user with `MarshalContractResolver` before the first serialization or deserialization.
  • Loading branch information
matteo-prosperi authored Mar 22, 2022
1 parent 7479dbb commit dd9e0c2
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 8 deletions.
44 changes: 41 additions & 3 deletions src/StreamJsonRpc/JsonMessageFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ public class JsonMessageFormatter : IJsonRpcAsyncMessageTextFormatter, IJsonRpcF
/// </summary>
private readonly SequenceTextReader sequenceTextReader = new SequenceTextReader();

/// <summary>
/// Object used to lock when running mutually exclusive operations related to this <see cref="JsonMessageFormatter"/> instance.
/// </summary>
private readonly object syncObject = new();

/// <summary>
/// Backing field for the <see cref="MultiplexingStream"/> property.
/// </summary>
Expand Down Expand Up @@ -162,6 +167,11 @@ public class JsonMessageFormatter : IJsonRpcAsyncMessageTextFormatter, IJsonRpcF
/// </summary>
private JsonRpcMessage? deserializingMessage;

/// <summary>
/// Whether <see cref="EnforceFormatterIsInitialized"/> has been executed.
/// </summary>
private bool formatterInitializationChecked;

/// <summary>
/// Initializes a new instance of the <see cref="JsonMessageFormatter"/> class
/// that uses JsonProgress (without the preamble) for its text encoding.
Expand Down Expand Up @@ -390,6 +400,8 @@ public JsonRpcMessage Deserialize(JToken json)
{
Requires.NotNull(json, nameof(json));

this.EnforceFormatterIsInitialized();

try
{
switch (this.ProtocolVersion.Major)
Expand Down Expand Up @@ -432,6 +444,8 @@ public JsonRpcMessage Deserialize(JToken json)
/// <returns>The JSON of the message.</returns>
public JToken Serialize(JsonRpcMessage message)
{
this.EnforceFormatterIsInitialized();

try
{
this.observedTransmittedRequestWithStringId |= message is JsonRpcRequest request && request.RequestId.String is not null;
Expand Down Expand Up @@ -528,6 +542,23 @@ private static object[] PartiallyParsePositionalArguments(JArray args)
return jtokenArray;
}

private void EnforceFormatterIsInitialized()
{
lock (this.syncObject)
{
if (!this.formatterInitializationChecked)
{
IContractResolver? originalContractResolver = this.JsonSerializer.ContractResolver;
if (originalContractResolver is not MarshalContractResolver)
{
this.JsonSerializer.ContractResolver = new MarshalContractResolver(this, originalContractResolver);
}

this.formatterInitializationChecked = true;
}
}
}

private void VerifyProtocolCompliance(bool condition, JToken message, string? explanation = null)
{
if (!condition)
Expand Down Expand Up @@ -1664,16 +1695,23 @@ public override void WriteJson(JsonWriter writer, Exception? value, JsonSerializ
}
}

private class MarshalContractResolver : DefaultContractResolver
private class MarshalContractResolver : IContractResolver
{
private readonly JsonMessageFormatter formatter;
private readonly IContractResolver underlyingContractResolver;

public MarshalContractResolver(JsonMessageFormatter formatter)
: this(formatter, null)
{
}

public MarshalContractResolver(JsonMessageFormatter formatter, IContractResolver? underlyingContractResolver)
{
this.formatter = formatter;
this.underlyingContractResolver = underlyingContractResolver ?? new DefaultContractResolver();
}

public override JsonContract ResolveContract(Type type)
public JsonContract ResolveContract(Type type)
{
if (this.formatter.TryGetMarshaledJsonConverter(type, out RpcMarshalableConverter? converter))
{
Expand All @@ -1684,7 +1722,7 @@ public override JsonContract ResolveContract(Type type)
};
}

JsonContract? result = base.ResolveContract(type);
JsonContract? result = this.underlyingContractResolver.ResolveContract(type);
switch (result)
{
case JsonObjectContract objectContract:
Expand Down
213 changes: 213 additions & 0 deletions test/StreamJsonRpc.Tests/JsonContractResolverTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Diagnostics;
using System.Runtime.Serialization;
using Microsoft.VisualStudio.Threading;
using Nerdbank.Streams;
using Newtonsoft.Json.Serialization;
using StreamJsonRpc;
using Xunit;
using Xunit.Abstractions;

public partial class JsonContractResolverTest : TestBase
{
private readonly Server server = new Server();
private readonly JsonRpc serverRpc;
private readonly JsonRpc clientRpc;
private readonly IServer client;

public JsonContractResolverTest(ITestOutputHelper logger)
: base(logger)
{
var pipes = FullDuplexStream.CreatePipePair();

this.client = JsonRpc.Attach<IServer>(new LengthHeaderMessageHandler(pipes.Item1, this.CreateFormatter()));
this.clientRpc = ((IJsonRpcClientProxy)this.client).JsonRpc;

this.serverRpc = new JsonRpc(new LengthHeaderMessageHandler(pipes.Item2, this.CreateFormatter()));
this.serverRpc.AddLocalRpcTarget(this.server);

this.serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose);
this.clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose);

this.serverRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger));
this.clientRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger));

this.serverRpc.StartListening();
}

protected interface IServer : IDisposable
{
Task GiveObserver(IObserver<int> observer);

Task GiveObserverContainer(Container<IObserver<int>> observerContainer);

Task<IObserver<int>> GetObserver();

Task<Container<IObserver<int>>> GetObserverContainer();

Task GiveMarshalable(IMarshalable marshalable);

Task GiveMarshalableContainer(Container<IMarshalable> marshalableContainer);

Task<IMarshalable> GetMarshalable();

Task<Container<IMarshalable>> GetMarshalableContainer();
}

[RpcMarshalable]
protected interface IMarshalable : IDisposable
{
void DoSomething();
}

[Fact]
public async Task GiveObserverTest()
{
var observer = new MockObserver();
await Task.Run(() => this.client.GiveObserver(observer)).WithCancellation(this.TimeoutToken);
await observer.Completion.WithCancellation(this.TimeoutToken);
}

[Fact]
public async Task GiveObserverContainerTest()
{
var observer = new MockObserver();
await Task.Run(() => this.client.GiveObserverContainer(new Container<IObserver<int>> { Field = observer })).WithCancellation(this.TimeoutToken);
await observer.Completion.WithCancellation(this.TimeoutToken);
}

[Fact]
public async Task GetObserverTest()
{
var observer = await this.client.GetObserver();
observer.OnCompleted();
await this.server.Observer.Completion.WithCancellation(this.TimeoutToken);
}

[Fact]
public async Task GetObserverContainerTest()
{
var observer = (await this.client.GetObserverContainer()).Field!;
observer.OnCompleted();
await this.server.Observer.Completion.WithCancellation(this.TimeoutToken);
}

[Fact]
public async Task GiveMarshalableTest()
{
var marshalable = new MockMarshalable();
await Task.Run(() => this.client.GiveMarshalable(marshalable)).WithCancellation(this.TimeoutToken);
await marshalable.Completion.WithCancellation(this.TimeoutToken);
}

[Fact]
public async Task GiveMarshalableContainerTest()
{
var marshalable = new MockMarshalable();
await Task.Run(() => this.client.GiveMarshalableContainer(new Container<IMarshalable> { Field = marshalable })).WithCancellation(this.TimeoutToken);
await marshalable.Completion.WithCancellation(this.TimeoutToken);
}

[Fact]
public async Task GetMarshalableTest()
{
var marshalable = await this.client.GetMarshalable();
marshalable.DoSomething();
await this.server.Marshalable.Completion.WithCancellation(this.TimeoutToken);
}

[Fact]
public async Task GetMarshalableContainerTest()
{
var marshalable = (await this.client.GetMarshalableContainer()).Field!;
marshalable.DoSomething();
await this.server.Marshalable.Completion.WithCancellation(this.TimeoutToken);
}

private IJsonRpcMessageFormatter CreateFormatter()
{
var formatter = new JsonMessageFormatter();
formatter.JsonSerializer.ContractResolver = new CamelCasePropertyNamesContractResolver();

return formatter;
}

[DataContract]
protected class Container<T>
where T : class
{
[DataMember]
public T? Field { get; set; }
}

private class Server : IServer
{
internal MockObserver Observer { get; } = new MockObserver();

internal MockMarshalable Marshalable { get; } = new MockMarshalable();

public Task GiveObserver(IObserver<int> observer)
{
observer.OnCompleted();
return Task.CompletedTask;
}

public Task GiveObserverContainer(Container<IObserver<int>> observerContainer)
{
observerContainer.Field?.OnCompleted();
return Task.CompletedTask;
}

public Task<IObserver<int>> GetObserver() => Task.FromResult<IObserver<int>>(this.Observer);

public Task<Container<IObserver<int>>> GetObserverContainer() => Task.FromResult(new Container<IObserver<int>>() { Field = this.Observer });

public Task GiveMarshalable(IMarshalable marshalable)
{
marshalable.DoSomething();
return Task.CompletedTask;
}

public Task GiveMarshalableContainer(Container<IMarshalable> marshalableContainer)
{
marshalableContainer.Field?.DoSomething();
return Task.CompletedTask;
}

public Task<IMarshalable> GetMarshalable() => Task.FromResult<IMarshalable>(this.Marshalable);

public Task<Container<IMarshalable>> GetMarshalableContainer() => Task.FromResult(new Container<IMarshalable>() { Field = this.Marshalable });

void IDisposable.Dispose()
{
}
}

private class MockObserver : IObserver<int>
{
private readonly TaskCompletionSource<bool> completed = new TaskCompletionSource<bool>();

internal Task Completion => this.completed.Task;

public void OnCompleted() => this.completed.SetResult(true);

public void OnError(Exception error) => throw new NotImplementedException();

public void OnNext(int value) => throw new NotImplementedException();
}

private class MockMarshalable : IMarshalable
{
private readonly TaskCompletionSource<bool> completed = new TaskCompletionSource<bool>();

internal Task Completion => this.completed.Task;

public void DoSomething() => this.completed.SetResult(true);

public void Dispose()
{
}
}
}
5 changes: 0 additions & 5 deletions test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;
using Microsoft;
using Microsoft.VisualStudio.Threading;
using Nerdbank.Streams;
Expand Down Expand Up @@ -362,10 +359,8 @@ protected class MockObserver<T> : IObserver<T>

internal event EventHandler<T>? Next;

[System.Runtime.Serialization.IgnoreDataMember]
internal ImmutableList<T> ReceivedValues { get; private set; } = ImmutableList<T>.Empty;

[System.Runtime.Serialization.IgnoreDataMember]
internal Task<ImmutableList<T>> Completion => this.completed.Task;

internal AsyncAutoResetEvent ItemReceived { get; } = new AsyncAutoResetEvent();
Expand Down

0 comments on commit dd9e0c2

Please sign in to comment.