diff --git a/tensorpipe/transport/connection_impl_boilerplate.h b/tensorpipe/transport/connection_impl_boilerplate.h index 2fb8d60f1..ec21b6075 100644 --- a/tensorpipe/transport/connection_impl_boilerplate.h +++ b/tensorpipe/transport/connection_impl_boilerplate.h @@ -126,8 +126,6 @@ class ConnectionImplBoilerplate : public std::enable_shared_from_this { // Deal with an error. void handleError(); - ClosingReceiver closingReceiver_; - // A sequence number for the calls to read and write. uint64_t nextBufferBeingRead_{0}; uint64_t nextBufferBeingWritten_{0}; @@ -145,9 +143,7 @@ ConnectionImplBoilerplate::ConnectionImplBoilerplate( ConstructorToken /* unused */, std::shared_ptr context, std::string id) - : context_(std::move(context)), - id_(std::move(id)), - closingReceiver_(context_, context_->getClosingEmitter()) {} + : context_(std::move(context)), id_(std::move(id)) {} template void ConnectionImplBoilerplate::init() { @@ -166,8 +162,6 @@ void ConnectionImplBoilerplate::initFromLoop() { return; } - closingReceiver_.activate(*this); - initImplFromLoop(); } diff --git a/tensorpipe/transport/context_impl_boilerplate.h b/tensorpipe/transport/context_impl_boilerplate.h index fd3571fc5..1d9e421f7 100644 --- a/tensorpipe/transport/context_impl_boilerplate.h +++ b/tensorpipe/transport/context_impl_boilerplate.h @@ -14,7 +14,6 @@ #include #include -#include #include #include #include @@ -55,8 +54,6 @@ class ContextImplBoilerplate : public virtual DeferredExecutor, // this must be called from within the loop. bool closed(); - ClosingEmitter& getClosingEmitter(); - void setId(std::string id); void close(); @@ -77,7 +74,6 @@ class ContextImplBoilerplate : public virtual DeferredExecutor, private: std::atomic closed_{false}; std::atomic joined_{false}; - ClosingEmitter closingEmitter_; const std::string domainDescriptor_; @@ -171,12 +167,6 @@ bool ContextImplBoilerplate::closed() { return closed_; }; -template -ClosingEmitter& ContextImplBoilerplate:: - getClosingEmitter() { - return closingEmitter_; -}; - template void ContextImplBoilerplate::setId(std::string id) { TP_VLOG(7) << "Transport context " << id_ << " was renamed to " << id; @@ -191,7 +181,21 @@ void ContextImplBoilerplate::close() { if (!closed_.exchange(true)) { TP_VLOG(7) << "Transport context " << id_ << " is closing"; - closingEmitter_.close(); + // Make a copy as they could unenroll themselves inline. + decltype(listeners_) listenersCopy = listeners_; + decltype(connections_) connectionsCopy = connections_; + // We call closeFromLoop, rather than just close, because we need these + // objects to transition _immediately_ to error, "atomically". If we just + // deferred closing to later, this could come after some already-enqueued + // operations that could try to access the context, which would be closed, + // and this could fail. + for (auto& iter : listenersCopy) { + iter.second->closeFromLoop(); + } + for (auto& iter : connectionsCopy) { + iter.second->closeFromLoop(); + } + closeImpl(); TP_VLOG(7) << "Transport context " << id_ << " done closing"; diff --git a/tensorpipe/transport/listener_impl_boilerplate.h b/tensorpipe/transport/listener_impl_boilerplate.h index 94790a543..efb13a27b 100644 --- a/tensorpipe/transport/listener_impl_boilerplate.h +++ b/tensorpipe/transport/listener_impl_boilerplate.h @@ -104,8 +104,6 @@ class ListenerImplBoilerplate : public std::enable_shared_from_this { // Deal with an error. void handleError(); - ClosingReceiver closingReceiver_; - // A sequence number for the calls to accept. uint64_t nextConnectionBeingAccepted_{0}; @@ -113,6 +111,12 @@ class ListenerImplBoilerplate : public std::enable_shared_from_this { // create their identifiers based off this listener's identifier. They will // only be used for logging and debugging. std::atomic connectionCounter_{0}; + + // Contexts do sometimes need to call directly into closeForLoop, in order to + // make sure that some of their operations can happen "atomically" on the + // connection, without possibly other operations occurring in between (e.g., + // an error). + friend ContextImplBoilerplate; }; template @@ -120,9 +124,7 @@ ListenerImplBoilerplate::ListenerImplBoilerplate( ConstructorToken /* unused */, std::shared_ptr context, std::string id) - : context_(std::move(context)), - id_(std::move(id)), - closingReceiver_(context_, context_->getClosingEmitter()) {} + : context_(std::move(context)), id_(std::move(id)) {} template void ListenerImplBoilerplate::init() { @@ -141,8 +143,6 @@ void ListenerImplBoilerplate::initFromLoop() { return; } - closingReceiver_.activate(*this); - initImplFromLoop(); }