diff --git a/tensorpipe/common/deferred_executor.h b/tensorpipe/common/deferred_executor.h index 95afc4041..653a03720 100644 --- a/tensorpipe/common/deferred_executor.h +++ b/tensorpipe/common/deferred_executor.h @@ -160,6 +160,12 @@ class EventLoopDeferredExecutor : public virtual DeferredExecutor { // subclasses and called inside the thread owned by this parent class. virtual void eventLoop() = 0; + // This is called after the event loop terminated, still within the thread + // that used to run that event loop. It will be called after this class has + // transitioned control to the on-demand deferred executor. It thus allows to + // clean up any resources without worrying about new work coming in. + virtual void cleanUpLoop() {} + // This function is called by the parent class when a function is deferred to // it, and must be implemented by subclasses, which are required to have their // event loop call runDeferredFunctionsFromEventLoop as soon as possible. This @@ -230,6 +236,8 @@ class EventLoopDeferredExecutor : public virtual DeferredExecutor { fn(); } } + + cleanUpLoop(); } std::thread thread_; diff --git a/tensorpipe/test/transport/connection_test.cc b/tensorpipe/test/transport/connection_test.cc index af60258cc..a82725a62 100644 --- a/tensorpipe/test/transport/connection_test.cc +++ b/tensorpipe/test/transport/connection_test.cc @@ -245,3 +245,39 @@ TEST_P(TransportTest, DISABLED_Connection_EmptyBuffer) { peers_->join(PeerGroup::kClient); }); } + +TEST_P(TransportTest, Connection_SpamAtClosing) { + using namespace std::chrono_literals; + + std::shared_ptr ctx = GetParam()->getContext(); + ctx->setId("loopback"); + + std::string addr = GetParam()->defaultAddr(); + std::shared_ptr listener = ctx->listen(addr); + + std::atomic stopSpamming{false}; + std::function spam = [&]() { + if (stopSpamming) { + return; + } + std::shared_ptr conn = ctx->connect(addr); + conn->read( + [&](const Error& error, const void* /* unused */, size_t /* unused */) { + EXPECT_TRUE(error); + spam(); + }); + conn->close(); + }; + + spam(); + + std::this_thread::sleep_for(10ms); + + ctx->close(); + + std::this_thread::sleep_for(10ms); + + stopSpamming = true; + + ctx->join(); +} diff --git a/tensorpipe/transport/connection_boilerplate.h b/tensorpipe/transport/connection_boilerplate.h index df5d3e9a2..a1044a814 100644 --- a/tensorpipe/transport/connection_boilerplate.h +++ b/tensorpipe/transport/connection_boilerplate.h @@ -31,6 +31,8 @@ class ConnectionBoilerplate : public Connection { std::string id, Args... args); + explicit ConnectionBoilerplate(std::shared_ptr connection); + ConnectionBoilerplate(const ConnectionBoilerplate&) = delete; ConnectionBoilerplate(ConnectionBoilerplate&&) = delete; ConnectionBoilerplate& operator=(const ConnectionBoilerplate&) = delete; @@ -79,6 +81,16 @@ ConnectionBoilerplate::ConnectionBoilerplate( impl_->init(); } +template +ConnectionBoilerplate::ConnectionBoilerplate( + std::shared_ptr connection) + : impl_(std::move(connection)) { + static_assert( + std::is_base_of, TConn>:: + value, + ""); +} + template void ConnectionBoilerplate::read(read_callback_fn fn) { impl_->read(std::move(fn)); diff --git a/tensorpipe/transport/connection_impl_boilerplate.h b/tensorpipe/transport/connection_impl_boilerplate.h index b38962f3c..ec21b6075 100644 --- a/tensorpipe/transport/connection_impl_boilerplate.h +++ b/tensorpipe/transport/connection_impl_boilerplate.h @@ -126,11 +126,16 @@ class ConnectionImplBoilerplate : public std::enable_shared_from_this { // Deal with an error. void handleError(); - ClosingReceiver closingReceiver_; - // A sequence number for the calls to read and write. uint64_t nextBufferBeingRead_{0}; uint64_t nextBufferBeingWritten_{0}; + + // Contexts and listeners do sometimes need to call directly into initFromLoop + // and closeForLoop, in order to make sure that some of their operations can + // happen "atomically" on the connection, without possibly other operations + // occurring in between (e.g., an error). + friend ContextImplBoilerplate; + friend ListenerImplBoilerplate; }; template @@ -138,9 +143,7 @@ ConnectionImplBoilerplate::ConnectionImplBoilerplate( ConstructorToken /* unused */, std::shared_ptr context, std::string id) - : context_(std::move(context)), - id_(std::move(id)), - closingReceiver_(context_, context_->getClosingEmitter()) {} + : context_(std::move(context)), id_(std::move(id)) {} template void ConnectionImplBoilerplate::init() { @@ -150,7 +153,14 @@ void ConnectionImplBoilerplate::init() { template void ConnectionImplBoilerplate::initFromLoop() { - closingReceiver_.activate(*this); + if (context_->closed()) { + // Set the error without calling setError because we do not want to invoke + // the subclass's handleErrorImpl as it would find itself in a weird state + // (since initFromLoop wouldn't have been called). + error_ = TP_CREATE_ERROR(ConnectionClosedError); + TP_VLOG(7) << "Connection " << id_ << " is closing (without initing)"; + return; + } initImplFromLoop(); } diff --git a/tensorpipe/transport/context_impl_boilerplate.h b/tensorpipe/transport/context_impl_boilerplate.h index 7d42cdcec..16dc5025a 100644 --- a/tensorpipe/transport/context_impl_boilerplate.h +++ b/tensorpipe/transport/context_impl_boilerplate.h @@ -9,12 +9,12 @@ #pragma once #include +#include #include #include #include #include -#include #include #include #include @@ -51,7 +51,9 @@ class ContextImplBoilerplate : public virtual DeferredExecutor, void unenroll(TList& listener); void unenroll(TConn& connection); - ClosingEmitter& getClosingEmitter(); + // Return whether the context is in a closed state. To avoid race conditions, + // this must be called from within the loop. + bool closed(); void setId(std::string id); @@ -73,7 +75,6 @@ class ContextImplBoilerplate : public virtual DeferredExecutor, private: std::atomic closed_{false}; std::atomic joined_{false}; - ClosingEmitter closingEmitter_; const std::string domainDescriptor_; @@ -162,9 +163,9 @@ void ContextImplBoilerplate::unenroll(TConn& connection) { } template -ClosingEmitter& ContextImplBoilerplate:: - getClosingEmitter() { - return closingEmitter_; +bool ContextImplBoilerplate::closed() { + TP_DCHECK(inLoop()); + return closed_; }; template @@ -175,14 +176,32 @@ void ContextImplBoilerplate::setId(std::string id) { template void ContextImplBoilerplate::close() { - if (!closed_.exchange(true)) { - TP_VLOG(7) << "Transport context " << id_ << " is closing"; - - closingEmitter_.close(); - closeImpl(); - - TP_VLOG(7) << "Transport context " << id_ << " done closing"; - } + // Defer this to the loop so that it won't race with other code accessing it + // (in other words: any code in the loop can assume that this won't change). + deferToLoop([this]() { + if (!closed_.exchange(true)) { + TP_VLOG(7) << "Transport context " << id_ << " is closing"; + + // Make a copy as they could unenroll themselves inline. + decltype(listeners_) listenersCopy = listeners_; + decltype(connections_) connectionsCopy = connections_; + // We call closeFromLoop, rather than just close, because we need these + // objects to transition _immediately_ to error, "atomically". If we just + // deferred closing to later, this could come after some already-enqueued + // operations that could try to access the context, which would be closed, + // and this could fail. + for (auto& iter : listenersCopy) { + iter.second->closeFromLoop(); + } + for (auto& iter : connectionsCopy) { + iter.second->closeFromLoop(); + } + + closeImpl(); + + TP_VLOG(7) << "Transport context " << id_ << " done closing"; + } + }); } template @@ -192,6 +211,14 @@ void ContextImplBoilerplate::join() { if (!joined_.exchange(true)) { TP_VLOG(7) << "Transport context " << id_ << " is joining"; + // As closing is deferred to the loop, we must wait for closeImpl to be + // actually called before we call joinImpl, to avoid race conditions. For + // this, we defer another task to the loop, which we know will run after the + // closing, and then we wait for that task to be run. + std::promise hasClosed; + deferToLoop([&]() { hasClosed.set_value(); }); + hasClosed.get_future().wait(); + joinImpl(); TP_VLOG(7) << "Transport context " << id_ << " done joining"; diff --git a/tensorpipe/transport/ibv/listener_impl.cc b/tensorpipe/transport/ibv/listener_impl.cc index 50aa92a7c..4558d7321 100644 --- a/tensorpipe/transport/ibv/listener_impl.cc +++ b/tensorpipe/transport/ibv/listener_impl.cc @@ -146,7 +146,7 @@ void ListenerImpl::handleEventsFromLoop(int events) { if (fns_.empty()) { context_->unregisterDescriptor(socket_.fd()); } - fn(Error::kSuccess, createConnection(std::move(socket))); + fn(Error::kSuccess, createAndInitConnection(std::move(socket))); } } // namespace ibv diff --git a/tensorpipe/transport/listener_impl_boilerplate.h b/tensorpipe/transport/listener_impl_boilerplate.h index a169df956..efb13a27b 100644 --- a/tensorpipe/transport/listener_impl_boilerplate.h +++ b/tensorpipe/transport/listener_impl_boilerplate.h @@ -78,7 +78,7 @@ class ListenerImplBoilerplate : public std::enable_shared_from_this { Error error_{Error::kSuccess}; template - std::shared_ptr createConnection(Args&&... args); + std::shared_ptr createAndInitConnection(Args&&... args); // An identifier for the listener, composed of the identifier for the context, // combined with an increasing sequence number. It will be used as a prefix @@ -104,8 +104,6 @@ class ListenerImplBoilerplate : public std::enable_shared_from_this { // Deal with an error. void handleError(); - ClosingReceiver closingReceiver_; - // A sequence number for the calls to accept. uint64_t nextConnectionBeingAccepted_{0}; @@ -113,6 +111,12 @@ class ListenerImplBoilerplate : public std::enable_shared_from_this { // create their identifiers based off this listener's identifier. They will // only be used for logging and debugging. std::atomic connectionCounter_{0}; + + // Contexts do sometimes need to call directly into closeForLoop, in order to + // make sure that some of their operations can happen "atomically" on the + // connection, without possibly other operations occurring in between (e.g., + // an error). + friend ContextImplBoilerplate; }; template @@ -120,9 +124,7 @@ ListenerImplBoilerplate::ListenerImplBoilerplate( ConstructorToken /* unused */, std::shared_ptr context, std::string id) - : context_(std::move(context)), - id_(std::move(id)), - closingReceiver_(context_, context_->getClosingEmitter()) {} + : context_(std::move(context)), id_(std::move(id)) {} template void ListenerImplBoilerplate::init() { @@ -132,7 +134,14 @@ void ListenerImplBoilerplate::init() { template void ListenerImplBoilerplate::initFromLoop() { - closingReceiver_.activate(*this); + if (context_->closed()) { + // Set the error without calling setError because we do not want to invoke + // the subclass's handleErrorImpl as it would find itself in a weird state + // (since initFromLoop wouldn't have been called). + error_ = TP_CREATE_ERROR(ListenerClosedError); + TP_VLOG(7) << "Listener " << id_ << " is closing (without initing)"; + return; + } initImplFromLoop(); } @@ -189,15 +198,24 @@ std::string ListenerImplBoilerplate::addrFromLoop() const { template template std::shared_ptr ListenerImplBoilerplate:: - createConnection(Args&&... args) { + createAndInitConnection(Args&&... args) { + TP_DCHECK(context_->inLoop()); std::string connectionId = id_ + ".c" + std::to_string(connectionCounter_++); TP_VLOG(7) << "Listener " << id_ << " is opening connection " << connectionId; - return std::make_shared>( + auto connection = std::make_shared( typename ConnectionImplBoilerplate:: ConstructorToken(), context_, std::move(connectionId), std::forward(args)...); + // We initialize the connection from the loop immediately, inline, because the + // initialization of a connection accepted by a listener typically happens + // partly in the listener (e.g., opening and accepting the socket) and partly + // in the connection's initFromLoop, and we need these two steps to happen + // "atomicically" to make it impossible for an error to occur in between. + connection->initFromLoop(); + return std::make_shared>( + std::move(connection)); } template diff --git a/tensorpipe/transport/shm/listener_impl.cc b/tensorpipe/transport/shm/listener_impl.cc index 953d8b217..126d9b316 100644 --- a/tensorpipe/transport/shm/listener_impl.cc +++ b/tensorpipe/transport/shm/listener_impl.cc @@ -136,7 +136,7 @@ void ListenerImpl::handleEventsFromLoop(int events) { if (fns_.empty()) { context_->unregisterDescriptor(socket_.fd()); } - fn(Error::kSuccess, createConnection(std::move(socket))); + fn(Error::kSuccess, createAndInitConnection(std::move(socket))); } } // namespace shm diff --git a/tensorpipe/transport/uv/connection_impl.cc b/tensorpipe/transport/uv/connection_impl.cc index 21366d181..12a25b161 100644 --- a/tensorpipe/transport/uv/connection_impl.cc +++ b/tensorpipe/transport/uv/connection_impl.cc @@ -52,6 +52,8 @@ ConnectionImpl::ConnectionImpl( void ConnectionImpl::initImplFromLoop() { context_->enroll(*this); + TP_VLOG(9) << "Connection " << id_ << " is initializing in loop"; + if (sockaddr_.has_value()) { handle_->initFromLoop(); handle_->connectFromLoop(sockaddr_.value(), [this](int status) { diff --git a/tensorpipe/transport/uv/listener_impl.cc b/tensorpipe/transport/uv/listener_impl.cc index f052f1d33..b9ed1f448 100644 --- a/tensorpipe/transport/uv/listener_impl.cc +++ b/tensorpipe/transport/uv/listener_impl.cc @@ -36,6 +36,8 @@ ListenerImpl::ListenerImpl( void ListenerImpl::initImplFromLoop() { context_->enroll(*this); + TP_VLOG(9) << "Listener " << id_ << " is initializing in loop"; + handle_->initFromLoop(); auto rv = handle_->bindFromLoop(sockaddr_); TP_THROW_UV_IF(rv < 0, rv); @@ -67,7 +69,8 @@ void ListenerImpl::connectionCallbackFromLoop(int status) { auto connection = context_->createHandle(); connection->initFromLoop(); handle_->acceptFromLoop(*connection); - callback_.trigger(Error::kSuccess, createConnection(std::move(connection))); + callback_.trigger( + Error::kSuccess, createAndInitConnection(std::move(connection))); } void ListenerImpl::closeCallbackFromLoop() { diff --git a/tensorpipe/transport/uv/loop.cc b/tensorpipe/transport/uv/loop.cc index 52152ce99..2c59c0026 100644 --- a/tensorpipe/transport/uv/loop.cc +++ b/tensorpipe/transport/uv/loop.cc @@ -58,6 +58,10 @@ void Loop::eventLoop() { rv = uv_run(&loop_, UV_RUN_DEFAULT); TP_THROW_ASSERT_IF(rv > 0) << ": uv_run returned with active handles or requests"; +} + +void Loop::cleanUpLoop() { + int rv; uv_ref(reinterpret_cast(&async_)); uv_close(reinterpret_cast(&async_), nullptr); diff --git a/tensorpipe/transport/uv/loop.h b/tensorpipe/transport/uv/loop.h index 2383f122d..a6e31f8bc 100644 --- a/tensorpipe/transport/uv/loop.h +++ b/tensorpipe/transport/uv/loop.h @@ -32,6 +32,10 @@ class Loop final : public EventLoopDeferredExecutor { return &loop_; } + bool closed() { + return closed_; + } + void close(); void join(); @@ -42,6 +46,9 @@ class Loop final : public EventLoopDeferredExecutor { // Event loop thread entry function. void eventLoop() override; + // Clean up after event loop transitioned to on-demand. + void cleanUpLoop() override; + // Wake up the event loop. void wakeupEventLoopToDeferFunction() override; diff --git a/tensorpipe/transport/uv/uv.h b/tensorpipe/transport/uv/uv.h index 01771e0a2..41ba616f7 100644 --- a/tensorpipe/transport/uv/uv.h +++ b/tensorpipe/transport/uv/uv.h @@ -261,6 +261,7 @@ class TCPHandle : public StreamHandle { void initFromLoop() { TP_DCHECK(this->loop_.inLoop()); + TP_THROW_ASSERT_IF(loop_.closed()); int rv; rv = uv_tcp_init(loop_.ptr(), this->ptr()); TP_THROW_UV_IF(rv < 0, rv);