Skip to content

Commit

Permalink
Implement API Key whitelisting
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexMacocian committed Aug 12, 2024
1 parent 45fc018 commit 88b01a0
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 19 deletions.
1 change: 1 addition & 0 deletions .github/workflows/docker-deploy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
$trimmedIps = "${{ secrets.COMMA_SEPARATED_IPS }}" -split ',' | ForEach-Object { $_.Trim() }
$newAddressesArray = $jsonContent.IpWhitelistOptions.Addresses + $trimmedIps
$jsonContent.IpWhitelistOptions.Addresses = $newAddressesArray
$jsonContent.ApiWhitelistOptions.Key = "${{ secrets.API_KEY }}"
$updatedJsonContent = $jsonContent | ConvertTo-Json -Depth 32
Set-Content -Path Config.json -Value $updatedJsonContent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ class _RealWebSocket : public easywsclient::WebSocket
};


easywsclient::WebSocket::pointer from_url(const std::string& url, bool useMask, const std::string& origin, const std::string& userAgent) {
easywsclient::WebSocket::pointer from_url(const std::string& url, bool useMask, const std::string& origin, const std::string& user_agent, const std::string& api_key) {
char host[512];
int port;
char path[512];
Expand Down Expand Up @@ -501,7 +501,8 @@ easywsclient::WebSocket::pointer from_url(const std::string& url, bool useMask,
else {
snprintf(line, 1024, "Host: %s:%d\r\n", host, port); ::send(sockfd, line, strlen(line), 0);
}
snprintf(line, 1024, "User-Agent: %s\r\n", userAgent.c_str()); ::send(sockfd, line, strlen(line), 0);
snprintf(line, 1024, "User-Agent: %s\r\n", user_agent.c_str()); ::send(sockfd, line, strlen(line), 0);
snprintf(line, 1024, "X-Api-Key: %s\r\n", api_key.c_str()); ::send(sockfd, line, strlen(line), 0);
snprintf(line, 1024, "Upgrade: websocket\r\n"); ::send(sockfd, line, strlen(line), 0);
snprintf(line, 1024, "Connection: Upgrade\r\n"); ::send(sockfd, line, strlen(line), 0);
if (!origin.empty()) {
Expand Down Expand Up @@ -544,12 +545,12 @@ WebSocket::pointer WebSocket::create_dummy() {
}


WebSocket::pointer WebSocket::from_url(const std::string& url, const std::string& userAgent, const std::string& origin) {
return ::from_url(url, true, origin, userAgent);
WebSocket::pointer WebSocket::from_url(const std::string& url, const std::string& user_agent, const std::string& api_key, const std::string& origin) {
return ::from_url(url, true, origin, user_agent, api_key);
}

WebSocket::pointer WebSocket::from_url_no_mask(const std::string& url, const std::string& userAgent, const std::string& origin) {
return ::from_url(url, false, origin, userAgent);
WebSocket::pointer WebSocket::from_url_no_mask(const std::string& url, const std::string& user_agent, const std::string& api_key, const std::string& origin) {
return ::from_url(url, false, origin, user_agent, api_key);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class WebSocket {

// Factories:
static pointer create_dummy();
static pointer from_url(const std::string& url, const std::string& userAgent, const std::string& origin = std::string());
static pointer from_url_no_mask(const std::string& url, const std::string& userAgent, const std::string& origin = std::string());
static pointer from_url(const std::string& url, const std::string& user_agent, const std::string& api_key, const std::string& origin = std::string());
static pointer from_url_no_mask(const std::string& url, const std::string& user_agent, const std::string& api_key, const std::string& origin = std::string());

// Interfaces:
virtual ~WebSocket() { }
Expand Down
7 changes: 6 additions & 1 deletion GuildWarsPartySearch.Bot/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ extern "C" {

struct BotConfiguration {
std::string web_socket_url = "";
std::string api_key = "development";
uint32_t map_id = 857; // Embark beach
District district = District::DISTRICT_AMERICAN;
uint32_t district_number = 0;
Expand Down Expand Up @@ -333,6 +334,10 @@ static void load_configuration() {
bot_configuration.connection_retries = stoi(get_next_argument(i));
i++;
}
else if (arg == "-api-key") {
bot_configuration.api_key = get_next_argument(i);
i++;
}
}
}
catch (std::exception) {
Expand Down Expand Up @@ -420,7 +425,7 @@ static easywsclient::WebSocket::pointer connect_websocket() {
disconnect_websocket();

LogInfo("Attempting to connect. Try %d/%d", i + 1, connect_retries);
ws = easywsclient::WebSocket::from_url(bot_configuration.web_socket_url, user_agent);
ws = easywsclient::WebSocket::from_url(bot_configuration.web_socket_url, user_agent, bot_configuration.api_key);
if (!ws) {
// Sleep before retry
time_sleep_sec(5);
Expand Down
3 changes: 3 additions & 0 deletions GuildWarsPartySearch/Config.Debug.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
"localhost"
]
},
"ApiWhitelistOptions": {
"Key": "development"
},
"ContentOptions": {
"StagingFolder": "Content"
}
Expand Down
3 changes: 3 additions & 0 deletions GuildWarsPartySearch/Config.Production.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
"localhost"
]
},
"ApiWhitelistOptions": {
"Key": "development"
},
"ContentOptions": {
"StagingFolder": "Content"
}
Expand Down
2 changes: 1 addition & 1 deletion GuildWarsPartySearch/Endpoints/PostPartySearch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace GuildWarsPartySearch.Server.Endpoints;

[ServiceFilter<IpWhitelistFilter>]
[ServiceFilter<ApiWhitelistFilter>]
[ServiceFilter<UserAgentRequired>]
public sealed class PostPartySearch : WebSocketRouteBase<PostPartySearchRequest, PostPartySearchResponse>
{
Expand Down
16 changes: 8 additions & 8 deletions GuildWarsPartySearch/Endpoints/StatusController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ public StatusController(
}

[HttpGet("bot-activity/all")]
[ServiceFilter<IpWhitelistFilter>]
[ServiceFilter<ApiWhitelistFilter>]
[ProducesResponseType(200)]
[ProducesResponseType(403)]
[SwaggerOperation(Description = $"Protected by *IP whitelisting*.\r\n\r\n Disabled in *Production* \r\n\r\n")]
[SwaggerOperation(Description = $"Protected by *API Key whitelisting*.\r\n\r\n Disabled in *Production* \r\n\r\n")]
public async Task<IActionResult> GetAllBotActivity()
{
if (this.environment.IsProduction())
Expand All @@ -37,10 +37,10 @@ public async Task<IActionResult> GetAllBotActivity()
}

[HttpGet("bot-activity/bots/{botName}")]
[ServiceFilter<IpWhitelistFilter>]
[ServiceFilter<ApiWhitelistFilter>]
[ProducesResponseType(200)]
[ProducesResponseType(403)]
[SwaggerOperation(Description = $"Protected by *IP whitelisting*.\r\n\r\n Disabled in *Production* \r\n\r\n")]
[SwaggerOperation(Description = $"Protected by *API Key whitelisting*.\r\n\r\n Disabled in *Production* \r\n\r\n")]
public async Task<IActionResult> GetBotActivityByName(string botName)
{
if (this.environment.IsProduction())
Expand All @@ -52,10 +52,10 @@ public async Task<IActionResult> GetBotActivityByName(string botName)
}

[HttpGet("bot-activity/maps/{map}")]
[ServiceFilter<IpWhitelistFilter>]
[ServiceFilter<ApiWhitelistFilter>]
[ProducesResponseType(200)]
[ProducesResponseType(403)]
[SwaggerOperation(Description = $"Protected by *IP whitelisting*.\r\n\r\n Disabled in *Production* \r\n\r\n")]
[SwaggerOperation(Description = $"Protected by *API Key whitelisting*.\r\n\r\n Disabled in *Production* \r\n\r\n")]
public async Task<IActionResult> GetBotActivityByMap(string map)
{
if (this.environment.IsProduction())
Expand All @@ -74,10 +74,10 @@ public async Task<IActionResult> GetBotActivityByMap(string map)
}

[HttpGet("bots")]
[ServiceFilter<IpWhitelistFilter>]
[ServiceFilter<ApiWhitelistFilter>]
[ProducesResponseType(200)]
[ProducesResponseType(403)]
[SwaggerOperation(Description = $"Protected by *IP whitelisting*.\r\n\r\n")]
[SwaggerOperation(Description = $"Protected by *API Key whitelisting*.\r\n\r\n Disabled in *Production* \r\n\r\n")]
public async Task<IActionResult> GetBotStatus()
{
if (this.environment.IsProduction())
Expand Down
44 changes: 44 additions & 0 deletions GuildWarsPartySearch/Filters/ApiWhitelistFilter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using GuildWarsPartySearch.Server.Options;
using Microsoft.AspNetCore.Mvc.Filters;
using Microsoft.Extensions.Options;
using System.Core.Extensions;
using System.Extensions;

namespace GuildWarsPartySearch.Server.Filters
{
public class ApiWhitelistFilter : IAsyncActionFilter
{
private const string XApiKeyHeaderKey = "X-Api-Key";

private readonly ApiWhitelistOptions apiWhitelistOptions;
private readonly ILogger<ApiWhitelistFilter> logger;

public ApiWhitelistFilter(
IOptions<ApiWhitelistOptions> options,
ILogger<ApiWhitelistFilter> logger)
{
this.apiWhitelistOptions = options.ThrowIfNull().Value.ThrowIfNull();
this.logger = logger.ThrowIfNull();
}

public async Task OnActionExecutionAsync(ActionExecutingContext context, ActionExecutionDelegate next)
{
var address = context.HttpContext.Connection.RemoteIpAddress?.ToString();
var scopedLogger = this.logger.CreateScopedLogger(nameof(this.OnActionExecutionAsync), address ?? string.Empty);
scopedLogger.LogDebug($"Received request");
if (context.HttpContext.Request.Headers.TryGetValue(XApiKeyHeaderKey, out var xApiKeyvalues))
{
scopedLogger.LogDebug($"X-Api-Key {string.Join(',', xApiKeyvalues.Select(s => s))}");
}

if (xApiKeyvalues.None(k => k == this.apiWhitelistOptions.Key))
{
context.Result = new ForbiddenResponseActionResult("Forbidden");
return;
}

await next();
}
}

}
4 changes: 3 additions & 1 deletion GuildWarsPartySearch/Launch/ServerConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ public static WebApplicationBuilder SetupOptions(this WebApplicationBuilder buil
.ConfigureExtended<ServerOptions>()
.ConfigureExtended<IpWhitelistOptions>()
.ConfigureExtended<BotHistoryDatabaseOptions>()
.ConfigureExtended<SQLiteDatabaseOptions>();
.ConfigureExtended<SQLiteDatabaseOptions>()
.ConfigureExtended<ApiWhitelistOptions>();
}

public static IServiceCollection SetupServices(this IServiceCollection services)
Expand All @@ -83,6 +84,7 @@ public static IServiceCollection SetupServices(this IServiceCollection services)
services.AddSingleton<IBotHistoryDatabase, BotHistorySqliteDatabase>();
services.AddScoped<UserAgentRequired>();
services.AddScoped<IpWhitelistFilter>();
services.AddScoped<ApiWhitelistFilter>();
services.AddScoped<IPartySearchService, PartySearchService>();
services.AddScoped<ICharNameValidator, CharNameValidator>();
return services;
Expand Down
6 changes: 6 additions & 0 deletions GuildWarsPartySearch/Options/ApiWhitelistOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace GuildWarsPartySearch.Server.Options;

public class ApiWhitelistOptions
{
public string? Key { get; set; }
}

0 comments on commit 88b01a0

Please sign in to comment.