Skip to content

Commit

Permalink
Limit live feed websocket connections to 2 per IP, excluding whitelis…
Browse files Browse the repository at this point in the history
…ted IPs

Closes #75
  • Loading branch information
AlexMacocian committed Aug 11, 2024
1 parent 72aaa02 commit 14d7444
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 24 deletions.
16 changes: 12 additions & 4 deletions GuildWarsPartySearch/Endpoints/LiveFeed.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,15 @@ public override async Task ExecuteAsync(TextContent? content, CancellationToken

public override async Task SocketAccepted(CancellationToken cancellationToken)
{
var scopedLogger = this.logger.CreateScopedLogger(nameof(this.SocketAccepted), this.Context?.Connection.RemoteIpAddress?.ToString() ?? string.Empty);
this.liveFeedService.AddClient(this.WebSocket!);
var ipAddress = this.Context?.Connection.RemoteIpAddress?.ToString();
var scopedLogger = this.logger.CreateScopedLogger(nameof(this.SocketAccepted), ipAddress ?? string.Empty);
if (!await this.liveFeedService.AddClient(this.WebSocket!, ipAddress, cancellationToken))
{
scopedLogger.LogError("Client rejected");
this.WebSocket?.CloseAsync(System.Net.WebSockets.WebSocketCloseStatus.NormalClosure, "Connection rejected", cancellationToken);
return;
}

scopedLogger.LogDebug("Client accepted to livefeed");

scopedLogger.LogDebug("Sending all party searches");
Expand All @@ -46,8 +53,9 @@ public override async Task SocketAccepted(CancellationToken cancellationToken)

public override Task SocketClosed()
{
var scopedLogger = this.logger.CreateScopedLogger(nameof(this.SocketAccepted), this.Context?.Connection.RemoteIpAddress?.ToString() ?? string.Empty);
this.liveFeedService.RemoveClient(this.WebSocket!);
var ipAddress = this.Context?.Connection.RemoteIpAddress?.ToString();
var scopedLogger = this.logger.CreateScopedLogger(nameof(this.SocketAccepted), ipAddress ?? string.Empty);
this.liveFeedService.RemoveClient(this.WebSocket!, ipAddress);
scopedLogger.LogDebug("Client removed from livefeed");
return Task.CompletedTask;
}
Expand Down
2 changes: 1 addition & 1 deletion GuildWarsPartySearch/Options/IpWhitelistOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ namespace GuildWarsPartySearch.Server.Options;
public class IpWhitelistOptions
{
[JsonPropertyName(nameof(Addresses))]
public List<string> Addresses { get; set; } = [ "127.0.0.1" ];
public List<string> Addresses { get; set; } = [];
}
4 changes: 2 additions & 2 deletions GuildWarsPartySearch/Services/Feed/ILiveFeedService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace GuildWarsPartySearch.Server.Services.Feed;

public interface ILiveFeedService
{
void AddClient(WebSocket webSocket);
void RemoveClient(WebSocket webSocket);
Task<bool> AddClient(WebSocket webSocket, string? ipAddress, CancellationToken cancellationToken);
void RemoveClient(WebSocket webSocket, string? ipAddress);
Task PushUpdate(Models.PartySearch partySearchUpdate, CancellationToken cancellationToken);
}
88 changes: 71 additions & 17 deletions GuildWarsPartySearch/Services/Feed/LiveFeedService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using GuildWarsPartySearch.Server.Models.Endpoints;
using GuildWarsPartySearch.Server.Services.Database;
using System.Core.Extensions;
using System.Extensions;
using System.Net.WebSockets;
using System.Text;
using System.Text.Json;
Expand All @@ -8,32 +10,37 @@ namespace GuildWarsPartySearch.Server.Services.Feed;

public sealed class LiveFeedService : ILiveFeedService
{
private const int MaxConnectionsPerIP = 2;

private readonly SemaphoreSlim semaphore = new(1);
private readonly List<WebSocket> clients = [];
private readonly Dictionary<string, List<WebSocket>> clients = [];
private readonly IIpWhitelistDatabase ipWhitelistDatabase;
private readonly JsonSerializerOptions jsonSerializerOptions;
private readonly ILogger<LiveFeedService> logger;

public LiveFeedService(
IHostApplicationLifetime lifetime,
IIpWhitelistDatabase ipWhitelistDatabase,
JsonSerializerOptions jsonSerializerOptions,
ILogger<LiveFeedService> logger)
{
lifetime.ApplicationStopping.Register(this.ShutDownConnections);
this.ipWhitelistDatabase = ipWhitelistDatabase.ThrowIfNull();
this.jsonSerializerOptions = jsonSerializerOptions.ThrowIfNull();
this.logger = logger.ThrowIfNull();
}

public void AddClient(WebSocket client)
public Task<bool> AddClient(WebSocket client, string? ipAddress, CancellationToken cancellationToken)
{
AddClientInternal(client);
return AddClientInternal(client, ipAddress, cancellationToken);
}

public async Task PushUpdate(Models.PartySearch partySearchUpdate, CancellationToken cancellationToken)
{
// Since LiveFeed endpoint expects a PartySearchList, so we send a PartySearchList with only the update to keep it consistent
var payloadString = JsonSerializer.Serialize(new PartySearchList { Searches = [partySearchUpdate] }, this.jsonSerializerOptions);
var payload = Encoding.UTF8.GetBytes(payloadString);
await ExecuteOnClientsInternal(async client =>
await ExecuteOnClientsInternal(async (address, client) =>
{
try
{
Expand All @@ -42,27 +49,71 @@ await ExecuteOnClientsInternal(async client =>
catch(Exception ex)
{
this.logger.LogError(ex, $"Encountered exception while broadcasting update");
RemoveClientInternal(client);
RemoveClientInternal(client, address);
}
});
}

public void RemoveClient(WebSocket client)
public void RemoveClient(WebSocket client, string? ipAddress)
{
RemoveClientInternal(client);
RemoveClientInternal(client, ipAddress);
}

private void AddClientInternal(WebSocket client)
private async Task<bool> AddClientInternal(WebSocket client, string? ipAddress, CancellationToken cancellationToken)
{
this.semaphore.Wait();
this.clients.Add(client);
this.semaphore.Release();
var scopedLogger = this.logger.CreateScopedLogger(nameof(this.AddClientInternal), ipAddress ?? string.Empty);

await this.semaphore.WaitAsync(cancellationToken);
try
{
if (ipAddress is null ||
ipAddress.IsNullOrWhiteSpace())
{
return false;
}

var whitelistedIps = await this.ipWhitelistDatabase.GetWhitelistedAddresses(cancellationToken);
if (whitelistedIps.None(addr => addr == ipAddress) &&
this.clients.TryGetValue(ipAddress, out var sockets) &&
sockets.Count >= 2)
{
scopedLogger.LogError("Too many live connections. Rejecting");
return false;
}

if (!this.clients.TryGetValue(ipAddress, out var existingSockets))
{
existingSockets = [];
this.clients[ipAddress] = existingSockets;
}

existingSockets.Add(client);
return true;
}
finally
{
this.semaphore.Release();
}
}

private void RemoveClientInternal(WebSocket client)
private void RemoveClientInternal(WebSocket client, string? ipAddress)
{
this.semaphore.Wait();
this.clients.Remove(client);
if (ipAddress is null ||
ipAddress.IsNullOrWhiteSpace())
{
return;
}

if (this.clients.TryGetValue(ipAddress, out var sockets))
{
sockets.Remove(client);
if (sockets.Count == 0)
{
this.clients.Remove(ipAddress);
}
}

if (client?.State is not WebSocketState.Closed or WebSocketState.Aborted)
{
client?.Abort();
Expand All @@ -71,19 +122,22 @@ private void RemoveClientInternal(WebSocket client)
this.semaphore.Release();
}

private async Task ExecuteOnClientsInternal(Func<WebSocket, Task> action)
private async Task ExecuteOnClientsInternal(Func<string, WebSocket, Task> action)
{
await this.semaphore.WaitAsync();
await Task.WhenAll(this.clients.Select(client => action(client)));
await Task.WhenAll(this.clients.SelectMany(pair => pair.Value.Select(client => action(pair.Key, client))));
this.semaphore.Release();
}

private void ShutDownConnections()
{
this.semaphore.Wait();
foreach(var client in this.clients)
foreach(var sockets in this.clients.Values)
{
client.Abort();
foreach(var client in sockets)
{
client.Abort();
}
}

this.semaphore.Release();
Expand Down

0 comments on commit 14d7444

Please sign in to comment.