Skip to content

Commit

Permalink
IsOverloaded func to take const refs
Browse files Browse the repository at this point in the history
Summary:
Changes parameters for the IsOverloaded user provided function to us references instead of pointers.

Some user methods were modified to more cleanly utilize the references, however many user methods can be further refined to propagate references to their own called methods instead of reconverting them to pointers, this will be left as an exercise for the users.

Reviewed By: sazonovkirill

Differential Revision: D65897036

fbshipit-source-id: 31860aa797f8213a70cbb688a2f3ad4c3ca9f3c6
  • Loading branch information
Charlie Marquez Cook authored and facebook-github-bot committed Nov 23, 2024
1 parent 4ebc0c5 commit 6980aa9
Show file tree
Hide file tree
Showing 11 changed files with 27 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,8 @@ void Cpp2Connection::requestReceived(
server->getGetHeaderHandler()(hreq->getHeader(), context_.getPeerAddress());
}

if (auto overloadResult = server->checkOverload(
&hreq->getHeader()->getHeaders(), &methodName)) {
if (auto overloadResult =
server->checkOverload(hreq->getHeader()->getHeaders(), methodName)) {
killRequestServerOverloaded(std::move(hreq), std::move(*overloadResult));
return;
}
Expand Down
4 changes: 2 additions & 2 deletions third-party/thrift/src/thrift/lib/cpp2/server/ServerConfigs.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class ServerConfigs {

// @see ThriftServer::checkOverload function.
virtual folly::Optional<OverloadResult> checkOverload(
const transport::THeader::StringToStringMap* readHeaders,
const std::string* method) = 0;
const transport::THeader::StringToStringMap& readHeaders,
const std::string& method) = 0;

// @see ThriftServer::preprocess function.
virtual PreprocessResult preprocess(const PreprocessParams& params) const = 0;
Expand Down
13 changes: 5 additions & 8 deletions third-party/thrift/src/thrift/lib/cpp2/server/ThriftServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1934,12 +1934,11 @@ PreprocessResult ThriftServer::preprocess(
}

folly::Optional<OverloadResult> ThriftServer::checkOverload(
const transport::THeader::StringToStringMap* readHeaders,
const std::string* method) {
const transport::THeader::StringToStringMap& readHeaders,
const std::string& method) {
if (UNLIKELY(
isOverloaded_ &&
(method == nullptr ||
!getMethodsBypassMaxRequestsLimit().contains(*method)) &&
!getMethodsBypassMaxRequestsLimit().contains(method) &&
isOverloaded_(readHeaders, method))) {
return OverloadResult{
kAppOverloadedErrorCode,
Expand All @@ -1956,8 +1955,7 @@ folly::Optional<OverloadResult> ThriftServer::checkOverload(
THRIFT_FLAG(enforce_queue_concurrency_resource_pools);
if (!isActiveRequestsTrackingDisabled() && !useQueueConcurrency) {
if (auto maxRequests = getMaxRequests(); maxRequests > 0 &&
(method == nullptr ||
!getMethodsBypassMaxRequestsLimit().contains(*method)) &&
!getMethodsBypassMaxRequestsLimit().contains(method) &&
static_cast<uint32_t>(getActiveRequests()) >= maxRequests) {
LoadShedder loadShedder = LoadShedder::MAX_REQUESTS;
if (getCPUConcurrencyController().requestShed(
Expand All @@ -1975,8 +1973,7 @@ folly::Optional<OverloadResult> ThriftServer::checkOverload(

if (auto maxQps = getMaxQps(); maxQps > 0 &&
FLAGS_thrift_server_enforces_qps_limit &&
(method == nullptr ||
!getMethodsBypassMaxRequestsLimit().contains(*method)) &&
!getMethodsBypassMaxRequestsLimit().contains(method) &&
!qpsTokenBucket_.consume(1.0, maxQps, maxQps)) {
LoadShedder loadShedder = LoadShedder::MAX_QPS;
if (getCPUConcurrencyController().requestShed(
Expand Down
6 changes: 3 additions & 3 deletions third-party/thrift/src/thrift/lib/cpp2/server/ThriftServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class ThriftServerStopController final {
};

using IsOverloadedFunc = folly::Function<bool(
const transport::THeader::StringToStringMap*, const std::string*) const>;
const transport::THeader::StringToStringMap&, const std::string&) const>;

typedef std::function<void(
folly::EventBase*,
Expand Down Expand Up @@ -2243,8 +2243,8 @@ class ThriftServer : public apache::thrift::concurrency::Runnable,

// if overloaded, returns applicable overloaded exception code.
folly::Optional<OverloadResult> checkOverload(
const transport::THeader::StringToStringMap* readHeaders = nullptr,
const std::string* = nullptr) final;
const transport::THeader::StringToStringMap& readHeaders,
const std::string& method) final;

// returns descriptive error if application is unable to process request
PreprocessResult preprocess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ folly::Optional<OverloadResult> CurrentlyOverloadedChecker::checkOverload(
(params.method == nullptr ||
!config_.getMethodsBypassMaxRequestsLimit().contains(
*params.method)) &&
isOverloaded_(params.readHeaders, params.method))) {
isOverloaded_(*params.readHeaders, *params.method))) {
return OverloadResult{
kAppOverloadedErrorCode,
fmt::format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class MockServerConfigs : public apache::thrift::server::ServerConfigs {
MOCK_METHOD(
folly::Optional<OverloadResult>,
checkOverload,
(const apache::thrift::transport::THeader::StringToStringMap*,
const std::string* method),
(const apache::thrift::transport::THeader::StringToStringMap&,
const std::string& method),
(override));
MOCK_METHOD(
apache::thrift::PreprocessResult,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,8 @@ void doLoadHeaderTest(bool isRocket) {
});

server.setIsOverloaded(
[&nCalls](const auto*, const std::string* method) {
EXPECT_EQ("voidResponse", *method);
[&nCalls](const auto&, const std::string& method) {
EXPECT_EQ("voidResponse", method);
return ++nCalls == 4;
});
});
Expand Down Expand Up @@ -2046,9 +2046,9 @@ TEST_P(OverloadTest, DISABLED_Test) {
auto client = makeClient(runner, &base);

runner.getThriftServer().setIsOverloaded(
[&](const auto*, const string* method) {
[&](const auto&, const string& method) {
if (errorType == ErrorType::AppOverload) {
EXPECT_EQ("voidResponse", *method);
EXPECT_EQ("voidResponse", method);
return true;
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class ServerConfigsMock : public ServerConfigs {
}

folly::Optional<OverloadResult> checkOverload(
const transport::THeader::StringToStringMap*,
const std::string*) override {
const transport::THeader::StringToStringMap&,
const std::string&) override {
return {};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ void TransportCompatibilityTest::TestRequestResponse_Saturation() {
void TransportCompatibilityTest::TestRequestResponse_IsOverloaded() {
// make sure server is overloaded
server_->getServer()->setIsOverloaded(
[](const transport::THeader::StringToStringMap*, const std::string*) {
[](const transport::THeader::StringToStringMap&, const std::string&) {
return true;
});
connectToServer([this](std::unique_ptr<TestServiceAsyncClient> client) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ void ThriftRocketServerHandler::handleRequestCommon(
const auto& headers = request->getTHeader().getHeaders();
const auto& name = request->getMethodName();

auto overloadResult = serverConfigs_->checkOverload(&headers, &name);
auto overloadResult = serverConfigs_->checkOverload(headers, name);
serverConfigs_->incActiveRequests();
if (UNLIKELY(overloadResult.has_value())) {
handleRequestOverloadedServer(
Expand Down
8 changes: 4 additions & 4 deletions third-party/thrift/src/thrift/lib/py3/test/is_overload/func.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ namespace py3 {
namespace test {

bool isOverloaded(
const apache::thrift::transport::THeader::StringToStringMap* headers,
const std::string* method_name) {
return *method_name == "overloaded_method";
const apache::thrift::transport::THeader::StringToStringMap& /* headers */,
const std::string& method_name) {
return method_name == "overloaded_method";
}

bool checkOverload(
const std::shared_ptr<apache::thrift::ThriftServer> server,
const std::string method_name) {
// dummy test doesn't use the headers, so pass nullptr
auto ret = server->checkOverload(nullptr, &method_name);
auto ret = server->checkOverload({}, method_name);
// ret will contain the error code if there is an overload
// otherwise, it will return no value
return ret.hasValue();
Expand Down

0 comments on commit 6980aa9

Please sign in to comment.