diff --git a/README.md b/README.md index 3eb6467..710b10a 100644 --- a/README.md +++ b/README.md @@ -9,22 +9,16 @@ ```shell $ pip install timeout_executor # or -$ pip install "timeout_executor[all]" -# or -$ pip install "timeout_executor[billiard]" -# or -$ pip install "timeout_executor[loky]" -# or -$ pip install "timeout_executor[dill]" -# or -$ pip install "timeout_executor[cloudpickle]" +$ pip install "timeout_executor[uvloop]" ``` ## how to use ```python +from __future__ import annotations + import time -from timeout_executor import TimeoutExecutor +from timeout_executor import AsyncResult, TimeoutExecutor def sample_func() -> None: @@ -37,9 +31,11 @@ try: except Exception as exc: assert isinstance(exc, TimeoutError) -executor = TimeoutExecutor(1, pickler="dill") # or cloudpickle +executor = TimeoutExecutor(1) result = executor.apply(lambda: "done") -assert result == "done" +assert isinstance(result, AsyncResult) +value = result.result() +assert value == "done" ``` ## License diff --git a/pyproject.toml b/pyproject.toml index 753489c..9c968b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,44 +16,24 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", + "Framework :: AnyIO", + "Framework :: AsyncIO", + "Framework :: Trio", ] requires-python = ">= 3.8" dependencies = [ "anyio>=4.3.0", "typing-extensions>=4.10.0", + "cloudpickle>=3.0.0", + "async-wrapper>=0.8.3", + "tblib>=3.0.0", ] [project.urls] Repository = "https://github.com/phi-friday/timeout-executor" [project.optional-dependencies] -all = [ - "billiard>=4.2.0", - "cloudpickle>=3.0.0", - "dill>=0.3.8", - "loky>=3.4.1", - "psutil>=5.9.8", -] -billiard = [ - "billiard>=4.2.0", - "dill>=0.3.8", -] -loky = [ - "cloudpickle>=3.0.0", - "loky>=3.4.1", - "psutil>=5.9.8", -] -dill = [ - "dill>=0.3.8", -] -cloudpickle = [ - "cloudpickle>=3.0.0", -] - - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" +uvloop = ["uvloop>=0.19.0; platform_system != 'Windows'"] [tool.rye] managed = true @@ -62,10 +42,14 @@ dev-dependencies = [ "ipykernel>=6.29.3", "pre-commit>=3.5.0", "pytest>=8.0.2", - "pytest-asyncio>=0.23.5", "pyyaml>=6.0.1", + "trio>=0.24.0", ] +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + [tool.hatch.metadata] allow-direct-references = true @@ -182,4 +166,4 @@ split-on-trailing-comma = false include = ["src", "tests"] pythonVersion = '3.8' pythonPlatform = 'Linux' -diagnostic = 'basic' \ No newline at end of file +diagnostic = 'basic' diff --git a/src/timeout_executor/__init__.py b/src/timeout_executor/__init__.py index bd1b0f8..2260608 100644 --- a/src/timeout_executor/__init__.py +++ b/src/timeout_executor/__init__.py @@ -2,9 +2,10 @@ from typing import Any -from .executor import TimeoutExecutor, get_executor +from timeout_executor.executor import TimeoutExecutor, apply_func, delay_func +from timeout_executor.result import AsyncResult -__all__ = ["TimeoutExecutor", "get_executor"] +__all__ = ["TimeoutExecutor", "AsyncResult", "apply_func", "delay_func"] __version__: str diff --git a/src/timeout_executor/concurrent/__init__.py b/src/timeout_executor/concurrent/__init__.py deleted file mode 100644 index fe3e10f..0000000 --- a/src/timeout_executor/concurrent/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import annotations - -from . import futures -from .main import get_executor_backend - -__all__ = ["futures", "get_executor_backend"] diff --git a/src/timeout_executor/concurrent/futures/__init__.py b/src/timeout_executor/concurrent/futures/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/timeout_executor/concurrent/futures/backend/__init__.py b/src/timeout_executor/concurrent/futures/backend/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/timeout_executor/concurrent/futures/backend/_billiard/__init__.py b/src/timeout_executor/concurrent/futures/backend/_billiard/__init__.py deleted file mode 100644 index 364ac11..0000000 --- a/src/timeout_executor/concurrent/futures/backend/_billiard/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from .process import ProcessPoolExecutor - -__all__ = ["ProcessPoolExecutor"] diff --git a/src/timeout_executor/concurrent/futures/backend/_billiard/process.py b/src/timeout_executor/concurrent/futures/backend/_billiard/process.py deleted file mode 100644 index 566e024..0000000 --- a/src/timeout_executor/concurrent/futures/backend/_billiard/process.py +++ /dev/null @@ -1,680 +0,0 @@ -"""obtained from concurrent.futures.process""" - -from __future__ import annotations - -import os -import queue -import sys -import threading -import traceback -import weakref -from concurrent.futures import _base -from concurrent.futures.process import ( - _MAX_WINDOWS_WORKERS, - EXTRA_QUEUED_CALLS, - BrokenProcessPool, - _CallItem, - _chain_from_iterable_of_lists, - _check_system_limits, - _get_chunks, - _global_shutdown, - _process_chunk, - _WorkItem, -) -from functools import partial -from traceback import format_exception -from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Union, cast - -from typing_extensions import ParamSpec, TypeAlias, TypeVar, override - -from timeout_executor.exception import ExtraError - -try: - import billiard # type: ignore - from billiard import util as bi_util # type: ignore - from billiard.connection import wait as bi_wait # type: ignore - from billiard.queues import Queue # type: ignore -except ImportError as exc: - error = ExtraError.from_import_error(exc, extra="billiard") - raise error from exc - -if TYPE_CHECKING: - from concurrent.futures import Future - from threading import Lock - from types import TracebackType - - from billiard.connection import Pipe # type: ignore - from billiard.context import ( # type: ignore - DefaultContext, # type: ignore - ForkContext, # type: ignore - ForkServerContext, # type: ignore - SpawnContext, # type: ignore - ) - from billiard.process import Process # type: ignore - from billiard.queues import SimpleQueue # type: ignore - - Context: TypeAlias = Union[ - ForkContext, SpawnContext, ForkServerContext, DefaultContext - ] - - if sys.platform != "win32": - from billiard.connection import Connection # type: ignore - else: - from billiard.connection import PipeConnection as Connection # type: ignore - - _P = ParamSpec("_P") - _T = TypeVar("_T", infer_variance=True) - - -class _ThreadWakeup: - _reader: Connection - _writer: Connection - - def __init__(self) -> None: - self._closed = False - if TYPE_CHECKING: - self._reader, self._writer = Pipe() - else: - self._reader, self._writer = billiard.Pipe(duplex=False) - - def close(self) -> None: - if not self._closed: - self._closed = True - self._writer.close() - self._reader.close() - - def wakeup(self) -> None: - if not self._closed: - self._writer.send_bytes(b"") - - def clear(self) -> None: - while self._reader.poll(): - self._reader.recv_bytes() - - -class _RemoteTraceback(Exception): # noqa: N818 - def __init__(self, tb: str) -> None: - self.tb = tb - - def __str__(self) -> str: - return self.tb - - -class _ExceptionWithTraceback: - def __init__(self, exc: BaseException, tb: TracebackType | None) -> None: - tb_text = "".join(format_exception(type(exc), exc, tb)) - self.exc = exc - # Traceback object needs to be garbage-collected as its frames - # contain references to all the objects in the exception scope - self.exc.__traceback__ = None - self.tb = '\n"""\n%s"""' % tb_text - - def __reduce__( - self, - ) -> tuple[ - Callable[[BaseException, str], BaseException], tuple[BaseException, str] - ]: - return _rebuild_exc, (self.exc, self.tb) - - -class _ResultItem: - def __init__( - self, - work_id: int, - exception: Exception | _ExceptionWithTraceback | None = None, - result: Any | None = None, - exit_pid: int | None = None, - ) -> None: - self.work_id = work_id - self.exception = exception - self.result = result - self.exit_pid = exit_pid - - -class _SafeQueue(Queue): - """Safe Queue set exception to the future object linked to a job""" - - def __init__( # noqa: PLR0913 - self, - max_size: int = 0, - *, - ctx: Context, - pending_work_items: dict[int, _WorkItem[Any]], - shutdown_lock: Lock, - thread_wakeup: _ThreadWakeup, - ) -> None: - self.pending_work_items = pending_work_items - self.shutdown_lock = shutdown_lock - self.thread_wakeup = thread_wakeup - super().__init__(max_size, ctx=ctx) - - def _on_queue_feeder_error(self, e: Exception, obj: Any) -> None: - if isinstance(obj, _CallItem): - tb = traceback.format_exception(type(e), e, e.__traceback__) - e.__cause__ = _RemoteTraceback('\n"""\n{}"""'.format("".join(tb))) - work_item = self.pending_work_items.pop(obj.work_id, None) - with self.shutdown_lock: - self.thread_wakeup.wakeup() - - if work_item is not None: - work_item.future.set_exception(e) - else: - self._mp_on_queue_feeder_error(e, obj) - - @staticmethod - def _mp_on_queue_feeder_error(e: Any, obj: Any) -> None: # noqa: ARG004 - import traceback - - traceback.print_exc() - - -class ProcessPoolExecutor(_base.Executor): - """process pool executor""" - - _mp_context: Context - - def __init__( # noqa: C901,PLR0912,PLR0915,PLR0913 - self, - max_workers: int | None = None, - mp_context: Context | None = None, - initializer: Callable[..., Any] | None = None, - initargs: tuple[Any, ...] = (), - *, - max_tasks_per_child: int | None = None, - ) -> None: - """Initializes a new ProcessPoolExecutor instance. - - Args: - max_workers: The maximum number of processes that can be used to - execute the given calls. If None or not given then as many - worker processes will be created as the machine has processors. - mp_context: A multiprocessing context to launch the workers. This - object should provide SimpleQueue, Queue and Process. Useful - to allow specific multiprocessing start methods. - initializer: A callable used to initialize worker processes. - initargs: A tuple of arguments to pass to the initializer. - max_tasks_per_child: The maximum number of tasks a worker process - can complete before it will exit and be replaced with a fresh - worker process. The default of None means worker process will - live as long as the executor. Requires a non-'fork' mp_context - start method. When given, we default to using 'spawn' if no - mp_context is supplied. - """ - _check_system_limits() - - if max_workers is None: - self._max_workers = os.cpu_count() or 1 - if sys.platform == "win32": - self._max_workers = min(_MAX_WINDOWS_WORKERS, self._max_workers) - else: - if max_workers <= 0: - raise ValueError("max_workers must be greater than 0") - if sys.platform == "win32" and max_workers > _MAX_WINDOWS_WORKERS: - error_msg = f"max_workers must be <= {_MAX_WINDOWS_WORKERS}" - raise ValueError(error_msg) - - self._max_workers = max_workers - - if mp_context is None: - if TYPE_CHECKING: - mp_context = DefaultContext(None).get_context() - assert mp_context is not None - self._mp_context = mp_context - elif max_tasks_per_child is not None: - self._mp_context = billiard.get_context("spawn") - elif sys.platform in {"linux", "linux2", "darwin"}: - self._mp_context = billiard.get_context("fork") - else: - self._mp_context = billiard.get_context() - mp_context = self._mp_context - else: - self._mp_context = mp_context - mp_context = cast("Context", mp_context) - - self._safe_to_dynamically_spawn_children = ( - self._mp_context.get_start_method(allow_none=False) != "fork" - ) - - if initializer is not None and not callable(initializer): - raise TypeError("initializer must be a callable") - self._initializer = initializer - self._initargs = initargs - - if max_tasks_per_child is not None: - if not isinstance(max_tasks_per_child, int): - raise TypeError("max_tasks_per_child must be an integer") - if max_tasks_per_child <= 0: - raise ValueError("max_tasks_per_child must be >= 1") - if self._mp_context.get_start_method(allow_none=False) == "fork": - raise ValueError( - "max_tasks_per_child is incompatible with" - " the 'fork' multiprocessing start method;" - " supply a different mp_context." - ) - self._max_tasks_per_child = max_tasks_per_child - self._executor_manager_thread = None - self._processes: dict[int | None, Process] = {} - self._shutdown_thread = False - self._shutdown_lock = threading.Lock() - self._idle_worker_semaphore = threading.Semaphore(0) - self._broken: str | bool = False - self._queue_count = 0 - self._pending_work_items = {} - self._cancel_pending_futures = False - - self._executor_manager_thread_wakeup: _ThreadWakeup = _ThreadWakeup() - - queue_size = self._max_workers + EXTRA_QUEUED_CALLS - self._call_queue: _SafeQueue = _SafeQueue( - max_size=queue_size, - ctx=self._mp_context, - pending_work_items=self._pending_work_items, - shutdown_lock=self._shutdown_lock, - thread_wakeup=self._executor_manager_thread_wakeup, - ) - self._call_queue._ignore_epipe = True # noqa: SLF001 - self._result_queue: SimpleQueue = mp_context.SimpleQueue() - self._work_ids = queue.Queue() - - def _start_executor_manager_thread(self) -> None: - if self._executor_manager_thread is None: - if not self._safe_to_dynamically_spawn_children: - self._launch_processes() - self._executor_manager_thread = _ExecutorManagerThread(self) - self._executor_manager_thread.start() - _threads_wakeups[self._executor_manager_thread] = ( - self._executor_manager_thread_wakeup - ) - - def _adjust_process_count(self) -> None: - if self._idle_worker_semaphore.acquire(blocking=False): - return - - process_count = len(self._processes) - if process_count < self._max_workers: - self._spawn_process() - - def _launch_processes(self) -> None: - # https://github.com/python/cpython/issues/90622 - assert not self._executor_manager_thread, ( # noqa: S101 - "Processes cannot be fork()ed after the thread has started, " - "deadlock in the child processes could result." - ) - for _ in range(len(self._processes), self._max_workers): - self._spawn_process() - - def _spawn_process(self) -> None: - p = self._mp_context.Process( - target=_process_worker, - args=( - self._call_queue, - self._result_queue, - self._initializer, - self._initargs, - self._max_tasks_per_child, - ), - ) - p.start() - self._processes[p.pid] = p - - @override - def submit( - self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs - ) -> Future[_T]: - with self._shutdown_lock: - if self._broken: - raise BrokenProcessPool(self._broken) - if self._shutdown_thread: - raise RuntimeError("cannot schedule new futures after shutdown") - if _global_shutdown: - raise RuntimeError( - "cannot schedule new futures after interpreter shutdown" - ) - - f = _base.Future() - w = _WorkItem(f, fn, args, kwargs) - - self._pending_work_items[self._queue_count] = w - self._work_ids.put(self._queue_count) - self._queue_count += 1 - # Wake up queue management thread - self._executor_manager_thread_wakeup.wakeup() - - if self._safe_to_dynamically_spawn_children: - self._adjust_process_count() - self._start_executor_manager_thread() - return f - - @override - def map( - self, - fn: Callable[..., _T], - *iterables: Iterable[Any], - timeout: float | None = None, - chunksize: int = 1, - ) -> Iterator[_T]: - if chunksize < 1: - raise ValueError("chunksize must be >= 1.") - - results = super().map( - partial(_process_chunk, fn), - _get_chunks(*iterables, chunksize=chunksize), - timeout=timeout, - ) - return _chain_from_iterable_of_lists(results) - - @override - def shutdown(self, wait: bool = True, cancel_futures: bool = False) -> None: - with self._shutdown_lock: - self._cancel_pending_futures = cancel_futures - self._shutdown_thread = True - if self._executor_manager_thread_wakeup is not None: - # Wake up queue management thread - self._executor_manager_thread_wakeup.wakeup() - - if self._executor_manager_thread is not None and wait: - self._executor_manager_thread.join() - # To reduce the risk of opening too many files, remove references to - # objects that use file descriptors. - self._executor_manager_thread = None - self._call_queue = None # type: ignore - if self._result_queue is not None and wait: - self._result_queue.close() - self._result_queue = None # type: ignore - self._processes = None # type: ignore - self._executor_manager_thread_wakeup = None # type: ignore - - -def _process_worker( - call_queue: Queue, - result_queue: SimpleQueue, - initializer: Callable[..., Any], - initargs: tuple[Any, ...], - max_tasks: int | None = None, -) -> None: - """Evaluates calls from call_queue and places the results in result_queue. - - This worker is run in a separate process. - - Args: - call_queue: A ctx.Queue of _CallItems that will be read and - evaluated by the worker. - result_queue: A ctx.Queue of _ResultItems that will written - to by the worker. - initializer: A callable initializer, or None - initargs: A tuple of args for the initializer - """ - if initializer is not None: - try: - initializer(*initargs) - except BaseException: # noqa: BLE001 - _base.LOGGER.critical("Exception in initializer:", exc_info=True) - return - num_tasks = 0 - exit_pid = None - while True: - call_item = call_queue.get(block=True) - if call_item is None: - result_queue.put(os.getpid()) - return - - if max_tasks is not None: - num_tasks += 1 - if num_tasks >= max_tasks: - exit_pid = os.getpid() - - try: - r = call_item.fn(*call_item.args, **call_item.kwargs) - except BaseException as e: # noqa: BLE001 - exc = _ExceptionWithTraceback(e, e.__traceback__) - _sendback_result( - result_queue, call_item.work_id, exception=exc, exit_pid=exit_pid - ) - else: - _sendback_result( - result_queue, call_item.work_id, result=r, exit_pid=exit_pid - ) - del r - - del call_item - - if exit_pid is not None: - return - - -class _ExecutorManagerThread(threading.Thread): - """Manages the communication between this process and the worker processes. - - The manager is run in a local thread. - - Args: - executor: A reference to the ProcessPoolExecutor that owns - this thread. A weakref will be own by the manager as well as - references to internal objects used to introspect the state of - the executor. - """ - - def __init__(self, executor: ProcessPoolExecutor) -> None: - self.thread_wakeup = executor._executor_manager_thread_wakeup # noqa: SLF001 - self.shutdown_lock = executor._shutdown_lock # noqa: SLF001 - - def weakref_cb( - _: Any, - thread_wakeup: _ThreadWakeup = self.thread_wakeup, - shutdown_lock: Lock = self.shutdown_lock, - ) -> None: - bi_util.debug( - "Executor collected: triggering callback for QueueManager wakeup" - ) - with shutdown_lock: - thread_wakeup.wakeup() - - self.executor_reference: weakref.ReferenceType[ProcessPoolExecutor] = ( - weakref.ref(executor, weakref_cb) - ) - self.processes = executor._processes # noqa: SLF001 - self.call_queue = executor._call_queue # noqa: SLF001 - self.result_queue = executor._result_queue # noqa: SLF001 - self.work_ids_queue = executor._work_ids # noqa: SLF001 - self.max_tasks_per_child = executor._max_tasks_per_child # noqa: SLF001 - self.pending_work_items = executor._pending_work_items # noqa: SLF001 - - super().__init__() - - def run(self) -> None: - while True: - self.add_call_item_to_queue() - - result_item, is_broken, cause = self.wait_result_broken_or_wakeup() - - if is_broken: - self.terminate_broken(cause) - return - if result_item is not None: - self.process_result_item(result_item) - - process_exited = result_item.exit_pid is not None - if process_exited: - p = self.processes.pop(result_item.exit_pid) - p.join() - - del result_item - - if executor := self.executor_reference(): - if process_exited: - with self.shutdown_lock: - executor._adjust_process_count() # noqa: SLF001 - else: - executor._idle_worker_semaphore.release() # noqa: SLF001 - del executor - - if self.is_shutting_down(): - self.flag_executor_shutting_down() - self.add_call_item_to_queue() - if not self.pending_work_items: - self.join_executor_internals() - return - - def add_call_item_to_queue(self) -> None: - while True: - if self.call_queue.full(): - return - try: - work_id = self.work_ids_queue.get(block=False) - except queue.Empty: - return - else: - work_item = self.pending_work_items[work_id] - - if work_item.future.set_running_or_notify_cancel(): - self.call_queue.put( - _CallItem( - work_id, work_item.fn, work_item.args, work_item.kwargs - ), - block=True, - ) - else: - del self.pending_work_items[work_id] - continue - - def wait_result_broken_or_wakeup(self) -> tuple[Any, bool, list[str] | None]: - result_reader = self.result_queue._reader # noqa: SLF001 - assert not self.thread_wakeup._closed # noqa: SLF001,S101 - wakeup_reader = self.thread_wakeup._reader # noqa: SLF001 - readers = [result_reader, wakeup_reader] - worker_sentinels = [p.sentinel for p in list(self.processes.values())] - ready = bi_wait(readers + worker_sentinels) - - cause = None - is_broken = True - result_item = None - if result_reader in ready: - try: - result_item = result_reader.recv() - is_broken = False - except BaseException as e: # noqa: BLE001 - cause = format_exception(type(e), e, e.__traceback__) - - elif wakeup_reader in ready: - is_broken = False - - with self.shutdown_lock: - self.thread_wakeup.clear() - - return result_item, is_broken, cause - - def process_result_item(self, result_item: _ResultItem) -> None: - if isinstance(result_item, int): - assert self.is_shutting_down() # noqa: S101 - p = self.processes.pop(result_item) - p.join() - if not self.processes: - self.join_executor_internals() - return - else: - work_item = self.pending_work_items.pop(result_item.work_id, None) - if work_item is not None: - if result_item.exception: - work_item.future.set_exception(result_item.exception) - else: - work_item.future.set_result(result_item.result) - - def is_shutting_down(self) -> bool: - executor = self.executor_reference() - return ( - _global_shutdown or executor is None or executor._shutdown_thread # noqa: SLF001 - ) - - def terminate_broken(self, cause: list[str] | None) -> None: - executor = self.executor_reference() - if executor is not None: - executor._broken = ( # noqa: SLF001 - "A child process terminated " - "abruptly, the process pool is not " - "usable anymore" - ) - executor._shutdown_thread = True # noqa: SLF001 - executor = None - - bpe = BrokenProcessPool( - "A process in the process pool was " - "terminated abruptly while the future was " - "running or pending." - ) - if cause is not None: - bpe.__cause__ = _RemoteTraceback(f"\n'''\n{''.join(cause)}'''") - - for work_item in self.pending_work_items.values(): - work_item.future.set_exception(bpe) - del work_item - self.pending_work_items.clear() - - for p in self.processes.values(): - p.terminate() - - self.join_executor_internals() - - def flag_executor_shutting_down(self) -> None: - executor = self.executor_reference() - if executor is not None: - executor._shutdown_thread = True # noqa: SLF001 - if executor._cancel_pending_futures: # noqa: SLF001 - new_pending_work_items = {} - for work_id, work_item in self.pending_work_items.items(): - if not work_item.future.cancel(): - new_pending_work_items[work_id] = work_item - self.pending_work_items = new_pending_work_items - while True: - try: - self.work_ids_queue.get_nowait() - except queue.Empty: # noqa: PERF203 - break - executor._cancel_pending_futures = False # noqa: SLF001 - - def shutdown_workers(self) -> None: - n_children_to_stop = self.get_n_children_alive() - n_sentinels_sent = 0 - while n_sentinels_sent < n_children_to_stop and self.get_n_children_alive() > 0: - for i in range(n_children_to_stop - n_sentinels_sent): # noqa: B007 - try: - self.call_queue.put_nowait(None) - n_sentinels_sent += 1 - except queue.Full: # noqa: PERF203 - break - - def join_executor_internals(self) -> None: - self.shutdown_workers() - self.call_queue.close() - self.call_queue.join_thread() - with self.shutdown_lock: - self.thread_wakeup.close() - for p in self.processes.values(): - p.join() - - def get_n_children_alive(self) -> int: - return sum(p.is_alive() for p in self.processes.values()) - - -def _sendback_result( - result_queue: SimpleQueue, - work_id: int, - result: Any | None = None, - exception: Exception | _ExceptionWithTraceback | None = None, - exit_pid: int | None = None, -) -> None: - """Safely send back the given result or exception""" - try: - result_queue.put( - _ResultItem(work_id, result=result, exception=exception, exit_pid=exit_pid) - ) - except BaseException as e: # noqa: BLE001 - exc = _ExceptionWithTraceback(e, e.__traceback__) - result_queue.put(_ResultItem(work_id, exception=exc, exit_pid=exit_pid)) - - -def _rebuild_exc(exc: BaseException, tb: str) -> BaseException: - exc.__cause__ = _RemoteTraceback(tb) - return exc - - -_threads_wakeups = weakref.WeakKeyDictionary() diff --git a/src/timeout_executor/concurrent/futures/backend/_loky/__init__.py b/src/timeout_executor/concurrent/futures/backend/_loky/__init__.py deleted file mode 100644 index 364ac11..0000000 --- a/src/timeout_executor/concurrent/futures/backend/_loky/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from .process import ProcessPoolExecutor - -__all__ = ["ProcessPoolExecutor"] diff --git a/src/timeout_executor/concurrent/futures/backend/_loky/process.py b/src/timeout_executor/concurrent/futures/backend/_loky/process.py deleted file mode 100644 index a51ee28..0000000 --- a/src/timeout_executor/concurrent/futures/backend/_loky/process.py +++ /dev/null @@ -1,73 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Generic - -from typing_extensions import TypeVar - -from timeout_executor.exception import ExtraError - -try: - from loky.process_executor import ( # type: ignore - ProcessPoolExecutor as LockyProcessPoolExecutor, # type: ignore - ) -except ImportError as exc: - error = ExtraError.from_import_error(exc, extra="loky") - raise error from exc - -__all__ = ["ProcessPoolExecutor"] - - -if TYPE_CHECKING: - from multiprocessing.context import ( - DefaultContext, - ForkContext, - ForkServerContext, - SpawnContext, - ) - - from loky._base import Future as LockyFuture # type: ignore - from typing_extensions import ParamSpec, override - - _P = ParamSpec("_P") - _T = TypeVar("_T", infer_variance=True) - - class Future(LockyFuture, Generic[_T]): - @override - def add_done_callback( # type: ignore - self, fn: Callable[[Future[_T]], object] - ) -> None: ... - - @override - def set_result(self, result: _T) -> None: # type: ignore - ... - - class ProcessPoolExecutor(LockyProcessPoolExecutor): - @override - def __init__( - self, - max_workers: int | None = None, - job_reducers: dict[type[Any], Callable[[Any], Any]] | None = None, - result_reducers: dict[type[Any], Callable[[Any], Any]] | None = None, - timeout: float | None = None, - context: ForkContext - | SpawnContext - | DefaultContext - | ForkServerContext - | None = None, - initializer: Callable[[], Any] | None = None, - initargs: tuple[Any, ...] = (), - env: dict[str, Any] | None = None, - ) -> None: ... - - @override - def submit( # type: ignore - self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs - ) -> Future[_T]: ... - - @override - def shutdown( # type: ignore - self, wait: bool = True, kill_workers: bool = False - ) -> None: ... - -else: - ProcessPoolExecutor = LockyProcessPoolExecutor diff --git a/src/timeout_executor/concurrent/futures/backend/_multiprocessing/__init__.py b/src/timeout_executor/concurrent/futures/backend/_multiprocessing/__init__.py deleted file mode 100644 index 364ac11..0000000 --- a/src/timeout_executor/concurrent/futures/backend/_multiprocessing/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from .process import ProcessPoolExecutor - -__all__ = ["ProcessPoolExecutor"] diff --git a/src/timeout_executor/concurrent/futures/backend/_multiprocessing/process.py b/src/timeout_executor/concurrent/futures/backend/_multiprocessing/process.py deleted file mode 100644 index 45f29fb..0000000 --- a/src/timeout_executor/concurrent/futures/backend/_multiprocessing/process.py +++ /dev/null @@ -1,661 +0,0 @@ -"""obtained from concurrent.futures.process""" - -from __future__ import annotations - -import multiprocessing as mp -import os -import queue -import sys -import threading -import traceback -import weakref -from concurrent.futures import _base -from concurrent.futures.process import ( - _MAX_WINDOWS_WORKERS, - EXTRA_QUEUED_CALLS, - BrokenProcessPool, - _CallItem, - _chain_from_iterable_of_lists, - _check_system_limits, - _get_chunks, - _global_shutdown, - _process_chunk, - _WorkItem, -) -from functools import partial -from multiprocessing import util as mp_util -from multiprocessing.connection import wait as mp_wait -from multiprocessing.queues import Queue -from traceback import format_exception -from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Union - -from typing_extensions import ParamSpec, TypeAlias, TypeVar, override - -if TYPE_CHECKING: - from concurrent.futures import Future - from multiprocessing.connection import Connection - from multiprocessing.context import ( - DefaultContext, - ForkContext, - ForkServerContext, - SpawnContext, - ) - from multiprocessing.process import BaseProcess - from multiprocessing.queues import SimpleQueue - from threading import Lock - from types import TracebackType - - Context: TypeAlias = Union[ - SpawnContext, ForkContext, ForkServerContext, DefaultContext - ] - - _P = ParamSpec("_P") - _T = TypeVar("_T", infer_variance=True) - - -class _ThreadWakeup: - _reader: Connection - _writer: Connection - - def __init__(self) -> None: - self._closed = False - self._reader, self._writer = mp.Pipe(duplex=False) - - def close(self) -> None: - if not self._closed: - self._closed = True - self._writer.close() - self._reader.close() - - def wakeup(self) -> None: - if not self._closed: - self._writer.send_bytes(b"") - - def clear(self) -> None: - while self._reader.poll(): - self._reader.recv_bytes() - - -class _RemoteTraceback(Exception): # noqa: N818 - def __init__(self, tb: str) -> None: - self.tb = tb - - def __str__(self) -> str: - return self.tb - - -class _ExceptionWithTraceback: - def __init__(self, exc: BaseException, tb: TracebackType | None) -> None: - tb_text = "".join(format_exception(type(exc), exc, tb)) - self.exc = exc - # Traceback object needs to be garbage-collected as its frames - # contain references to all the objects in the exception scope - self.exc.__traceback__ = None - self.tb = '\n"""\n%s"""' % tb_text - - def __reduce__( - self, - ) -> tuple[ - Callable[[BaseException, str], BaseException], tuple[BaseException, str] - ]: - return _rebuild_exc, (self.exc, self.tb) - - -class _ResultItem: - def __init__( - self, - work_id: int, - exception: Exception | _ExceptionWithTraceback | None = None, - result: Any | None = None, - exit_pid: int | None = None, - ) -> None: - self.work_id = work_id - self.exception = exception - self.result = result - self.exit_pid = exit_pid - - -class _SafeQueue(Queue): - """Safe Queue set exception to the future object linked to a job""" - - def __init__( # noqa: PLR0913 - self, - max_size: int = 0, - *, - ctx: Context, - pending_work_items: dict[int, _WorkItem[Any]], - shutdown_lock: Lock, - thread_wakeup: _ThreadWakeup, - ) -> None: - self.pending_work_items = pending_work_items - self.shutdown_lock = shutdown_lock - self.thread_wakeup = thread_wakeup - super().__init__(max_size, ctx=ctx) - - def _on_queue_feeder_error(self, e: Exception, obj: Any) -> None: - if isinstance(obj, _CallItem): - tb = traceback.format_exception(type(e), e, e.__traceback__) - e.__cause__ = _RemoteTraceback('\n"""\n{}"""'.format("".join(tb))) - work_item = self.pending_work_items.pop(obj.work_id, None) - with self.shutdown_lock: - self.thread_wakeup.wakeup() - - if work_item is not None: - work_item.future.set_exception(e) - else: - self._mp_on_queue_feeder_error(e, obj) - - @staticmethod - def _mp_on_queue_feeder_error(e: Any, obj: Any) -> None: # noqa: ARG004 - import traceback - - traceback.print_exc() - - -class ProcessPoolExecutor(_base.Executor): - """process pool executor""" - - _mp_context: Context - - def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915 - self, - max_workers: int | None = None, - mp_context: Context | None = None, - initializer: Callable[..., Any] | None = None, - initargs: tuple[Any, ...] = (), - *, - max_tasks_per_child: int | None = None, - ) -> None: - """Initializes a new ProcessPoolExecutor instance. - - Args: - max_workers: The maximum number of processes that can be used to - execute the given calls. If None or not given then as many - worker processes will be created as the machine has processors. - mp_context: A multiprocessing context to launch the workers. This - object should provide SimpleQueue, Queue and Process. Useful - to allow specific multiprocessing start methods. - initializer: A callable used to initialize worker processes. - initargs: A tuple of arguments to pass to the initializer. - max_tasks_per_child: The maximum number of tasks a worker process - can complete before it will exit and be replaced with a fresh - worker process. The default of None means worker process will - live as long as the executor. Requires a non-'fork' mp_context - start method. When given, we default to using 'spawn' if no - mp_context is supplied. - """ - _check_system_limits() - - if max_workers is None: - self._max_workers = os.cpu_count() or 1 - if sys.platform == "win32": - self._max_workers = min(_MAX_WINDOWS_WORKERS, self._max_workers) - else: - if max_workers <= 0: - raise ValueError("max_workers must be greater than 0") - if sys.platform == "win32" and max_workers > _MAX_WINDOWS_WORKERS: - error_msg = f"max_workers must be <= {_MAX_WINDOWS_WORKERS}" - raise ValueError(error_msg) - - self._max_workers = max_workers - - if mp_context is None: - if max_tasks_per_child is not None: - self._mp_context = mp.get_context("spawn") - elif sys.platform in {"linux", "linux2", "darwin"}: - self._mp_context = mp.get_context("fork") - else: - self._mp_context = mp.get_context() - mp_context = self._mp_context - else: - self._mp_context = mp_context - - self._safe_to_dynamically_spawn_children = ( - self._mp_context.get_start_method(allow_none=False) != "fork" - ) - - if initializer is not None and not callable(initializer): - raise TypeError("initializer must be a callable") - self._initializer = initializer - self._initargs = initargs - - if max_tasks_per_child is not None: - if not isinstance(max_tasks_per_child, int): - raise TypeError("max_tasks_per_child must be an integer") - if max_tasks_per_child <= 0: - raise ValueError("max_tasks_per_child must be >= 1") - if self._mp_context.get_start_method(allow_none=False) == "fork": - raise ValueError( - "max_tasks_per_child is incompatible with" - " the 'fork' multiprocessing start method;" - " supply a different mp_context." - ) - self._max_tasks_per_child = max_tasks_per_child - self._executor_manager_thread = None - self._processes: dict[int | None, BaseProcess] = {} - self._shutdown_thread = False - self._shutdown_lock = threading.Lock() - self._idle_worker_semaphore = threading.Semaphore(0) - self._broken: str | bool = False - self._queue_count = 0 - self._pending_work_items = {} - self._cancel_pending_futures = False - - self._executor_manager_thread_wakeup: _ThreadWakeup = _ThreadWakeup() - - queue_size = self._max_workers + EXTRA_QUEUED_CALLS - self._call_queue: _SafeQueue = _SafeQueue( - max_size=queue_size, - ctx=self._mp_context, - pending_work_items=self._pending_work_items, - shutdown_lock=self._shutdown_lock, - thread_wakeup=self._executor_manager_thread_wakeup, - ) - self._call_queue._ignore_epipe = True # noqa: SLF001 # type: ignore - self._result_queue: SimpleQueue = mp_context.SimpleQueue() - self._work_ids = queue.Queue() - - def _start_executor_manager_thread(self) -> None: - if self._executor_manager_thread is None: - if not self._safe_to_dynamically_spawn_children: - self._launch_processes() - self._executor_manager_thread = _ExecutorManagerThread(self) - self._executor_manager_thread.start() - _threads_wakeups[self._executor_manager_thread] = ( - self._executor_manager_thread_wakeup - ) - - def _adjust_process_count(self) -> None: - if self._idle_worker_semaphore.acquire(blocking=False): - return - - process_count = len(self._processes) - if process_count < self._max_workers: - self._spawn_process() - - def _launch_processes(self) -> None: - # https://github.com/python/cpython/issues/90622 - assert not self._executor_manager_thread, ( # noqa: S101 - "Processes cannot be fork()ed after the thread has started, " - "deadlock in the child processes could result." - ) - for _ in range(len(self._processes), self._max_workers): - self._spawn_process() - - def _spawn_process(self) -> None: - p = self._mp_context.Process( - target=_process_worker, - args=( - self._call_queue, - self._result_queue, - self._initializer, - self._initargs, - self._max_tasks_per_child, - ), - ) - p.start() - self._processes[p.pid] = p - - @override - def submit( - self, fn: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs - ) -> Future[_T]: - with self._shutdown_lock: - if self._broken: - raise BrokenProcessPool(self._broken) - if self._shutdown_thread: - raise RuntimeError("cannot schedule new futures after shutdown") - if _global_shutdown: - raise RuntimeError( - "cannot schedule new futures after interpreter shutdown" - ) - - f = _base.Future() - w = _WorkItem(f, fn, args, kwargs) - - self._pending_work_items[self._queue_count] = w - self._work_ids.put(self._queue_count) - self._queue_count += 1 - # Wake up queue management thread - self._executor_manager_thread_wakeup.wakeup() - - if self._safe_to_dynamically_spawn_children: - self._adjust_process_count() - self._start_executor_manager_thread() - return f - - @override - def map( - self, - fn: Callable[..., _T], - *iterables: Iterable[Any], - timeout: float | None = None, - chunksize: int = 1, - ) -> Iterator[_T]: - if chunksize < 1: - raise ValueError("chunksize must be >= 1.") - - results = super().map( - partial(_process_chunk, fn), - _get_chunks(*iterables, chunksize=chunksize), - timeout=timeout, - ) - return _chain_from_iterable_of_lists(results) - - @override - def shutdown(self, wait: bool = True, cancel_futures: bool = False) -> None: - with self._shutdown_lock: - self._cancel_pending_futures = cancel_futures - self._shutdown_thread = True - if self._executor_manager_thread_wakeup is not None: - # Wake up queue management thread - self._executor_manager_thread_wakeup.wakeup() - - if self._executor_manager_thread is not None and wait: - self._executor_manager_thread.join() - # To reduce the risk of opening too many files, remove references to - # objects that use file descriptors. - self._executor_manager_thread = None - self._call_queue = None # type: ignore - if self._result_queue is not None and wait: - for _pipe_attr in "_reader", "_writer": - if (_pipe := getattr(self._result_queue, _pipe_attr, None)) is not None: - _pipe.close() - self._result_queue = None # type: ignore - self._processes = None # type: ignore - self._executor_manager_thread_wakeup = None # type: ignore - - -def _process_worker( - call_queue: Queue, - result_queue: SimpleQueue, - initializer: Callable[..., Any], - initargs: tuple[Any, ...], - max_tasks: int | None = None, -) -> None: - """Evaluates calls from call_queue and places the results in result_queue. - - This worker is run in a separate process. - - Args: - call_queue: A ctx.Queue of _CallItems that will be read and - evaluated by the worker. - result_queue: A ctx.Queue of _ResultItems that will written - to by the worker. - initializer: A callable initializer, or None - initargs: A tuple of args for the initializer - """ - if initializer is not None: - try: - initializer(*initargs) - except BaseException: # noqa: BLE001 - _base.LOGGER.critical("Exception in initializer:", exc_info=True) - return - num_tasks = 0 - exit_pid = None - while True: - call_item = call_queue.get(block=True) - if call_item is None: - result_queue.put(os.getpid()) - return - - if max_tasks is not None: - num_tasks += 1 - if num_tasks >= max_tasks: - exit_pid = os.getpid() - - try: - r = call_item.fn(*call_item.args, **call_item.kwargs) - except BaseException as e: # noqa: BLE001 - exc = _ExceptionWithTraceback(e, e.__traceback__) - _sendback_result( - result_queue, call_item.work_id, exception=exc, exit_pid=exit_pid - ) - else: - _sendback_result( - result_queue, call_item.work_id, result=r, exit_pid=exit_pid - ) - del r - - del call_item - - if exit_pid is not None: - return - - -class _ExecutorManagerThread(threading.Thread): - """Manages the communication between this process and the worker processes. - - The manager is run in a local thread. - - Args: - executor: A reference to the ProcessPoolExecutor that owns - this thread. A weakref will be own by the manager as well as - references to internal objects used to introspect the state of - the executor. - """ - - def __init__(self, executor: ProcessPoolExecutor) -> None: - self.thread_wakeup = executor._executor_manager_thread_wakeup # noqa: SLF001 - self.shutdown_lock = executor._shutdown_lock # noqa: SLF001 - - def weakref_cb( - _: Any, - thread_wakeup: _ThreadWakeup = self.thread_wakeup, - shutdown_lock: Lock = self.shutdown_lock, - ) -> None: - mp_util.debug( - "Executor collected: triggering callback for QueueManager wakeup" - ) - with shutdown_lock: - thread_wakeup.wakeup() - - self.executor_reference: weakref.ReferenceType[ProcessPoolExecutor] = ( - weakref.ref(executor, weakref_cb) - ) - self.processes = executor._processes # noqa: SLF001 - self.call_queue = executor._call_queue # noqa: SLF001 - self.result_queue = executor._result_queue # noqa: SLF001 - self.work_ids_queue = executor._work_ids # noqa: SLF001 - self.max_tasks_per_child = executor._max_tasks_per_child # noqa: SLF001 - self.pending_work_items = executor._pending_work_items # noqa: SLF001 - - super().__init__() - - def run(self) -> None: - while True: - self.add_call_item_to_queue() - - result_item, is_broken, cause = self.wait_result_broken_or_wakeup() - - if is_broken: - self.terminate_broken(cause) - return - if result_item is not None: - self.process_result_item(result_item) - - process_exited = result_item.exit_pid is not None - if process_exited: - p = self.processes.pop(result_item.exit_pid) - p.join() - - del result_item - - if executor := self.executor_reference(): - if process_exited: - with self.shutdown_lock: - executor._adjust_process_count() # noqa: SLF001 - else: - executor._idle_worker_semaphore.release() # noqa: SLF001 - del executor - - if self.is_shutting_down(): - self.flag_executor_shutting_down() - self.add_call_item_to_queue() - if not self.pending_work_items: - self.join_executor_internals() - return - - def add_call_item_to_queue(self) -> None: - while True: - if self.call_queue.full(): - return - try: - work_id = self.work_ids_queue.get(block=False) - except queue.Empty: - return - else: - work_item = self.pending_work_items[work_id] - - if work_item.future.set_running_or_notify_cancel(): - self.call_queue.put( - _CallItem( - work_id, work_item.fn, work_item.args, work_item.kwargs - ), - block=True, - ) - else: - del self.pending_work_items[work_id] - continue - - def wait_result_broken_or_wakeup(self) -> tuple[Any, bool, list[str] | None]: - result_reader = self.result_queue._reader # noqa: SLF001 # type: ignore - assert not self.thread_wakeup._closed # noqa: SLF001,S101 - wakeup_reader = self.thread_wakeup._reader # noqa: SLF001 - readers = [result_reader, wakeup_reader] - worker_sentinels = [p.sentinel for p in list(self.processes.values())] - ready = mp_wait(readers + worker_sentinels) - - cause = None - is_broken = True - result_item = None - if result_reader in ready: - try: - result_item = result_reader.recv() - is_broken = False - except BaseException as e: # noqa: BLE001 - cause = format_exception(type(e), e, e.__traceback__) - - elif wakeup_reader in ready: - is_broken = False - - with self.shutdown_lock: - self.thread_wakeup.clear() - - return result_item, is_broken, cause - - def process_result_item(self, result_item: _ResultItem) -> None: - if isinstance(result_item, int): - assert self.is_shutting_down() # noqa: S101 - p = self.processes.pop(result_item) - p.join() - if not self.processes: - self.join_executor_internals() - return - else: - work_item = self.pending_work_items.pop(result_item.work_id, None) - if work_item is not None: - if result_item.exception: - work_item.future.set_exception(result_item.exception) - else: - work_item.future.set_result(result_item.result) - - def is_shutting_down(self) -> bool: - executor = self.executor_reference() - return ( - _global_shutdown or executor is None or executor._shutdown_thread # noqa: SLF001 - ) - - def terminate_broken(self, cause: list[str] | None) -> None: - executor = self.executor_reference() - if executor is not None: - executor._broken = ( # noqa: SLF001 - "A child process terminated " - "abruptly, the process pool is not " - "usable anymore" - ) - executor._shutdown_thread = True # noqa: SLF001 - executor = None - - bpe = BrokenProcessPool( - "A process in the process pool was " - "terminated abruptly while the future was " - "running or pending." - ) - if cause is not None: - bpe.__cause__ = _RemoteTraceback(f"\n'''\n{''.join(cause)}'''") - - for work_item in self.pending_work_items.values(): - work_item.future.set_exception(bpe) - del work_item - self.pending_work_items.clear() - - for p in self.processes.values(): - p.terminate() - - self.join_executor_internals() - - def flag_executor_shutting_down(self) -> None: - executor = self.executor_reference() - if executor is not None: - executor._shutdown_thread = True # noqa: SLF001 - if executor._cancel_pending_futures: # noqa: SLF001 - new_pending_work_items = {} - for work_id, work_item in self.pending_work_items.items(): - if not work_item.future.cancel(): - new_pending_work_items[work_id] = work_item - self.pending_work_items = new_pending_work_items - while True: - try: - self.work_ids_queue.get_nowait() - except queue.Empty: # noqa: PERF203 - break - executor._cancel_pending_futures = False # noqa: SLF001 - - def shutdown_workers(self) -> None: - n_children_to_stop = self.get_n_children_alive() - n_sentinels_sent = 0 - while n_sentinels_sent < n_children_to_stop and self.get_n_children_alive() > 0: - for i in range(n_children_to_stop - n_sentinels_sent): # noqa: B007 - try: - self.call_queue.put_nowait(None) - n_sentinels_sent += 1 - except queue.Full: # noqa: PERF203 - break - - def join_executor_internals(self) -> None: - self.shutdown_workers() - self.call_queue.close() - self.call_queue.join_thread() - with self.shutdown_lock: - self.thread_wakeup.close() - for p in self.processes.values(): - p.join() - - def get_n_children_alive(self) -> int: - return sum(p.is_alive() for p in self.processes.values()) - - -def _sendback_result( - result_queue: SimpleQueue, - work_id: int, - result: Any | None = None, - exception: Exception | _ExceptionWithTraceback | None = None, - exit_pid: int | None = None, -) -> None: - """Safely send back the given result or exception""" - try: - result_queue.put( - _ResultItem(work_id, result=result, exception=exception, exit_pid=exit_pid) - ) - except BaseException as e: # noqa: BLE001 - exc = _ExceptionWithTraceback(e, e.__traceback__) - result_queue.put(_ResultItem(work_id, exception=exc, exit_pid=exit_pid)) - - -def _rebuild_exc(exc: BaseException, tb: str) -> BaseException: - exc.__cause__ = _RemoteTraceback(tb) - return exc - - -_threads_wakeups = weakref.WeakKeyDictionary() diff --git a/src/timeout_executor/concurrent/main.py b/src/timeout_executor/concurrent/main.py deleted file mode 100644 index bf1142c..0000000 --- a/src/timeout_executor/concurrent/main.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import annotations - -from importlib import import_module -from importlib.util import find_spec -from typing import TYPE_CHECKING, Literal, overload - -if TYPE_CHECKING: - from .futures.backend import _billiard as billiard_future - from .futures.backend import _loky as loky_future - from .futures.backend import _multiprocessing as multiprocessing_future - -__all__ = ["get_executor_backend"] - -BackendType = Literal["billiard", "multiprocessing", "loky"] -DEFAULT_BACKEND = "multiprocessing" - - -@overload -def get_executor_backend( - backend: Literal["multiprocessing"] | None = ..., -) -> type[multiprocessing_future.ProcessPoolExecutor]: ... - - -@overload -def get_executor_backend( - backend: Literal["billiard"] = ..., -) -> type[billiard_future.ProcessPoolExecutor]: ... - - -@overload -def get_executor_backend( - backend: Literal["loky"] = ..., -) -> type[loky_future.ProcessPoolExecutor]: ... - - -def get_executor_backend( - backend: BackendType | None = None, -) -> type[ - billiard_future.ProcessPoolExecutor - | multiprocessing_future.ProcessPoolExecutor - | loky_future.ProcessPoolExecutor -]: - """get pool executor - - Args: - backend: billiard or multiprocessing or loky. - Defaults to None. - - Returns: - ProcessPoolExecutor - """ - backend = backend or DEFAULT_BACKEND - name = f".futures.backend._{backend}" - spec = find_spec(name, __package__) - if spec is None: - error_msg = f"invalid backend: {backend}" - raise ImportError(error_msg) - module = import_module(name, __package__) - return module.ProcessPoolExecutor diff --git a/src/timeout_executor/const.py b/src/timeout_executor/const.py new file mode 100644 index 0000000..05d17fc --- /dev/null +++ b/src/timeout_executor/const.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +__all__ = ["TIMEOUT_EXECUTOR_INPUT_FILE", "SUBPROCESS_COMMAND"] +TIMEOUT_EXECUTOR_INPUT_FILE = "_TIMEOUT_EXECUTOR_INPUT_FILE" +SUBPROCESS_COMMAND = ( + "from timeout_executor.subprocess import run_in_subprocess;run_in_subprocess()" +) diff --git a/src/timeout_executor/exception.py b/src/timeout_executor/exception.py deleted file mode 100644 index dd43de4..0000000 --- a/src/timeout_executor/exception.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -import sys -from textwrap import indent -from typing import TYPE_CHECKING, Any, Sequence - -from typing_extensions import Self, override - -if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup # type: ignore - -__all__ = ["ExtraError", "ImportErrors"] - - -class ExtraError(ImportError): # noqa: D101 - @override - def __init__( - self, *args: Any, name: str | None = None, path: str | None = None, extra: str - ) -> None: - super().__init__(*args, name=name, path=path) - self._extra = extra - - @property - @override - def msg(self) -> str: # pyright: ignore[reportIncompatibleVariableOverride] - return f"install extra first: {self._extra}" - - def __str__(self) -> str: - return self.msg - - def __repr__(self) -> str: - return f"{type(self).__name__!s}({self.msg!r})" - - @classmethod - def from_import_error(cls, error: ImportError, extra: str) -> Self: - """create from import error - - Args: - error: import error - extra: extra name - - Returns: - extra error - """ - return cls(name=error.name, path=error.path, extra=extra) - - -class ImportErrors(ExceptionGroup[ImportError]): # noqa: D101 - if TYPE_CHECKING: - - @override - def __new__( - cls, __message: str, __exceptions: Sequence[ImportError] - ) -> Self: ... - - exceptions: Sequence[ImportError] # pyright: ignore[reportIncompatibleMethodOverride] - - def render(self, depth: int = 0) -> str: # noqa: D102 - msg = str(self) - if depth: - msg = indent(msg, prefix=" ") - return msg - - def __str__(self) -> str: - return ( - super().__str__() - + "\n" - + indent( - "\n".join(f"[{error.name}] {error!s}" for error in self.exceptions), - prefix=" ", - ) - ) diff --git a/src/timeout_executor/executor.py b/src/timeout_executor/executor.py index 963c3b4..c9d17d3 100644 --- a/src/timeout_executor/executor.py +++ b/src/timeout_executor/executor.py @@ -1,274 +1,230 @@ from __future__ import annotations -import asyncio -from concurrent.futures import wait -from contextlib import contextmanager -from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generator, Literal, overload +import shlex +import subprocess +import sys +import tempfile +from pathlib import Path +from typing import Any, Callable, Coroutine, Generic, overload +from uuid import uuid4 import anyio +import cloudpickle from typing_extensions import ParamSpec, TypeVar -from timeout_executor.concurrent import get_executor_backend -from timeout_executor.log import logger -from timeout_executor.serde import monkey_patch -from timeout_executor.serde.lock import patch_lock +from timeout_executor.const import SUBPROCESS_COMMAND, TIMEOUT_EXECUTOR_INPUT_FILE +from timeout_executor.result import AsyncResult -if TYPE_CHECKING: - from threading import RLock +__all__ = ["TimeoutExecutor", "apply_func", "delay_func"] - from anyio.abc import ObjectSendStream +P = ParamSpec("P") +T = TypeVar("T", infer_variance=True) +P2 = ParamSpec("P2") +T2 = TypeVar("T2", infer_variance=True) - from timeout_executor.concurrent.futures.backend import _billiard as billiard_future - from timeout_executor.concurrent.futures.backend import _loky as loky_future - from timeout_executor.concurrent.futures.backend import ( - _multiprocessing as multiprocessing_future, - ) - from timeout_executor.concurrent.main import BackendType - from timeout_executor.serde.main import PicklerType -__all__ = ["TimeoutExecutor", "get_executor"] +class _Executor(Generic[P, T]): + def __init__(self, timeout: float, func: Callable[P, T]) -> None: + self._timeout = timeout + self._func = func -ParamT = ParamSpec("ParamT") -ResultT = TypeVar("ResultT", infer_variance=True) + def _create_temp_files(self) -> tuple[Path, Path]: + unique_id = uuid4() + temp_dir = Path(tempfile.gettempdir()) / "timeout_executor" + temp_dir.mkdir(exist_ok=True) -class TimeoutExecutor: - """exec with timeout""" + unique_dir = temp_dir / str(unique_id) + unique_dir.mkdir(exist_ok=False) - def __init__( - self, - timeout: float, - backend: BackendType | None = None, - pickler: PicklerType | None = None, - *, - executor: billiard_future.ProcessPoolExecutor - | multiprocessing_future.ProcessPoolExecutor - | loky_future.ProcessPoolExecutor - | None = None, - ) -> None: - self.timeout = timeout - self._init = None - self._args = () - self._kwargs = {} - self._select = (backend, pickler) - self._executor = executor - self._external_executor = executor is not None - - @property - def lock(self) -> RLock: - """patch lock""" - - return patch_lock - - @property - def executor( - self, - ) -> ( - billiard_future.ProcessPoolExecutor - | multiprocessing_future.ProcessPoolExecutor - | loky_future.ProcessPoolExecutor - ): - """process pool executor""" - if self._executor is None: - self._executor = get_executor(self._select[0], self._select[1])( - 1, initializer=self._partial_init() - ) - return self._executor - - @executor.setter - def executor( - self, - executor: billiard_future.ProcessPoolExecutor - | multiprocessing_future.ProcessPoolExecutor - | loky_future.ProcessPoolExecutor, - ) -> None: - if self._executor is not None: - raise AttributeError("executor already exists") - self._executor = executor - self._external_executor = True - - @executor.deleter - def executor(self) -> None: - if self._executor is None: - return - if self._executor._executor_manager_thread_wakeup is None: # noqa: SLF001 - self._executor = None - return - self._executor = None - self._external_executor = False - - @contextmanager - def _enter( - self, - ) -> Generator[ - billiard_future.ProcessPoolExecutor - | multiprocessing_future.ProcessPoolExecutor - | loky_future.ProcessPoolExecutor, - None, - None, - ]: - executor = self.executor - try: - yield executor - finally: - if self._external_executor: - logger.warning("shutdown executor yourself") - else: - executor.shutdown(False, True) # noqa: FBT003 - del self.executor - - def _partial_init(self) -> Callable[[], Any] | None: - if self._init is None: - return None - return partial(self._init, *self._args, **self._kwargs) - - def set_init( - self, init: Callable[ParamT, Any], *args: ParamT.args, **kwargs: ParamT.kwargs - ) -> None: - """set init func + input_file = unique_dir / "input.b" + output_file = unique_dir / "output.b" - Args: - init: pickable func - """ - self._init = init - self._args = args - self._kwargs = kwargs + return input_file, output_file - def apply( - self, - func: Callable[ParamT, ResultT], - *args: ParamT.args, - **kwargs: ParamT.kwargs, - ) -> ResultT: - """apply only pickable func + @staticmethod + def _command() -> list[str]: + return shlex.split(f'{sys.executable} -c "{SUBPROCESS_COMMAND}"') - Both args and kwargs should be pickable. + def apply(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[T]: + input_file, output_file = self._create_temp_files() - Args: - func: pickable func + input_args = (self._func, args, kwargs, output_file) + input_args_as_bytes = cloudpickle.dumps(input_args) + with input_file.open("wb+") as file: + file.write(input_args_as_bytes) - Raises: - TimeoutError: When the time is exceeded - exc: Error during pickable func execution + command = self._command() + process = subprocess.Popen( + command, # noqa: S603 + env={TIMEOUT_EXECUTOR_INPUT_FILE: input_file.as_posix()}, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + return AsyncResult(process, input_file, output_file, self._timeout) - Returns: - pickable func result - """ - with self._enter() as pool: - future = pool.submit(func, *args, **kwargs) - wait([future], timeout=self.timeout) - if not future.done(): - pool.shutdown(False, True) # noqa: FBT003 - error_msg = f"timeout > {self.timeout}s" - raise TimeoutError(error_msg) - return future.result() - - async def apply_async( - self, - func: Callable[ParamT, Coroutine[None, None, ResultT]], - *args: ParamT.args, - **kwargs: ParamT.kwargs, - ) -> ResultT: - """apply only pickable func + async def delay(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[T]: + input_file, output_file = self._create_temp_files() + input_file, output_file = anyio.Path(input_file), anyio.Path(output_file) - Both args and kwargs should be pickable. + input_args = (self._func, args, kwargs, output_file) + input_args_as_bytes = cloudpickle.dumps(input_args) + async with await input_file.open("wb+") as file: + await file.write(input_args_as_bytes) - Args: - func: pickable func + command = self._command() + process = subprocess.Popen( # noqa: ASYNC101 + command, # noqa: S603 + env={TIMEOUT_EXECUTOR_INPUT_FILE: input_file.as_posix()}, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + return AsyncResult(process, input_file, output_file, self._timeout) - Raises: - TimeoutError: When the time is exceeded - exc: Error during pickable func execution - Returns: - pickable func result - """ - with self._enter() as pool: - try: - future = pool.submit( - _async_run, func, *args, _timeout=self.timeout, **kwargs - ) - coro = asyncio.wrap_future(future) - return await coro - except TimeoutError: - pool.shutdown(False, True) # noqa: FBT003 - raise +@overload +def apply_func( + timeout: float, + func: Callable[P2, Coroutine[Any, Any, T2]], + *args: P2.args, + **kwargs: P2.kwargs, +) -> AsyncResult[T2]: ... @overload -def get_executor( - backend: Literal["multiprocessing"] | None = ..., pickler: PicklerType | None = ... -) -> type[multiprocessing_future.ProcessPoolExecutor]: ... +def apply_func( + timeout: float, func: Callable[P2, T2], *args: P2.args, **kwargs: P2.kwargs +) -> AsyncResult[T2]: ... + + +def apply_func( + timeout: float, func: Callable[P2, Any], *args: P2.args, **kwargs: P2.kwargs +) -> AsyncResult[Any]: + """run function with deadline + + Args: + timeout: deadline + func: func(sync or async) + + Returns: + async result container + """ + executor = _Executor(timeout, func) + return executor.apply(*args, **kwargs) @overload -def get_executor( - backend: Literal["billiard"] = ..., pickler: PicklerType | None = ... -) -> type[billiard_future.ProcessPoolExecutor]: ... +async def delay_func( + timeout: float, + func: Callable[P2, Coroutine[Any, Any, T2]], + *args: P2.args, + **kwargs: P2.kwargs, +) -> AsyncResult[T2]: ... @overload -def get_executor( - backend: Literal["loky"] = ..., pickler: PicklerType | None = ... -) -> type[loky_future.ProcessPoolExecutor]: ... +async def delay_func( + timeout: float, func: Callable[P2, T2], *args: P2.args, **kwargs: P2.kwargs +) -> AsyncResult[T2]: ... -def get_executor( - backend: BackendType | None = None, pickler: PicklerType | None = None -) -> type[ - billiard_future.ProcessPoolExecutor - | multiprocessing_future.ProcessPoolExecutor - | loky_future.ProcessPoolExecutor -]: - """get pool executor +async def delay_func( + timeout: float, func: Callable[P2, Any], *args: P2.args, **kwargs: P2.kwargs +) -> AsyncResult[Any]: + """run function with deadline Args: - backend: backend type as string. Defaults to None. - pickler: pickler type as string. Defaults to None. + timeout: deadline + func: func(sync or async) Returns: - ProcessPoolExecutor + async result container """ - backend = backend or "multiprocessing" - executor = get_executor_backend(backend) - monkey_patch(backend, pickler) - return executor - - -def _async_run( - func: Callable[..., Any], *args: Any, _timeout: float, **kwargs: Any -) -> Any: - return asyncio.run( - _async_run_with_timeout(func, *args, _timeout=_timeout, **kwargs) - ) - - -async def _async_run_with_timeout( - func: Callable[..., Any], *args: Any, _timeout: float, **kwargs: Any -) -> Any: - send, recv = anyio.create_memory_object_stream() - async with anyio.create_task_group() as task_group: - with anyio.fail_after(_timeout): - async with send: - task_group.start_soon( - partial( - _async_run_with_stream, - func, - *args, - _stream=send.clone(), - **kwargs, - ) - ) - async with recv: - result = await recv.receive() - - return result - - -async def _async_run_with_stream( - func: Callable[..., Any], *args: Any, _stream: ObjectSendStream[Any], **kwargs: Any -) -> None: - async with _stream: - result = await func(*args, **kwargs) - await _stream.send(result) + executor = _Executor(timeout, func) + return await executor.delay(*args, **kwargs) + + +class TimeoutExecutor: + """timeout executor""" + + def __init__(self, timeout: float) -> None: + self._timeout = timeout + + def _create_executor(self, func: Callable[P, T]) -> _Executor[P, T]: + return _Executor(self._timeout, func) + + @overload + def apply( + self, + func: Callable[P, Coroutine[Any, Any, T]], + *args: P.args, + **kwargs: P.kwargs, + ) -> AsyncResult[T]: ... + @overload + def apply( + self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs + ) -> AsyncResult[T]: ... + def apply( + self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs + ) -> AsyncResult[Any]: + """run function with deadline + + Args: + func: func(sync or async) + + Returns: + async result container + """ + return apply_func(self._timeout, func, *args, **kwargs) + + @overload + async def delay( + self, + func: Callable[P, Coroutine[Any, Any, T]], + *args: P.args, + **kwargs: P.kwargs, + ) -> AsyncResult[T]: ... + @overload + async def delay( + self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs + ) -> AsyncResult[T]: ... + async def delay( + self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs + ) -> AsyncResult[Any]: + """run function with deadline + + Args: + func: func(sync or async) + + Returns: + async result container + """ + return await delay_func(self._timeout, func, *args, **kwargs) + + @overload + async def apply_async( + self, + func: Callable[P, Coroutine[Any, Any, T]], + *args: P.args, + **kwargs: P.kwargs, + ) -> AsyncResult[T]: ... + @overload + async def apply_async( + self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs + ) -> AsyncResult[T]: ... + async def apply_async( + self, func: Callable[P, Any], *args: P.args, **kwargs: P.kwargs + ) -> AsyncResult[Any]: + """run function with deadline. + + alias of `delay` + + Args: + func: func(sync or async) + + Returns: + async result container + """ + return await self.delay(func, *args, **kwargs) diff --git a/src/timeout_executor/log.py b/src/timeout_executor/log.py deleted file mode 100644 index ae52568..0000000 --- a/src/timeout_executor/log.py +++ /dev/null @@ -1,8 +0,0 @@ -from __future__ import annotations - -import logging - -__all__ = ["logger"] - -logger = logging.getLogger("timeout_executor") -logger.setLevel(logging.INFO) diff --git a/src/timeout_executor/readonly.py b/src/timeout_executor/readonly.py deleted file mode 100644 index 56bdbea..0000000 --- a/src/timeout_executor/readonly.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from typing import Any, Generic, NoReturn, final - -from typing_extensions import TypeVar - -ValueT = TypeVar("ValueT", infer_variance=True) - -__all__ = ["ReadOnly"] - - -@final -class ReadOnly(Generic[ValueT]): # noqa: D101 - def __init__(self, value: ValueT) -> None: - self._value = value - - @property - def value(self) -> ValueT: # noqa: D102 - return self._value - - @value.setter - def value(self, value: Any) -> NoReturn: - raise NotImplementedError - - @value.deleter - def value(self) -> NoReturn: - raise NotImplementedError - - def __eq__(self, value: object) -> bool: - return type(value) is type(self.value) and value == self.value # noqa: E721 - - def force_set(self, value: ValueT) -> None: # noqa: D102 - self._value = value diff --git a/src/timeout_executor/result.py b/src/timeout_executor/result.py new file mode 100644 index 0000000..c8ccdd0 --- /dev/null +++ b/src/timeout_executor/result.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import subprocess +from functools import partial +from typing import TYPE_CHECKING, Any, Generic + +import anyio +import cloudpickle +from async_wrapper import async_to_sync, sync_to_async +from typing_extensions import TypeVar + +from timeout_executor.serde import SerializedError, loads_error + +if TYPE_CHECKING: + from pathlib import Path + +__all__ = ["AsyncResult"] + +T = TypeVar("T", infer_variance=True) + +SENTINEL = object() + + +class AsyncResult(Generic[T]): + """async result container""" + + _result: Any + + def __init__( + self, + process: subprocess.Popen, + input_file: Path | anyio.Path, + output_file: Path | anyio.Path, + timeout: float, + ) -> None: + self._process = process + self._timeout = timeout + self._result = SENTINEL + + if not isinstance(output_file, anyio.Path): + output_file = anyio.Path(output_file) + self._output = output_file + + if not isinstance(input_file, anyio.Path): + input_file = anyio.Path(input_file) + self._input = input_file + + def result(self, timeout: float | None = None) -> T: + """get value sync method""" + future = async_to_sync(self.delay) + return future(timeout) + + async def delay(self, timeout: float | None = None) -> T: + """get value async method""" + try: + return await self._delay(timeout) + finally: + with anyio.CancelScope(shield=True): + self._process.terminate() + + async def _delay(self, timeout: float | None) -> T: + if self._process.returncode is not None: + return await self._load_output() + + if timeout is None: + timeout = self._timeout + + try: + await wait_process(self._process, timeout, self._input) + except subprocess.TimeoutExpired as exc: + raise TimeoutError(exc.timeout) from exc + except TimeoutError as exc: + if not exc.args: + raise TimeoutError(timeout) from exc + raise + + return await self._load_output() + + async def _load_output(self) -> T: + if self._result is not SENTINEL: + if isinstance(self._result, SerializedError): + self._result = loads_error(self._result) + if isinstance(self._result, Exception): + raise self._result + return self._result + + if self._process.returncode is None: + raise RuntimeError("process is running") + + if not await self._output.exists(): + raise FileNotFoundError(self._output) + + async with await self._output.open("rb") as file: + value = await file.read() + self._result = cloudpickle.loads(value) + + async with anyio.create_task_group() as task_group: + task_group.start_soon(self._output.unlink, True) # noqa: FBT003 + task_group.start_soon(self._input.unlink, True) # noqa: FBT003 + await self._output.parent.rmdir() + return await self._load_output() + + +async def wait_process( + process: subprocess.Popen, timeout: float, input_file: Path | anyio.Path +) -> None: + wait_func = partial(sync_to_async(process.wait), timeout) + if not isinstance(input_file, anyio.Path): + input_file = anyio.Path(input_file) + + try: + with anyio.fail_after(timeout): + await wait_func() + finally: + with anyio.CancelScope(shield=True): + if process.returncode is not None: + await input_file.unlink(missing_ok=False) diff --git a/src/timeout_executor/serde.py b/src/timeout_executor/serde.py new file mode 100644 index 0000000..cd42ce6 --- /dev/null +++ b/src/timeout_executor/serde.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from collections import deque +from dataclasses import dataclass +from operator import itemgetter +from types import TracebackType +from typing import Any + +import cloudpickle +from tblib.pickling_support import ( + pickle_exception, + pickle_traceback, + unpickle_exception, + unpickle_traceback, +) + +__all__ = ["dumps_error", "loads_error", "serialize_error", "deserialize_error"] + + +@dataclass(frozen=True) +class SerializedError: + arg_exception: tuple[Any, ...] + arg_tracebacks: tuple[tuple[int, tuple[Any, ...]], ...] + + exception: tuple[Any, ...] + tracebacks: tuple[tuple[int, tuple[Any, ...]], ...] + + +def serialize_traceback(traceback: TracebackType) -> tuple[Any, ...]: + return pickle_traceback(traceback) + + +def serialize_error(error: Exception) -> SerializedError: + """serialize exception""" + exception = pickle_exception(error)[1:] + + exception_args, exception = exception[0], exception[1:] + + arg_result: deque[Any] = deque() + arg_tracebacks: deque[tuple[int, tuple[Any, ...]]] = deque() + + exception_result: deque[Any] = deque() + tracebacks: deque[tuple[int, tuple[Any, ...]]] = deque() + + for result, tb_result, args in zip( + (arg_result, exception_result), + (arg_tracebacks, tracebacks), + (exception_args, exception), + ): + for index, value in enumerate(args): + if not isinstance(value, TracebackType): + result.append(value) + continue + new = serialize_traceback(value)[1] + tb_result.append((index, new)) + + return SerializedError( + arg_exception=tuple(arg_result), + arg_tracebacks=tuple(arg_tracebacks), + exception=tuple(exception_result), + tracebacks=tuple(tracebacks), + ) + + +def deserialize_error(error: SerializedError) -> Exception: + """deserialize exception""" + arg_exception: deque[Any] = deque(error.arg_exception) + arg_tracebacks: deque[tuple[int, tuple[Any, ...]]] = deque(error.arg_tracebacks) + + exception: deque[Any] = deque(error.exception) + tracebacks: deque[tuple[int, tuple[Any, ...]]] = deque(error.tracebacks) + + for result, tb_result in zip( + (arg_exception, exception), (arg_tracebacks, tracebacks) + ): + for salt, (index, value) in enumerate(sorted(tb_result, key=itemgetter(0))): + traceback = unpickle_traceback(*value) + result.insert(index + salt, traceback) + + return unpickle_exception(*arg_exception, *exception) + + +def dumps_error(error: Exception | SerializedError) -> bytes: + """serialize exception as bytes""" + if not isinstance(error, SerializedError): + error = serialize_error(error) + + return cloudpickle.dumps(error) + + +def loads_error(error: bytes | SerializedError) -> Exception: + """deserialize exception from bytes""" + if isinstance(error, bytes): + error = cloudpickle.loads(error) + if not isinstance(error, SerializedError): + error_msg = f"error is not SerializedError object: {type(error).__name__}" + raise TypeError(error_msg) + + return deserialize_error(error) diff --git a/src/timeout_executor/serde/__init__.py b/src/timeout_executor/serde/__init__.py deleted file mode 100644 index c68ff9d..0000000 --- a/src/timeout_executor/serde/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from __future__ import annotations - -from .main import monkey_patch - -__all__ = ["monkey_patch"] diff --git a/src/timeout_executor/serde/backend/__init__.py b/src/timeout_executor/serde/backend/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/timeout_executor/serde/backend/_billiard/__init__.py b/src/timeout_executor/serde/backend/_billiard/__init__.py deleted file mode 100644 index 150de4a..0000000 --- a/src/timeout_executor/serde/backend/_billiard/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import annotations - -from .const import order, replace, unpatch -from .patch import monkey_patch, monkey_unpatch - -__all__ = ["monkey_patch", "monkey_unpatch", "replace", "unpatch", "order"] diff --git a/src/timeout_executor/serde/backend/_billiard/const.py b/src/timeout_executor/serde/backend/_billiard/const.py deleted file mode 100644 index 5a846e5..0000000 --- a/src/timeout_executor/serde/backend/_billiard/const.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from timeout_executor.serde.main import PicklerType - - -unpatch: frozenset[PicklerType] = frozenset({"pickle"}) -replace: dict[PicklerType, PicklerType] = {"pickle": "dill"} -order: tuple[PicklerType, ...] = ("dill", "cloudpickle") diff --git a/src/timeout_executor/serde/backend/_billiard/patch.py b/src/timeout_executor/serde/backend/_billiard/patch.py deleted file mode 100644 index be0e73b..0000000 --- a/src/timeout_executor/serde/backend/_billiard/patch.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Final - -from timeout_executor.exception import ExtraError -from timeout_executor.readonly import ReadOnly - -if TYPE_CHECKING: - from timeout_executor.serde.base import Pickler - -__all__ = ["monkey_patch", "monkey_unpatch"] - -billiard_origin: ReadOnly[type[Pickler]] = ReadOnly(None) # type: ignore -billiard_origin_status: Final[str] = "billiard" -billiard_status: ReadOnly[str] = ReadOnly(billiard_origin_status) - - -def monkey_patch(name: str, pickler: type[Pickler]) -> None: - """patch billiard""" - from timeout_executor.serde.lock import patch_lock - - with patch_lock: - if billiard_status == name: - return - - _set_origin() - try: - from billiard import ( # type: ignore - connection, # type: ignore - queues, # type: ignore - reduction, # type: ignore - sharedctypes, # type: ignore - ) - except ImportError as exc: - error = ExtraError.from_import_error(exc, extra="billiard") - raise error from exc - - origin_register: dict[type[Any], Callable[[Any], Any]] = ( - reduction.ForkingPickler._extra_reducers # noqa: SLF001 # pyright: ignore[reportAttributeAccessIssue] - ) # type: ignore - reduction.ForkingPickler = pickler - reduction.register = pickler.register - pickler._extra_reducers.update(origin_register) # noqa: SLF001 - queues.ForkingPickler = pickler - connection.ForkingPickler = pickler - sharedctypes.ForkingPickler = pickler - - billiard_status.force_set(name) - - -def monkey_unpatch() -> None: - """unpatch billiard""" - from timeout_executor.serde.lock import patch_lock - - with patch_lock: - try: - from billiard import ( # type: ignore - connection, # type: ignore - queues, # type: ignore - reduction, # type: ignore - sharedctypes, # type: ignore - ) - except ImportError as exc: - error = ExtraError.from_import_error(exc, extra="billiard") - raise error from exc - - if billiard_status == billiard_origin_status: - return - if billiard_origin.value is None: - raise RuntimeError("origin is None") - - reduction.ForkingPickler = billiard_origin.value - reduction.register = billiard_origin.value.register - queues.ForkingPickler = billiard_origin.value - connection.ForkingPickler = billiard_origin.value - sharedctypes.ForkingPickler = billiard_origin.value - - billiard_status.force_set(billiard_origin_status) - - -def _set_origin() -> None: - if billiard_origin.value is not None: - return - - try: - from billiard.reduction import ForkingPickler - except ImportError as exc: - error = ExtraError.from_import_error(exc, extra="billiard") - raise error from exc - - billiard_origin.force_set(ForkingPickler) # type: ignore diff --git a/src/timeout_executor/serde/backend/_loky/__init__.py b/src/timeout_executor/serde/backend/_loky/__init__.py deleted file mode 100644 index 150de4a..0000000 --- a/src/timeout_executor/serde/backend/_loky/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import annotations - -from .const import order, replace, unpatch -from .patch import monkey_patch, monkey_unpatch - -__all__ = ["monkey_patch", "monkey_unpatch", "replace", "unpatch", "order"] diff --git a/src/timeout_executor/serde/backend/_loky/const.py b/src/timeout_executor/serde/backend/_loky/const.py deleted file mode 100644 index 38c44b8..0000000 --- a/src/timeout_executor/serde/backend/_loky/const.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from timeout_executor.serde.main import PicklerType - - -unpatch: frozenset[PicklerType] = frozenset({"pickle", "cloudpickle"}) -replace: dict[PicklerType, PicklerType] = {"pickle": "cloudpickle"} -order: tuple[PicklerType, ...] = ("cloudpickle", "dill") diff --git a/src/timeout_executor/serde/backend/_loky/patch.py b/src/timeout_executor/serde/backend/_loky/patch.py deleted file mode 100644 index 79cfd06..0000000 --- a/src/timeout_executor/serde/backend/_loky/patch.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from timeout_executor.exception import ExtraError -from timeout_executor.log import logger - -if TYPE_CHECKING: - from timeout_executor.serde.base import Pickler - -__all__ = ["monkey_patch", "monkey_unpatch"] - - -def monkey_patch(name: str, pickler: type[Pickler]) -> None: # noqa: ARG001 - """patch loky""" - from timeout_executor.serde.lock import patch_lock - - with patch_lock: - if name == "pickle": - logger.warning("loky uses cloudpickle as the default") - name = "cloudpickle" - try: - from loky.backend.reduction import ( # type: ignore - get_loky_pickler_name, # type: ignore - set_loky_pickler, # type: ignore - ) - except ImportError as exc: - error = ExtraError.from_import_error(exc, extra="loky") - raise error from exc - - if get_loky_pickler_name() == name: - return - set_loky_pickler(name) - - -def monkey_unpatch() -> None: - """unpatch loky""" - from timeout_executor.serde.lock import patch_lock - - with patch_lock: - try: - from loky.backend.reduction import ( # type: ignore - get_loky_pickler_name, # type: ignore - set_loky_pickler, # type: ignore - ) - except ImportError as exc: - error = ExtraError.from_import_error(exc, extra="loky") - raise error from exc - - if get_loky_pickler_name == "cloudpickle": - return - set_loky_pickler() diff --git a/src/timeout_executor/serde/backend/_multiprocessing/__init__.py b/src/timeout_executor/serde/backend/_multiprocessing/__init__.py deleted file mode 100644 index 150de4a..0000000 --- a/src/timeout_executor/serde/backend/_multiprocessing/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from __future__ import annotations - -from .const import order, replace, unpatch -from .patch import monkey_patch, monkey_unpatch - -__all__ = ["monkey_patch", "monkey_unpatch", "replace", "unpatch", "order"] diff --git a/src/timeout_executor/serde/backend/_multiprocessing/const.py b/src/timeout_executor/serde/backend/_multiprocessing/const.py deleted file mode 100644 index 062bcea..0000000 --- a/src/timeout_executor/serde/backend/_multiprocessing/const.py +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from timeout_executor.serde.main import PicklerType - - -unpatch: frozenset[PicklerType] = frozenset({"pickle"}) -replace: dict[PicklerType, PicklerType] = {} -order: tuple[PicklerType, ...] = ("dill", "cloudpickle", "pickle") diff --git a/src/timeout_executor/serde/backend/_multiprocessing/patch.py b/src/timeout_executor/serde/backend/_multiprocessing/patch.py deleted file mode 100644 index a3413ba..0000000 --- a/src/timeout_executor/serde/backend/_multiprocessing/patch.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Callable, Final - -from timeout_executor.readonly import ReadOnly - -if TYPE_CHECKING: - from timeout_executor.serde.base import Pickler - -__all__ = ["monkey_patch", "monkey_unpatch"] - -multiprocessing_origin: ReadOnly[type[Pickler]] = ReadOnly(None) # type: ignore -multiprocessing_origin_status: Final[str] = "multiprocessing" -multiprocessing_status: ReadOnly[str] = ReadOnly(multiprocessing_origin_status) - - -def monkey_patch(name: str, pickler: type[Pickler]) -> None: - """patch multiprocessing""" - from timeout_executor.serde.lock import patch_lock - - with patch_lock: - if multiprocessing_status == name: - return - - _set_origin() - from multiprocessing import connection, queues, reduction, sharedctypes - - origin_register: dict[type[Any], Callable[[Any], Any]] = ( - reduction.ForkingPickler._extra_reducers # noqa: SLF001 # pyright: ignore[reportAttributeAccessIssue] - ) # type: ignore - reduction.ForkingPickler = pickler - reduction.register = pickler.register # type: ignore - pickler._extra_reducers.update(origin_register) # noqa: SLF001 - reduction.AbstractReducer.ForkingPickler = pickler # type: ignore - queues._ForkingPickler = pickler # noqa: SLF001 # type: ignore - connection._ForkingPickler = pickler # noqa: SLF001 # type: ignore - sharedctypes._ForkingPickler = pickler # noqa: SLF001 # type: ignore - - multiprocessing_status.force_set(name) - - -def monkey_unpatch() -> None: - """unpatch multiprocessing""" - from timeout_executor.serde.lock import patch_lock - - with patch_lock: - from multiprocessing import connection, queues, reduction, sharedctypes - - if multiprocessing_status == multiprocessing_origin_status: - return - if multiprocessing_origin.value is None: - raise RuntimeError("origin is None") - - reduction.ForkingPickler = multiprocessing_origin.value - reduction.register = multiprocessing_origin.value.register - reduction.AbstractReducer.ForkingPickler = ( # type: ignore - multiprocessing_origin.value - ) - queues._ForkingPickler = ( # noqa: SLF001 # type: ignore - multiprocessing_origin.value - ) - connection._ForkingPickler = ( # noqa: SLF001 # type: ignore - multiprocessing_origin.value - ) - sharedctypes._ForkingPickler = ( # noqa: SLF001 # type: ignore - multiprocessing_origin.value - ) - - multiprocessing_status.force_set(multiprocessing_origin_status) - - -def _set_origin() -> None: - if multiprocessing_origin.value is not None: - return - - from multiprocessing.reduction import ForkingPickler - - multiprocessing_origin.force_set(ForkingPickler) # type: ignore diff --git a/src/timeout_executor/serde/base.py b/src/timeout_executor/serde/base.py deleted file mode 100644 index 6b7f428..0000000 --- a/src/timeout_executor/serde/base.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import annotations - -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Protocol, - Type, - runtime_checkable, -) - -from typing_extensions import TypeVar - -if TYPE_CHECKING: - import io - from types import ModuleType - - from timeout_executor.serde.main import PicklerType - - class BackendModule(ModuleType): # noqa: D101 - unpatch: frozenset[PicklerType] - replace: dict[PicklerType, PicklerType] - order: tuple[PicklerType] - - monkey_patch: Monkey - monkey_unpatch: UnMonkey - - class PicklerModule(ModuleType): # noqa: D101 - Pickler: type[Pickler] - - -ValueT = TypeVar("ValueT", infer_variance=True) - -__all__ = ["Pickler", "Monkey", "UnMonkey", "BackendModule", "PicklerModule"] - - -@runtime_checkable -class Pickler(Protocol): # noqa: D101 - _extra_reducers: ClassVar[dict[type[Any], Callable[[Any], Any]]] - _copyreg_dispatch_table: ClassVar[dict[type[Any], Callable[[Any], Any]]] - - @classmethod - def register( - cls, - type: type[ValueT], # noqa: A002 - reduce: Callable[[ValueT], Any], - ) -> None: - """Register a reduce function for a type.""" - - @classmethod - def dumps( # noqa: D102 - cls, obj: Any, protocol: int | None = None - ) -> memoryview: ... - - @classmethod - def loadbuf( # noqa: D102 - cls, buf: io.BytesIO, protocol: int | None = None - ) -> Any: ... - - loads: Callable[..., Any] - - -Monkey = Callable[[str, Type[Pickler]], None] -UnMonkey = Callable[[], None] diff --git a/src/timeout_executor/serde/lock.py b/src/timeout_executor/serde/lock.py deleted file mode 100644 index 7b5a8b6..0000000 --- a/src/timeout_executor/serde/lock.py +++ /dev/null @@ -1,7 +0,0 @@ -from __future__ import annotations - -from threading import RLock - -__all__ = ["patch_lock"] - -patch_lock = RLock() diff --git a/src/timeout_executor/serde/main.py b/src/timeout_executor/serde/main.py deleted file mode 100644 index 18e726e..0000000 --- a/src/timeout_executor/serde/main.py +++ /dev/null @@ -1,100 +0,0 @@ -from __future__ import annotations - -from importlib import import_module -from importlib.util import find_spec -from typing import TYPE_CHECKING, Literal - -from timeout_executor.exception import ImportErrors -from timeout_executor.log import logger - -if TYPE_CHECKING: - from timeout_executor.concurrent.main import BackendType - from timeout_executor.serde.base import BackendModule, PicklerModule - -__all__ = ["monkey_patch"] - -PicklerType = Literal["pickle", "dill", "cloudpickle"] - - -def monkey_patch(backend: BackendType, pickler: PicklerType | None) -> None: - """monkey patch or unpatch""" - backend_module = _import_backend(backend) - pickler, pickler_module = _try_import_pickler(backend, backend_module, pickler) - if pickler_module is None: - logger.debug("backend: %r, %r will be set to the default.", backend, pickler) - logger.debug("backend: %r: unpatch", backend) - backend_module.monkey_unpatch() - return - logger.debug("backend: %r, pickler: %r: patch", backend, pickler) - backend_module.monkey_patch(pickler, pickler_module.Pickler) - - -def _import_backend(backend: BackendType) -> BackendModule: - name = f".backend._{backend}" - spec = find_spec(name, __package__) - if spec is None: - error_msg = f"invalid backend: {backend}" - raise ImportError(error_msg) - return import_module(name, __package__) # type: ignore - - -def _import_pickler(pickler: PicklerType) -> PicklerModule: - name = f".pickler._{pickler}" - spec = find_spec(name, __package__) - if spec is None: - error_msg = f"invalid pickler: {pickler}" - raise ImportError(error_msg) - return import_module(name, __package__) # type: ignore - - -def _validate_pickler( - backend_name: BackendType, backend: BackendModule, pickler: PicklerType | None -) -> PicklerType: - if not pickler: - logger.debug( - "backend: %r, pickler is not specified. use default: %r.", - backend_name, - backend.order[0], - ) - pickler = backend.order[0] - if pickler in backend.replace: - logger.debug( - "backend: %r, %r is replaced by %r.", - backend_name, - pickler, - backend.replace[pickler], - ) - pickler = backend.replace[pickler] - return pickler - - -def _try_import_pickler( - backend_name: BackendType, backend: BackendModule, pickler: PicklerType | None -) -> tuple[PicklerType, PicklerModule | None]: - pickler = _validate_pickler(backend_name, backend, pickler) - if pickler in backend.unpatch: - return pickler, None - - try: - pickler_idx = backend.order.index(pickler) - except ValueError: - error_msg = f"invalid pickler: {pickler}" - raise ImportError(error_msg) # noqa: TRY200,B904 - - pickler_queue: tuple[PicklerType, ...] - if pickler_idx + 1 < len(backend.order): - pickler_queue = backend.order[pickler_idx + 1 :] - else: - pickler_queue = () - - errors: tuple[ImportError, ...] = () - for sub_pickler in (pickler, *pickler_queue): - try: - pickler_module = _import_pickler(sub_pickler) - except ImportError as exc: # noqa: PERF203 - errors = (*errors, exc) - else: - return sub_pickler, pickler_module - - error_msg = "failed import pickler modules" - raise ImportErrors(error_msg, errors) diff --git a/src/timeout_executor/serde/pickler/__init__.py b/src/timeout_executor/serde/pickler/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/timeout_executor/serde/pickler/_cloudpickle.py b/src/timeout_executor/serde/pickler/_cloudpickle.py deleted file mode 100644 index 31ebfb1..0000000 --- a/src/timeout_executor/serde/pickler/_cloudpickle.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -import copyreg -import io -from typing import Any, Callable, ClassVar - -from typing_extensions import TypeVar - -from timeout_executor.exception import ExtraError -from timeout_executor.serde.base import Pickler as BasePickler - -try: - import cloudpickle # type: ignore -except ImportError as exc: - error = ExtraError.from_import_error(exc, extra="cloudpickle") - raise error from exc - -ValueT = TypeVar("ValueT", infer_variance=True) - -__all__ = ["Pickler"] - - -class Pickler(cloudpickle.Pickler): - _extra_reducers: ClassVar[dict[type[Any], Callable[[Any], Any]]] = {} - _copyreg_dispatch_table = copyreg.dispatch_table - - def __init__(self, *args: Any) -> None: - super().__init__(*args) - self.dispatch_table = self._copyreg_dispatch_table.copy() - self.dispatch_table.update(self._extra_reducers) - - @classmethod - def register( - cls, - type: type[ValueT], # noqa: A002 - reduce: Callable[[ValueT], Any], - ) -> None: - """Register a reduce function for a type.""" - cls._extra_reducers[type] = reduce - - @classmethod - def dumps(cls, obj: Any, protocol: int | None = None) -> memoryview: - buf = io.BytesIO() - cls(buf, protocol).dump(obj) - return buf.getbuffer() - - @classmethod - def loadbuf( - cls, - buf: io.BytesIO, - protocol: int | None = None, # noqa: ARG003 - ) -> Any: - return cls.loads(buf.getbuffer()) # type: ignore - - loads = cloudpickle.loads - - -if not isinstance(Pickler, BasePickler): - error_msg = f"{__name__}.Pickler is not Pickler type" - raise TypeError(error_msg) diff --git a/src/timeout_executor/serde/pickler/_dill.py b/src/timeout_executor/serde/pickler/_dill.py deleted file mode 100644 index 1ce8715..0000000 --- a/src/timeout_executor/serde/pickler/_dill.py +++ /dev/null @@ -1,60 +0,0 @@ -from __future__ import annotations - -import copyreg -import io -from typing import Any, Callable, ClassVar - -from typing_extensions import TypeVar - -from timeout_executor.exception import ExtraError -from timeout_executor.serde.base import Pickler as BasePickler - -try: - import dill # type: ignore -except ImportError as exc: - error = ExtraError.from_import_error(exc, extra="dill") - raise error from exc - -ValueT = TypeVar("ValueT", infer_variance=True) - -__all__ = ["Pickler"] - - -class Pickler(dill.Pickler): - _extra_reducers: ClassVar[dict[type[Any], Callable[[Any], Any]]] = {} - _copyreg_dispatch_table = copyreg.dispatch_table - - def __init__(self, *args: Any) -> None: - super().__init__(*args) - self.dispatch_table = self._copyreg_dispatch_table.copy() - self.dispatch_table.update(self._extra_reducers) - - @classmethod - def register( - cls, - type: type[ValueT], # noqa: A002 - reduce: Callable[[ValueT], Any], - ) -> None: - """Register a reduce function for a type.""" - cls._extra_reducers[type] = reduce - - @classmethod - def dumps(cls, obj: Any, protocol: int | None = None) -> memoryview: - buf = io.BytesIO() - cls(buf, protocol).dump(obj) - return buf.getbuffer() - - @classmethod - def loadbuf( - cls, - buf: io.BytesIO, - protocol: int | None = None, # noqa: ARG003 - ) -> Any: - return cls.loads(buf.getbuffer()) - - loads = dill.loads - - -if not isinstance(Pickler, BasePickler): - error_msg = f"{__name__}.Pickler is not Pickler type" - raise TypeError(error_msg) diff --git a/src/timeout_executor/subprocess.py b/src/timeout_executor/subprocess.py new file mode 100644 index 0000000..d604f27 --- /dev/null +++ b/src/timeout_executor/subprocess.py @@ -0,0 +1,103 @@ +"""only using in subprocess""" + +from __future__ import annotations + +from functools import wraps +from inspect import iscoroutinefunction +from os import environ +from pathlib import Path +from typing import Any, Callable, Coroutine + +import anyio +import cloudpickle +from async_wrapper import async_to_sync +from typing_extensions import ParamSpec, TypeVar + +from timeout_executor.const import TIMEOUT_EXECUTOR_INPUT_FILE +from timeout_executor.serde import dumps_error + +__all__ = [] + +P = ParamSpec("P") +T = TypeVar("T", infer_variance=True) + + +def run_in_subprocess() -> None: + input_file = Path(environ.get(TIMEOUT_EXECUTOR_INPUT_FILE, "")) + with input_file.open("rb") as file_io: + func, args, kwargs, output_file = cloudpickle.load(file_io) + new_func = output_to_file(output_file)(func) + + if iscoroutinefunction(new_func): + new_func = async_to_sync(new_func) + + new_func(*args, **kwargs) + + +def output_to_file( + file: Path | anyio.Path, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + def wrapper(func: Callable[P, Any]) -> Callable[P, Any]: + if iscoroutinefunction(func): + return _output_to_file_async(file)(func) + return _output_to_file_sync(file)(func) + + return wrapper + + +def _output_to_file_sync( + file: Path | anyio.Path, +) -> Callable[[Callable[P, T]], Callable[P, T]]: + if isinstance(file, anyio.Path): + file = file._path # noqa: SLF001 + + def wrapper(func: Callable[P, T]) -> Callable[P, T]: + @wraps(func) + def inner(*args: P.args, **kwargs: P.kwargs) -> T: + dump = b"" + try: + result = func(*args, **kwargs) + except Exception as exc: + dump = dumps_error(exc) + raise + else: + dump = cloudpickle.dumps(result) + return result + finally: + with file.open("wb+") as file_io: + file_io.write(dump) + + return inner + + return wrapper + + +def _output_to_file_async( + file: Path | anyio.Path, +) -> Callable[ + [Callable[P, Coroutine[Any, Any, T]]], Callable[P, Coroutine[Any, Any, T]] +]: + if isinstance(file, Path): + file = anyio.Path(file) + + def wrapper( + func: Callable[P, Coroutine[Any, Any, T]], + ) -> Callable[P, Coroutine[Any, Any, T]]: + @wraps(func) + async def inner(*args: P.args, **kwargs: P.kwargs) -> T: + dump = b"" + try: + result = await func(*args, **kwargs) + except Exception as exc: + dump = dumps_error(exc) + raise + else: + dump = cloudpickle.dumps(result) + return result + finally: + async with await file.open("wb+") as file_io: + await file_io.write(dump) + + return inner + + return wrapper diff --git a/tests/conftest.py b/tests/conftest.py index c788583..c9df477 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,18 @@ from __future__ import annotations -import asyncio +from typing import Any import pytest -@pytest.fixture(scope="session") -def event_loop(): - policy = asyncio.get_event_loop_policy() - loop = policy.new_event_loop() - yield loop - loop.close() +@pytest.fixture( + params=[ + pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"), + pytest.param(("asyncio", {"use_uvloop": True}), id="asyncio-uvloop"), + pytest.param( + ("trio", {"restrict_keyboard_interrupt_to_checkpoints": True}), id="trio" + ), + ] +) +def anyio_backend(request) -> tuple[str, dict[str, Any]]: + return request.param diff --git a/tests/test_executor.py b/tests/test_executor.py index 7b3c374..4900828 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1,24 +1,14 @@ from __future__ import annotations import asyncio -import sys import time -from collections import deque -from functools import partial -from itertools import combinations, product -from pickle import PicklingError -from typing import Any, Callable +from itertools import product +from typing import Any import anyio import pytest -from anyio.abc import ObjectSendStream -from timeout_executor import TimeoutExecutor -from timeout_executor.concurrent.main import BackendType -from timeout_executor.serde.main import PicklerType - -if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup # type: ignore +from timeout_executor import AsyncResult, TimeoutExecutor TEST_SIZE = 3 @@ -28,6 +18,8 @@ class TestExecutorSync: def test_apply_args(self, x: int, y: int): executor = TimeoutExecutor(1) result = executor.apply(sample_func, x, y) + assert isinstance(result, AsyncResult) + result = result.result() assert isinstance(result, tuple) assert len(result) == 2 assert not result[1] @@ -38,56 +30,50 @@ def test_apply_args(self, x: int, y: int): def test_apply_kwargs(self, *, x: int, y: int): executor = TimeoutExecutor(1) result = executor.apply(sample_func, x=x, y=y) + assert isinstance(result, AsyncResult) + result = result.result() assert isinstance(result, tuple) assert len(result) == 2 assert not result[0] assert isinstance(result[1], dict) assert result[1] == {"x": x, "y": y} - @pytest.mark.parametrize("x", range(TEST_SIZE)) - def test_apply_init(self, x: int): + def test_apply_timeout(self): executor = TimeoutExecutor(1) - name = f"x_{x}" - alter_left, alter_right = f"x_{x - 1}", f"x_{x + 1}" - executor.set_init(sample_init_set, **{name: x}) - result = executor.apply(sample_init_get, name, alter_left, alter_right) - assert isinstance(result, dict) - assert result - assert result.get(name) == str(x) - assert result.get(alter_left) is None - assert result.get(alter_right) is None + result = executor.apply(time.sleep, 1.5) + assert isinstance(result, AsyncResult) + pytest.raises(TimeoutError, result.result) - def test_apply_timeout(self): + @pytest.mark.parametrize("x", range(TEST_SIZE)) + def test_apply_lambda(self, x: int): executor = TimeoutExecutor(1) - executor.apply(time.sleep, 0.5) - pytest.raises(TimeoutError, executor.apply, time.sleep, 1.5) - - @pytest.mark.parametrize( - ("backend", "pickler", "x"), - product( # pyright: ignore[reportCallIssue] - ("billiard", "multiprocessing", "loky"), - ("dill", "cloudpickle"), - range(TEST_SIZE), # pyright: ignore[reportArgumentType] - ), - ) - def test_apply_lambda(self, backend: BackendType, pickler: PicklerType, x: int): - executor = TimeoutExecutor(1, backend, pickler=pickler) result = executor.apply(lambda: x) + assert isinstance(result, AsyncResult) + result = result.result() assert isinstance(result, int) assert result == x @pytest.mark.parametrize("x", range(TEST_SIZE)) def test_apply_lambda_error(self, x: int): - executor = TimeoutExecutor(1, backend="multiprocessing", pickler="pickle") - pytest.raises((PicklingError, AttributeError), executor.apply, lambda: x) + executor = TimeoutExecutor(1) + + def temp_func(x: int) -> None: + raise RuntimeError(x) + lambda_func = lambda: temp_func(x) # noqa: E731 + result = executor.apply(lambda_func) + assert isinstance(result, AsyncResult) + pytest.raises(RuntimeError, result.result) -@pytest.mark.asyncio() + +@pytest.mark.anyio() class TestExecutorAsync: @pytest.mark.parametrize(("x", "y"), product(range(TEST_SIZE), range(TEST_SIZE))) async def test_apply_args(self, x: int, y: int): executor = TimeoutExecutor(1) - result = await executor.apply_async(sample_async_func, x, y) + result = await executor.delay(sample_async_func, x, y) + assert isinstance(result, AsyncResult) + result = await result.delay() assert isinstance(result, tuple) assert len(result) == 2 assert not result[1] @@ -97,116 +83,51 @@ async def test_apply_args(self, x: int, y: int): @pytest.mark.parametrize(("x", "y"), product(range(TEST_SIZE), range(TEST_SIZE))) async def test_apply_kwargs(self, *, x: int, y: int): executor = TimeoutExecutor(1) - result = await executor.apply_async(sample_async_func, x=x, y=y) + result = await executor.delay(sample_async_func, x=x, y=y) + assert isinstance(result, AsyncResult) + result = await result.delay() assert isinstance(result, tuple) assert len(result) == 2 assert not result[0] assert isinstance(result[1], dict) assert result[1] == {"x": x, "y": y} - @pytest.mark.parametrize(("x", "y"), combinations(range(TEST_SIZE * 2), 2)) - async def test_apply_gather(self, *, x: int, y: int): - executor = TimeoutExecutor(1) - - send, recv = anyio.create_memory_object_stream() - result = deque() - with executor.lock: - async with anyio.create_task_group() as task_group: - async with send: - task_group.start_soon( - send_result, - send.clone(), - executor.apply_async, - partial(sample_async_func, y=y), - x, - ) - task_group.start_soon( - send_result, - send.clone(), - executor.apply_async, - partial(sample_async_func, x=x), - y, - ) - async with recv: - async for value in recv: - result.append(value) - - assert len(result) == 2 - for value in result: - assert isinstance(value, tuple) - assert len(value) == 2 - assert isinstance(value[0], tuple) - assert isinstance(value[1], dict) - assert value[0] - assert value[1] - if value[0][0] == x: - assert "y" in value[1] - assert value[1]["y"] == y - else: - assert "x" in value[1] - assert value[1]["x"] == x - - @pytest.mark.parametrize("x", range(TEST_SIZE)) - async def test_apply_init(self, x: int): - executor = TimeoutExecutor(1) - name = f"x_{x}" - alter_left, alter_right = f"x_{x - 1}", f"x_{x + 1}" - executor.set_init(sample_init_set, **{name: x}) - result = await executor.apply_async( - sample_init_async_get, name, alter_left, alter_right - ) - assert isinstance(result, dict) - assert result - assert result.get(name) == str(x) - assert result.get(alter_left) is None - assert result.get(alter_right) is None - async def test_apply_timeout(self): executor = TimeoutExecutor(1) - await executor.apply_async(asyncio.sleep, 0.5) + result = await executor.delay(anyio.sleep, 1.5) try: - await executor.apply_async(asyncio.sleep, 1.5) - except ExceptionGroup as exc_group: - for exc in exc_group.exceptions: - assert isinstance(exc, TimeoutError) + await result.delay(0.1) except Exception as exc: # noqa: BLE001 assert isinstance(exc, TimeoutError) # noqa: PT017 else: raise Exception("TimeoutError does not occur") # noqa: TRY002 - @pytest.mark.parametrize( - ("backend", "pickler", "x"), - product( # pyright: ignore[reportCallIssue] - ("billiard", "multiprocessing", "loky"), - ("dill", "cloudpickle"), - range(TEST_SIZE), # pyright: ignore[reportArgumentType] - ), - ) - async def test_apply_lambda( - self, backend: BackendType, pickler: PicklerType, x: int - ): - executor = TimeoutExecutor(1, backend, pickler=pickler) + @pytest.mark.parametrize("x", range(TEST_SIZE)) + async def test_apply_lambda(self, x: int): + executor = TimeoutExecutor(1) async def lambdalike() -> int: await asyncio.sleep(0.1) return x - result = await executor.apply_async(lambdalike) + result = await executor.delay(lambdalike) + assert isinstance(result, AsyncResult) + result = await result.delay() assert isinstance(result, int) assert result == x - @pytest.mark.parametrize("x", range(TEST_SIZE)) - async def test_apply_lambda_error(self, x: int): - executor = TimeoutExecutor(1, backend="multiprocessing", pickler="pickle") + async def test_apply_lambda_error(self): + executor = TimeoutExecutor(1) async def lambdalike() -> int: - await asyncio.sleep(0.1) - return x + await asyncio.sleep(10) + raise RuntimeError("error") + result = await executor.delay(lambdalike) try: - await executor.apply_async(lambdalike) + await result.delay(0.1) except Exception as exc: # noqa: BLE001 - assert isinstance(exc, (PicklingError, AttributeError)) # noqa: PT017 + assert isinstance(exc, TimeoutError) # noqa: PT017 else: raise Exception("PicklingError does not occur") # noqa: TRY002 @@ -220,29 +141,3 @@ async def sample_async_func( ) -> tuple[tuple[Any, ...], dict[str, Any]]: await asyncio.sleep(0.1) return sample_func(*args, **kwargs) - - -def sample_init_set(**kwargs: Any) -> None: - from os import environ - - for key, value in kwargs.items(): - environ.setdefault(key, str(value)) - - -def sample_init_get(*names: str) -> dict[str, Any]: - from os import environ - - return {name: environ.get(name, None) for name in names} - - -async def sample_init_async_get(*names: str) -> dict[str, Any]: - await asyncio.sleep(0.1) - return sample_init_get(*names) - - -async def send_result( - stream: ObjectSendStream[Any], func: Callable[..., Any], *args: Any, **kwargs: Any -) -> None: - async with stream: - result = await func(*args, **kwargs) - await stream.send(result)