diff --git a/src/timeout_executor/subprocess.py b/src/timeout_executor/subprocess.py index 33e8d16..81010ea 100644 --- a/src/timeout_executor/subprocess.py +++ b/src/timeout_executor/subprocess.py @@ -2,12 +2,15 @@ from __future__ import annotations -from collections.abc import Awaitable +from functools import partial +from inspect import isawaitable from os import environ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable +import anyio import cloudpickle +from anyio.lowlevel import checkpoint from timeout_executor.const import TIMEOUT_EXECUTOR_INPUT_FILE @@ -39,11 +42,12 @@ def dumps_value(value: Any) -> bytes: def output_to_file(file: str) -> Callable[[Callable[P, T]], Callable[P, T]]: def wrapper(func: Callable[P, T]) -> Callable[P, T]: + func = wrap_function_as_sync(func) + def inner(*args: P.args, **kwargs: P.kwargs) -> T: dump = b"" try: result = func(*args, **kwargs) - result = wrap_awaitable(result) except BaseException as exc: dump = dumps_value(exc) raise @@ -59,10 +63,22 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> T: return wrapper -def wrap_awaitable(value: Any) -> Any: - if not isinstance(value, Awaitable): - return value +def wrap_function_as_async(func: Callable[P, Any]) -> Callable[P, Any]: + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> Any: + await checkpoint() + result = func(*args, **kwargs) + if isawaitable(result): + return await result + return result + + return wrapped + + +def wrap_function_as_sync(func: Callable[P, Any]) -> Callable[P, Any]: + async_wrapped = wrap_function_as_async(func) - from async_wrapper import async_to_sync + def wrapped(*args: P.args, **kwargs: P.kwargs) -> Any: + new_func = partial(async_wrapped, *args, **kwargs) + return anyio.run(new_func) - return async_to_sync(value)() + return wrapped