Skip to content

Commit

Permalink
#1375 #1237 #925 #920 Fix DownstreamRoute DangerousAcceptAnyServerCer…
Browse files Browse the repository at this point in the history
…tificateValidator (#1377)

* Fix/1375 fix DownstreamRoute DangerousAcceptAnyServerCertificateValidator does not work

* Fix the exception when RoundRobin services is empty

* Fix build errors

* IDE0063 'using' statement can be simplified

* IDE0051 Private member 'StreamCopyBufferSize' is unused

* Use nameof() in string interpolations

* @RaynaldM code review

* Code review. Refactor method

* Organize folders for WebSockets feature

* Add unit tests class for WebSockets feature

* Refactor middleware to make it suitable for unit testing

* Add unit test

* Review current acceptance tests for WebSockets

* Review

---------

Co-authored-by: raman-m <[email protected]>
  • Loading branch information
zqlovejyc and raman-m authored Sep 28, 2023
1 parent 5fd5bf9 commit fa179bf
Show file tree
Hide file tree
Showing 17 changed files with 495 additions and 29 deletions.
2 changes: 2 additions & 0 deletions src/Ocelot/DependencyInjection/OcelotBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
using Ocelot.Security.IPSecurity;
using Ocelot.ServiceDiscovery;
using Ocelot.ServiceDiscovery.Providers;
using Ocelot.WebSockets;
using System.Reflection;

namespace Ocelot.DependencyInjection
Expand Down Expand Up @@ -138,6 +139,7 @@ public OcelotBuilder(IServiceCollection services, IConfiguration configurationRo
Services.TryAddSingleton<IQoSFactory, QoSFactory>();
Services.TryAddSingleton<IExceptionToErrorMapper, HttpExceptionToErrorMapper>();
Services.TryAddSingleton<IVersionCreator, HttpVersionCreator>();
Services.TryAddSingleton<IWebSocketsFactory, WebSocketsFactory>();

// Add security
Services.TryAddSingleton<ISecurityOptionsCreator, SecurityOptionsCreator>();
Expand Down
25 changes: 15 additions & 10 deletions src/Ocelot/LoadBalancer/LoadBalancers/RoundRobin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,35 @@ namespace Ocelot.LoadBalancer.LoadBalancers
{
public class RoundRobin : ILoadBalancer
{
private readonly Func<Task<List<Service>>> _services;
private readonly Func<Task<List<Service>>> _servicesDelegate;
private readonly object _lock = new();

private int _last;

public RoundRobin(Func<Task<List<Service>>> services)
{
_services = services;
_servicesDelegate = services;
}

public async Task<Response<ServiceHostAndPort>> Lease(HttpContext httpContext)
{
var services = await _services();
lock (_lock)
var services = await _servicesDelegate?.Invoke() ?? new List<Service>();

if (services?.Count != 0)
{
if (_last >= services.Count)
lock (_lock)
{
_last = 0;
}
if (_last >= services.Count)
{
_last = 0;
}

var next = services[_last];
_last++;
return new OkResponse<ServiceHostAndPort>(next.HostAndPort);
var next = services[_last++];
return new OkResponse<ServiceHostAndPort>(next.HostAndPort);
}
}

return new ErrorResponse<ServiceHostAndPort>(new ServicesAreEmptyError($"There were no services in {nameof(RoundRobin)} during {nameof(Lease)} operation."));
}

public void Release(ServiceHostAndPort hostAndPort)
Expand Down
2 changes: 1 addition & 1 deletion src/Ocelot/Middleware/OcelotPipelineExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
using Ocelot.RequestId.Middleware;
using Ocelot.Responder.Middleware;
using Ocelot.Security.Middleware;
using Ocelot.WebSockets.Middleware;
using Ocelot.WebSockets;

namespace Ocelot.Middleware
{
Expand Down
35 changes: 35 additions & 0 deletions src/Ocelot/WebSockets/ClientWebSocketOptionsProxy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System.Net.Security;
using System.Net.WebSockets;
using System.Security.Cryptography.X509Certificates;

namespace Ocelot.WebSockets;

public class ClientWebSocketOptionsProxy : IClientWebSocketOptions
{
private readonly ClientWebSocketOptions _real;

public ClientWebSocketOptionsProxy(ClientWebSocketOptions options)
{
_real = options;
}

public Version HttpVersion { get => _real.HttpVersion; set => _real.HttpVersion = value; }
public HttpVersionPolicy HttpVersionPolicy { get => _real.HttpVersionPolicy; set => _real.HttpVersionPolicy = value; }
public bool UseDefaultCredentials { get => _real.UseDefaultCredentials; set => _real.UseDefaultCredentials = value; }
public ICredentials Credentials { get => _real.Credentials; set => _real.Credentials = value; }
public IWebProxy Proxy { get => _real.Proxy; set => _real.Proxy = value; }
public X509CertificateCollection ClientCertificates { get => _real.ClientCertificates; set => _real.ClientCertificates = value; }
public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get => _real.RemoteCertificateValidationCallback; set => _real.RemoteCertificateValidationCallback = value; }
public CookieContainer Cookies { get => _real.Cookies; set => _real.Cookies = value; }
public TimeSpan KeepAliveInterval { get => _real.KeepAliveInterval; set => _real.KeepAliveInterval = value; }
public WebSocketDeflateOptions DangerousDeflateOptions { get => _real.DangerousDeflateOptions; set => _real.DangerousDeflateOptions = value; }
public bool CollectHttpResponseDetails { get => _real.CollectHttpResponseDetails; set => _real.CollectHttpResponseDetails = value; }

public void AddSubProtocol(string subProtocol) => _real.AddSubProtocol(subProtocol);

public void SetBuffer(int receiveBufferSize, int sendBufferSize) => _real.SetBuffer(receiveBufferSize, sendBufferSize);

public void SetBuffer(int receiveBufferSize, int sendBufferSize, ArraySegment<byte> buffer) => _real.SetBuffer(receiveBufferSize, sendBufferSize, buffer);

public void SetRequestHeader(string headerName, string headerValue) => _real.SetRequestHeader(headerName, headerValue);
}
49 changes: 49 additions & 0 deletions src/Ocelot/WebSockets/ClientWebSocketProxy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using System.Net.WebSockets;

namespace Ocelot.WebSockets;

public class ClientWebSocketProxy : WebSocket, IClientWebSocket
{
// RealSubject (Service) class of Proxy design pattern
private readonly ClientWebSocket _realService;
private readonly IClientWebSocketOptions _options;

public ClientWebSocketProxy()
{
_realService = new ClientWebSocket();
_options = new ClientWebSocketOptionsProxy(_realService.Options);
}

// ClientWebSocket implementations
public IClientWebSocketOptions Options => _options;

public Task ConnectAsync(Uri uri, CancellationToken cancellationToken)
=> _realService.ConnectAsync(uri, cancellationToken);

// WebSocket implementations
public override WebSocketCloseStatus? CloseStatus => _realService.CloseStatus;

public override string CloseStatusDescription => _realService.CloseStatusDescription;

public override WebSocketState State => _realService.State;

public override string SubProtocol => _realService.SubProtocol;

public override void Abort() => _realService.Abort();

public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
=> _realService.CloseAsync(closeStatus, statusDescription, cancellationToken);

public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
=> _realService.CloseOutputAsync(closeStatus, statusDescription, cancellationToken);

public override void Dispose() => _realService.Dispose();

public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
=> _realService.ReceiveAsync(buffer, cancellationToken);

public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
=> _realService.SendAsync(buffer, messageType, endOfMessage, cancellationToken);

public WebSocket ToWebSocket() => _realService;
}
24 changes: 24 additions & 0 deletions src/Ocelot/WebSockets/IClientWebSocket.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System.Net.WebSockets;

namespace Ocelot.WebSockets;

public interface IClientWebSocket
{
WebSocket ToWebSocket();

// ClientWebSocket definitions
IClientWebSocketOptions Options { get; }
Task ConnectAsync(Uri uri, CancellationToken cancellationToken);

// WebSocket definitions
WebSocketCloseStatus? CloseStatus { get; }
string CloseStatusDescription { get; }
WebSocketState State { get; }
string SubProtocol { get; }
void Abort();
Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken);
Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken);
void Dispose();
Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken);
Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken);
}
24 changes: 24 additions & 0 deletions src/Ocelot/WebSockets/IClientWebSocketOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System.Net.Security;
using System.Net.WebSockets;
using System.Security.Cryptography.X509Certificates;

namespace Ocelot.WebSockets;

public interface IClientWebSocketOptions
{
Version HttpVersion { get; set; }
HttpVersionPolicy HttpVersionPolicy { get; set; }
void SetRequestHeader(string headerName, string headerValue);
bool UseDefaultCredentials { get; set; }
ICredentials Credentials { get; set; }
IWebProxy Proxy { get; set; }
X509CertificateCollection ClientCertificates { get; set; }
RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get; set; }
CookieContainer Cookies { get; set; }
void AddSubProtocol(string subProtocol);
TimeSpan KeepAliveInterval { get; set; }
WebSocketDeflateOptions DangerousDeflateOptions { get; set; }
void SetBuffer(int receiveBufferSize, int sendBufferSize);
void SetBuffer(int receiveBufferSize, int sendBufferSize, ArraySegment<byte> buffer);
bool CollectHttpResponseDetails { get; set; }
}
6 changes: 6 additions & 0 deletions src/Ocelot/WebSockets/IWebSocketsFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Ocelot.WebSockets;

public interface IWebSocketsFactory
{
IClientWebSocket CreateClient();
}
6 changes: 6 additions & 0 deletions src/Ocelot/WebSockets/WebSocketsFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Ocelot.WebSockets;

public class WebSocketsFactory : IWebSocketsFactory
{
public IClientWebSocket CreateClient() => new ClientWebSocketProxy();
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
// Modified https://github.com/aspnet/Proxy websockets class to use in Ocelot.

using Microsoft.AspNetCore.Http;
using Ocelot.Configuration;
using Ocelot.Logging;
using Ocelot.Middleware;
using System.Net.WebSockets;

namespace Ocelot.WebSockets.Middleware
namespace Ocelot.WebSockets
{
public class WebSocketsProxyMiddleware : OcelotMiddleware
{
private static readonly string[] NotForwardedWebSocketHeaders = new[] { "Connection", "Host", "Upgrade", "Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions" };
private static readonly string[] NotForwardedWebSocketHeaders = new[]
{
"Connection", "Host", "Upgrade",
"Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions",
};
private const int DefaultWebSocketBufferSize = 4096;
private const int StreamCopyBufferSize = 81920;
private readonly RequestDelegate _next;
private readonly IWebSocketsFactory _factory;

public WebSocketsProxyMiddleware(RequestDelegate next,
IOcelotLoggerFactory loggerFactory)
: base(loggerFactory.CreateLogger<WebSocketsProxyMiddleware>())
public WebSocketsProxyMiddleware(IOcelotLoggerFactory loggerFactory,
RequestDelegate next,
IWebSocketsFactory factory)
: base(loggerFactory.CreateLogger<WebSocketsProxyMiddleware>())
{
_next = next;
_factory = factory;
}

private static async Task PumpWebSocket(WebSocket source, WebSocket destination, int bufferSize, CancellationToken cancellationToken)
Expand Down Expand Up @@ -67,10 +74,11 @@ private static async Task PumpWebSocket(WebSocket source, WebSocket destination,
public async Task Invoke(HttpContext httpContext)
{
var uri = httpContext.Items.DownstreamRequest().ToUri();
await Proxy(httpContext, uri);
var downstreamRoute = httpContext.Items.DownstreamRoute();
await Proxy(httpContext, uri, downstreamRoute);
}

private static async Task Proxy(HttpContext context, string serverEndpoint)
private async Task Proxy(HttpContext context, string serverEndpoint, DownstreamRoute downstreamRoute)
{
if (context == null)
{
Expand All @@ -87,7 +95,14 @@ private static async Task Proxy(HttpContext context, string serverEndpoint)
throw new InvalidOperationException();
}

var client = new ClientWebSocket();
var client = _factory.CreateClient(); // new ClientWebSocket();

if (downstreamRoute.DangerousAcceptAnyServerCertificateValidator)
{
client.Options.RemoteCertificateValidationCallback = (request, certificate, chain, errors) => true;
Logger.LogWarning($"You have ignored all SSL warnings by using {nameof(DownstreamRoute.DangerousAcceptAnyServerCertificateValidator)} for this downstream route! {nameof(DownstreamRoute.UpstreamPathTemplate)}: '{downstreamRoute.UpstreamPathTemplate}', {nameof(DownstreamRoute.DownstreamPathTemplate)}: '{downstreamRoute.DownstreamPathTemplate}'.");
}

foreach (var protocol in context.WebSockets.WebSocketRequestedProtocols)
{
client.Options.AddSubProtocol(protocol);
Expand All @@ -112,10 +127,12 @@ private static async Task Proxy(HttpContext context, string serverEndpoint)

var destinationUri = new Uri(serverEndpoint);
await client.ConnectAsync(destinationUri, context.RequestAborted);

using (var server = await context.WebSockets.AcceptWebSocketAsync(client.SubProtocol))
{
var bufferSize = DefaultWebSocketBufferSize;
await Task.WhenAll(PumpWebSocket(client, server, bufferSize, context.RequestAborted), PumpWebSocket(server, client, bufferSize, context.RequestAborted));
await Task.WhenAll(
PumpWebSocket(client.ToWebSocket(), server, DefaultWebSocketBufferSize, context.RequestAborted),
PumpWebSocket(server, client.ToWebSocket(), DefaultWebSocketBufferSize, context.RequestAborted));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Microsoft.AspNetCore.Builder;

namespace Ocelot.WebSockets.Middleware
namespace Ocelot.WebSockets
{
public static class WebSocketsProxyMiddlewareExtensions
{
Expand Down
5 changes: 3 additions & 2 deletions test/Ocelot.AcceptanceTests/ConsulWebSocketTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.AspNetCore.Http;
using Newtonsoft.Json;
using Ocelot.Configuration.File;
using Ocelot.WebSockets;
using System.Net.WebSockets;
using System.Text;

Expand Down Expand Up @@ -142,7 +143,7 @@ private async Task WhenIStartTheClients()

private async Task StartClient(string url)
{
var client = new ClientWebSocket();
IClientWebSocket client = new ClientWebSocketProxy();

await client.ConnectAsync(new Uri(url), CancellationToken.None);

Expand Down Expand Up @@ -194,7 +195,7 @@ private async Task StartSecondClient(string url)
{
await Task.Delay(500);

var client = new ClientWebSocket();
IClientWebSocket client = new ClientWebSocketProxy();

await client.ConnectAsync(new Uri(url), CancellationToken.None);

Expand Down
5 changes: 3 additions & 2 deletions test/Ocelot.AcceptanceTests/WebSocketTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Ocelot.Configuration.File;
using Ocelot.WebSockets;
using System.Net.WebSockets;
using System.Text;

Expand Down Expand Up @@ -124,7 +125,7 @@ private async Task WhenIStartTheClients()

private async Task StartClient(string url)
{
var client = new ClientWebSocket();
IClientWebSocket client = new ClientWebSocketProxy();

await client.ConnectAsync(new Uri(url), CancellationToken.None);

Expand Down Expand Up @@ -176,7 +177,7 @@ private async Task StartSecondClient(string url)
{
await Task.Delay(500);

var client = new ClientWebSocket();
IClientWebSocket client = new ClientWebSocketProxy();

await client.ConnectAsync(new Uri(url), CancellationToken.None);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using Ocelot.LoadBalancer.Middleware;
using Ocelot.Middleware;
using Ocelot.Request.Middleware;
using Ocelot.WebSockets.Middleware;
using Ocelot.WebSockets;

namespace Ocelot.UnitTests.Middleware
{
Expand Down
1 change: 1 addition & 0 deletions test/Ocelot.UnitTests/Ocelot.UnitTests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.6.3" />
<PackageReference Include="Nito.AsyncEx" Version="5.1.2" />
<PackageReference Include="xunit" Version="2.5.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down
Loading

0 comments on commit fa179bf

Please sign in to comment.