diff --git a/include/unifex/at_coroutine_exit.hpp b/include/unifex/at_coroutine_exit.hpp index 775fc0b9..ae9f0339 100644 --- a/include/unifex/at_coroutine_exit.hpp +++ b/include/unifex/at_coroutine_exit.hpp @@ -27,6 +27,7 @@ #include #include #include +#include #if UNIFEX_NO_COROUTINES # error "Coroutine support is required to use this header" @@ -57,10 +58,25 @@ inline constexpr struct _fn { } // namespace _xchg_cont using _xchg_cont::exchange_continuation; +template struct _cleanup_promise_base { struct final_awaitable { bool await_ready() const noexcept { return false; } + template + coro::coroutine_handle<> await_suspend_impl( + coro::coroutine_handle h) const noexcept { + if constexpr (WithAsyncStackSupport) { + if (h.promise().parentFrame_ != nullptr) { + popAsyncStackFrameCallee(h.promise().frame_); + } + } + + auto continuation = h.promise().next(); + h.destroy(); // The cleanup action has finished executing. Destroy it. + return continuation; + } + // Clang before clang-12 has a bug with coroutines that self-destruct in an // await_suspend that uses symmetric transfer. It appears that MSVC has the // same bug, while Emscripten, the WebAssembly compiler just doesn't support @@ -81,18 +97,14 @@ struct _cleanup_promise_base { # endif void await_suspend(coro::coroutine_handle h) const noexcept { - auto continuation = h.promise().next(); - h.destroy(); // The cleanup action has finished executing. Destroy it. - continuation.resume(); + await_suspend_impl(h).resume(); } #else // No bugs here! OK to use symmetric transfer. template coro::coroutine_handle<> await_suspend(coro::coroutine_handle h) const noexcept { - auto continuation = h.promise().next(); - h.destroy(); // The cleanup action has finished executing. Destroy it. - return continuation; + return await_suspend_impl(h); } #endif @@ -135,10 +147,24 @@ struct _cleanup_promise_base { return p.sched_; } + template(typename Promise) // + (requires WithAsyncStackSupport AND + convertible_to) // + friend constexpr AsyncStackFrame* tag_invoke( + tag_t, const Promise& p) noexcept { + return &p.frame_; + } + inline static constexpr inline_scheduler _default_scheduler{}; continuation_handle<> continuation_{}; any_scheduler sched_{_default_scheduler}; bool isUnhandledDone_{false}; + UNIFEX_NO_UNIQUE_ADDRESS mutable std:: + conditional_t> + parentFrame_{}; + UNIFEX_NO_UNIQUE_ADDRESS mutable std:: + conditional_t> + frame_; }; // The die_on_done algorithm implemented here could be implemented in terms of @@ -233,16 +259,16 @@ struct _die_on_done_fn { } }; -template +template struct _cleanup_task; -template -struct _cleanup_promise : _cleanup_promise_base { +template +struct _cleanup_promise : _cleanup_promise_base { template explicit _cleanup_promise(Action&&, Ts&... ts) : args_(ts...) {} - _cleanup_task get_return_object() noexcept { - return _cleanup_task( + _cleanup_task get_return_object() noexcept { + return _cleanup_task( coro::coroutine_handle<_cleanup_promise>::from_promise(*this)); } @@ -253,6 +279,12 @@ struct _cleanup_promise : _cleanup_promise_base { template decltype(auto) await_transform(Value&& value) noexcept(noexcept( unifex::await_transform(*this, _die_on_done_fn{}((Value&&)value)))) { + if constexpr (WithAsyncStackSupport) { + if (this->parentFrame_ != nullptr) { + pushAsyncStackFrameCallerCallee(*this->parentFrame_, this->frame_); + } + } + return unifex::await_transform(*this, _die_on_done_fn{}((Value&&)value)); } @@ -261,15 +293,15 @@ struct _cleanup_promise : _cleanup_promise_base { // Record that we are processing an unhandled done signal. This is checked // in the final_suspend of the cleanup action to know which subsequent // continuation to resume. - isUnhandledDone_ = true; + this->isUnhandledDone_ = true; // On unhandled_done, run the cleanup action: return coro::coroutine_handle<_cleanup_promise>::from_promise(*this); }); }; -template +template struct [[nodiscard]] _cleanup_task { - using promise_type = _cleanup_promise; + using promise_type = _cleanup_promise; explicit _cleanup_task(coro::coroutine_handle coro) noexcept : continuation_(coro) {} @@ -279,29 +311,46 @@ struct [[nodiscard]] _cleanup_task { ~_cleanup_task() { UNIFEX_ASSERT(!continuation_); } - bool await_ready() const noexcept { return false; } + struct awaiter { + bool await_ready() const noexcept { return false; } - template - bool await_suspend_impl_(Promise& parent) noexcept { - continuation_.promise().continuation_ = - exchange_continuation(parent, continuation_); - continuation_.promise().sched_ = get_scheduler(parent); - return false; - } + template + bool await_suspend_impl_( + Promise& parent, + [[maybe_unused]] instruction_ptr returnAddress = + instruction_ptr::read_return_address()) noexcept { + continuation_.promise().continuation_ = + exchange_continuation(parent, continuation_); + continuation_.promise().sched_ = get_scheduler(parent); + if constexpr (WithAsyncStackSupport) { + continuation_.promise().parentFrame_ = get_async_stack_frame(parent); + continuation_.promise().frame_.setReturnAddress(returnAddress); + } + return false; + } - template - bool await_suspend(coro::coroutine_handle parent) noexcept { - return await_suspend_impl_(parent.promise()); - } + template + UNIFEX_NO_INLINE bool + await_suspend(coro::coroutine_handle parent) noexcept { + return await_suspend_impl_(parent.promise()); + } - std::tuple await_resume() noexcept { - return std::move(std::exchange(continuation_, {}).promise().args_); - } + std::tuple await_resume() noexcept { + return std::move(std::exchange(continuation_, {}).promise().args_); + } + + // TODO: how do we address always-inline awaitables + friend constexpr auto tag_invoke(tag_t, const awaiter&) noexcept { + return blocking_kind::always_inline; + } - // TODO: how do we address always-inline awaitables - friend constexpr auto - tag_invoke(tag_t, const _cleanup_task&) noexcept { - return blocking_kind::always_inline; + continuation_handle continuation_; + }; + + template + friend awaiter + tag_invoke(tag_t, Promise&, _cleanup_task task) noexcept { + return awaiter{std::exchange(task.continuation_, {})}; } private: @@ -311,18 +360,23 @@ struct [[nodiscard]] _cleanup_task { namespace _at_coroutine_exit { inline constexpr struct _fn { private: - template - static _cleanup_task at_coroutine_exit(Action action, Ts... ts) { + template + static _cleanup_task + at_coroutine_exit(Action action, Ts... ts) { co_await std::move(action)(std::move(ts)...); } public: - template(typename Action, typename... Ts) // + template( + typename Action, + typename... Ts, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS) // (requires std:: is_invocable_v, std::decay_t...>) // - _cleanup_task...> + _cleanup_task...> operator()(Action&& action, Ts&&... ts) const { - return _fn::at_coroutine_exit((Action&&)action, (Ts&&)ts...); + return _fn::at_coroutine_exit( + (Action&&)action, (Ts&&)ts...); } } at_coroutine_exit{}; } // namespace _at_coroutine_exit diff --git a/include/unifex/await_transform.hpp b/include/unifex/await_transform.hpp index 10530968..e0d08ed2 100644 --- a/include/unifex/await_transform.hpp +++ b/include/unifex/await_transform.hpp @@ -68,18 +68,18 @@ struct _expected { namespace _await_tfx { using namespace _util; -template +template struct _awaitable_base { struct type; }; -template +template struct _awaitable { struct type; }; -template -struct _awaitable_base::type { +template +struct _awaitable_base::type { struct _rec { public: explicit _rec( @@ -92,6 +92,19 @@ struct _awaitable_base::type { : result_(std::exchange(r.result_, nullptr)) , continuation_(std::move(r.continuation_)) {} + void complete() noexcept { + if constexpr (WithAsyncStackSupport) { + if (auto* frame = get_async_stack_frame(continuation_.promise())) { + detail::ScopedAsyncStackRoot root; + root.activateFrame(*frame); + return continuation_.resume(); + } + } + + // run this when stacks are disabled and when the parent hasn't got one + continuation_.resume(); + } + template(class... Us) // (requires( constructible_from || @@ -101,13 +114,13 @@ struct _awaitable_base::type { std::is_void_v) { unifex::activate_union_member(result_->value_, (Us&&)us...); result_->state_ = _state::value; - continuation_.resume(); + complete(); } void set_error(std::exception_ptr eptr) && noexcept { unifex::activate_union_member(result_->exception_, std::move(eptr)); result_->state_ = _state::exception; - continuation_.resume(); + complete(); } void set_error(std::error_code code) && noexcept { @@ -117,6 +130,23 @@ struct _awaitable_base::type { void set_done() && noexcept { result_->state_ = _state::done; + + if constexpr (WithAsyncStackSupport) { + if (auto* parentFrame = + get_async_stack_frame(continuation_.promise())) { + // we need a dummy frame for the waiting coroutine's unhandled_done() + // to pop for us + AsyncStackFrame frame; + frame.setParentFrame(*parentFrame); + + detail::ScopedAsyncStackRoot root; + root.activateFrame(frame); + + return continuation_.resume_done(); + } + } + + // run this when stacks are disabled and when the parent hasn't got one continuation_.resume_done(); } @@ -158,18 +188,21 @@ struct _awaitable_base::type { _expected result_; }; -template +template using _awaitable_base_t = typename _awaitable_base< Promise, - sender_single_value_return_type_t>>::type; + sender_single_value_return_type_t>, + WithAsyncStackSupport>::type; -template -using _receiver_t = typename _awaitable_base_t::_rec; +template +using _receiver_t = + typename _awaitable_base_t::_rec; -template -struct _awaitable::type : _awaitable_base_t { +template +struct _awaitable::type + : _awaitable_base_t { private: - using _rec = _receiver_t; + using _rec = _receiver_t; connect_result_t op_; public: @@ -177,13 +210,288 @@ struct _awaitable::type : _awaitable_base_t { is_nothrow_connectable_v) : op_(unifex::connect((Sender&&)sender, _rec{&this->result_, h})) {} - void await_suspend(coro::coroutine_handle) noexcept { + void await_suspend(coro::coroutine_handle handle) noexcept { + if constexpr (WithAsyncStackSupport) { + auto* frame = get_async_stack_frame(handle.promise()); + if (frame) { + deactivateAsyncStackFrame((*frame)); + } + } unifex::start(op_); } }; -template -using _as_awaitable = typename _awaitable::type; +template +using _as_awaitable = + typename _awaitable::type; + +template +struct is_resumer_promise : std::false_type {}; + +template +struct is_resumer_promise : std::true_type {}; + +template +constexpr bool is_resumer_promise_v = is_resumer_promise::value; + +template +struct _coro_resumer final { + struct type; +}; + +template +struct _coro_resumer::type final { + struct promise_type { + using resumer_promise_t = void; + + static_assert(!is_resumer_promise_v); + + promise_type(coro::coroutine_handle& h) noexcept : handle_(h) {} + + type get_return_object() noexcept { + return type{coro::coroutine_handle::from_promise(*this)}; + } + + coro::suspend_always initial_suspend() noexcept { return {}; } + + [[noreturn]] coro::suspend_always final_suspend() noexcept { + std::terminate(); + } + + // TODO: unhandled_done()? + + [[noreturn]] void return_void() noexcept { std::terminate(); } + + [[noreturn]] void unhandled_exception() noexcept { std::terminate(); } + + struct awaiter { + coro::coroutine_handle h; + + bool await_ready() noexcept { return false; } + + void await_suspend(coro::coroutine_handle<>) noexcept { + auto* frame = get_async_stack_frame(h.promise()); + if (frame) { + detail::ScopedAsyncStackRoot root; + root.activateFrame(*frame); + + h.resume(); + + root.ensureFrameDeactivated(frame); + } else { + h.resume(); + } + } + + [[noreturn]] void await_resume() noexcept { std::terminate(); } + }; + + awaiter await_transform(coro::coroutine_handle h) noexcept { + return awaiter{h}; + } + + template(typename CPO) // + (requires is_receiver_query_cpo_v) // + friend auto tag_invoke(CPO cpo, const promise_type& self) noexcept( + is_nothrow_tag_invocable_v) + -> tag_invoke_result_t { + return tag_invoke(std::move(cpo), std::as_const(self.handle_.promise())); + } + + continuation_handle handle_; + }; + + type() noexcept = default; + + type(type&& other) noexcept : h_(std::exchange(other.h_, {})) {} + + ~type() { + if (h_) { + h_.destroy(); + } + } + + type& operator=(type rhs) noexcept { + std::swap(h_, rhs.h_); + return *this; + } + + coro::coroutine_handle handle() && noexcept { + return std::exchange(h_, {}); + } + +private: + explicit type(coro::coroutine_handle h) noexcept : h_(h) {} + + coro::coroutine_handle h_; +}; + +template +using coro_resumer = typename _coro_resumer::type; + +template +coro_resumer +resume_with_stack_root(coro::coroutine_handle h) { + co_await h; +} + +template +struct _awaitable_wrapper final { + class type; +}; + +template +class _awaitable_wrapper::type final { + using awaiter_t = awaiter_type_t; + + Awaitable&& awaitable_; + awaiter_t awaiter_; + coro::coroutine_handle<> coro_; + +public: + using awaitable_wrapper_t = void; + + explicit type(Awaitable&& awaitable) + : awaitable_(std::forward(awaitable)) + , awaiter_(get_awaiter(std::forward(awaitable))) {} + + type(type&& other) noexcept(std::is_nothrow_move_constructible_v) + : awaitable_(std::move(other.awaitable_)) + , awaiter_(std::move(other.awaiter_)) + , coro_(std::exchange(other.coro_, {})) { + // we should only be move-constructed before being awaited + UNIFEX_ASSERT(!coro_); + } + + ~type() { + if (coro_) { + coro_.destroy(); + } + } + + bool await_ready() noexcept(noexcept(awaiter_.await_ready())) { + return awaiter_.await_ready(); + } + + template + using resume_coro_handle_t = + coro::coroutine_handle::promise_type>; + + template + using _suspend_result_t = decltype(awaiter_.await_suspend( + resume_coro_handle_t::from_address(nullptr))); + + template + using suspend_result_t = std::conditional_t< + convertible_to<_suspend_result_t, coro::coroutine_handle<>>, + coro::coroutine_handle<>, + _suspend_result_t>; + + template(typename Promise) // + (requires same_as>) // + bool await_suspend_impl( + coro::coroutine_handle h, AsyncStackFrame* frame) { + auto* root = frame->getStackRoot(); + + auto resumer = resume_with_stack_root(h).handle(); + + // save for later destruction + coro_ = resumer; + + // ensure that it's safe for the resumer coroutine to activate h's stack + // frame on resumption + deactivateAsyncStackFrame(*frame); + + if (awaiter_.await_suspend(resumer)) { + // suspend + return true; + } else { + // we're not actually suspending so undo the stack manipulation we just + // did + activateAsyncStackFrame(*root, *frame); + + // proactively destroy the unneeded coro_resumer + std::exchange(coro_, {}).destroy(); + + // resume the caller + return false; + } + } + + template(typename Promise) // + (requires(!same_as>)) // + suspend_result_t await_suspend_impl( + coro::coroutine_handle h, AsyncStackFrame* frame) { + auto resumer = resume_with_stack_root(h).handle(); + + // save for later destruction + coro_ = resumer; + + // ensure that it's safe for the resumer coroutine to activate h's stack + // frame on resumption + deactivateAsyncStackFrame(*frame); + + return awaiter_.await_suspend(resumer); + } + + template + suspend_result_t await_suspend(coro::coroutine_handle h) { + if (auto* frame = get_async_stack_frame(h.promise())) { + return await_suspend_impl(h, frame); + } + + using awaiter_suspend_result_t = decltype(awaiter_.await_suspend(h)); + + // Note: it's technically possible for an awaitable's implementation of + // await_suspend() to return different types depending on its argument + // type. This is easily handled if the "different types" are different + // coroutine_handle<> types: just convert them all to + // coro::coroutine_handle<>; but it's a pain if the different return + // types mix-and-match between void, bool, and coroutine handles. If + // any reports ever come in that these static asserts are breaking + // builds, we can handle it by forcing *our* return type to always be + // coro::coroutine_handle<> and just map the void and bool cases to + // the appropriate handle, but let's avoid that complexity until it's + // proven necessary. + if constexpr (same_as>) { + static_assert(same_as); + } else if constexpr (same_as>) { + static_assert(same_as); + } else { + static_assert( + convertible_to>); + } + + return awaiter_.await_suspend(h); + } + + auto await_resume() noexcept(noexcept(awaiter_.await_resume())) + -> decltype(awaiter_.await_resume()) { + return awaiter_.await_resume(); + } + + template(typename CPO) // + (requires same_as, CPO> AND + std::is_invocable_v) // + friend auto tag_invoke(CPO cpo, const type& self) noexcept( + std::is_nothrow_invocable_v) + -> std::invoke_result_t { + return std::move(cpo)(std::as_const(self.awaitable)); + } +}; + +template +using awaitable_wrapper = typename _awaitable_wrapper::type; + +template +struct is_awaitable_wrapper : std::false_type {}; + +template +struct is_awaitable_wrapper + : std::true_type {}; + +template +constexpr bool is_awaitable_wrapper_v = is_awaitable_wrapper::value; struct _fn { // Call custom implementation if present. @@ -201,26 +509,41 @@ struct _fn { } // Default implementation for naturally awaitable types - template(typename Promise, typename Value) // + template( + typename Promise, + typename Value, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS) // (requires(!tag_invocable<_fn, Promise&, Value>) AND detail::_awaitable) // - Value&& + decltype(auto) operator()(Promise&, Value&& value) const noexcept { - return std::forward(value); + if constexpr ( + WithAsyncStackSupport && + !is_awaitable_wrapper_v>) { + return awaitable_wrapper{std::forward(value)}; + } else { + return std::forward(value); + } } // Default implementation for non-awaitable senders - template(typename Promise, typename Value) // + template( + typename Promise, + typename Value, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS) // (requires(!tag_invocable<_fn, Promise&, Value>) AND(!detail::_awaitable) AND unifex::sender) // decltype(auto) operator()(Promise& promise, Value&& value) const { static_assert( - unifex::sender_to>, + unifex::sender_to< + Value, + _receiver_t>, "This sender is not awaitable in this coroutine type."); auto h = coro::coroutine_handle::from_promise(promise); - return _as_awaitable{(Value&&)value, h}; + return _as_awaitable{ + (Value&&)value, h}; } // Fall back to returning the argument if none of the above conditions are met diff --git a/include/unifex/connect_awaitable.hpp b/include/unifex/connect_awaitable.hpp index e8f5586c..16894665 100644 --- a/include/unifex/connect_awaitable.hpp +++ b/include/unifex/connect_awaitable.hpp @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include @@ -36,19 +38,28 @@ namespace unifex { namespace _await { -template +template struct _sender_task { class type; }; -template -using sender_task = typename _sender_task::type; +template +using sender_task = + typename _sender_task::type; -template -class _sender_task::type { +template +class _sender_task::type { public: struct promise_type { template - explicit promise_type(Awaitable&, Receiver& r) noexcept : receiver_(r) {} + explicit promise_type( + Awaitable&, + Receiver& r, + [[maybe_unused]] instruction_ptr returnAddress) noexcept + : receiver_(r) { + if constexpr (WithAsyncStackSupport) { + frame_.setReturnAddress(returnAddress); + } + } type get_return_object() noexcept { return type{coro::coroutine_handle::from_promise(*this)}; @@ -68,8 +79,12 @@ class _sender_task::type { struct awaiter { Func&& func_; bool await_ready() noexcept { return false; } - void await_suspend(coro::coroutine_handle) noexcept( + void await_suspend(coro::coroutine_handle h) noexcept( std::is_nothrow_invocable_v) { + if constexpr (WithAsyncStackSupport) { + deactivateAsyncStackFrame(h.promise().frame_); + } + std::forward(func_)(); } [[noreturn]] void await_resume() noexcept { std::terminate(); } @@ -99,12 +114,27 @@ class _sender_task::type { friend auto tag_invoke(CPO cpo, const promise_type& p) noexcept( std::is_nothrow_invocable_v) -> std::invoke_result_t { - return cpo(std::as_const(p.receiver_)); + if constexpr ( + WithAsyncStackSupport && same_as>) { + return &p.frame_; + } else { + return std::move(cpo)(std::as_const(p.receiver_)); + } } Receiver& receiver_; - done_coro doneCoro_ = unifex::unhandled_done( - [this]() noexcept { unifex::set_done(std::move(receiver_)); }); + done_coro doneCoro_ = unifex::unhandled_done([this]() noexcept { + if constexpr (WithAsyncStackSupport) { + popAsyncStackFrameFromCaller(frame_); + deactivateAsyncStackFrame(frame_); + } + + unifex::set_done(std::move(receiver_)); + }); + + UNIFEX_NO_UNIQUE_ADDRESS mutable std:: + conditional_t> + frame_; }; coro::coroutine_handle coro_; @@ -119,7 +149,24 @@ class _sender_task::type { coro_.destroy(); } - void start() & noexcept { coro_.resume(); } + void start() & noexcept { + if constexpr (WithAsyncStackSupport) { + detail::ScopedAsyncStackRoot root; + + auto* frame = &coro_.promise().frame_; + if (auto parentFrame = get_async_stack_frame(coro_.promise().receiver_)) { + frame->setParentFrame(*parentFrame); + } + + root.activateFrame(*frame); + + coro_.resume(); + + root.ensureFrameDeactivated(frame); + } else { + coro_.resume(); + } + } }; } // namespace _await @@ -138,9 +185,10 @@ inline const struct _fn { operator unit() const noexcept { return {}; } }; - template - static auto connect_impl(Awaitable awaitable, Receiver receiver) - -> _await::sender_task { + template + static auto + connect_impl(Awaitable awaitable, Receiver receiver, instruction_ptr) + -> _await::sender_task { #if !UNIFEX_NO_EXCEPTIONS std::exception_ptr ex; try { @@ -149,7 +197,8 @@ inline const struct _fn { // The _sender_task's promise type has an await_transform that passes the // awaitable through unifex::await_transform. So take that into // consideration when computing the result type: - using promise_type = typename _await::sender_task::promise_type; + using promise_type = typename _await:: + sender_task::promise_type; using awaitable_type = std::invoke_result_t< tag_t, promise_type&, @@ -165,12 +214,17 @@ inline const struct _fn { // after the coroutine is suspended so that it is safe // for the receiver to destroy the coroutine. co_yield [&](result_type&& result) { - return [&] { - if constexpr (std::is_void_v>) { - unifex::set_value(std::move(receiver)); - } else { - unifex::set_value( - std::move(receiver), std::forward(result)); + return [&]() noexcept { + UNIFEX_TRY { + if constexpr (std::is_void_v>) { + unifex::set_value(std::move(receiver)); + } else { + unifex::set_value( + std::move(receiver), std::forward(result)); + } + } + UNIFEX_CATCH(...) { + unifex::set_error(std::move(receiver), std::current_exception()); } }; // The _comma_hack here makes this well-formed when the co_await @@ -189,11 +243,16 @@ inline const struct _fn { } public: - template + template < + typename Awaitable, + typename Receiver, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS> auto operator()(Awaitable&& awaitable, Receiver&& receiver) const - -> _await::sender_task> { - return connect_impl( - std::forward(awaitable), std::forward(receiver)); + -> _await::sender_task, WithAsyncStackSupport> { + return connect_impl( + std::forward(awaitable), + std::forward(receiver), + instruction_ptr::read_return_address()); } } connect_awaitable{}; } // namespace _await_cpo @@ -297,6 +356,7 @@ struct _fn { (requires detail::_awaitable) // _sender> operator()(Awaitable&& awaitable) const { + // TODO: this is going to generate an unfortunate return address return _sender>{ std::forward(awaitable), instruction_ptr::read_return_address()}; diff --git a/include/unifex/stop_if_requested.hpp b/include/unifex/stop_if_requested.hpp index 733a6064..edb25960 100644 --- a/include/unifex/stop_if_requested.hpp +++ b/include/unifex/stop_if_requested.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) Facebook, Inc. and its affiliates. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License Version 2.0 with LLVM Exceptions * (the "License"); you may not use this file except in compliance with @@ -16,11 +16,14 @@ #pragma once #include +#include #include #include #include #include +#include #include +#include #include @@ -37,35 +40,60 @@ struct _fn { void start() & noexcept { UNIFEX_TRY { if (get_stop_token(std::as_const(rec_)).stop_requested()) { - unifex::set_done((Receiver &&) rec_); + unifex::set_done((Receiver&&)rec_); } else { - unifex::set_value((Receiver &&) rec_); + unifex::set_value((Receiver&&)rec_); } } UNIFEX_CATCH(...) { - unifex::set_error((Receiver &&) rec_, std::current_exception()); + unifex::set_error((Receiver&&)rec_, std::current_exception()); } } }; }; - public: #if !UNIFEX_NO_COROUTINES + template + struct awaiter { + UNIFEX_NO_UNIQUE_ADDRESS + std::conditional_t< + WithAsyncStackSupport, + AsyncStackFrame, + detail::_empty<0>> + frame; + + bool await_ready() const noexcept { return false; } + + template + coro::coroutine_handle<> + await_suspend(coro::coroutine_handle coro) noexcept { + if (get_stop_token(coro.promise()).stop_requested()) { + if constexpr (WithAsyncStackSupport) { + frame.setReturnAddress(); + if (auto parentFrame = get_async_stack_frame(coro.promise())) { + pushAsyncStackFrameCallerCallee(*parentFrame, frame); + } + } + + return coro.promise().unhandled_done(); + } + return coro; // don't suspend + } + void await_resume() const noexcept {} + }; + // Provide an awaiter interface in addition to the sender interface // because as an awaiter we can take advantage of symmetric transfer // to save stack space: - bool await_ready() const noexcept { return false; } - template - coro::coroutine_handle<> - await_suspend(coro::coroutine_handle coro) const noexcept { - if (get_stop_token(coro.promise()).stop_requested()) { - return coro.promise().unhandled_done(); - } - return coro; // don't suspend + template < + typename Promise, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS> + friend awaiter + tag_invoke(tag_t, Promise&, _sender) noexcept { + return {}; } - void await_resume() const noexcept {} #endif - + public: template < template class Variant, @@ -84,7 +112,7 @@ struct _fn { (requires receiver_of) // auto connect(Receiver&& rec) const -> typename _op>::type { - return typename _op>::type{(Receiver &&) rec}; + return typename _op>::type{(Receiver&&)rec}; } }; diff --git a/include/unifex/task.hpp b/include/unifex/task.hpp index eab2f88b..d75eab78 100644 --- a/include/unifex/task.hpp +++ b/include/unifex/task.hpp @@ -16,6 +16,7 @@ #pragma once #include +#include #include #include #include @@ -33,6 +34,8 @@ #include #include #include +#include +#include #include #include #include @@ -181,6 +184,11 @@ struct _promise_base { return std::exchange(p.continuation_, std::move(action)); } + friend constexpr AsyncStackFrame* + tag_invoke(tag_t, const _promise_base& p) noexcept { + return p.frame_; + } + #ifdef UNIFEX_ENABLE_CONTINUATION_VISITATIONS template friend void @@ -201,6 +209,10 @@ struct _promise_base { done_coro doneCoro_; // gets set to the return address of the ramp function instruction_ptr returnAddress_; + // the async stack frame corresponding to this coroutine + // null until this coroutine is awaited; stays null when async stack support + // is disabled + AsyncStackFrame* frame_{}; }; /** @@ -208,8 +220,13 @@ struct _promise_base { */ struct _task_promise_base : _promise_base { _task_promise_base() - : _promise_base([this]() noexcept { return continuation_.done_handle(); }) { - } + : _promise_base([this]() noexcept { + if (frame_) { + popAsyncStackFrameFromCaller(*frame_); + } + + return continuation_.done_handle(); + }) {} // the implementation of the magic of co_await schedule(s); this is to be // ripped out and replaced with something more explicit @@ -347,7 +364,9 @@ struct _promise final { return awaiter{}; } - template + template < + typename Value, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS> // todo: consider if this should be nothrow or not // NOTE: Magic rescheduling is not currently supported by nothrow tasks decltype(auto) await_transform(Value&& value) { @@ -366,9 +385,9 @@ struct _promise final { } else if constexpr ( tag_invocable, type&, Value> || detail::_awaitable) { - // Either await_transform has been customized or Value is an awaitable. - // Either way, we can dispatch to the await_transform CPO, then insert a - // transition back to the correct execution context if necessary. + // await_transform has been customized so we can dispatch to the + // await_transform CPO, then insert a transition back to the correct + // execution context if necessary. return with_scheduler_affinity( *this, unifex::await_transform(*this, static_cast(value)), @@ -401,18 +420,28 @@ struct _promise final { }; }; +struct _frame_state { + _frame_state() noexcept = default; + + explicit _frame_state(AsyncStackFrame& frame, AsyncStackRoot& root) noexcept + : frame_(&frame) + , root_(&root) {} + + void restore_frame_state() const noexcept { + if (frame_) { + activateAsyncStackFrame(*root_, *frame_); + } + } + +private: + AsyncStackFrame* frame_{}; + AsyncStackRoot* root_; // only conditionally initialized +}; + struct _sr_thunk_promise_base : _promise_base { _sr_thunk_promise_base() : _promise_base([this]() noexcept -> coro::coroutine_handle<> { - callback_.destruct(); - - whoToContinue_ = continuation::DONE; - - if (refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - return continuation_.done_handle(); - } else { - return coro::noop_coroutine(); - } + return complete_and_choose_continuation(continuation_.done_handle()); }) {} friend inplace_stop_token @@ -444,11 +473,15 @@ struct _sr_thunk_promise_base : _promise_base { void set_value(bool) noexcept { if (self->refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - if (self->whoToContinue_ == continuation::PRIMARY) { - self->continuation_.resume(); + UNIFEX_ASSERT(self->whoToContinue_); + + if (self->frame_) { + unifex::detail::ScopedAsyncStackRoot root; + root.activateFrame(*self->frame_); + + self->whoToContinue_.resume(); } else { - UNIFEX_ASSERT(self->whoToContinue_ == continuation::DONE); - self->continuation_.resume_done(); + self->whoToContinue_.resume(); } } } @@ -480,17 +513,71 @@ struct _sr_thunk_promise_base : _promise_base { std::atomic refCount_{1}; - enum class continuation : uint8_t { - UNSET, - PRIMARY, - DONE, - }; - - continuation whoToContinue_{continuation::UNSET}; + coro::coroutine_handle<> whoToContinue_{}; void register_stop_callback() noexcept { callback_.construct(stoken_, stop_callback{this}); } + + _frame_state ensure_frame_deactivated() noexcept { + if (frame_ != nullptr) { + if (whoToContinue_ == continuation_.done_handle()) { + popAsyncStackFrameFromCaller(*frame_); + } + + auto* root = frame_->getStackRoot(); + // this asserts that root is not null + deactivateAsyncStackFrame(*frame_); + + return _frame_state(*frame_, *root); + } + + return {}; + } + + // performs the final steps of completing this coroutine: + // - destroy (and thus synchronize with) the stop callback if it exists + // - record the continuation (normal or done) that should be resumed + // - ensure the async stack state is correct + // - decrement the refcount + // + // returns the coroutine handle to resume, which will either be the argument, + // or the no-op coroutine, depending on whether there's an outstanding + // deferred stop request to wait for + coro::coroutine_handle<> complete_and_choose_continuation( + coro::coroutine_handle<> whoToContinue) noexcept { + UNIFEX_ASSERT( + whoToContinue == continuation_.handle() || + whoToContinue == continuation_.done_handle()); + + callback_.destruct(); + + // whoToContinue_ needs to be written before we decrement the refcount + // to ensure that we synchronize this write with the corresponding + // read in the deferred stop callback's completion + whoToContinue_ = whoToContinue; + + // deactivate our async stack frame before decrementing the refcount + // + // Once the refcount has been decremented, it's possible for the + // deferred stop callback to resume our continuation and it must + // activate our frame on a new stack root before doing; for that to be + // safe, it can't be active on any other stack root. If it turns out + // *we* are going to resume our continuation then we have to + // reactivate our frame to undo this proactivate deactivation. + const auto frameState = ensure_frame_deactivated(); + + // if we're last to complete, continue our continuation; otherwise do + // nothing and wait for the async stop request to do it + if (refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + frameState.restore_frame_state(); + + return whoToContinue; + } else { + // the deferred stop callback will reactivate this frame + return coro::noop_coroutine(); + } + } }; // TODO: determine if this should also be nothrow @@ -526,48 +613,24 @@ struct _sr_thunk_promise final { auto final_suspend() noexcept { struct awaiter final : _final_suspend_awaiter_base { + coro::coroutine_handle<> + await_suspend_impl(coro::coroutine_handle h) noexcept { + auto& p = h.promise(); + + return p.complete_and_choose_continuation(p.continuation_.handle()); + } + #if (defined(_MSC_VER) && !defined(__clang__)) || defined(__EMSCRIPTEN__) // MSVC doesn't seem to like symmetric transfer in this final awaiter // and the Emscripten (WebAssembly) compiler doesn't support tail-calls void await_suspend(coro::coroutine_handle h) noexcept { - auto& p = h.promise(); - - p.callback_.destruct(); - - // this needs to be written before we decrement the refcount to ensure - // that we synchronize this write with the corresponding read in the - // deferred stop callback's completion - p.whoToContinue_ = continuation::PRIMARY; - - // if we're last to complete, continue our continuation; otherwise do - // nothing and wait for the async stop request to do it - if (p.refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - return h.promise().continuation_.handle().resume(); - } - - // don't resume anything here; wait for the deferred stop request to - // resume our continuation + await_suspend_impl(h).resume(); } #else coro::coroutine_handle<> await_suspend(coro::coroutine_handle h) noexcept { - auto& p = h.promise(); - - p.callback_.destruct(); - - // this needs to be written before we decrement the refcount to ensure - // that we synchronize this write with the corresponding read in the - // deferred stop callback's completion - p.whoToContinue_ = continuation::PRIMARY; - - // if we're last to complete, continue our continuation; otherwise do - // nothing and wait for the async stop request to do it - if (p.refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - return h.promise().continuation_.handle(); - } else { - return coro::noop_coroutine(); - } - } + return await_suspend_impl(h); + }; #endif }; @@ -581,7 +644,10 @@ struct _sr_thunk_promise final { }; }; -template +template < + typename ThisPromise, + typename OtherPromise, + bool WithAsyncStackSupport> struct _awaiter final { /** * An awaitable type that knows how to await a task<>, sa_task<>, or @@ -636,6 +702,9 @@ struct _awaiter final { promise.register_stop_callback(); + maybePushAsyncStackFrame( + promise, h.promise(), instruction_ptr::read_return_address()); + return thisCoro; } @@ -648,6 +717,9 @@ struct _awaiter final { auto thisCoro = coro::coroutine_handle::from_address( (void*)std::exchange(--coro_, 0)); coro_holder destroyOnExit{thisCoro}; + + maybePopAsyncStackFrame(); + return thisCoro.promise().result(); } @@ -658,6 +730,28 @@ struct _awaiter final { std::bool_constant>; using needs_stop_token_t = std::bool_constant>; + using needs_async_stack_frame_t = std::bool_constant; + + void maybePushAsyncStackFrame( + [[maybe_unused]] ThisPromise& callee, + [[maybe_unused]] OtherPromise& caller, + [[maybe_unused]] instruction_ptr returnAddress) noexcept { + if constexpr (WithAsyncStackSupport) { + if (auto* callerFrame = get_async_stack_frame(caller)) { + frame_.setReturnAddress(returnAddress); + callee.frame_ = &frame_; + pushAsyncStackFrameCallerCallee(*callerFrame, frame_); + } + } + } + + void maybePopAsyncStackFrame() noexcept { + if constexpr (WithAsyncStackSupport) { + if (frame_.getParentFrame() != nullptr) { + popAsyncStackFrameCallee(frame_); + } + } + } // Only store the scheduler and the stop_token in the awaiter if we need to // type erase them. Otherwise, these members are "empty" and should take up @@ -676,6 +770,12 @@ struct _awaiter final { inplace_stop_token_adapter, detail::_empty<1>> stopTokenAdapter_; + UNIFEX_NO_UNIQUE_ADDRESS + conditional_t< + needs_async_stack_frame_t::value, + AsyncStackFrame, + detail::_empty<2>> + frame_; }; }; @@ -689,16 +789,20 @@ struct _sr_thunk_task::type final : coro_holder { friend promise_type; private: - template - using awaiter = typename _awaiter::type; + template + using awaiter = + typename _awaiter:: + type; explicit type(coro::coroutine_handle h) noexcept : coro_holder(h) {} - template - friend awaiter + template < + typename Promise, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS> + friend awaiter tag_invoke(tag_t, Promise&, type&& t) noexcept { - return awaiter{std::exchange(t.coro_, {})}; + return awaiter{std::exchange(t.coro_, {})}; } friend instruction_ptr @@ -827,18 +931,22 @@ struct _sa_task::type final : public _task::type { type(base&& t) noexcept : base(std::move(t)) {} - template - using awaiter = - typename _awaiter::type; + template + using awaiter = typename _awaiter< + typename base::promise_type, + OtherPromise, + WithAsyncStackSupport>::type; // given that we're awaited in a scheduler-affine context, we are ourselves // scheduler-affine static constexpr bool is_always_scheduler_affine = true; - template - friend awaiter + template < + typename Promise, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS> + friend awaiter tag_invoke(tag_t, Promise&, type&& t) noexcept { - return awaiter{std::exchange(t.coro_, {})}; + return awaiter{std::exchange(t.coro_, {})}; } template diff --git a/include/unifex/tracing/async_stack-inl.hpp b/include/unifex/tracing/async_stack-inl.hpp index 787ddc9a..add04bde 100644 --- a/include/unifex/tracing/async_stack-inl.hpp +++ b/include/unifex/tracing/async_stack-inl.hpp @@ -67,6 +67,16 @@ popAsyncStackFrameCallee(unifex::AsyncStackFrame& calleeFrame) noexcept { calleeFrame.stackRoot = nullptr; } +inline void popAsyncStackFrameFromCaller( + [[maybe_unused]] unifex::AsyncStackFrame& callerFrame) noexcept { + auto root = tryGetCurrentAsyncStackRoot(); + assert(root != nullptr); + auto topFrame = root->getTopFrame(); + assert(topFrame != nullptr); + assert(topFrame->getParentFrame() == &callerFrame); + popAsyncStackFrameCallee(*topFrame); +} + inline std::size_t getAsyncStackTraceFromInitialFrame( unifex::AsyncStackFrame* initialFrame, std::uintptr_t* addresses, diff --git a/include/unifex/tracing/async_stack.hpp b/include/unifex/tracing/async_stack.hpp index 18c747d0..ca52d394 100644 --- a/include/unifex/tracing/async_stack.hpp +++ b/include/unifex/tracing/async_stack.hpp @@ -221,6 +221,9 @@ void pushAsyncStackFrameCallerCallee( // the current AsyncStackRoot. void popAsyncStackFrameCallee(unifex::AsyncStackFrame& calleeFrame) noexcept; +void popAsyncStackFrameFromCaller( + unifex::AsyncStackFrame& callerFrame) noexcept; + // Get a pointer to a special frame that can be used as the root-frame // for a chain of AsyncStackFrame that does not chain onto a normal // call-stack. @@ -516,7 +519,7 @@ class ScopedAsyncStackRoot { assert(tryGetCurrentAsyncStackRoot() == &root_); [[maybe_unused]] auto topFrame = root_.topFrame.exchange(nullptr, std::memory_order_relaxed); - assert(topFrame == possiblyDeadFrame); + assert(topFrame == nullptr || topFrame == possiblyDeadFrame); } private: diff --git a/source/task.cpp b/source/task.cpp index acbc9401..94b4970e 100644 --- a/source/task.cpp +++ b/source/task.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) Facebook, Inc. and its affiliates. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ void _task_promise_base::transform_schedule_sender_impl_( // correct scheduler, do so now: if (!std::exchange(this->rescheduled_, true)) { // Create a cleanup action that transitions back onto the current scheduler: - auto cleanupTask = at_coroutine_exit(schedule, this->sched_); + auto cleanupTask = + await_transform(*this, at_coroutine_exit(schedule, this->sched_)); // Insert the cleanup action into the head of the continuation chain by // making direct calls to the cleanup task's awaiter member functions. See // type _cleanup_task in at_coroutine_exit.hpp: