From 05f30f1f7de59cdb4d2228e13e5a844470e51a8d Mon Sep 17 00:00:00 2001 From: Ratchet Date: Wed, 27 Mar 2024 11:28:01 +0200 Subject: [PATCH 1/4] added 'exception' mode --- mplite/__init__.py | 146 +++++++++++++++++++++++++++++++++---------- requirements.txt | 3 +- tests/test_basics.py | 21 +++++++ 3 files changed, 136 insertions(+), 34 deletions(-) diff --git a/mplite/__init__.py b/mplite/__init__.py index 8f8bd9a..ca49468 100644 --- a/mplite/__init__.py +++ b/mplite/__init__.py @@ -6,19 +6,24 @@ from tqdm import tqdm as _tqdm import queue from itertools import count -from typing import Callable, Any, Union, Tuple +from typing import Callable, Any, Union, Tuple, Literal from multiprocessing.context import BaseContext +import tblib.pickling_support as pklex -major, minor, patch = 1, 2, 7 +major, minor, patch = 1, 3, 0 __version_info__ = (major, minor, patch) __version__ = '.'.join(str(i) for i in __version_info__) default_context = "spawn" +ERR_MODE_STR = "str" +ERR_MODE_EXCEPTION = "exception" + + class Task(object): task_id_counter = count(start=1) def __init__(self, f, *args, **kwargs) -> None: - + if not callable(f): raise TypeError(f"{f} is not callable") self.f = f @@ -37,15 +42,8 @@ def __repr__(self) -> str: return f"Task(f={self.f.__name__}, *{self.args}, **{self.kwargs})" def execute(self): - try: - return self.f(*self.args, **self.kwargs) - except Exception as e: - f = io.StringIO() - traceback.print_exc(limit=3, file=f) - f.seek(0) - error = f.read() - f.close() - return error + return self.f(*self.args, **self.kwargs) + class TaskChain(object): def __init__(self, task: Task, next_task: Callable[[Task, Any], Union[Task, "TaskChain"]] = None) -> None: @@ -75,13 +73,13 @@ def resolve(self, result): return task raise StopIteration() - + def __str__(self) -> str: return repr(self) def __repr__(self) -> str: return f"TaskChain(f={self.task.f.__name__}, *{self.task.args}, **{self.task.kwargs}, is_last={self.next is None})" - + def execute(self): """ execute task chain synchronously """ t = self @@ -96,7 +94,7 @@ def execute(self): class Worker(object): - def __init__(self, ctx: BaseContext, name: str, tq: multiprocessing.Queue, rq: multiprocessing.Queue, init: Task): + def __init__(self, ctx: BaseContext, name: str, tq: multiprocessing.Queue, rq: multiprocessing.Queue, init: Task, error_mode: Literal["str", "exception"]): """ Worker class responsible for executing tasks in parallel, created by TaskManager. @@ -112,13 +110,17 @@ def __init__(self, ctx: BaseContext, name: str, tq: multiprocessing.Queue, rq: m Result queue init: Task Task executed when worker starts. + error_mode: 'str' | 'exception' + Which error mode to use, 'str' for legacy where exception is returned as string or 'exception' where exception is returned as pickled object. """ + assert error_mode in (ERR_MODE_STR, ERR_MODE_EXCEPTION), f"Error mode must be in ('{ERR_MODE_STR}', '{ERR_MODE_EXCEPTION}'), got '{error_mode}'" self.ctx = ctx self.exit = ctx.Event() self.tq = tq # workers task queue self.rq = rq # workers result queue self.init = init + self.err_mode = error_mode self.process = ctx.Process(group=None, target=self.update, name=name, daemon=False) def start(self): @@ -126,7 +128,7 @@ def start(self): def is_alive(self): return self.process.is_alive() - + @property def exitcode(self): return self.process.exitcode @@ -135,6 +137,8 @@ def update(self): if self.init: self.init.f(*self.init.args, **self.init.kwargs) + do_task = _do_task_exception_mode if self.err_mode == ERR_MODE_EXCEPTION else _do_task_str_mode + while True: try: task = self.tq.get_nowait() @@ -147,23 +151,43 @@ def update(self): break elif isinstance(task, Task): - result = task.execute() - self.rq.put((task.id, result)) + self.rq.put((task.id, do_task(task))) else: time.sleep(0.01) class TaskManager(object): - def __init__(self, cpu_count: int = None, context=default_context, worker_init: Task = None) -> None: + def __init__(self, cpu_count: int = None, context=default_context, worker_init: Task = None, error_mode: Literal["str", "exception"] = ERR_MODE_STR) -> None: + """ + Class responsible for managing worker processes and tasks. + + OPTIONAL + -------- + cpu_count: int + Number of worker processes to use. + Default: {cpu core count}. + ctx: BaseContext + Process spawning context ForkContext/SpawnContext. Note: Windows cannot fork. + Default: "spawn" + worker_init: Task | None + Task executed when worker starts. + Default: None + error_mode: 'str' | 'exception' + Which error mode to use, 'str' for legacy where exception is returned as string or 'exception' where exception is returned as pickled object. + Default: 'str' + """ + + assert error_mode in (ERR_MODE_STR, ERR_MODE_EXCEPTION), f"Error mode must be in ('{ERR_MODE_STR}', '{ERR_MODE_EXCEPTION}'), got '{error_mode}'" + assert worker_init is None or isinstance(worker_init, Task), "Init is not (None, type[Task])" + self._ctx = multiprocessing.get_context(context) self._cpus = multiprocessing.cpu_count() if cpu_count is None else cpu_count self.tq = self._ctx.Queue() self.rq = self._ctx.Queue() self.pool: list[Worker] = [] - self._open_tasks = 0 - - assert worker_init is None or isinstance(worker_init, Task) + self._open_tasks: list[int] = [] + self.error_mode = error_mode self.worker_init = worker_init def __enter__(self): @@ -175,13 +199,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): # signature requires these, thou def start(self): for i in range(self._cpus): # create workers - worker = Worker(self._ctx, name=str(i), tq=self.tq, rq=self.rq, init=self.worker_init) + worker = Worker(self._ctx, name=str(i), tq=self.tq, rq=self.rq, init=self.worker_init, error_mode=self.error_mode) self.pool.append(worker) worker.start() while not all(p.is_alive() for p in self.pool): time.sleep(0.01) - def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm=None): + def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm = None): """ Execute tasks using mplite @@ -207,7 +231,8 @@ def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm if None is provided, progress bar will be created using tqdm callable provided by tqdm parameter. """ task_count = len(tasks) - self._open_tasks += task_count + tasks_running = [t.id for t in tasks] + self._open_tasks.extend(tasks_running) task_indices: dict[int, Tuple[int, Union[Task, TaskChain]]] = {} for i, t in enumerate(tasks): @@ -217,14 +242,20 @@ def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm if pbar is None: """ if pbar object was not passed, create a new tqdm compatible object """ - pbar = tqdm(total=self._open_tasks, unit='tasks') + pbar = tqdm(total=task_count, unit='tasks') - while self._open_tasks != 0: + while len(tasks_running) > 0: try: - task_key, res = self.rq.get_nowait() + task_key, (success, res) = self.rq.get_nowait() + + if not success and self.error_mode == ERR_MODE_EXCEPTION: + [self._open_tasks.remove(idx) for idx in tasks_running] + raise unpickle_exception(res) + idx, t = task_indices[task_key] if isinstance(t, Task) or t.next is None: - self._open_tasks -= 1 + self._open_tasks.remove(t.id) + tasks_running.remove(t.id) results[idx] = res pbar.update(1) else: @@ -248,21 +279,26 @@ def submit(self, task: Task): """ permits asynchronous submission of tasks. """ if not isinstance(task, Task): raise TypeError(f"expected mplite.Task, not {type(task)}") - self._open_tasks += 1 + self._open_tasks.append(task.id) self.tq.put(task) def take(self): """ permits asynchronous retrieval of results """ try: - _, result = self.rq.get_nowait() - self._open_tasks -= 1 + task_id, (success, result) = self.rq.get_nowait() + + self._open_tasks.remove(task_id) + + if not success and self.error_mode == ERR_MODE_EXCEPTION: + raise unpickle_exception(result) + except queue.Empty: result = None return result @property def open_tasks(self): - return self._open_tasks + return len(self._open_tasks) def stop(self): for _ in range(self._cpus): @@ -274,3 +310,47 @@ def stop(self): _ = self.tq.get_nowait() while not self.rq.empty: _ = self.rq.get_nowait() + + +def pickle_exception(e: Exception): + if e.__traceback__ is not None: + tback = pklex.pickle_traceback(e.__traceback__) + e.__traceback__ = None + else: + tback = None + + fn_ex, (ex_cls, ex_txt, ex_rsn, _) = pklex.pickle_exception(e) + + return fn_ex, (ex_cls, ex_txt, ex_rsn, tback) + + +def unpickle_exception(e): + fn_ex, (ex_cls, ex_txt, ex_rsn, tback) = e + + if tback is not None: + fn_tback, args_tback = tback + tback = fn_tback(*args_tback) + + return fn_ex(ex_cls, ex_txt, ex_rsn, tback) + + +def _do_task_exception_mode(task: Task): + """ execute task in exception mode""" + try: + return True, task.execute() + except Exception as e: + return False, pickle_exception(e) + + +def _do_task_str_mode(task: Task): + """ execute task in legacy string mode """ + try: + return True, task.execute() + except Exception: + f = io.StringIO() + traceback.print_exc(limit=3, file=f) + f.seek(0) + error = f.read() + f.close() + + return False, error diff --git a/requirements.txt b/requirements.txt index f7c2ebe..c540684 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ -tqdm>=4.63.0 \ No newline at end of file +tqdm>=4.63.0 +tblib \ No newline at end of file diff --git a/tests/test_basics.py b/tests/test_basics.py index f892049..ff6a063 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -3,6 +3,7 @@ import signal from mplite import TaskManager, Task, TaskChain import time +import traceback import random def test_alpha(): @@ -204,5 +205,25 @@ def post_1(prev, res): assert res == [3, 3, 3, 3, 3] +def task_exception(i): + if i == 4: + raise ValueError(f"my exception: {i}") + + return i + +def test_exception_mode(): + tasks = [Task(task_exception, i) for i in range(10)] + + with TaskManager(10, error_mode="exception") as tm: + try: + [k for k, *_ in tm.execute(tasks)] + assert False + except Exception as e: + assert tm.open_tasks == 0, "there should be no left-over tasks" + assert str(e) == "my exception: 4", "wrong exception" + assert isinstance(e, ValueError), "wrong exception type" + assert type(e.__traceback__).__name__ == "traceback", "not a traceback" + assert traceback.format_tb(e.__traceback__)[-1].endswith('in task_exception\n raise ValueError(f"my exception: {i}")\n'), "wrong callastack" + if __name__ == "__main__": test_task_order() \ No newline at end of file From ec8ec74120c607720b0ac7476d3f3536740fc577 Mon Sep 17 00:00:00 2001 From: Ratchet Date: Wed, 27 Mar 2024 11:44:11 +0200 Subject: [PATCH 2/4] why do pipelines fail? --- mplite/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mplite/__init__.py b/mplite/__init__.py index ca49468..8a646d0 100644 --- a/mplite/__init__.py +++ b/mplite/__init__.py @@ -313,13 +313,16 @@ def stop(self): def pickle_exception(e: Exception): + print(e) if e.__traceback__ is not None: tback = pklex.pickle_traceback(e.__traceback__) e.__traceback__ = None else: tback = None - fn_ex, (ex_cls, ex_txt, ex_rsn, _) = pklex.pickle_exception(e) + pkld = pklex.pickle_exception(e) + print(pkld) + fn_ex, (ex_cls, ex_txt, ex_rsn, _) = pkld return fn_ex, (ex_cls, ex_txt, ex_rsn, tback) From 144dd6edcbd03ca8bceb6ed50626345c21c35e08 Mon Sep 17 00:00:00 2001 From: Ratchet Date: Wed, 27 Mar 2024 11:46:55 +0200 Subject: [PATCH 3/4] ooh.. --- mplite/__init__.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/mplite/__init__.py b/mplite/__init__.py index 8a646d0..f9fb664 100644 --- a/mplite/__init__.py +++ b/mplite/__init__.py @@ -313,28 +313,25 @@ def stop(self): def pickle_exception(e: Exception): - print(e) if e.__traceback__ is not None: tback = pklex.pickle_traceback(e.__traceback__) e.__traceback__ = None else: tback = None - pkld = pklex.pickle_exception(e) - print(pkld) - fn_ex, (ex_cls, ex_txt, ex_rsn, _) = pkld + fn_ex, (ex_cls, ex_txt, ex_rsn, _, *others) = pklex.pickle_exception(e) - return fn_ex, (ex_cls, ex_txt, ex_rsn, tback) + return fn_ex, (ex_cls, ex_txt, ex_rsn, tback, *others) def unpickle_exception(e): - fn_ex, (ex_cls, ex_txt, ex_rsn, tback) = e + fn_ex, (ex_cls, ex_txt, ex_rsn, tback, *others) = e if tback is not None: fn_tback, args_tback = tback tback = fn_tback(*args_tback) - return fn_ex(ex_cls, ex_txt, ex_rsn, tback) + return fn_ex(ex_cls, ex_txt, ex_rsn, tback, *others) def _do_task_exception_mode(task: Task): From 84815c61f1350c8037fa09d8493b824d0e6cfa1e Mon Sep 17 00:00:00 2001 From: Ratchet Date: Wed, 27 Mar 2024 11:52:43 +0200 Subject: [PATCH 4/4] py 311 adds additional newline at the end --- tests/test_basics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_basics.py b/tests/test_basics.py index ff6a063..bfac634 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -223,7 +223,7 @@ def test_exception_mode(): assert str(e) == "my exception: 4", "wrong exception" assert isinstance(e, ValueError), "wrong exception type" assert type(e.__traceback__).__name__ == "traceback", "not a traceback" - assert traceback.format_tb(e.__traceback__)[-1].endswith('in task_exception\n raise ValueError(f"my exception: {i}")\n'), "wrong callastack" + assert 'in task_exception\n raise ValueError(f"my exception: {i}")\n' in traceback.format_tb(e.__traceback__)[-1], "wrong callstack" if __name__ == "__main__": test_task_order() \ No newline at end of file