Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

Add test that stresses races during transport shutdown #250

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tensorpipe/common/deferred_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ class EventLoopDeferredExecutor : public virtual DeferredExecutor {
// subclasses and called inside the thread owned by this parent class.
virtual void eventLoop() = 0;

// This is called after the event loop terminated, still within the thread
// that used to run that event loop. It will be called after this class has
// transitioned control to the on-demand deferred executor. It thus allows to
// clean up any resources without worrying about new work coming in.
virtual void cleanUpLoop() {}

// This function is called by the parent class when a function is deferred to
// it, and must be implemented by subclasses, which are required to have their
// event loop call runDeferredFunctionsFromEventLoop as soon as possible. This
Expand Down Expand Up @@ -230,6 +236,8 @@ class EventLoopDeferredExecutor : public virtual DeferredExecutor {
fn();
}
}

cleanUpLoop();
}

std::thread thread_;
Expand Down
36 changes: 36 additions & 0 deletions tensorpipe/test/transport/connection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,39 @@ TEST_P(TransportTest, DISABLED_Connection_EmptyBuffer) {
peers_->join(PeerGroup::kClient);
});
}

TEST_P(TransportTest, Connection_SpamAtClosing) {
using namespace std::chrono_literals;

std::shared_ptr<Context> ctx = GetParam()->getContext();
ctx->setId("loopback");

std::string addr = GetParam()->defaultAddr();
std::shared_ptr<Listener> listener = ctx->listen(addr);

std::atomic<bool> stopSpamming{false};
std::function<void()> spam = [&]() {
if (stopSpamming) {
return;
}
std::shared_ptr<Connection> conn = ctx->connect(addr);
conn->read(
[&](const Error& error, const void* /* unused */, size_t /* unused */) {
EXPECT_TRUE(error);
spam();
});
conn->close();
};

spam();

std::this_thread::sleep_for(10ms);

ctx->close();

std::this_thread::sleep_for(10ms);

stopSpamming = true;

ctx->join();
}
12 changes: 12 additions & 0 deletions tensorpipe/transport/connection_boilerplate.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class ConnectionBoilerplate : public Connection {
std::string id,
Args... args);

explicit ConnectionBoilerplate(std::shared_ptr<TConn> connection);

ConnectionBoilerplate(const ConnectionBoilerplate&) = delete;
ConnectionBoilerplate(ConnectionBoilerplate&&) = delete;
ConnectionBoilerplate& operator=(const ConnectionBoilerplate&) = delete;
Expand Down Expand Up @@ -79,6 +81,16 @@ ConnectionBoilerplate<TCtx, TList, TConn>::ConnectionBoilerplate(
impl_->init();
}

template <typename TCtx, typename TList, typename TConn>
ConnectionBoilerplate<TCtx, TList, TConn>::ConnectionBoilerplate(
std::shared_ptr<TConn> connection)
: impl_(std::move(connection)) {
static_assert(
std::is_base_of<ConnectionImplBoilerplate<TCtx, TList, TConn>, TConn>::
value,
"");
}

template <typename TCtx, typename TList, typename TConn>
void ConnectionBoilerplate<TCtx, TList, TConn>::read(read_callback_fn fn) {
impl_->read(std::move(fn));
Expand Down
22 changes: 16 additions & 6 deletions tensorpipe/transport/connection_impl_boilerplate.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,21 +126,24 @@ class ConnectionImplBoilerplate : public std::enable_shared_from_this<TConn> {
// Deal with an error.
void handleError();

ClosingReceiver closingReceiver_;

// A sequence number for the calls to read and write.
uint64_t nextBufferBeingRead_{0};
uint64_t nextBufferBeingWritten_{0};

// Contexts and listeners do sometimes need to call directly into initFromLoop
// and closeForLoop, in order to make sure that some of their operations can
// happen "atomically" on the connection, without possibly other operations
// occurring in between (e.g., an error).
friend ContextImplBoilerplate<TCtx, TList, TConn>;
friend ListenerImplBoilerplate<TCtx, TList, TConn>;
};

template <typename TCtx, typename TList, typename TConn>
ConnectionImplBoilerplate<TCtx, TList, TConn>::ConnectionImplBoilerplate(
ConstructorToken /* unused */,
std::shared_ptr<TCtx> context,
std::string id)
: context_(std::move(context)),
id_(std::move(id)),
closingReceiver_(context_, context_->getClosingEmitter()) {}
: context_(std::move(context)), id_(std::move(id)) {}

template <typename TCtx, typename TList, typename TConn>
void ConnectionImplBoilerplate<TCtx, TList, TConn>::init() {
Expand All @@ -150,7 +153,14 @@ void ConnectionImplBoilerplate<TCtx, TList, TConn>::init() {

template <typename TCtx, typename TList, typename TConn>
void ConnectionImplBoilerplate<TCtx, TList, TConn>::initFromLoop() {
closingReceiver_.activate(*this);
if (context_->closed()) {
// Set the error without calling setError because we do not want to invoke
// the subclass's handleErrorImpl as it would find itself in a weird state
// (since initFromLoop wouldn't have been called).
error_ = TP_CREATE_ERROR(ConnectionClosedError);
TP_VLOG(7) << "Connection " << id_ << " is closing (without initing)";
return;
}

initImplFromLoop();
}
Expand Down
55 changes: 41 additions & 14 deletions tensorpipe/transport/context_impl_boilerplate.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
#pragma once

#include <atomic>
#include <future>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>

#include <tensorpipe/common/callback.h>
#include <tensorpipe/common/defs.h>
#include <tensorpipe/transport/connection_boilerplate.h>
#include <tensorpipe/transport/listener_boilerplate.h>
Expand Down Expand Up @@ -51,7 +51,9 @@ class ContextImplBoilerplate : public virtual DeferredExecutor,
void unenroll(TList& listener);
void unenroll(TConn& connection);

ClosingEmitter& getClosingEmitter();
// Return whether the context is in a closed state. To avoid race conditions,
// this must be called from within the loop.
bool closed();

void setId(std::string id);

Expand All @@ -73,7 +75,6 @@ class ContextImplBoilerplate : public virtual DeferredExecutor,
private:
std::atomic<bool> closed_{false};
std::atomic<bool> joined_{false};
ClosingEmitter closingEmitter_;

const std::string domainDescriptor_;

Expand Down Expand Up @@ -162,9 +163,9 @@ void ContextImplBoilerplate<TCtx, TList, TConn>::unenroll(TConn& connection) {
}

template <typename TCtx, typename TList, typename TConn>
ClosingEmitter& ContextImplBoilerplate<TCtx, TList, TConn>::
getClosingEmitter() {
return closingEmitter_;
bool ContextImplBoilerplate<TCtx, TList, TConn>::closed() {
TP_DCHECK(inLoop());
return closed_;
};

template <typename TCtx, typename TList, typename TConn>
Expand All @@ -175,14 +176,32 @@ void ContextImplBoilerplate<TCtx, TList, TConn>::setId(std::string id) {

template <typename TCtx, typename TList, typename TConn>
void ContextImplBoilerplate<TCtx, TList, TConn>::close() {
if (!closed_.exchange(true)) {
TP_VLOG(7) << "Transport context " << id_ << " is closing";

closingEmitter_.close();
closeImpl();

TP_VLOG(7) << "Transport context " << id_ << " done closing";
}
// Defer this to the loop so that it won't race with other code accessing it
// (in other words: any code in the loop can assume that this won't change).
deferToLoop([this]() {
if (!closed_.exchange(true)) {
TP_VLOG(7) << "Transport context " << id_ << " is closing";

// Make a copy as they could unenroll themselves inline.
decltype(listeners_) listenersCopy = listeners_;
decltype(connections_) connectionsCopy = connections_;
// We call closeFromLoop, rather than just close, because we need these
// objects to transition _immediately_ to error, "atomically". If we just
// deferred closing to later, this could come after some already-enqueued
// operations that could try to access the context, which would be closed,
// and this could fail.
for (auto& iter : listenersCopy) {
iter.second->closeFromLoop();
}
for (auto& iter : connectionsCopy) {
iter.second->closeFromLoop();
}

closeImpl();

TP_VLOG(7) << "Transport context " << id_ << " done closing";
}
});
}

template <typename TCtx, typename TList, typename TConn>
Expand All @@ -192,6 +211,14 @@ void ContextImplBoilerplate<TCtx, TList, TConn>::join() {
if (!joined_.exchange(true)) {
TP_VLOG(7) << "Transport context " << id_ << " is joining";

// As closing is deferred to the loop, we must wait for closeImpl to be
// actually called before we call joinImpl, to avoid race conditions. For
// this, we defer another task to the loop, which we know will run after the
// closing, and then we wait for that task to be run.
std::promise<void> hasClosed;
deferToLoop([&]() { hasClosed.set_value(); });
hasClosed.get_future().wait();

joinImpl();

TP_VLOG(7) << "Transport context " << id_ << " done joining";
Expand Down
2 changes: 1 addition & 1 deletion tensorpipe/transport/ibv/listener_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void ListenerImpl::handleEventsFromLoop(int events) {
if (fns_.empty()) {
context_->unregisterDescriptor(socket_.fd());
}
fn(Error::kSuccess, createConnection(std::move(socket)));
fn(Error::kSuccess, createAndInitConnection(std::move(socket)));
}

} // namespace ibv
Expand Down
36 changes: 27 additions & 9 deletions tensorpipe/transport/listener_impl_boilerplate.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class ListenerImplBoilerplate : public std::enable_shared_from_this<TList> {
Error error_{Error::kSuccess};

template <typename... Args>
std::shared_ptr<Connection> createConnection(Args&&... args);
std::shared_ptr<Connection> createAndInitConnection(Args&&... args);

// An identifier for the listener, composed of the identifier for the context,
// combined with an increasing sequence number. It will be used as a prefix
Expand All @@ -104,25 +104,27 @@ class ListenerImplBoilerplate : public std::enable_shared_from_this<TList> {
// Deal with an error.
void handleError();

ClosingReceiver closingReceiver_;

// A sequence number for the calls to accept.
uint64_t nextConnectionBeingAccepted_{0};

// Sequence numbers for the connections created by this listener, used to
// create their identifiers based off this listener's identifier. They will
// only be used for logging and debugging.
std::atomic<uint64_t> connectionCounter_{0};

// Contexts do sometimes need to call directly into closeForLoop, in order to
// make sure that some of their operations can happen "atomically" on the
// connection, without possibly other operations occurring in between (e.g.,
// an error).
friend ContextImplBoilerplate<TCtx, TList, TConn>;
};

template <typename TCtx, typename TList, typename TConn>
ListenerImplBoilerplate<TCtx, TList, TConn>::ListenerImplBoilerplate(
ConstructorToken /* unused */,
std::shared_ptr<TCtx> context,
std::string id)
: context_(std::move(context)),
id_(std::move(id)),
closingReceiver_(context_, context_->getClosingEmitter()) {}
: context_(std::move(context)), id_(std::move(id)) {}

template <typename TCtx, typename TList, typename TConn>
void ListenerImplBoilerplate<TCtx, TList, TConn>::init() {
Expand All @@ -132,7 +134,14 @@ void ListenerImplBoilerplate<TCtx, TList, TConn>::init() {

template <typename TCtx, typename TList, typename TConn>
void ListenerImplBoilerplate<TCtx, TList, TConn>::initFromLoop() {
closingReceiver_.activate(*this);
if (context_->closed()) {
// Set the error without calling setError because we do not want to invoke
// the subclass's handleErrorImpl as it would find itself in a weird state
// (since initFromLoop wouldn't have been called).
error_ = TP_CREATE_ERROR(ListenerClosedError);
TP_VLOG(7) << "Listener " << id_ << " is closing (without initing)";
return;
}

initImplFromLoop();
}
Expand Down Expand Up @@ -189,15 +198,24 @@ std::string ListenerImplBoilerplate<TCtx, TList, TConn>::addrFromLoop() const {
template <typename TCtx, typename TList, typename TConn>
template <typename... Args>
std::shared_ptr<Connection> ListenerImplBoilerplate<TCtx, TList, TConn>::
createConnection(Args&&... args) {
createAndInitConnection(Args&&... args) {
TP_DCHECK(context_->inLoop());
std::string connectionId = id_ + ".c" + std::to_string(connectionCounter_++);
TP_VLOG(7) << "Listener " << id_ << " is opening connection " << connectionId;
return std::make_shared<ConnectionBoilerplate<TCtx, TList, TConn>>(
auto connection = std::make_shared<TConn>(
typename ConnectionImplBoilerplate<TCtx, TList, TConn>::
ConstructorToken(),
context_,
std::move(connectionId),
std::forward<Args>(args)...);
// We initialize the connection from the loop immediately, inline, because the
// initialization of a connection accepted by a listener typically happens
// partly in the listener (e.g., opening and accepting the socket) and partly
// in the connection's initFromLoop, and we need these two steps to happen
// "atomicically" to make it impossible for an error to occur in between.
connection->initFromLoop();
return std::make_shared<ConnectionBoilerplate<TCtx, TList, TConn>>(
std::move(connection));
}

template <typename TCtx, typename TList, typename TConn>
Expand Down
2 changes: 1 addition & 1 deletion tensorpipe/transport/shm/listener_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void ListenerImpl::handleEventsFromLoop(int events) {
if (fns_.empty()) {
context_->unregisterDescriptor(socket_.fd());
}
fn(Error::kSuccess, createConnection(std::move(socket)));
fn(Error::kSuccess, createAndInitConnection(std::move(socket)));
}

} // namespace shm
Expand Down
2 changes: 2 additions & 0 deletions tensorpipe/transport/uv/connection_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ ConnectionImpl::ConnectionImpl(
void ConnectionImpl::initImplFromLoop() {
context_->enroll(*this);

TP_VLOG(9) << "Connection " << id_ << " is initializing in loop";

if (sockaddr_.has_value()) {
handle_->initFromLoop();
handle_->connectFromLoop(sockaddr_.value(), [this](int status) {
Expand Down
5 changes: 4 additions & 1 deletion tensorpipe/transport/uv/listener_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ ListenerImpl::ListenerImpl(
void ListenerImpl::initImplFromLoop() {
context_->enroll(*this);

TP_VLOG(9) << "Listener " << id_ << " is initializing in loop";

handle_->initFromLoop();
auto rv = handle_->bindFromLoop(sockaddr_);
TP_THROW_UV_IF(rv < 0, rv);
Expand Down Expand Up @@ -67,7 +69,8 @@ void ListenerImpl::connectionCallbackFromLoop(int status) {
auto connection = context_->createHandle();
connection->initFromLoop();
handle_->acceptFromLoop(*connection);
callback_.trigger(Error::kSuccess, createConnection(std::move(connection)));
callback_.trigger(
Error::kSuccess, createAndInitConnection(std::move(connection)));
}

void ListenerImpl::closeCallbackFromLoop() {
Expand Down
4 changes: 4 additions & 0 deletions tensorpipe/transport/uv/loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ void Loop::eventLoop() {
rv = uv_run(&loop_, UV_RUN_DEFAULT);
TP_THROW_ASSERT_IF(rv > 0)
<< ": uv_run returned with active handles or requests";
}

void Loop::cleanUpLoop() {
int rv;

uv_ref(reinterpret_cast<uv_handle_t*>(&async_));
uv_close(reinterpret_cast<uv_handle_t*>(&async_), nullptr);
Expand Down
7 changes: 7 additions & 0 deletions tensorpipe/transport/uv/loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class Loop final : public EventLoopDeferredExecutor {
return &loop_;
}

bool closed() {
return closed_;
}

void close();

void join();
Expand All @@ -42,6 +46,9 @@ class Loop final : public EventLoopDeferredExecutor {
// Event loop thread entry function.
void eventLoop() override;

// Clean up after event loop transitioned to on-demand.
void cleanUpLoop() override;

// Wake up the event loop.
void wakeupEventLoopToDeferFunction() override;

Expand Down
Loading