Skip to content

Commit

Permalink
Review loop tasks context reference copying
Browse files Browse the repository at this point in the history
  • Loading branch information
gi0baro committed May 16, 2023
1 parent 2c19283 commit 2fff960
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 23 deletions.
12 changes: 8 additions & 4 deletions src/asgi/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ impl CallbackRunnerHTTP {
pub(crate) struct CallbackTaskHTTP {
proto: Py<HTTPProtocol>,
context: TaskLocals,
pycontext: PyObject,
cb: PyObject
}

Expand All @@ -67,7 +68,8 @@ impl CallbackTaskHTTP {
proto: Py<HTTPProtocol>,
context: TaskLocals
) -> PyResult<Self> {
Ok(Self { proto, context: context.copy_context(py)?, cb })
let pyctx = context.context(py);
Ok(Self { proto, context, pycontext: pyctx.call_method0("copy")?.into(), cb })
}

fn done(&self, py: Python) {
Expand All @@ -91,7 +93,7 @@ impl CallbackTaskHTTP {

#[pymethods]
impl CallbackTaskHTTP {
fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<PyObject> {
fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<()> {
callback_impl_loop_step!(pyself, py)
}

Expand Down Expand Up @@ -136,6 +138,7 @@ impl CallbackRunnerWebsocket {
pub(crate) struct CallbackTaskWebsocket {
proto: Py<WebsocketProtocol>,
context: TaskLocals,
pycontext: PyObject,
cb: PyObject
}

Expand All @@ -146,7 +149,8 @@ impl CallbackTaskWebsocket {
proto: Py<WebsocketProtocol>,
context: TaskLocals
) -> PyResult<Self> {
Ok(Self { proto, context: context.copy_context(py)?, cb })
let pyctx = context.context(py);
Ok(Self { proto, context, pycontext: pyctx.call_method0("copy")?.into(), cb })
}

fn done(&self, py: Python) {
Expand All @@ -168,7 +172,7 @@ impl CallbackTaskWebsocket {

#[pymethods]
impl CallbackTaskWebsocket {
fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<PyObject> {
fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<()> {
callback_impl_loop_step!(pyself, py)
}

Expand Down
55 changes: 40 additions & 15 deletions src/callbacks.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use once_cell::sync::OnceCell;
use pyo3::prelude::*;
use pyo3::pyclass::IterNextOutput;


static CONTEXTVARS: OnceCell<PyObject> = OnceCell::new();
static CONTEXT: OnceCell<PyObject> = OnceCell::new();

#[derive(Clone)]
pub(crate) struct CallbackWrapper {
pub callback: PyObject,
Expand Down Expand Up @@ -189,20 +193,41 @@ impl PyFutureAwaitableResult {
}
}

fn contextvars(py: Python) -> PyResult<&PyAny> {
Ok(CONTEXTVARS
.get_or_try_init(|| py.import("contextvars").map(|m| m.into()))?
.as_ref(py))
}

pub fn empty_pycontext(py: Python) -> PyResult<&PyAny> {
Ok(CONTEXT
.get_or_try_init(|| contextvars(py)?.getattr("Context")?.call0().map(|c| c.into()))?
.as_ref(py))
}

macro_rules! callback_impl_run {
() => {
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> {
let event_loop = self.context.event_loop(py);
let target = self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_task"))?;
event_loop.call_method1(pyo3::intern!(py, "call_soon_threadsafe"), (target,))
let kwctx = pyo3::types::PyDict::new(py);
kwctx.set_item(
pyo3::intern!(py, "context"),
crate::callbacks::empty_pycontext(py)?
)?;
event_loop.call_method(
pyo3::intern!(py, "call_soon_threadsafe"),
(target,),
Some(kwctx)
)
}
};
}

macro_rules! callback_impl_loop_run {
() => {
pub fn run<'p>(self, py: Python<'p>) -> PyResult<&'p PyAny> {
let context = self.context.context(py);
let context = self.pycontext.clone().into_ref(py);
context.call_method1(
pyo3::intern!(py, "run"),
(self.into_py(py).getattr(py, pyo3::intern!(py, "_loop_step"))?,)
Expand All @@ -223,9 +248,9 @@ macro_rules! callback_impl_loop_step {
_ => false
};

let ctx = $pyself.context.context($py);
let ctx = $pyself.pycontext.clone();
let kwctx = pyo3::types::PyDict::new($py);
kwctx.set_item("context", ctx)?;
kwctx.set_item(pyo3::intern!($py, "context"), ctx)?;

match blocking {
true => {
Expand All @@ -244,7 +269,7 @@ macro_rules! callback_impl_loop_step {
),
Some(kwctx)
)?;
Ok($py.None())
Ok(())
},
false => {
let event_loop = $pyself.context.event_loop($py);
Expand All @@ -257,20 +282,20 @@ macro_rules! callback_impl_loop_step {
),
Some(kwctx)
)?;
Ok($py.None())
Ok(())
}
}
},
Err(err) => {
match err.is_instance_of::<pyo3::exceptions::PyStopIteration>($py) {
true => {
$pyself.done($py);
Ok($py.None())
},
false => {
$pyself.err($py);
Err(err)
}
if (
err.is_instance_of::<pyo3::exceptions::PyStopIteration>($py) ||
err.is_instance_of::<pyo3::exceptions::asyncio::CancelledError>($py)
) {
$pyself.done($py);
Ok(())
} else {
$pyself.err($py);
Err(err)
}
}
}
Expand Down
12 changes: 8 additions & 4 deletions src/rsgi/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ impl CallbackRunnerHTTP {
pub(crate) struct CallbackTaskHTTP {
proto: Py<HTTPProtocol>,
context: TaskLocals,
pycontext: PyObject,
cb: PyObject
}

Expand All @@ -66,7 +67,8 @@ impl CallbackTaskHTTP {
proto: Py<HTTPProtocol>,
context: TaskLocals
) -> PyResult<Self> {
Ok(Self { proto, context: context.copy_context(py)?, cb })
let pyctx = context.context(py);
Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), cb })
}

fn done(&self, py: Python) {
Expand All @@ -90,7 +92,7 @@ impl CallbackTaskHTTP {

#[pymethods]
impl CallbackTaskHTTP {
fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<PyObject> {
fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<()> {
callback_impl_loop_step!(pyself, py)
}

Expand Down Expand Up @@ -135,6 +137,7 @@ impl CallbackRunnerWebsocket {
pub(crate) struct CallbackTaskWebsocket {
proto: Py<WebsocketProtocol>,
context: TaskLocals,
pycontext: PyObject,
cb: PyObject
}

Expand All @@ -145,7 +148,8 @@ impl CallbackTaskWebsocket {
proto: Py<WebsocketProtocol>,
context: TaskLocals
) -> PyResult<Self> {
Ok(Self { proto, context: context.copy_context(py)?, cb })
let pyctx = context.context(py);
Ok(Self { proto, context, pycontext: pyctx.call_method0(pyo3::intern!(py, "copy"))?.into(), cb })
}

fn done(&self, py: Python) {
Expand All @@ -167,7 +171,7 @@ impl CallbackTaskWebsocket {

#[pymethods]
impl CallbackTaskWebsocket {
fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<PyObject> {
fn _loop_step(pyself: PyRef<'_, Self>, py: Python) -> PyResult<()> {
callback_impl_loop_step!(pyself, py)
}

Expand Down

0 comments on commit 2fff960

Please sign in to comment.