From 35b503522d3310465c3b6d6085ec12267287fa62 Mon Sep 17 00:00:00 2001 From: Casey Waldren Date: Tue, 3 Sep 2024 13:28:57 -0700 Subject: [PATCH] contract tests + validation --- .../include/data_model/data_model.hpp | 172 ++++++++++++------ .../src/entity_manager.cpp | 20 +- .../server-contract-tests/src/main.cpp | 8 +- libs/server-sdk/src/CMakeLists.txt | 2 + .../payload_filter_validation.cpp | 24 +++ .../payload_filter_validation.hpp | 7 + .../sources/polling/polling_data_source.cpp | 17 +- .../streaming/streaming_data_source.cpp | 14 +- 8 files changed, 191 insertions(+), 73 deletions(-) create mode 100644 libs/server-sdk/src/data_systems/background_sync/detail/payload_filter_validation/payload_filter_validation.cpp create mode 100644 libs/server-sdk/src/data_systems/background_sync/detail/payload_filter_validation/payload_filter_validation.hpp diff --git a/contract-tests/data-model/include/data_model/data_model.hpp b/contract-tests/data-model/include/data_model/data_model.hpp index 51ee897ca..a752ccd7d 100644 --- a/contract-tests/data-model/include/data_model/data_model.hpp +++ b/contract-tests/data-model/include/data_model/data_model.hpp @@ -5,31 +5,41 @@ #include #include "nlohmann/json.hpp" -namespace nlohmann { - -template -struct adl_serializer> { - static void to_json(json& j, std::optional const& opt) { - if (opt == std::nullopt) { - j = nullptr; - } else { - j = *opt; // this will call adl_serializer::to_json which will - // find the free function to_json in T's namespace! +namespace nlohmann +{ + template + struct adl_serializer> + { + static void to_json(json& j, std::optional const& opt) + { + if (opt == std::nullopt) + { + j = nullptr; + } + else + { + j = *opt; // this will call adl_serializer::to_json which will + // find the free function to_json in T's namespace! + } } - } - - static void from_json(json const& j, std::optional& opt) { - if (j.is_null()) { - opt = std::nullopt; - } else { - opt = j.get(); // same as above, but with - // adl_serializer::from_json + + static void from_json(json const& j, std::optional& opt) + { + if (j.is_null()) + { + opt = std::nullopt; + } + else + { + opt = j.get(); // same as above, but with + // adl_serializer::from_json + } } - } -}; -} // namespace nlohmann + }; +} // namespace nlohmann -struct ConfigTLSParams { +struct ConfigTLSParams +{ std::optional skipVerifyPeer; std::optional customCAFile; }; @@ -38,23 +48,30 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ConfigTLSParams, skipVerifyPeer, customCAFile); -struct ConfigStreamingParams { +struct ConfigStreamingParams +{ std::optional baseUri; std::optional initialRetryDelayMs; + std::optional filter; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ConfigStreamingParams, baseUri, - initialRetryDelayMs); + initialRetryDelayMs, filter); -struct ConfigPollingParams { +struct ConfigPollingParams +{ std::optional baseUri; std::optional pollIntervalMs; + std::optional filter; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ConfigPollingParams, baseUri, - pollIntervalMs); + pollIntervalMs, filter); -struct ConfigEventParams { +struct ConfigEventParams +{ std::optional baseUri; std::optional capacity; std::optional enableDiagnostics; @@ -62,6 +79,7 @@ struct ConfigEventParams { std::vector globalPrivateAttributes; std::optional flushIntervalMs; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ConfigEventParams, baseUri, capacity, @@ -69,35 +87,43 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ConfigEventParams, allAttributesPrivate, globalPrivateAttributes, flushIntervalMs); -struct ConfigServiceEndpointsParams { + +struct ConfigServiceEndpointsParams +{ std::optional streaming; std::optional polling; std::optional events; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ConfigServiceEndpointsParams, streaming, polling, events); -struct ConfigClientSideParams { +struct ConfigClientSideParams +{ nlohmann::json initialContext; std::optional evaluationReasons; std::optional useReport; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ConfigClientSideParams, initialContext, evaluationReasons, useReport); -struct ConfigTags { +struct ConfigTags +{ std::optional applicationId; std::optional applicationVersion; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ConfigTags, applicationId, applicationVersion); -struct ConfigParams { +struct ConfigParams +{ std::string credential; std::optional startWaitTimeMs; std::optional initCanFail; @@ -109,6 +135,7 @@ struct ConfigParams { std::optional tags; std::optional tls; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ConfigParams, credential, startWaitTimeMs, @@ -121,7 +148,8 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ConfigParams, tags, tls); -struct ContextSingleParams { +struct ContextSingleParams +{ std::optional kind; std::string key; std::optional name; @@ -133,7 +161,8 @@ struct ContextSingleParams { // These are defined manually because of the 'private' field, which is a // reserved keyword in C++. inline void to_json(nlohmann::json& nlohmann_json_j, - ContextSingleParams const& nlohmann_json_t) { + ContextSingleParams const& nlohmann_json_t) +{ nlohmann_json_j["kind"] = nlohmann_json_t.kind; nlohmann_json_j["key"] = nlohmann_json_t.key; nlohmann_json_j["name"] = nlohmann_json_t.name; @@ -141,8 +170,10 @@ inline void to_json(nlohmann::json& nlohmann_json_j, nlohmann_json_j["private"] = nlohmann_json_t._private; nlohmann_json_j["custom"] = nlohmann_json_t.custom; } + inline void from_json(nlohmann::json const& nlohmann_json_j, - ContextSingleParams& nlohmann_json_t) { + ContextSingleParams& nlohmann_json_t) +{ ContextSingleParams nlohmann_json_default_obj; nlohmann_json_t.kind = nlohmann_json_j.value("kind", nlohmann_json_default_obj.kind); @@ -158,7 +189,8 @@ inline void from_json(nlohmann::json const& nlohmann_json_j, nlohmann_json_j.value("custom", nlohmann_json_default_obj.custom); } -struct ContextBuildParams { +struct ContextBuildParams +{ std::optional single; std::optional> multi; }; @@ -167,37 +199,43 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ContextBuildParams, single, multi); -struct ContextConvertParams { +struct ContextConvertParams +{ std::string input; }; NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ContextConvertParams, input); -struct ContextResponse { +struct ContextResponse +{ std::optional output; std::optional error; }; NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(ContextResponse, output, error); -struct CreateInstanceParams { +struct CreateInstanceParams +{ ConfigParams configuration; std::string tag; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(CreateInstanceParams, configuration, tag); enum class ValueType { Bool = 1, Int, Double, String, Any, Unspecified }; + NLOHMANN_JSON_SERIALIZE_ENUM(ValueType, {{ValueType::Bool, "bool"}, - {ValueType::Int, "int"}, - {ValueType::Double, "double"}, - {ValueType::String, "string"}, - {ValueType::Any, "any"}, - {ValueType::Unspecified, ""}}) + {ValueType::Int, "int"}, + {ValueType::Double, "double"}, + {ValueType::String, "string"}, + {ValueType::Any, "any"}, + {ValueType::Unspecified, ""}}) -struct EvaluateFlagParams { +struct EvaluateFlagParams +{ std::string flagKey; std::optional context; ValueType valueType; @@ -205,6 +243,7 @@ struct EvaluateFlagParams { bool detail; EvaluateFlagParams(); }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(EvaluateFlagParams, flagKey, context, @@ -212,40 +251,49 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(EvaluateFlagParams, defaultValue, detail); -struct EvaluateFlagResponse { +struct EvaluateFlagResponse +{ nlohmann::json value; std::optional variationIndex; std::optional reason; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(EvaluateFlagResponse, value, variationIndex, reason); -struct EvaluateAllFlagParams { +struct EvaluateAllFlagParams +{ std::optional context; std::optional withReasons; std::optional clientSideOnly; std::optional detailsOnlyForTrackedFlags; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(EvaluateAllFlagParams, context, withReasons, clientSideOnly, detailsOnlyForTrackedFlags); -struct EvaluateAllFlagsResponse { + +struct EvaluateAllFlagsResponse +{ nlohmann::json state; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(EvaluateAllFlagsResponse, state); -struct CustomEventParams { +struct CustomEventParams +{ std::string eventKey; std::optional context; std::optional data; std::optional omitNullData; std::optional metricValue; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(CustomEventParams, eventKey, context, @@ -253,12 +301,15 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(CustomEventParams, omitNullData, metricValue); -struct IdentifyEventParams { +struct IdentifyEventParams +{ nlohmann::json context; }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(IdentifyEventParams, context); -enum class Command { +enum class Command +{ Unknown = -1, EvaluateFlag, EvaluateAllFlags, @@ -268,17 +319,19 @@ enum class Command { ContextBuild, ContextConvert }; + NLOHMANN_JSON_SERIALIZE_ENUM(Command, {{Command::Unknown, nullptr}, - {Command::EvaluateFlag, "evaluate"}, - {Command::EvaluateAllFlags, "evaluateAll"}, - {Command::IdentifyEvent, "identifyEvent"}, - {Command::CustomEvent, "customEvent"}, - {Command::FlushEvents, "flushEvents"}, - {Command::ContextBuild, "contextBuild"}, - {Command::ContextConvert, "contextConvert"}}); - -struct CommandParams { + {Command::EvaluateFlag, "evaluate"}, + {Command::EvaluateAllFlags, "evaluateAll"}, + {Command::IdentifyEvent, "identifyEvent"}, + {Command::CustomEvent, "customEvent"}, + {Command::FlushEvents, "flushEvents"}, + {Command::ContextBuild, "contextBuild"}, + {Command::ContextConvert, "contextConvert"}}); + +struct CommandParams +{ Command command; std::optional evaluate; std::optional evaluateAll; @@ -288,6 +341,7 @@ struct CommandParams { std::optional contextConvert; CommandParams(); }; + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(CommandParams, command, evaluate, diff --git a/contract-tests/server-contract-tests/src/entity_manager.cpp b/contract-tests/server-contract-tests/src/entity_manager.cpp index 68d75d642..c92093d0e 100644 --- a/contract-tests/server-contract-tests/src/entity_manager.cpp +++ b/contract-tests/server-contract-tests/src/entity_manager.cpp @@ -11,7 +11,8 @@ using namespace launchdarkly::server_side; EntityManager::EntityManager(boost::asio::any_io_executor executor, launchdarkly::Logger& logger) - : counter_{0}, executor_{std::move(executor)}, logger_{logger} {} + : counter_{0}, executor_{std::move(executor)}, logger_{logger} { +} std::optional EntityManager::create(ConfigParams const& in) { std::string id = std::to_string(counter_++); @@ -31,9 +32,9 @@ std::optional EntityManager::create(ConfigParams const& in) { auto& endpoints = config_builder.ServiceEndpoints() - .EventsBaseUrl(default_endpoints.EventsBaseUrl()) - .PollingBaseUrl(default_endpoints.PollingBaseUrl()) - .StreamingBaseUrl(default_endpoints.StreamingBaseUrl()); + .EventsBaseUrl(default_endpoints.EventsBaseUrl()) + .PollingBaseUrl(default_endpoints.PollingBaseUrl()) + .StreamingBaseUrl(default_endpoints.StreamingBaseUrl()); if (in.serviceEndpoints) { if (in.serviceEndpoints->streaming) { @@ -52,12 +53,15 @@ std::optional EntityManager::create(ConfigParams const& in) { if (in.streaming->baseUri) { endpoints.StreamingBaseUrl(*in.streaming->baseUri); } + auto streaming = decltype(datasystem)::Streaming(); if (in.streaming->initialRetryDelayMs) { - auto streaming = decltype(datasystem)::Streaming(); streaming.InitialReconnectDelay( std::chrono::milliseconds(*in.streaming->initialRetryDelayMs)); - datasystem.Synchronizer(std::move(streaming)); } + if (in.streaming->filter) { + streaming.Filter(*in.streaming->filter); + } + datasystem.Synchronizer(std::move(streaming)); } if (in.polling) { @@ -72,6 +76,9 @@ std::optional EntityManager::create(ConfigParams const& in) { std::chrono::milliseconds( *in.polling->pollIntervalMs))); } + if (in.polling->filter) { + method.Filter(*in.polling->filter); + } datasystem.Synchronizer(std::move(method)); } } @@ -106,7 +113,6 @@ std::optional EntityManager::create(ConfigParams const& in) { event_config.FlushInterval( std::chrono::milliseconds(*events.flushIntervalMs)); } - } else { event_config.Disable(); } diff --git a/contract-tests/server-contract-tests/src/main.cpp b/contract-tests/server-contract-tests/src/main.cpp index e492b4ab6..c4084d0a6 100644 --- a/contract-tests/server-contract-tests/src/main.cpp +++ b/contract-tests/server-contract-tests/src/main.cpp @@ -25,7 +25,7 @@ int main(int argc, char* argv[]) { std::string port = default_port; if (argc == 2) { port = - argv[1]; // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic) + argv[1]; // NOLINT(cppcoreguidelines-pro-bounds-pointer-arithmetic) } try { @@ -45,7 +45,8 @@ int main(int argc, char* argv[]) { srv.add_capability("tls:verify-peer"); srv.add_capability("tls:skip-verify-peer"); srv.add_capability("tls:custom-ca"); - + srv.add_capability("filtering"); + srv.add_capability("filtering-strict"); net::signal_set signals{ioc, SIGINT, SIGTERM}; boost::asio::spawn(ioc.get_executor(), [&](auto yield) mutable { @@ -56,12 +57,11 @@ int main(int argc, char* argv[]) { ioc.run(); LD_LOG(logger, LogLevel::kInfo) << "bye!"; - } catch (boost::bad_lexical_cast&) { LD_LOG(logger, LogLevel::kError) << "invalid port (" << port << "), provide a number (no arguments defaults " - "to port " + "to port " << default_port << ")"; return EXIT_FAILURE; } catch (std::exception const& e) { diff --git a/libs/server-sdk/src/CMakeLists.txt b/libs/server-sdk/src/CMakeLists.txt index 692e4af38..a440b79ed 100644 --- a/libs/server-sdk/src/CMakeLists.txt +++ b/libs/server-sdk/src/CMakeLists.txt @@ -42,6 +42,8 @@ target_sources(${LIBNAME} data_components/serialization_adapters/json_deserializer.cpp data_components/serialization_adapters/json_destination.hpp data_components/serialization_adapters/json_destination.cpp + data_systems/background_sync/detail/payload_filter_validation/payload_filter_validation.hpp + data_systems/background_sync/detail/payload_filter_validation/payload_filter_validation.cpp data_systems/background_sync/sources/polling/polling_data_source.hpp data_systems/background_sync/sources/polling/polling_data_source.cpp data_systems/background_sync/sources/streaming/streaming_data_source.hpp diff --git a/libs/server-sdk/src/data_systems/background_sync/detail/payload_filter_validation/payload_filter_validation.cpp b/libs/server-sdk/src/data_systems/background_sync/detail/payload_filter_validation/payload_filter_validation.cpp new file mode 100644 index 000000000..24df165b5 --- /dev/null +++ b/libs/server-sdk/src/data_systems/background_sync/detail/payload_filter_validation/payload_filter_validation.cpp @@ -0,0 +1,24 @@ +#include "payload_filter_validation.hpp" + +#include + +namespace launchdarkly::server_side::data_systems::detail { +bool ValidateFilterKey(std::string const& filter_key) { + if (filter_key.empty()) { + return false; + } + try { + return regex_search(filter_key, + boost::regex( + "^[a-zA-Z0-9][._\\-a-zA-Z0-9]*$")); + } catch (boost::bad_expression) { + // boost::bad_expression can be thrown by basic_regex when compiling a + // regular expression. + return false; + } catch (std::runtime_error) { + // std::runtime_error can be thrown when a call + // to regex_search results in an "everlasting" search + return false; + } +} +} // namespace launchdarkly::server_side::data_systems::detail diff --git a/libs/server-sdk/src/data_systems/background_sync/detail/payload_filter_validation/payload_filter_validation.hpp b/libs/server-sdk/src/data_systems/background_sync/detail/payload_filter_validation/payload_filter_validation.hpp new file mode 100644 index 000000000..dcbad651b --- /dev/null +++ b/libs/server-sdk/src/data_systems/background_sync/detail/payload_filter_validation/payload_filter_validation.hpp @@ -0,0 +1,7 @@ +#pragma once + +#include + +namespace launchdarkly::server_side::data_systems::detail { +bool ValidateFilterKey(std::string const& filter_key); +} diff --git a/libs/server-sdk/src/data_systems/background_sync/sources/polling/polling_data_source.cpp b/libs/server-sdk/src/data_systems/background_sync/sources/polling/polling_data_source.cpp index 30fc6d252..0f2b5c5cb 100644 --- a/libs/server-sdk/src/data_systems/background_sync/sources/polling/polling_data_source.cpp +++ b/libs/server-sdk/src/data_systems/background_sync/sources/polling/polling_data_source.cpp @@ -10,6 +10,8 @@ #include +#include "../../detail/payload_filter_validation/payload_filter_validation.hpp" + #include namespace launchdarkly::server_side::data_systems { @@ -20,7 +22,13 @@ static char const* const kErrorPutInvalid = static char const* const kCouldNotParseEndpoint = "Could not parse polling endpoint URL"; +static char const* const kInvalidFilterKey = + "Invalid payload filter configured on polling data source, full environment " + "will be fetched.\nEnsure the filter key is not empty and was copied " + "correctly from LaunchDarkly settings"; + static network::HttpRequest MakeRequest( + Logger const& logger, config::built::BackgroundSyncConfig::PollingConfig const& polling_config, config::built::ServiceEndpoints const& endpoints, config::built::HttpProperties const& http_properties) { @@ -29,7 +37,11 @@ static network::HttpRequest MakeRequest( url = network::AppendUrl(url, polling_config.polling_get_path); if (polling_config.filter_key && url) { - url->append("?filter=" + *polling_config.filter_key); + if (detail::ValidateFilterKey(*polling_config.filter_key)) { + url->append("?filter=" + *polling_config.filter_key); + } else { + LD_LOG(logger, LogLevel::kError) << kInvalidFilterKey; + } } network::HttpRequest::BodyType body; @@ -58,7 +70,8 @@ PollingDataSource::PollingDataSource( status_manager_(status_manager), requester_(ioc, http_properties.Tls()), polling_interval_(data_source_config.poll_interval), - request_(MakeRequest(data_source_config, endpoints, http_properties)), + request_(MakeRequest(logger_, data_source_config, endpoints, + http_properties)), timer_(ioc), sink_(nullptr) { if (polling_interval_ < data_source_config.min_polling_interval) { diff --git a/libs/server-sdk/src/data_systems/background_sync/sources/streaming/streaming_data_source.cpp b/libs/server-sdk/src/data_systems/background_sync/sources/streaming/streaming_data_source.cpp index a77505da4..6ba637364 100644 --- a/libs/server-sdk/src/data_systems/background_sync/sources/streaming/streaming_data_source.cpp +++ b/libs/server-sdk/src/data_systems/background_sync/sources/streaming/streaming_data_source.cpp @@ -2,6 +2,8 @@ #include +#include "../../detail/payload_filter_validation/payload_filter_validation.hpp" + #include #include #include @@ -14,6 +16,12 @@ namespace launchdarkly::server_side::data_systems { static char const* const kCouldNotParseEndpoint = "Could not parse streaming endpoint URL"; +static char const* const kInvalidFilterKey = + "Invalid payload filter configured on polling data source, full environment " + "will be fetched.\nEnsure the filter key is not empty and was copied " + "correctly from LaunchDarkly settings"; + + std::string const& StreamingDataSource::Identity() const { static std::string const identity = "streaming data source"; return identity; @@ -47,7 +55,11 @@ void StreamingDataSource::StartAsync( streaming_config_.streaming_path); if (streaming_config_.filter_key && updated_url) { - updated_url->append("?filter=" + *streaming_config_.filter_key); + if (detail::ValidateFilterKey(*streaming_config_.filter_key)) { + updated_url->append("?filter=" + *streaming_config_.filter_key); + } else { + LD_LOG(logger_, LogLevel::kError) << kInvalidFilterKey; + } } // Bad URL, don't set the client. Start will then report the bad status.