diff --git a/src/Clients/js/dotnet/UiPath.CoreIpc.NodeInterop/Program.cs b/src/Clients/js/dotnet/UiPath.CoreIpc.NodeInterop/Program.cs index 47d71ba8..bb5ed827 100644 --- a/src/Clients/js/dotnet/UiPath.CoreIpc.NodeInterop/Program.cs +++ b/src/Clients/js/dotnet/UiPath.CoreIpc.NodeInterop/Program.cs @@ -120,15 +120,21 @@ IEnumerable EnumeratePings() { if (webSocketUrl is not null) { - yield return new WebSocketClient() + yield return new IpcClient { - Uri = new(webSocketUrl), - ServiceProvider = sp, - RequestTimeout = TimeSpan.FromHours(5), - Callbacks = new() + Config = new() { - { typeof(IArithmetic), callback } + ServiceProvider = sp, + RequestTimeout = TimeSpan.FromHours(5), + Callbacks = new() + { + { typeof(IArithmetic), callback } + }, }, + Transport = new WebSocketTransport + { + Uri = new(webSocketUrl), + } } .GetProxy() .Ping(); @@ -136,14 +142,20 @@ IEnumerable EnumeratePings() if (pipeName is not null) { - yield return new NamedPipeClient() + yield return new IpcClient { - PipeName = pipeName, - ServiceProvider = sp, - RequestTimeout = TimeSpan.FromHours(5), - Callbacks = new() + Config = new() + { + ServiceProvider = sp, + RequestTimeout = TimeSpan.FromHours(5), + Callbacks = new() + { + { typeof(IArithmetic), callback } + } + }, + Transport = new NamedPipeTransport() { - { typeof(IArithmetic), callback } + PipeName = pipeName, } } .GetProxy() diff --git a/src/Playground/Program.cs b/src/Playground/Program.cs index b5dd6925..0cf75920 100644 --- a/src/Playground/Program.cs +++ b/src/Playground/Program.cs @@ -40,7 +40,14 @@ private static async Task Main(string[] args) ServiceProvider = serverSP, Endpoints = new() { - typeof(Contracts.IServerOperations), + typeof(Contracts.IServerOperations), // DEVINE + new EndpointSettings(typeof(Contracts.IServerOperations)) // ASTALALT + { + BeforeCall = async (callInfo, _) => + { + Console.WriteLine($"Server: {callInfo.Method.Name}"); + } + }, typeof(Contracts.IClientOperations2) }, Listeners = [ @@ -73,21 +80,66 @@ private static async Task Main(string[] args) throw; } - var proxy1 = new NamedPipeClient() + var c1 = new IpcClient() { - PipeName = Contracts.PipeName, - ServerName = ".", - AllowImpersonation = false, + Config = new() + { + Callbacks = new() + { + typeof(Contracts.IClientOperations), + { typeof(Contracts.IClientOperations2), new Impl.Client2() }, + }, + ServiceProvider = clientSP, + Scheduler = clientScheduler, + }, + Transport = new NamedPipeTransport() + { + PipeName = Contracts.PipeName, + ServerName = ".", + AllowImpersonation = false, + }, + }; + + var c2 = new IpcClient() + { + Config = new() + { + ServiceProvider = clientSP, + Callbacks = new() + { + typeof(Contracts.IClientOperations), + { typeof(Contracts.IClientOperations2), new Impl.Client2() }, + }, + Scheduler = clientScheduler, + }, + Transport = new NamedPipeTransport() + { + PipeName = Contracts.PipeName, + ServerName = ".", + AllowImpersonation = false, + }, + }; - ServiceProvider = clientSP, - Callbacks = new() + var proxy1 = new IpcClient() + { + Config = new() { - typeof(Contracts.IClientOperations), - { typeof(Contracts.IClientOperations2), new Impl.Client2() } + ServiceProvider = clientSP, + Callbacks = new() + { + typeof(Contracts.IClientOperations), + { typeof(Contracts.IClientOperations2), new Impl.Client2() }, + }, + Scheduler = clientScheduler, }, - Scheduler = clientScheduler - } - .GetProxy(); + Transport = new NamedPipeTransport() + { + PipeName = Contracts.PipeName, + ServerName = ".", + AllowImpersonation = false, + }, + }.GetProxy(); + await proxy1.Register(); await proxy1.Broadcast("Hello Bidirectional Http!"); diff --git a/src/UiPath.CoreIpc/Client/IpcProxy.cs b/src/UiPath.CoreIpc/Client/IpcProxy.cs new file mode 100644 index 00000000..9203b1a5 --- /dev/null +++ b/src/UiPath.CoreIpc/Client/IpcProxy.cs @@ -0,0 +1,21 @@ +namespace UiPath.Ipc; + +public class IpcProxy : DispatchProxy, IDisposable +{ + internal ServiceClient ServiceClient { get; set; } = null!; + + protected override object? Invoke(MethodInfo? targetMethod, object?[]? args) + => ServiceClient.Invoke(targetMethod!, args!); + + public void Dispose() => ServiceClient?.Dispose(); + + public ValueTask CloseConnection() => ServiceClient.CloseConnection(); + + public event EventHandler ConnectionClosed + { + add => ServiceClient.ConnectionClosed += value; + remove => ServiceClient.ConnectionClosed -= value; + } + + public Stream? Network => ServiceClient.Network; +} diff --git a/src/UiPath.CoreIpc/Client/ServiceClient.cs b/src/UiPath.CoreIpc/Client/ServiceClient.cs index 4939651d..e546e4fd 100644 --- a/src/UiPath.CoreIpc/Client/ServiceClient.cs +++ b/src/UiPath.CoreIpc/Client/ServiceClient.cs @@ -1,8 +1,5 @@ namespace UiPath.Ipc; -using System.Linq.Expressions; -using ServiceClientProperFactory = Func; - internal abstract class ServiceClient : IDisposable { #region " NonGeneric-Generic adapter cache " @@ -19,7 +16,6 @@ private static InvokeDelegate CreateInvokeDelegate(Type returnType) #endregion protected abstract TimeSpan RequestTimeout { get; } - protected abstract ConnectionFactory? ConnectionFactory { get; } protected abstract BeforeCallHandler? BeforeCall { get; } protected abstract ILogger? Log { get; } protected abstract string DebugName { get; } @@ -151,47 +147,11 @@ private void Dispose(bool disposing) public override string ToString() => DebugName; } -internal static class ServiceClientProper -{ - private static readonly ConcurrentDictionary CachedFactories = new(); - private static ServiceClientProperFactory GetFactory(Type clientType) => CachedFactories.GetOrAdd(clientType, CreateFactory); - private static ServiceClientProperFactory CreateFactory(Type clientType) - { - if (clientType - .GetInterfaces() - .SingleOrDefault(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IClient<,>)) - ?.GetGenericArguments() - is not [var clientStateType, var clientType2] || clientType2 != clientType) - { - throw new ArgumentOutOfRangeException(nameof(clientType), "The client implements 0 or more than 1 IClient<,> interfaces or the single interface's 2nd generic argument is not the client type itself."); - } - - var ctor = typeof(ServiceClientProper<,>) - .MakeGenericType(clientType, clientStateType) - .GetConstructor([clientType, typeof(Type)])!; - - var paramofClientBase = Expression.Parameter(typeof(ClientBase)); - var paramofType = Expression.Parameter(typeof(Type)); - return Expression.Lambda( - Expression.New( - ctor, - Expression.Convert(paramofClientBase, clientType), - paramofType), - paramofClientBase, - paramofType).Compile(); - } - - public static ServiceClient Create(ClientBase client, Type proxyType) - => GetFactory(client.GetType())(client, proxyType); -} - -internal sealed class ServiceClientProper : ServiceClient - where TClient : ClientBase, IClient - where TClientState : class, IClientState, new() +internal sealed class ServiceClientProper : ServiceClient { private readonly FastAsyncLock _lock = new(); - private readonly TClientState _clientState = new(); - private readonly TClient _client; + private readonly IpcClient _client; + private readonly IClientState _clientState; private Connection? _latestConnection; private Server? _latestServer; @@ -222,9 +182,10 @@ private Connection? LatestConnection public override Stream? Network => LatestConnection?.Network; - public ServiceClientProper(TClient client, Type interfaceType) : base(interfaceType) + public ServiceClientProper(IpcClient client, Type interfaceType) : base(interfaceType) { _client = client; + _clientState = client.Transport.CreateState(); } public override async ValueTask CloseConnection() @@ -248,8 +209,8 @@ public override async ValueTask CloseConnection() } LatestConnection = new Connection(await Connect(ct), Serializer, Log, DebugName); - var router = new Router(_client.CreateCallbackRouterConfig(), _client.ServiceProvider); - _latestServer = new Server(router, _client.RequestTimeout, LatestConnection); + var router = new Router(_client.Config.CreateCallbackRouterConfig(), _client.Config.ServiceProvider); + _latestServer = new Server(router, _client.Config.RequestTimeout, LatestConnection); LatestConnection.Listen().LogException(Log, DebugName); return (LatestConnection, newlyConnected: true); } @@ -257,12 +218,6 @@ public override async ValueTask CloseConnection() private async Task Connect(CancellationToken ct) { - if (ConnectionFactory is not null - && await ConnectionFactory(_clientState.Network, ct) is { } userProvidedNetwork) - { - return userProvidedNetwork; - } - await _clientState.Connect(_client, ct); if (_clientState.Network is not { } network) @@ -273,12 +228,11 @@ private async Task Connect(CancellationToken ct) return network; } - protected override TimeSpan RequestTimeout => _client.RequestTimeout; - protected override ConnectionFactory? ConnectionFactory => _client.ConnectionFactory; - protected override BeforeCallHandler? BeforeCall => _client.BeforeCall; - protected override ILogger? Log => _client.Logger; - protected override string DebugName => _client.ToString(); - protected override ISerializer? Serializer => _client.Serializer; + protected override TimeSpan RequestTimeout => _client.Config.RequestTimeout; + protected override BeforeCallHandler? BeforeCall => _client.Config.BeforeCall; + protected override ILogger? Log => _client.Config.Logger; + protected override string DebugName => _client.Transport.ToString(); + protected override ISerializer? Serializer => _client.Config.Serializer; } internal sealed class ServiceClientForCallback : ServiceClient @@ -298,29 +252,8 @@ public ServiceClientForCallback(Connection connection, Listener listener, Type i => Task.FromResult((_connection, newlyConnected: false)); protected override TimeSpan RequestTimeout => _listener.Config.RequestTimeout; - protected override ConnectionFactory? ConnectionFactory => null; protected override BeforeCallHandler? BeforeCall => null; protected override ILogger? Log => null; protected override string DebugName => $"ReverseClient for {_listener}"; protected override ISerializer? Serializer => null; } - -public class IpcProxy : DispatchProxy, IDisposable -{ - internal ServiceClient ServiceClient { get; set; } = null!; - - protected override object? Invoke(MethodInfo? targetMethod, object?[]? args) - => ServiceClient.Invoke(targetMethod!, args!); - - public void Dispose() => ServiceClient?.Dispose(); - - public ValueTask CloseConnection() => ServiceClient.CloseConnection(); - - public event EventHandler ConnectionClosed - { - add => ServiceClient.ConnectionClosed += value; - remove => ServiceClient.ConnectionClosed -= value; - } - - public Stream? Network => ServiceClient.Network; -} diff --git a/src/UiPath.CoreIpc/Extensibility/ClientBase.cs b/src/UiPath.CoreIpc/Config/ClientConfig.cs similarity index 60% rename from src/UiPath.CoreIpc/Extensibility/ClientBase.cs rename to src/UiPath.CoreIpc/Config/ClientConfig.cs index 584d7449..f2aed93a 100644 --- a/src/UiPath.CoreIpc/Extensibility/ClientBase.cs +++ b/src/UiPath.CoreIpc/Config/ClientConfig.cs @@ -1,24 +1,16 @@ namespace UiPath.Ipc; -public abstract record ClientBase : EndpointConfig +public sealed record ClientConfig : EndpointConfig { - private readonly ConcurrentDictionary _clients = new(); - private ServiceClient GetServiceClient(Type proxyType) => _clients.GetOrAdd(proxyType, ServiceClientProper.Create(this, proxyType)); + public EndpointCollection? Callbacks { get; init; } public IServiceProvider? ServiceProvider { get; init; } - public EndpointCollection? Callbacks { get; init; } public ILogger? Logger { get; init; } - public ConnectionFactory? ConnectionFactory { get; init; } public BeforeCallHandler? BeforeCall { get; init; } public TaskScheduler? Scheduler { get; init; } public ISerializer? Serializer { get; set; } - public virtual void Validate() { } - - public TProxy GetProxy() where TProxy : class - => GetServiceClient(typeof(TProxy)).GetProxy(); - - internal void ValidateInternal() + internal void Validate() { var haveDeferredInjectedCallbacks = Callbacks?.Any(x => x.Service.MaybeGetServiceProvider() is null && x.Service.MaybeGetInstance() is null) ?? false; @@ -26,8 +18,6 @@ internal void ValidateInternal() { throw new InvalidOperationException("ServiceProvider is required when you register injectable callbacks. Consider registering a callback instance."); } - - Validate(); } internal ILogger? GetLogger(string name) @@ -55,16 +45,10 @@ internal override RouterConfig CreateCallbackRouterConfig() }); } -public interface IClient - where TSelf : ClientBase, IClient - where TState : class, IClientState, new() { } - -public interface IClientState : IDisposable - where TSelf : class, IClientState, new() - where TClient : ClientBase, IClient +public interface IClientState : IDisposable { Network? Network { get; } bool IsConnected(); - ValueTask Connect(TClient client, CancellationToken ct); + ValueTask Connect(IpcClient client, CancellationToken ct); } \ No newline at end of file diff --git a/src/UiPath.CoreIpc/Config/ClientTransport.cs b/src/UiPath.CoreIpc/Config/ClientTransport.cs new file mode 100644 index 00000000..2e21728a --- /dev/null +++ b/src/UiPath.CoreIpc/Config/ClientTransport.cs @@ -0,0 +1,7 @@ +namespace UiPath.Ipc; + +public abstract record ClientTransport +{ + public abstract IClientState CreateState(); + public abstract void Validate(); +} diff --git a/src/UiPath.CoreIpc/Config/IpcClient.cs b/src/UiPath.CoreIpc/Config/IpcClient.cs new file mode 100644 index 00000000..28ff2f37 --- /dev/null +++ b/src/UiPath.CoreIpc/Config/IpcClient.cs @@ -0,0 +1,32 @@ +namespace UiPath.Ipc; + +public sealed class IpcClient +{ + public required ClientConfig Config { get; init; } + public required ClientTransport Transport { get; init; } + + private readonly ConcurrentDictionary _clients = new(); + private ServiceClient GetServiceClient(Type proxyType) + { + return _clients.GetOrAdd(proxyType, Create); + + ServiceClient Create(Type proxyType) => new ServiceClientProper(this, proxyType); + } + public TProxy GetProxy() where TProxy : class + => GetServiceClient(typeof(TProxy)).GetProxy(); + + internal void Validate() + { + if (Config is null) + { + throw new InvalidOperationException($"{Config} is required."); + } + if (Transport is null) + { + throw new InvalidOperationException($"{Transport} is required."); + } + + Config.Validate(); + Transport.Validate(); + } +} diff --git a/src/UiPath.CoreIpc/GlobalUsings.cs b/src/UiPath.CoreIpc/GlobalUsings.cs index 40d72368..c3b8f439 100644 --- a/src/UiPath.CoreIpc/GlobalUsings.cs +++ b/src/UiPath.CoreIpc/GlobalUsings.cs @@ -1,6 +1,5 @@ global using UiPath.Ipc.Extensibility; global using BeforeCallHandler = System.Func; -global using ConnectionFactory = System.Func?, System.Threading.CancellationToken, System.Threading.Tasks.Task?>>; global using InvokeDelegate = System.Func; global using Accept = System.Func>; global using Network = UiPath.Ipc.Extensibility.OneOf; diff --git a/src/UiPath.CoreIpc/Server/EndpointSettings.cs b/src/UiPath.CoreIpc/Server/EndpointSettings.cs index 9f346b9d..26a39d6b 100644 --- a/src/UiPath.CoreIpc/Server/EndpointSettings.cs +++ b/src/UiPath.CoreIpc/Server/EndpointSettings.cs @@ -34,6 +34,9 @@ public EndpointSettings(Type contractType, IServiceProvider serviceProvider) : t private protected EndpointSettings(ServiceFactory service) => Service = service; + public virtual EndpointSettings WithServiceProvider(IServiceProvider? serviceProvider) + => new(Service.WithProvider(serviceProvider)); + public void Validate() { Validator.Validate(Service.Type); @@ -48,4 +51,8 @@ public sealed record EndpointSettings : EndpointSettings where TContr { public EndpointSettings(TContract? serviceInstance = null) : base(typeof(TContract), serviceInstance) { } public EndpointSettings(IServiceProvider serviceProvider) : base(typeof(TContract), serviceProvider) { } + private EndpointSettings(ServiceFactory service) : base(service) { } + + public override EndpointSettings WithServiceProvider(IServiceProvider? serviceProvider) + => new EndpointSettings(Service.WithProvider(serviceProvider)); } diff --git a/src/UiPath.CoreIpc/Transport/NamedPipe/NamedPipeClient.cs b/src/UiPath.CoreIpc/Transport/NamedPipe/NamedPipeClient.cs deleted file mode 100644 index 5f970e56..00000000 --- a/src/UiPath.CoreIpc/Transport/NamedPipe/NamedPipeClient.cs +++ /dev/null @@ -1,34 +0,0 @@ -using System.IO.Pipes; -using System.Security.Principal; - -namespace UiPath.Ipc.Transport.NamedPipe; - -public sealed record NamedPipeClient : ClientBase, IClient -{ - public required string PipeName { get; init; } - public string ServerName { get; init; } = "."; - public bool AllowImpersonation { get; init; } = false; - - public override string ToString() => $"ClientPipe={PipeName}"; -} - -internal sealed class NamedPipeClientState : IClientState -{ - private NamedPipeClientStream? _pipe; - - public Network? Network => _pipe; - public bool IsConnected() => _pipe?.IsConnected is true; - - public async ValueTask Connect(NamedPipeClient client, CancellationToken ct) - { - _pipe = new NamedPipeClientStream( - client.ServerName, - client.PipeName, - PipeDirection.InOut, - PipeOptions.Asynchronous, - client.AllowImpersonation ? TokenImpersonationLevel.Impersonation : TokenImpersonationLevel.Identification); - await _pipe.ConnectAsync(ct); - } - - public void Dispose() => _pipe?.Dispose(); -} diff --git a/src/UiPath.CoreIpc/Transport/NamedPipe/NamedPipeTransport.cs b/src/UiPath.CoreIpc/Transport/NamedPipe/NamedPipeTransport.cs new file mode 100644 index 00000000..56f60d36 --- /dev/null +++ b/src/UiPath.CoreIpc/Transport/NamedPipe/NamedPipeTransport.cs @@ -0,0 +1,51 @@ +using System.IO.Pipes; +using System.Security.Principal; + +namespace UiPath.Ipc.Transport.NamedPipe; + +public sealed record NamedPipeTransport : ClientTransport +{ + public required string PipeName { get; init; } + public string ServerName { get; init; } = "."; + public bool AllowImpersonation { get; init; } + + public override string ToString() => $"ClientPipe={PipeName}"; + + public override IClientState CreateState() => new NamedPipeClientState(); + + public override void Validate() + { + if (PipeName is null or "") + { + throw new InvalidOperationException($"{nameof(PipeName)} is required."); + } + if (ServerName is null or "") + { + throw new InvalidOperationException($"{nameof(ServerName)} is required."); + } + } +} + +internal sealed class NamedPipeClientState : IClientState +{ + private NamedPipeClientStream? _pipe; + + public Network? Network => _pipe; + public bool IsConnected() => _pipe?.IsConnected is true; + + public async ValueTask Connect(IpcClient client, CancellationToken ct) + { + var transport = client.Transport as NamedPipeTransport ?? throw new InvalidOperationException(); + + _pipe = new NamedPipeClientStream( + transport.ServerName, + transport.PipeName, + PipeDirection.InOut, + PipeOptions.Asynchronous, + transport.AllowImpersonation ? TokenImpersonationLevel.Impersonation : TokenImpersonationLevel.Identification); + + await _pipe.ConnectAsync(ct); + } + + public void Dispose() => _pipe?.Dispose(); +} diff --git a/src/UiPath.CoreIpc/Transport/Tcp/TcpClient.cs b/src/UiPath.CoreIpc/Transport/Tcp/TcpTransport.cs similarity index 53% rename from src/UiPath.CoreIpc/Transport/Tcp/TcpClient.cs rename to src/UiPath.CoreIpc/Transport/Tcp/TcpTransport.cs index 340d24d1..e89b8f8a 100644 --- a/src/UiPath.CoreIpc/Transport/Tcp/TcpClient.cs +++ b/src/UiPath.CoreIpc/Transport/Tcp/TcpTransport.cs @@ -2,14 +2,24 @@ namespace UiPath.Ipc.Transport.Tcp; -public sealed record TcpClient : ClientBase, IClient +public sealed record TcpTransport : ClientTransport { public required IPEndPoint EndPoint { get; init; } public override string ToString() => $"TcpClient={EndPoint}"; + + public override IClientState CreateState() => new TcpClientState(); + + public override void Validate() + { + if (EndPoint is null) + { + throw new InvalidOperationException($"{nameof(EndPoint)} is required."); + } + } } -internal sealed class TcpClientState : IClientState +internal sealed class TcpClientState : IClientState { private System.Net.Sockets.TcpClient? _tcpClient; @@ -20,14 +30,16 @@ public bool IsConnected() return _tcpClient?.Client?.Connected is true; } - public async ValueTask Connect(TcpClient client, CancellationToken ct) + public async ValueTask Connect(IpcClient client, CancellationToken ct) { + var transport = client.Transport as TcpTransport ?? throw new InvalidOperationException(); + _tcpClient = new System.Net.Sockets.TcpClient(); #if NET461 using var ctreg = ct.Register(_tcpClient.Dispose); try { - await _tcpClient.ConnectAsync(client.EndPoint.Address, client.EndPoint.Port); + await _tcpClient.ConnectAsync(transport.EndPoint.Address, transport.EndPoint.Port); } catch (ObjectDisposedException) { @@ -35,7 +47,7 @@ public async ValueTask Connect(TcpClient client, CancellationToken ct) throw new OperationCanceledException(ct); } #else - await _tcpClient.ConnectAsync(client.EndPoint.Address, client.EndPoint.Port, ct); + await _tcpClient.ConnectAsync(transport.EndPoint.Address, transport.EndPoint.Port, ct); #endif Network = _tcpClient.GetStream(); } diff --git a/src/UiPath.CoreIpc/Transport/WebSocket/WebSocketClient.cs b/src/UiPath.CoreIpc/Transport/WebSocket/WebSocketClient.cs deleted file mode 100644 index 7b9cd0a6..00000000 --- a/src/UiPath.CoreIpc/Transport/WebSocket/WebSocketClient.cs +++ /dev/null @@ -1,27 +0,0 @@ -using System.Net.WebSockets; - -namespace UiPath.Ipc.Transport.WebSocket; - -public sealed record WebSocketClient : ClientBase, IClient -{ - public required Uri Uri { get; init; } - public override string ToString() => $"WebSocketClient={Uri}"; -} - -internal sealed class WebSocketClientState : IClientState -{ - private ClientWebSocket? _clientWebSocket; - - public Network? Network { get; private set; } - - public bool IsConnected() => _clientWebSocket?.State is WebSocketState.Open; - - public async ValueTask Connect(WebSocketClient client, CancellationToken ct) - { - _clientWebSocket = new(); - await _clientWebSocket.ConnectAsync(client.Uri, ct); - Network = new WebSocketStream(_clientWebSocket); - } - - public void Dispose() => _clientWebSocket?.Dispose(); -} diff --git a/src/UiPath.CoreIpc/Transport/WebSocket/WebSocketTransport.cs b/src/UiPath.CoreIpc/Transport/WebSocket/WebSocketTransport.cs new file mode 100644 index 00000000..aa2fe47c --- /dev/null +++ b/src/UiPath.CoreIpc/Transport/WebSocket/WebSocketTransport.cs @@ -0,0 +1,39 @@ +using System.Net.WebSockets; + +namespace UiPath.Ipc.Transport.WebSocket; + +public sealed record WebSocketTransport : ClientTransport +{ + public required Uri Uri { get; init; } + public override string ToString() => $"WebSocketClient={Uri}"; + + public override IClientState CreateState() => new WebSocketClientState(); + + public override void Validate() + { + if (Uri is null) + { + throw new InvalidOperationException($"{nameof(Uri)} is required."); + } + } +} + +internal sealed class WebSocketClientState : IClientState +{ + private ClientWebSocket? _clientWebSocket; + + public Network? Network { get; private set; } + + public bool IsConnected() => _clientWebSocket?.State is WebSocketState.Open; + + public async ValueTask Connect(IpcClient client, CancellationToken ct) + { + var transport = client.Transport as WebSocketTransport ?? throw new InvalidOperationException(); + + _clientWebSocket = new(); + await _clientWebSocket.ConnectAsync(transport.Uri, ct); + Network = new WebSocketStream(_clientWebSocket); + } + + public void Dispose() => _clientWebSocket?.Dispose(); +} diff --git a/src/UiPath.Ipc.Tests/ComputingTests.cs b/src/UiPath.Ipc.Tests/ComputingTests.cs index 2a03e5e2..1fea8e13 100644 --- a/src/UiPath.Ipc.Tests/ComputingTests.cs +++ b/src/UiPath.Ipc.Tests/ComputingTests.cs @@ -30,10 +30,10 @@ protected override ListenerConfig ConfigTransportAgnostic(ListenerConfig listene { ConcurrentAccepts = 10, RequestTimeout = Timeouts.DefaultRequest, - MaxReceivedMessageSizeInMegabytes = 1, + MaxReceivedMessageSizeInMegabytes = 1, }; - protected override ClientBase ConfigTransportAgnostic(ClientBase client) - => client with + protected override ClientConfig CreateClientConfig() + => new() { RequestTimeout = Timeouts.DefaultRequest, Scheduler = GuiScheduler, @@ -93,8 +93,7 @@ await Proxy.GetCallbackThreadName( private sealed class ShortClientTimeout : OverrideConfig { - public override ClientBase Override(ClientBase client) - => client with { RequestTimeout = TimeSpan.FromMilliseconds(10) }; + public override IpcClient Override(IpcClient client) => client.WithRequestTimeout(TimeSpan.FromMilliseconds(10)); } [Theory, IpcAutoData] @@ -125,4 +124,21 @@ public async Task BeforeCall_ShouldApplyToCallsButNotToToCallbacks() _serverBeforeCalls.ShouldContain(x => x.Method.Name == nameof(IComputingService.GetCallbackThreadName)); _serverBeforeCalls.ShouldNotContain(x => x.Method.Name == nameof(IComputingCallback.GetThreadName)); } + + [Fact] + public async Task ServerBeforeCall_WhenSync_ShouldShareAsyncLocalContextWithTheTargetMethodCall() + { + await Proxy.GetCallContext().ShouldBeAsync(null); + + var id = $"{Guid.NewGuid():N}"; + var expectedCallContext = $"{nameof(IComputingService.GetCallContext)}-{id}"; + + _tailBeforeCall = (callInfo, _) => + { + ComputingService.CallContext = $"{callInfo.Method.Name}-{id}"; + return Task.CompletedTask; + }; + + await Proxy.GetCallContext().ShouldBeAsync(expectedCallContext); + } } diff --git a/src/UiPath.Ipc.Tests/ComputingTestsOverNamedPipes.cs b/src/UiPath.Ipc.Tests/ComputingTestsOverNamedPipes.cs index 3f020789..8811da92 100644 --- a/src/UiPath.Ipc.Tests/ComputingTestsOverNamedPipes.cs +++ b/src/UiPath.Ipc.Tests/ComputingTestsOverNamedPipes.cs @@ -13,9 +13,9 @@ public ComputingTestsOverNamedPipes(ITestOutputHelper outputHelper) : base(outpu { PipeName = PipeName }; - protected override ClientBase CreateClient() => new NamedPipeClient() + protected override ClientTransport CreateClientTransport() => new NamedPipeTransport() { PipeName = PipeName, AllowImpersonation = true, - }; + }; } diff --git a/src/UiPath.Ipc.Tests/ComputingTestsOverTcp.cs b/src/UiPath.Ipc.Tests/ComputingTestsOverTcp.cs index 70cc9c07..d6fdf0ae 100644 --- a/src/UiPath.Ipc.Tests/ComputingTestsOverTcp.cs +++ b/src/UiPath.Ipc.Tests/ComputingTestsOverTcp.cs @@ -16,9 +16,6 @@ protected override ListenerConfig CreateListener() EndPoint = _endPoint, }; - protected override ClientBase CreateClient() - => new TcpClient() - { - EndPoint = _endPoint, - }; + protected override ClientTransport CreateClientTransport() + => new TcpTransport() { EndPoint = _endPoint }; } diff --git a/src/UiPath.Ipc.Tests/ComputingTestsOverWebSockets.cs b/src/UiPath.Ipc.Tests/ComputingTestsOverWebSockets.cs index 3f4dcdb2..fd571df5 100644 --- a/src/UiPath.Ipc.Tests/ComputingTestsOverWebSockets.cs +++ b/src/UiPath.Ipc.Tests/ComputingTestsOverWebSockets.cs @@ -19,8 +19,6 @@ protected override async Task DisposeAsync() { Accept = _webSocketContext.Accept, }; - protected override ClientBase CreateClient() => new WebSocketClient() - { - Uri = _webSocketContext.ClientUri, - }; + protected override ClientTransport CreateClientTransport() + => new WebSocketTransport() { Uri = _webSocketContext.ClientUri }; } \ No newline at end of file diff --git a/src/UiPath.Ipc.Tests/Config/OverrideConfigAttribute.cs b/src/UiPath.Ipc.Tests/Config/OverrideConfigAttribute.cs index 2898fdbf..102278c8 100644 --- a/src/UiPath.Ipc.Tests/Config/OverrideConfigAttribute.cs +++ b/src/UiPath.Ipc.Tests/Config/OverrideConfigAttribute.cs @@ -35,5 +35,5 @@ public OverrideConfigAttribute(Type overrideConfigType) public abstract class OverrideConfig { public virtual ListenerConfig Override(ListenerConfig listener) => listener; - public virtual ClientBase Override(ClientBase client) => client; + public virtual IpcClient Override(IpcClient client) => client; } \ No newline at end of file diff --git a/src/UiPath.Ipc.Tests/Helpers/IpcHelpers.cs b/src/UiPath.Ipc.Tests/Helpers/IpcHelpers.cs index c3142da4..d3263494 100644 --- a/src/UiPath.Ipc.Tests/Helpers/IpcHelpers.cs +++ b/src/UiPath.Ipc.Tests/Helpers/IpcHelpers.cs @@ -32,3 +32,20 @@ public static IServiceProvider GetRequired(this IServiceProvider serviceProvi return serviceProvider; } } + +internal static class IpcClientExtensions +{ + public static IpcClient WithRequestTimeout(this IpcClient ipcClient, TimeSpan requestTimeout) + => new() + { + Config = ipcClient.Config with { RequestTimeout = requestTimeout }, + Transport = ipcClient.Transport, + }; + + public static IpcClient WithCallbacks(this IpcClient ipcClient, EndpointCollection callbacks) + => new() + { + Config = ipcClient.Config with { Callbacks = callbacks }, + Transport = ipcClient.Transport, + }; +} \ No newline at end of file diff --git a/src/UiPath.Ipc.Tests/Services/ComputingService.cs b/src/UiPath.Ipc.Tests/Services/ComputingService.cs index 7e623163..7f34afa4 100644 --- a/src/UiPath.Ipc.Tests/Services/ComputingService.cs +++ b/src/UiPath.Ipc.Tests/Services/ComputingService.cs @@ -5,6 +5,13 @@ namespace UiPath.Ipc.Tests; public sealed class ComputingService(ILogger logger) : IComputingService { + private static readonly AsyncLocal CallContextStorage = new(); + public static string? CallContext + { + get => CallContextStorage.Value; + set => CallContextStorage.Value = value; + } + public async Task AddFloats(float a, float b, CancellationToken ct = default) { logger.LogInformation($"{nameof(AddFloats)} called."); @@ -51,4 +58,10 @@ public async Task MultiplyInts(int x, int y, Message message = null!) return result; } + + public async Task GetCallContext() + { + await Task.Delay(1).ConfigureAwait(continueOnCapturedContext: false); + return CallContext; + } } diff --git a/src/UiPath.Ipc.Tests/Services/IComputingService.cs b/src/UiPath.Ipc.Tests/Services/IComputingService.cs index c0de004d..1db7cefb 100644 --- a/src/UiPath.Ipc.Tests/Services/IComputingService.cs +++ b/src/UiPath.Ipc.Tests/Services/IComputingService.cs @@ -13,6 +13,7 @@ public interface IComputingService : IComputingServiceBase Task GetCallbackThreadName(TimeSpan duration, Message message = null!, CancellationToken cancellationToken = default); Task AddComplexNumberList(IReadOnlyList numbers); Task MultiplyInts(int x, int y, Message message = null!); + Task GetCallContext(); } public interface IComputingCallbackBase diff --git a/src/UiPath.Ipc.Tests/SystemTests.cs b/src/UiPath.Ipc.Tests/SystemTests.cs index 6c77b913..aec5837f 100644 --- a/src/UiPath.Ipc.Tests/SystemTests.cs +++ b/src/UiPath.Ipc.Tests/SystemTests.cs @@ -29,8 +29,8 @@ protected override ListenerConfig ConfigTransportAgnostic(ListenerConfig listene RequestTimeout = Timeouts.DefaultRequest, MaxReceivedMessageSizeInMegabytes = 1, }; - protected override ClientBase ConfigTransportAgnostic(ClientBase client) - => client with + protected override ClientConfig CreateClientConfig() + => new() { RequestTimeout = Timeouts.DefaultRequest, ServiceProvider = ServiceProvider @@ -86,13 +86,15 @@ public async Task ClientWaitingForTooLongACall_ShouldThrowTimeout() private sealed class ServerExecutingTooLongACall_ShouldThrowTimeout_Config : OverrideConfig { public override ListenerConfig Override(ListenerConfig listener) => listener with { RequestTimeout = Timeouts.Short }; - public override ClientBase Override(ClientBase client) => client with { RequestTimeout = Timeout.InfiniteTimeSpan }; + public override IpcClient Override(IpcClient client) + => client.WithRequestTimeout(Timeout.InfiniteTimeSpan); } private sealed class ClientWaitingForTooLongACall_ShouldThrowTimeout_Config : OverrideConfig { public override ListenerConfig Override(ListenerConfig listener) => listener with { RequestTimeout = Timeout.InfiniteTimeSpan }; - public override ClientBase Override(ClientBase client) => client with { RequestTimeout = Timeouts.IpcRoundtrip }; + public override IpcClient Override(IpcClient client) + => client.WithRequestTimeout(Timeouts.IpcRoundtrip); } private ListenerConfig ShortClientTimeout(ListenerConfig listener) => listener with { RequestTimeout = TimeSpan.FromMilliseconds(100) }; @@ -155,15 +157,12 @@ public async Task ServerCallingMultipleCallbackTypes_ShouldWork() private sealed class RegisterCallbacks : OverrideConfig { - public override ClientBase Override(ClientBase client) - => client with + public override IpcClient Override(IpcClient client) + => client.WithCallbacks(new() { - Callbacks = new() - { - { typeof(IComputingCallback), new ComputingCallback() }, - { typeof(IArithmeticCallback), new ArithmeticCallback() }, - } - }; + { typeof(IComputingCallback), new ComputingCallback() }, + { typeof(IArithmeticCallback), new ArithmeticCallback() }, + }); } [Fact] diff --git a/src/UiPath.Ipc.Tests/SystemTestsOverNamedPipes.cs b/src/UiPath.Ipc.Tests/SystemTestsOverNamedPipes.cs index 4e2ced4c..02b6c6da 100644 --- a/src/UiPath.Ipc.Tests/SystemTestsOverNamedPipes.cs +++ b/src/UiPath.Ipc.Tests/SystemTestsOverNamedPipes.cs @@ -13,7 +13,7 @@ public SystemTestsOverNamedPipes(ITestOutputHelper outputHelper) : base(outputHe { PipeName = PipeName }; - protected sealed override ClientBase CreateClient() => new NamedPipeClient() + protected sealed override ClientTransport CreateClientTransport() => new NamedPipeTransport() { PipeName = PipeName, AllowImpersonation = true, diff --git a/src/UiPath.Ipc.Tests/SystemTestsOverTcp.cs b/src/UiPath.Ipc.Tests/SystemTestsOverTcp.cs index ae0ad4b8..cb159518 100644 --- a/src/UiPath.Ipc.Tests/SystemTestsOverTcp.cs +++ b/src/UiPath.Ipc.Tests/SystemTestsOverTcp.cs @@ -16,9 +16,6 @@ protected sealed override ListenerConfig CreateListener() EndPoint = _endPoint, }; - protected override ClientBase CreateClient() - => new TcpClient() - { - EndPoint = _endPoint, - }; + protected override ClientTransport CreateClientTransport() + => new TcpTransport() { EndPoint = _endPoint }; } diff --git a/src/UiPath.Ipc.Tests/SystemTestsOverWebSockets.cs b/src/UiPath.Ipc.Tests/SystemTestsOverWebSockets.cs index 0762ea07..2884106f 100644 --- a/src/UiPath.Ipc.Tests/SystemTestsOverWebSockets.cs +++ b/src/UiPath.Ipc.Tests/SystemTestsOverWebSockets.cs @@ -19,8 +19,6 @@ protected override async Task DisposeAsync() { Accept = _webSocketContext.Accept, }; - protected override ClientBase CreateClient() => new WebSocketClient() - { - Uri = _webSocketContext.ClientUri, - }; + protected override ClientTransport CreateClientTransport() + => new WebSocketTransport() { Uri = _webSocketContext.ClientUri }; } diff --git a/src/UiPath.Ipc.Tests/TestBase.cs b/src/UiPath.Ipc.Tests/TestBase.cs index d7a6b342..70518a36 100644 --- a/src/UiPath.Ipc.Tests/TestBase.cs +++ b/src/UiPath.Ipc.Tests/TestBase.cs @@ -12,6 +12,7 @@ public abstract class TestBase : IAsyncLifetime private readonly ServiceProvider _serviceProvider; private readonly AsyncContext _guiThread = new AsyncContextThread().Context; private readonly Lazy _ipcServer; + private readonly Lazy _ipcClient; private readonly OverrideConfig? _overrideConfig; protected TestRunId TestRunId { get; } = TestRunId.New(); @@ -22,6 +23,7 @@ public abstract class TestBase : IAsyncLifetime protected abstract Type ContractType { get; } protected readonly ConcurrentBag _serverBeforeCalls = new(); + protected Func? _tailBeforeCall = null; public TestBase(ITestOutputHelper outputHelper) { @@ -43,18 +45,8 @@ public TestBase(ITestOutputHelper outputHelper) _guiThread.SynchronizationContext.Send(() => Thread.CurrentThread.Name = Names.GuiThreadName); _serviceProvider = IpcHelpers.ConfigureServices(_outputHelper); - _ipcServer = new(() => new() - { - Endpoints = new() { - new EndpointSettings(ContractType) - { - BeforeCall = async (callInfo, _) => _serverBeforeCalls.Add(callInfo) - } - }, - Listeners = [CreateListenerAndConfigure()], - ServiceProvider = _serviceProvider, - Scheduler = GuiScheduler - }); + _ipcServer = new(CreateServer); + _ipcClient = new(CreateClient); OverrideConfig? GetOverrideConfig() { @@ -94,35 +86,47 @@ private ListenerConfig CreateListenerAndConfigure() _outputHelper.WriteLine($" Result:\r\n\t\t{listener}\r\n"); return listener; } - private TContract CreateClientAndConfigure() where TContract : class + + private IpcServer CreateServer() + => new() { - _outputHelper.WriteLine("Creating client..."); - _outputHelper.WriteLine(" - Creating transport specific client..."); - var client = CreateClient(); - client = ConfigTransportAgnostic(client); - _outputHelper.WriteLine($" Result:\r\n\t\t{client}"); - _outputHelper.WriteLine(" - Applying transport agnostic configuration..."); - _outputHelper.WriteLine($" Result:\r\n\t\t{client}"); - if (_overrideConfig is null) - { - _outputHelper.WriteLine($" - No configuration override found for method {CustomTestFramework.Context?.Method.Name}"); - } - else + Endpoints = new() { + new EndpointSettings(ContractType) + { + BeforeCall = (callInfo, ct) => + { + _serverBeforeCalls.Add(callInfo); + return _tailBeforeCall?.Invoke(callInfo, ct) ?? Task.CompletedTask; + } + } + }, + Listeners = [CreateListenerAndConfigure()], + ServiceProvider = _serviceProvider, + Scheduler = GuiScheduler + }; + private IpcClient CreateClient() + { + var config = CreateClientConfig(); + var transport = CreateClientTransport(); + var client = new IpcClient { - _outputHelper.WriteLine($" - Applying configuration override provided by \"{_overrideConfig.GetType().Name}\" ..."); - } + Config = config, + Transport = transport + }; client = _overrideConfig?.Override(client) ?? client; - _outputHelper.WriteLine($" Result:\r\n\t\t{client}\r\n"); - return client.GetProxy(); + return client; } + private TContract GetProxy() where TContract : class + => _ipcClient.Value.GetProxy(); - protected void CreateLazyProxy(out Lazy lazy) where TContract : class => lazy = new(CreateClientAndConfigure); + protected void CreateLazyProxy(out Lazy lazy) where TContract : class => lazy = new(GetProxy); protected abstract ListenerConfig CreateListener(); - protected abstract ClientBase CreateClient(); + + protected abstract ClientConfig CreateClientConfig(); + protected abstract ClientTransport CreateClientTransport(); protected abstract ListenerConfig ConfigTransportAgnostic(ListenerConfig listener); - protected abstract ClientBase ConfigTransportAgnostic(ClientBase client); protected virtual async Task DisposeAsync() {