Skip to content

Commit

Permalink
Merge pull request #14 from realratchet/main
Browse files Browse the repository at this point in the history
Re-raise child exceptions
  • Loading branch information
realratchet authored Mar 27, 2024
2 parents 873e0d9 + 84815c6 commit fbd0eed
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 34 deletions.
146 changes: 113 additions & 33 deletions mplite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@
from tqdm import tqdm as _tqdm
import queue
from itertools import count
from typing import Callable, Any, Union, Tuple
from typing import Callable, Any, Union, Tuple, Literal
from multiprocessing.context import BaseContext
import tblib.pickling_support as pklex

major, minor, patch = 1, 2, 7
major, minor, patch = 1, 3, 0
__version_info__ = (major, minor, patch)
__version__ = '.'.join(str(i) for i in __version_info__)
default_context = "spawn"

ERR_MODE_STR = "str"
ERR_MODE_EXCEPTION = "exception"


class Task(object):
task_id_counter = count(start=1)

def __init__(self, f, *args, **kwargs) -> None:

if not callable(f):
raise TypeError(f"{f} is not callable")
self.f = f
Expand All @@ -37,15 +42,8 @@ def __repr__(self) -> str:
return f"Task(f={self.f.__name__}, *{self.args}, **{self.kwargs})"

def execute(self):
try:
return self.f(*self.args, **self.kwargs)
except Exception as e:
f = io.StringIO()
traceback.print_exc(limit=3, file=f)
f.seek(0)
error = f.read()
f.close()
return error
return self.f(*self.args, **self.kwargs)


class TaskChain(object):
def __init__(self, task: Task, next_task: Callable[[Task, Any], Union[Task, "TaskChain"]] = None) -> None:
Expand Down Expand Up @@ -75,13 +73,13 @@ def resolve(self, result):

return task
raise StopIteration()

def __str__(self) -> str:
return repr(self)

def __repr__(self) -> str:
return f"TaskChain(f={self.task.f.__name__}, *{self.task.args}, **{self.task.kwargs}, is_last={self.next is None})"

def execute(self):
""" execute task chain synchronously """
t = self
Expand All @@ -96,7 +94,7 @@ def execute(self):


class Worker(object):
def __init__(self, ctx: BaseContext, name: str, tq: multiprocessing.Queue, rq: multiprocessing.Queue, init: Task):
def __init__(self, ctx: BaseContext, name: str, tq: multiprocessing.Queue, rq: multiprocessing.Queue, init: Task, error_mode: Literal["str", "exception"]):
"""
Worker class responsible for executing tasks in parallel, created by TaskManager.
Expand All @@ -112,21 +110,25 @@ def __init__(self, ctx: BaseContext, name: str, tq: multiprocessing.Queue, rq: m
Result queue
init: Task
Task executed when worker starts.
error_mode: 'str' | 'exception'
Which error mode to use, 'str' for legacy where exception is returned as string or 'exception' where exception is returned as pickled object.
"""
assert error_mode in (ERR_MODE_STR, ERR_MODE_EXCEPTION), f"Error mode must be in ('{ERR_MODE_STR}', '{ERR_MODE_EXCEPTION}'), got '{error_mode}'"
self.ctx = ctx
self.exit = ctx.Event()
self.tq = tq # workers task queue
self.rq = rq # workers result queue
self.init = init

self.err_mode = error_mode
self.process = ctx.Process(group=None, target=self.update, name=name, daemon=False)

def start(self):
self.process.start()

def is_alive(self):
return self.process.is_alive()

@property
def exitcode(self):
return self.process.exitcode
Expand All @@ -135,6 +137,8 @@ def update(self):
if self.init:
self.init.f(*self.init.args, **self.init.kwargs)

do_task = _do_task_exception_mode if self.err_mode == ERR_MODE_EXCEPTION else _do_task_str_mode

while True:
try:
task = self.tq.get_nowait()
Expand All @@ -147,23 +151,43 @@ def update(self):
break

elif isinstance(task, Task):
result = task.execute()
self.rq.put((task.id, result))
self.rq.put((task.id, do_task(task)))
else:
time.sleep(0.01)


class TaskManager(object):
def __init__(self, cpu_count: int = None, context=default_context, worker_init: Task = None) -> None:
def __init__(self, cpu_count: int = None, context=default_context, worker_init: Task = None, error_mode: Literal["str", "exception"] = ERR_MODE_STR) -> None:
"""
Class responsible for managing worker processes and tasks.
OPTIONAL
--------
cpu_count: int
Number of worker processes to use.
Default: {cpu core count}.
ctx: BaseContext
Process spawning context ForkContext/SpawnContext. Note: Windows cannot fork.
Default: "spawn"
worker_init: Task | None
Task executed when worker starts.
Default: None
error_mode: 'str' | 'exception'
Which error mode to use, 'str' for legacy where exception is returned as string or 'exception' where exception is returned as pickled object.
Default: 'str'
"""

assert error_mode in (ERR_MODE_STR, ERR_MODE_EXCEPTION), f"Error mode must be in ('{ERR_MODE_STR}', '{ERR_MODE_EXCEPTION}'), got '{error_mode}'"
assert worker_init is None or isinstance(worker_init, Task), "Init is not (None, type[Task])"

self._ctx = multiprocessing.get_context(context)
self._cpus = multiprocessing.cpu_count() if cpu_count is None else cpu_count
self.tq = self._ctx.Queue()
self.rq = self._ctx.Queue()
self.pool: list[Worker] = []
self._open_tasks = 0

assert worker_init is None or isinstance(worker_init, Task)
self._open_tasks: list[int] = []

self.error_mode = error_mode
self.worker_init = worker_init

def __enter__(self):
Expand All @@ -175,13 +199,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): # signature requires these, thou

def start(self):
for i in range(self._cpus): # create workers
worker = Worker(self._ctx, name=str(i), tq=self.tq, rq=self.rq, init=self.worker_init)
worker = Worker(self._ctx, name=str(i), tq=self.tq, rq=self.rq, init=self.worker_init, error_mode=self.error_mode)
self.pool.append(worker)
worker.start()
while not all(p.is_alive() for p in self.pool):
time.sleep(0.01)

def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm=None):
def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm = None):
"""
Execute tasks using mplite
Expand All @@ -207,7 +231,8 @@ def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm
if None is provided, progress bar will be created using tqdm callable provided by tqdm parameter.
"""
task_count = len(tasks)
self._open_tasks += task_count
tasks_running = [t.id for t in tasks]
self._open_tasks.extend(tasks_running)
task_indices: dict[int, Tuple[int, Union[Task, TaskChain]]] = {}

for i, t in enumerate(tasks):
Expand All @@ -217,14 +242,20 @@ def execute(self, tasks: "list[Union[Task, TaskChain]]", tqdm=_tqdm, pbar: _tqdm

if pbar is None:
""" if pbar object was not passed, create a new tqdm compatible object """
pbar = tqdm(total=self._open_tasks, unit='tasks')
pbar = tqdm(total=task_count, unit='tasks')

while self._open_tasks != 0:
while len(tasks_running) > 0:
try:
task_key, res = self.rq.get_nowait()
task_key, (success, res) = self.rq.get_nowait()

if not success and self.error_mode == ERR_MODE_EXCEPTION:
[self._open_tasks.remove(idx) for idx in tasks_running]
raise unpickle_exception(res)

idx, t = task_indices[task_key]
if isinstance(t, Task) or t.next is None:
self._open_tasks -= 1
self._open_tasks.remove(t.id)
tasks_running.remove(t.id)
results[idx] = res
pbar.update(1)
else:
Expand All @@ -248,21 +279,26 @@ def submit(self, task: Task):
""" permits asynchronous submission of tasks. """
if not isinstance(task, Task):
raise TypeError(f"expected mplite.Task, not {type(task)}")
self._open_tasks += 1
self._open_tasks.append(task.id)
self.tq.put(task)

def take(self):
""" permits asynchronous retrieval of results """
try:
_, result = self.rq.get_nowait()
self._open_tasks -= 1
task_id, (success, result) = self.rq.get_nowait()

self._open_tasks.remove(task_id)

if not success and self.error_mode == ERR_MODE_EXCEPTION:
raise unpickle_exception(result)

except queue.Empty:
result = None
return result

@property
def open_tasks(self):
return self._open_tasks
return len(self._open_tasks)

def stop(self):
for _ in range(self._cpus):
Expand All @@ -274,3 +310,47 @@ def stop(self):
_ = self.tq.get_nowait()
while not self.rq.empty:
_ = self.rq.get_nowait()


def pickle_exception(e: Exception):
if e.__traceback__ is not None:
tback = pklex.pickle_traceback(e.__traceback__)
e.__traceback__ = None
else:
tback = None

fn_ex, (ex_cls, ex_txt, ex_rsn, _, *others) = pklex.pickle_exception(e)

return fn_ex, (ex_cls, ex_txt, ex_rsn, tback, *others)


def unpickle_exception(e):
fn_ex, (ex_cls, ex_txt, ex_rsn, tback, *others) = e

if tback is not None:
fn_tback, args_tback = tback
tback = fn_tback(*args_tback)

return fn_ex(ex_cls, ex_txt, ex_rsn, tback, *others)


def _do_task_exception_mode(task: Task):
""" execute task in exception mode"""
try:
return True, task.execute()
except Exception as e:
return False, pickle_exception(e)


def _do_task_str_mode(task: Task):
""" execute task in legacy string mode """
try:
return True, task.execute()
except Exception:
f = io.StringIO()
traceback.print_exc(limit=3, file=f)
f.seek(0)
error = f.read()
f.close()

return False, error
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
tqdm>=4.63.0
tqdm>=4.63.0
tblib
21 changes: 21 additions & 0 deletions tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import signal
from mplite import TaskManager, Task, TaskChain
import time
import traceback
import random

def test_alpha():
Expand Down Expand Up @@ -204,5 +205,25 @@ def post_1(prev, res):

assert res == [3, 3, 3, 3, 3]

def task_exception(i):
if i == 4:
raise ValueError(f"my exception: {i}")

return i

def test_exception_mode():
tasks = [Task(task_exception, i) for i in range(10)]

with TaskManager(10, error_mode="exception") as tm:
try:
[k for k, *_ in tm.execute(tasks)]
assert False
except Exception as e:
assert tm.open_tasks == 0, "there should be no left-over tasks"
assert str(e) == "my exception: 4", "wrong exception"
assert isinstance(e, ValueError), "wrong exception type"
assert type(e.__traceback__).__name__ == "traceback", "not a traceback"
assert 'in task_exception\n raise ValueError(f"my exception: {i}")\n' in traceback.format_tb(e.__traceback__)[-1], "wrong callstack"

if __name__ == "__main__":
test_task_order()

0 comments on commit fbd0eed

Please sign in to comment.