diff --git a/src/Ocelot/DependencyInjection/OcelotBuilder.cs b/src/Ocelot/DependencyInjection/OcelotBuilder.cs index 5bd011c50..159d336ca 100644 --- a/src/Ocelot/DependencyInjection/OcelotBuilder.cs +++ b/src/Ocelot/DependencyInjection/OcelotBuilder.cs @@ -37,6 +37,7 @@ using Ocelot.Security.IPSecurity; using Ocelot.ServiceDiscovery; using Ocelot.ServiceDiscovery.Providers; +using Ocelot.WebSockets; using System.Reflection; namespace Ocelot.DependencyInjection @@ -138,6 +139,7 @@ public OcelotBuilder(IServiceCollection services, IConfiguration configurationRo Services.TryAddSingleton(); Services.TryAddSingleton(); Services.TryAddSingleton(); + Services.TryAddSingleton(); // Add security Services.TryAddSingleton(); diff --git a/src/Ocelot/LoadBalancer/LoadBalancers/RoundRobin.cs b/src/Ocelot/LoadBalancer/LoadBalancers/RoundRobin.cs index 87e8e6aa6..834f01e4d 100644 --- a/src/Ocelot/LoadBalancer/LoadBalancers/RoundRobin.cs +++ b/src/Ocelot/LoadBalancer/LoadBalancers/RoundRobin.cs @@ -6,30 +6,35 @@ namespace Ocelot.LoadBalancer.LoadBalancers { public class RoundRobin : ILoadBalancer { - private readonly Func>> _services; + private readonly Func>> _servicesDelegate; private readonly object _lock = new(); private int _last; public RoundRobin(Func>> services) { - _services = services; + _servicesDelegate = services; } public async Task> Lease(HttpContext httpContext) { - var services = await _services(); - lock (_lock) + var services = await _servicesDelegate?.Invoke() ?? new List(); + + if (services?.Count != 0) { - if (_last >= services.Count) + lock (_lock) { - _last = 0; - } + if (_last >= services.Count) + { + _last = 0; + } - var next = services[_last]; - _last++; - return new OkResponse(next.HostAndPort); + var next = services[_last++]; + return new OkResponse(next.HostAndPort); + } } + + return new ErrorResponse(new ServicesAreEmptyError($"There were no services in {nameof(RoundRobin)} during {nameof(Lease)} operation.")); } public void Release(ServiceHostAndPort hostAndPort) diff --git a/src/Ocelot/Middleware/OcelotPipelineExtensions.cs b/src/Ocelot/Middleware/OcelotPipelineExtensions.cs index adfa98e8f..01ec573fb 100644 --- a/src/Ocelot/Middleware/OcelotPipelineExtensions.cs +++ b/src/Ocelot/Middleware/OcelotPipelineExtensions.cs @@ -18,7 +18,7 @@ using Ocelot.RequestId.Middleware; using Ocelot.Responder.Middleware; using Ocelot.Security.Middleware; -using Ocelot.WebSockets.Middleware; +using Ocelot.WebSockets; namespace Ocelot.Middleware { diff --git a/src/Ocelot/WebSockets/ClientWebSocketOptionsProxy.cs b/src/Ocelot/WebSockets/ClientWebSocketOptionsProxy.cs new file mode 100644 index 000000000..fea55c146 --- /dev/null +++ b/src/Ocelot/WebSockets/ClientWebSocketOptionsProxy.cs @@ -0,0 +1,35 @@ +using System.Net.Security; +using System.Net.WebSockets; +using System.Security.Cryptography.X509Certificates; + +namespace Ocelot.WebSockets; + +public class ClientWebSocketOptionsProxy : IClientWebSocketOptions +{ + private readonly ClientWebSocketOptions _real; + + public ClientWebSocketOptionsProxy(ClientWebSocketOptions options) + { + _real = options; + } + + public Version HttpVersion { get => _real.HttpVersion; set => _real.HttpVersion = value; } + public HttpVersionPolicy HttpVersionPolicy { get => _real.HttpVersionPolicy; set => _real.HttpVersionPolicy = value; } + public bool UseDefaultCredentials { get => _real.UseDefaultCredentials; set => _real.UseDefaultCredentials = value; } + public ICredentials Credentials { get => _real.Credentials; set => _real.Credentials = value; } + public IWebProxy Proxy { get => _real.Proxy; set => _real.Proxy = value; } + public X509CertificateCollection ClientCertificates { get => _real.ClientCertificates; set => _real.ClientCertificates = value; } + public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get => _real.RemoteCertificateValidationCallback; set => _real.RemoteCertificateValidationCallback = value; } + public CookieContainer Cookies { get => _real.Cookies; set => _real.Cookies = value; } + public TimeSpan KeepAliveInterval { get => _real.KeepAliveInterval; set => _real.KeepAliveInterval = value; } + public WebSocketDeflateOptions DangerousDeflateOptions { get => _real.DangerousDeflateOptions; set => _real.DangerousDeflateOptions = value; } + public bool CollectHttpResponseDetails { get => _real.CollectHttpResponseDetails; set => _real.CollectHttpResponseDetails = value; } + + public void AddSubProtocol(string subProtocol) => _real.AddSubProtocol(subProtocol); + + public void SetBuffer(int receiveBufferSize, int sendBufferSize) => _real.SetBuffer(receiveBufferSize, sendBufferSize); + + public void SetBuffer(int receiveBufferSize, int sendBufferSize, ArraySegment buffer) => _real.SetBuffer(receiveBufferSize, sendBufferSize, buffer); + + public void SetRequestHeader(string headerName, string headerValue) => _real.SetRequestHeader(headerName, headerValue); +} diff --git a/src/Ocelot/WebSockets/ClientWebSocketProxy.cs b/src/Ocelot/WebSockets/ClientWebSocketProxy.cs new file mode 100644 index 000000000..e74786591 --- /dev/null +++ b/src/Ocelot/WebSockets/ClientWebSocketProxy.cs @@ -0,0 +1,49 @@ +using System.Net.WebSockets; + +namespace Ocelot.WebSockets; + +public class ClientWebSocketProxy : WebSocket, IClientWebSocket +{ + // RealSubject (Service) class of Proxy design pattern + private readonly ClientWebSocket _realService; + private readonly IClientWebSocketOptions _options; + + public ClientWebSocketProxy() + { + _realService = new ClientWebSocket(); + _options = new ClientWebSocketOptionsProxy(_realService.Options); + } + + // ClientWebSocket implementations + public IClientWebSocketOptions Options => _options; + + public Task ConnectAsync(Uri uri, CancellationToken cancellationToken) + => _realService.ConnectAsync(uri, cancellationToken); + + // WebSocket implementations + public override WebSocketCloseStatus? CloseStatus => _realService.CloseStatus; + + public override string CloseStatusDescription => _realService.CloseStatusDescription; + + public override WebSocketState State => _realService.State; + + public override string SubProtocol => _realService.SubProtocol; + + public override void Abort() => _realService.Abort(); + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + => _realService.CloseAsync(closeStatus, statusDescription, cancellationToken); + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + => _realService.CloseOutputAsync(closeStatus, statusDescription, cancellationToken); + + public override void Dispose() => _realService.Dispose(); + + public override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) + => _realService.ReceiveAsync(buffer, cancellationToken); + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + => _realService.SendAsync(buffer, messageType, endOfMessage, cancellationToken); + + public WebSocket ToWebSocket() => _realService; +} diff --git a/src/Ocelot/WebSockets/IClientWebSocket.cs b/src/Ocelot/WebSockets/IClientWebSocket.cs new file mode 100644 index 000000000..3724f111c --- /dev/null +++ b/src/Ocelot/WebSockets/IClientWebSocket.cs @@ -0,0 +1,24 @@ +using System.Net.WebSockets; + +namespace Ocelot.WebSockets; + +public interface IClientWebSocket +{ + WebSocket ToWebSocket(); + + // ClientWebSocket definitions + IClientWebSocketOptions Options { get; } + Task ConnectAsync(Uri uri, CancellationToken cancellationToken); + + // WebSocket definitions + WebSocketCloseStatus? CloseStatus { get; } + string CloseStatusDescription { get; } + WebSocketState State { get; } + string SubProtocol { get; } + void Abort(); + Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken); + Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken); + void Dispose(); + Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken); + Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken); +} diff --git a/src/Ocelot/WebSockets/IClientWebSocketOptions.cs b/src/Ocelot/WebSockets/IClientWebSocketOptions.cs new file mode 100644 index 000000000..0f3db6bad --- /dev/null +++ b/src/Ocelot/WebSockets/IClientWebSocketOptions.cs @@ -0,0 +1,24 @@ +using System.Net.Security; +using System.Net.WebSockets; +using System.Security.Cryptography.X509Certificates; + +namespace Ocelot.WebSockets; + +public interface IClientWebSocketOptions +{ + Version HttpVersion { get; set; } + HttpVersionPolicy HttpVersionPolicy { get; set; } + void SetRequestHeader(string headerName, string headerValue); + bool UseDefaultCredentials { get; set; } + ICredentials Credentials { get; set; } + IWebProxy Proxy { get; set; } + X509CertificateCollection ClientCertificates { get; set; } + RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get; set; } + CookieContainer Cookies { get; set; } + void AddSubProtocol(string subProtocol); + TimeSpan KeepAliveInterval { get; set; } + WebSocketDeflateOptions DangerousDeflateOptions { get; set; } + void SetBuffer(int receiveBufferSize, int sendBufferSize); + void SetBuffer(int receiveBufferSize, int sendBufferSize, ArraySegment buffer); + bool CollectHttpResponseDetails { get; set; } +} diff --git a/src/Ocelot/WebSockets/IWebSocketsFactory.cs b/src/Ocelot/WebSockets/IWebSocketsFactory.cs new file mode 100644 index 000000000..536e38fb5 --- /dev/null +++ b/src/Ocelot/WebSockets/IWebSocketsFactory.cs @@ -0,0 +1,6 @@ +namespace Ocelot.WebSockets; + +public interface IWebSocketsFactory +{ + IClientWebSocket CreateClient(); +} diff --git a/src/Ocelot/WebSockets/WebSocketsFactory.cs b/src/Ocelot/WebSockets/WebSocketsFactory.cs new file mode 100644 index 000000000..9482a87af --- /dev/null +++ b/src/Ocelot/WebSockets/WebSocketsFactory.cs @@ -0,0 +1,6 @@ +namespace Ocelot.WebSockets; + +public class WebSocketsFactory : IWebSocketsFactory +{ + public IClientWebSocket CreateClient() => new ClientWebSocketProxy(); +} diff --git a/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddleware.cs b/src/Ocelot/WebSockets/WebSocketsProxyMiddleware.cs similarity index 69% rename from src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddleware.cs rename to src/Ocelot/WebSockets/WebSocketsProxyMiddleware.cs index fb28c48c8..abc8b646e 100644 --- a/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddleware.cs +++ b/src/Ocelot/WebSockets/WebSocketsProxyMiddleware.cs @@ -1,26 +1,33 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Modified https://github.com/aspnet/Proxy websockets class to use in Ocelot. using Microsoft.AspNetCore.Http; +using Ocelot.Configuration; using Ocelot.Logging; using Ocelot.Middleware; using System.Net.WebSockets; -namespace Ocelot.WebSockets.Middleware +namespace Ocelot.WebSockets { public class WebSocketsProxyMiddleware : OcelotMiddleware { - private static readonly string[] NotForwardedWebSocketHeaders = new[] { "Connection", "Host", "Upgrade", "Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions" }; + private static readonly string[] NotForwardedWebSocketHeaders = new[] + { + "Connection", "Host", "Upgrade", + "Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions", + }; private const int DefaultWebSocketBufferSize = 4096; - private const int StreamCopyBufferSize = 81920; private readonly RequestDelegate _next; + private readonly IWebSocketsFactory _factory; - public WebSocketsProxyMiddleware(RequestDelegate next, - IOcelotLoggerFactory loggerFactory) - : base(loggerFactory.CreateLogger()) + public WebSocketsProxyMiddleware(IOcelotLoggerFactory loggerFactory, + RequestDelegate next, + IWebSocketsFactory factory) + : base(loggerFactory.CreateLogger()) { _next = next; + _factory = factory; } private static async Task PumpWebSocket(WebSocket source, WebSocket destination, int bufferSize, CancellationToken cancellationToken) @@ -67,10 +74,11 @@ private static async Task PumpWebSocket(WebSocket source, WebSocket destination, public async Task Invoke(HttpContext httpContext) { var uri = httpContext.Items.DownstreamRequest().ToUri(); - await Proxy(httpContext, uri); + var downstreamRoute = httpContext.Items.DownstreamRoute(); + await Proxy(httpContext, uri, downstreamRoute); } - private static async Task Proxy(HttpContext context, string serverEndpoint) + private async Task Proxy(HttpContext context, string serverEndpoint, DownstreamRoute downstreamRoute) { if (context == null) { @@ -87,7 +95,14 @@ private static async Task Proxy(HttpContext context, string serverEndpoint) throw new InvalidOperationException(); } - var client = new ClientWebSocket(); + var client = _factory.CreateClient(); // new ClientWebSocket(); + + if (downstreamRoute.DangerousAcceptAnyServerCertificateValidator) + { + client.Options.RemoteCertificateValidationCallback = (request, certificate, chain, errors) => true; + Logger.LogWarning($"You have ignored all SSL warnings by using {nameof(DownstreamRoute.DangerousAcceptAnyServerCertificateValidator)} for this downstream route! {nameof(DownstreamRoute.UpstreamPathTemplate)}: '{downstreamRoute.UpstreamPathTemplate}', {nameof(DownstreamRoute.DownstreamPathTemplate)}: '{downstreamRoute.DownstreamPathTemplate}'."); + } + foreach (var protocol in context.WebSockets.WebSocketRequestedProtocols) { client.Options.AddSubProtocol(protocol); @@ -112,10 +127,12 @@ private static async Task Proxy(HttpContext context, string serverEndpoint) var destinationUri = new Uri(serverEndpoint); await client.ConnectAsync(destinationUri, context.RequestAborted); + using (var server = await context.WebSockets.AcceptWebSocketAsync(client.SubProtocol)) { - var bufferSize = DefaultWebSocketBufferSize; - await Task.WhenAll(PumpWebSocket(client, server, bufferSize, context.RequestAborted), PumpWebSocket(server, client, bufferSize, context.RequestAborted)); + await Task.WhenAll( + PumpWebSocket(client.ToWebSocket(), server, DefaultWebSocketBufferSize, context.RequestAborted), + PumpWebSocket(server, client.ToWebSocket(), DefaultWebSocketBufferSize, context.RequestAborted)); } } } diff --git a/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddlewareExtensions.cs b/src/Ocelot/WebSockets/WebSocketsProxyMiddlewareExtensions.cs similarity index 88% rename from src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddlewareExtensions.cs rename to src/Ocelot/WebSockets/WebSocketsProxyMiddlewareExtensions.cs index f190c00c6..f02d117be 100644 --- a/src/Ocelot/WebSockets/Middleware/WebSocketsProxyMiddlewareExtensions.cs +++ b/src/Ocelot/WebSockets/WebSocketsProxyMiddlewareExtensions.cs @@ -1,6 +1,6 @@ using Microsoft.AspNetCore.Builder; -namespace Ocelot.WebSockets.Middleware +namespace Ocelot.WebSockets { public static class WebSocketsProxyMiddlewareExtensions { diff --git a/test/Ocelot.AcceptanceTests/ConsulWebSocketTests.cs b/test/Ocelot.AcceptanceTests/ConsulWebSocketTests.cs index c93b1df49..d03d7e6b9 100644 --- a/test/Ocelot.AcceptanceTests/ConsulWebSocketTests.cs +++ b/test/Ocelot.AcceptanceTests/ConsulWebSocketTests.cs @@ -2,6 +2,7 @@ using Microsoft.AspNetCore.Http; using Newtonsoft.Json; using Ocelot.Configuration.File; +using Ocelot.WebSockets; using System.Net.WebSockets; using System.Text; @@ -142,7 +143,7 @@ private async Task WhenIStartTheClients() private async Task StartClient(string url) { - var client = new ClientWebSocket(); + IClientWebSocket client = new ClientWebSocketProxy(); await client.ConnectAsync(new Uri(url), CancellationToken.None); @@ -194,7 +195,7 @@ private async Task StartSecondClient(string url) { await Task.Delay(500); - var client = new ClientWebSocket(); + IClientWebSocket client = new ClientWebSocketProxy(); await client.ConnectAsync(new Uri(url), CancellationToken.None); diff --git a/test/Ocelot.AcceptanceTests/WebSocketTests.cs b/test/Ocelot.AcceptanceTests/WebSocketTests.cs index e8e506993..0ede6baea 100644 --- a/test/Ocelot.AcceptanceTests/WebSocketTests.cs +++ b/test/Ocelot.AcceptanceTests/WebSocketTests.cs @@ -1,4 +1,5 @@ using Ocelot.Configuration.File; +using Ocelot.WebSockets; using System.Net.WebSockets; using System.Text; @@ -124,7 +125,7 @@ private async Task WhenIStartTheClients() private async Task StartClient(string url) { - var client = new ClientWebSocket(); + IClientWebSocket client = new ClientWebSocketProxy(); await client.ConnectAsync(new Uri(url), CancellationToken.None); @@ -176,7 +177,7 @@ private async Task StartSecondClient(string url) { await Task.Delay(500); - var client = new ClientWebSocket(); + IClientWebSocket client = new ClientWebSocketProxy(); await client.ConnectAsync(new Uri(url), CancellationToken.None); diff --git a/test/Ocelot.UnitTests/Middleware/OcelotPipelineExtensionsTests.cs b/test/Ocelot.UnitTests/Middleware/OcelotPipelineExtensionsTests.cs index f9426d816..f7f41ae6a 100644 --- a/test/Ocelot.UnitTests/Middleware/OcelotPipelineExtensionsTests.cs +++ b/test/Ocelot.UnitTests/Middleware/OcelotPipelineExtensionsTests.cs @@ -8,7 +8,7 @@ using Ocelot.LoadBalancer.Middleware; using Ocelot.Middleware; using Ocelot.Request.Middleware; -using Ocelot.WebSockets.Middleware; +using Ocelot.WebSockets; namespace Ocelot.UnitTests.Middleware { diff --git a/test/Ocelot.UnitTests/Ocelot.UnitTests.csproj b/test/Ocelot.UnitTests/Ocelot.UnitTests.csproj index b6335cc4b..0c9fe46f7 100644 --- a/test/Ocelot.UnitTests/Ocelot.UnitTests.csproj +++ b/test/Ocelot.UnitTests/Ocelot.UnitTests.csproj @@ -55,6 +55,7 @@ + runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/test/Ocelot.UnitTests/WebSockets/MockWebSocket.cs b/test/Ocelot.UnitTests/WebSockets/MockWebSocket.cs new file mode 100644 index 000000000..4e26e19fa --- /dev/null +++ b/test/Ocelot.UnitTests/WebSockets/MockWebSocket.cs @@ -0,0 +1,171 @@ +// Copyright © Kubernetes C# Client +// Repository: https://github.com/kubernetes-client/csharp +// Class: https://github.com/kubernetes-client/csharp/blob/master/tests/KubernetesClient.Tests/Mock/MockWebSocket.cs + +using Nito.AsyncEx; +using System.Collections.Concurrent; +using System.Net.WebSockets; + +namespace Ocelot.UnitTests.WebSockets; + +internal class MockWebSocket : WebSocket +{ + private WebSocketCloseStatus? closeStatus; + private string closeStatusDescription; + private WebSocketState state; + private readonly string subProtocol; + private readonly ConcurrentQueue receiveBuffers = new ConcurrentQueue(); + private readonly AsyncAutoResetEvent receiveEvent = new AsyncAutoResetEvent(false); + private bool disposedValue; + + public MockWebSocket(string subProtocol = null) + { + this.subProtocol = subProtocol; + } + + public void SetState(WebSocketState state) + { + this.state = state; + } + + public EventHandler MessageSent { get; set; } + + public Task InvokeReceiveAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage) + { + receiveBuffers.Enqueue(new MessageData() + { + Buffer = buffer, + MessageType = messageType, + EndOfMessage = endOfMessage, + }); + receiveEvent.Set(); + return Task.CompletedTask; + } + + public override WebSocketCloseStatus? CloseStatus => closeStatus; + + public override string CloseStatusDescription => closeStatusDescription; + + public override WebSocketState State => state; + + public override string SubProtocol => subProtocol; + + public override void Abort() + { + throw new NotImplementedException(); + } + + public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, + CancellationToken cancellationToken) + { + this.closeStatus = closeStatus; + closeStatusDescription = statusDescription; + receiveBuffers.Enqueue(new MessageData() + { + Buffer = new ArraySegment(new byte[] { }), + EndOfMessage = true, + MessageType = WebSocketMessageType.Close, + }); + receiveEvent.Set(); + return Task.CompletedTask; + } + + public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, + CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + + public override async Task ReceiveAsync( + ArraySegment buffer, + CancellationToken cancellationToken) + { + if (receiveBuffers.IsEmpty) + { + await receiveEvent.WaitAsync(cancellationToken).ConfigureAwait(false); + } + + var bytesReceived = 0; + var endOfMessage = true; + var messageType = WebSocketMessageType.Close; + + MessageData received = null; + if (receiveBuffers.TryPeek(out received)) + { + messageType = received.MessageType; + if (received.Buffer.Count <= buffer.Count) + { + receiveBuffers.TryDequeue(out received); + received.Buffer.CopyTo(buffer); + bytesReceived = received.Buffer.Count; + endOfMessage = received.EndOfMessage; + } + else + { + received.Buffer.Slice(0, buffer.Count).CopyTo(buffer); + bytesReceived = buffer.Count; + endOfMessage = false; + received.Buffer = received.Buffer.Slice(buffer.Count); + } + } + + return new WebSocketReceiveResult(bytesReceived, messageType, endOfMessage); + } + + public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, + CancellationToken cancellationToken) + { + MessageSent?.Invoke( + this, + new MessageDataEventArgs() + { + Data = new MessageData() + { + Buffer = buffer, + MessageType = messageType, + EndOfMessage = endOfMessage, + }, + }); + return Task.CompletedTask; + } + + public class MessageData + { + public ArraySegment Buffer { get; set; } + public WebSocketMessageType MessageType { get; set; } + public bool EndOfMessage { get; set; } + } + + public class MessageDataEventArgs : EventArgs + { + public MessageData Data { get; set; } + } + + protected virtual void Dispose(bool disposing) + { + if (!disposedValue) + { + if (disposing) + { + receiveBuffers.Clear(); + receiveEvent.Set(); + } + + disposedValue = true; + } + } + + // // TODO: override finalizer only if 'Dispose(bool disposing)' has code to free unmanaged resources + // ~MockWebSocket() + // { + // // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + // Dispose(disposing: false); + // } + + public override void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(true); + GC.SuppressFinalize(this); + } +} diff --git a/test/Ocelot.UnitTests/WebSockets/WebSocketsProxyMiddlewareTests.cs b/test/Ocelot.UnitTests/WebSockets/WebSocketsProxyMiddlewareTests.cs new file mode 100644 index 000000000..4f9afcef9 --- /dev/null +++ b/test/Ocelot.UnitTests/WebSockets/WebSocketsProxyMiddlewareTests.cs @@ -0,0 +1,124 @@ +using Microsoft.AspNetCore.Http; +using Ocelot.Configuration.Builder; +using Ocelot.Logging; +using Ocelot.Middleware; +using Ocelot.Request.Middleware; +using Ocelot.WebSockets; +using System.Net.Security; +using System.Net.WebSockets; + +namespace Ocelot.UnitTests.WebSockets; + +public class WebSocketsProxyMiddlewareTests +{ + private readonly WebSocketsProxyMiddleware _middleware; + + private readonly Mock _loggerFactory; + private readonly Mock _next; + private readonly Mock _factory; + + private readonly Mock _context; + private readonly Mock _logger; + + public WebSocketsProxyMiddlewareTests() + { + _loggerFactory = new Mock(); + _next = new Mock(); + _factory = new Mock(); + + _context = new Mock(); + _logger = new Mock(); + _loggerFactory.Setup(x => x.CreateLogger()) + .Returns(_logger.Object); + + _middleware = new WebSocketsProxyMiddleware(_loggerFactory.Object, _next.Object, _factory.Object); + } + + [Fact] + public void ShouldIgnoreAllSslWarnings_WhenDangerousAcceptAnyServerCertificateValidatorIsTrue() + { + this.Given(x => x.GivenPropertyDangerousAcceptAnyServerCertificateValidator(true)) + .And(x => x.AndDoNotSetupProtocolsAndHeaders()) + .And(x => x.AndDoNotConnectReally()) + .When(x => x.WhenInvokeWithHttpContext()) + .Then(x => x.ThenIgnoredAllSslWarnings()) + .BDDfy(); + } + + private void GivenPropertyDangerousAcceptAnyServerCertificateValidator(bool enabled) + { + var request = new HttpRequestMessage(HttpMethod.Get, "http://localhost:80"); + var downstream = new DownstreamRequest(request); + var route = new DownstreamRouteBuilder() + .WithDangerousAcceptAnyServerCertificateValidator(enabled) + .Build(); + _context.SetupGet(x => x.Items).Returns(new Dictionary + { + { "DownstreamRequest", downstream }, + { "DownstreamRoute", route }, + }); + + _context.SetupGet(x => x.WebSockets.IsWebSocketRequest).Returns(true); + + _client = new Mock(); + _factory.Setup(x => x.CreateClient()).Returns(_client.Object); + + _client.SetupSet(x => x.Options.RemoteCertificateValidationCallback = It.IsAny()) + .Callback(value => _callback = value); + + _warning = string.Empty; + _logger.Setup(x => x.LogWarning(It.IsAny())) + .Callback(message => _warning = message); + } + + private void AndDoNotSetupProtocolsAndHeaders() + { + _context.SetupGet(x => x.WebSockets.WebSocketRequestedProtocols).Returns(new List()); + _context.SetupGet(x => x.Request.Headers).Returns(new HeaderDictionary()); + } + + private void AndDoNotConnectReally() + { + _client.Setup(x => x.ConnectAsync(It.IsAny(), It.IsAny())).Verifiable(); + var clientSocket = new Mock(); + var serverSocket = new Mock(); + _client.Setup(x => x.ToWebSocket()).Returns(clientSocket.Object); + _context.Setup(x => x.WebSockets.AcceptWebSocketAsync(It.IsAny())).ReturnsAsync(serverSocket.Object); + + var happyEnd = new WebSocketReceiveResult(1, WebSocketMessageType.Close, true); + clientSocket.Setup(x => x.ReceiveAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(happyEnd); + serverSocket.Setup(x => x.ReceiveAsync(It.IsAny>(), It.IsAny())) + .ReturnsAsync(happyEnd); + + clientSocket.Setup(x => x.CloseOutputAsync(It.IsAny(), It.IsAny(), It.IsAny())); + serverSocket.Setup(x => x.CloseOutputAsync(It.IsAny(), It.IsAny(), It.IsAny())); + clientSocket.SetupGet(x => x.CloseStatus).Returns(WebSocketCloseStatus.Empty); + serverSocket.SetupGet(x => x.CloseStatus).Returns(WebSocketCloseStatus.Empty); + } + + private Mock _client; + private RemoteCertificateValidationCallback _callback; + private string _warning; + + private async Task WhenInvokeWithHttpContext() + { + await _middleware.Invoke(_context.Object); + } + + private void ThenIgnoredAllSslWarnings() + { + _context.Object.Items.DownstreamRoute().DangerousAcceptAnyServerCertificateValidator + .ShouldBeTrue(); + + _logger.Verify(x => x.LogWarning(It.IsAny()), Times.Once()); + _warning.ShouldNotBeNullOrEmpty(); + + _client.VerifySet(x => x.Options.RemoteCertificateValidationCallback = It.IsAny(), + Times.Once()); + + _callback.ShouldNotBeNull(); + var validation = _callback.Invoke(null, null, null, SslPolicyErrors.None); + validation.ShouldBeTrue(); + } +}