From 5aa6a40a8a794e28956ad0325d16143db1016796 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ram=C3=B3n=20Garc=C3=ADa=20Fern=C3=A1ndez?= Date: Sat, 6 Apr 2024 17:14:32 +0200 Subject: [PATCH] timeout: handle the case of running blocking code inside the Tokio runtime. --- src/blocking/wait.rs | 27 ++++++++++++++++----------- tests/blocking.rs | 17 +++++++++++++++++ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/blocking/wait.rs b/src/blocking/wait.rs index 659f9615e..fd13d9e15 100644 --- a/src/blocking/wait.rs +++ b/src/blocking/wait.rs @@ -10,7 +10,22 @@ pub(crate) fn timeout(fut: F, timeout: Option) -> Result>, { - enter(); + + let try_tokio_handle = tokio::runtime::Handle::try_current(); + if let Ok(tokio_handle) = try_tokio_handle { + return tokio::task::block_in_place(|| + tokio_handle.block_on(async { + if let Some(actual_timeout) = timeout { + tokio::select! { + result = fut => result.map_err(|e| Waited::Inner(e)), + _ = tokio::time::sleep(actual_timeout) => Err(Waited::TimedOut(crate::error::TimedOut)) + } + } else { + fut.await.map_err(|e| Waited::Inner(e)) + } + }) + ) + } let deadline = timeout.map(|d| { log::trace!("wait at most {d:?}"); @@ -66,13 +81,3 @@ impl futures_util::task::ArcWake for ThreadWaker { } } -fn enter() { - // Check we aren't already in a runtime - #[cfg(debug_assertions)] - { - let _enter = tokio::runtime::Builder::new_current_thread() - .build() - .expect("build shell runtime") - .enter(); - } -} diff --git a/tests/blocking.rs b/tests/blocking.rs index 7d57db7b4..394d19baa 100644 --- a/tests/blocking.rs +++ b/tests/blocking.rs @@ -107,6 +107,23 @@ fn test_get() { assert_eq!(res.text().unwrap().len(), 0) } +#[test] +fn test_get_inside_tokio() { + let server = server::http(move |_req| async { http::Response::default() }); + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(async { + let url = format!("http://{}/1", server.addr()); + let res = reqwest::blocking::get(&url).unwrap(); + assert_eq!(res.url().as_str(), &url); + assert_eq!(res.status(), reqwest::StatusCode::OK); + assert_eq!(res.remote_addr(), Some(server.addr())); + assert_eq!(res.text().unwrap().len(), 0) + }); +} + #[test] fn test_post() { let server = server::http(move |req| async move {