diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index ba0afcc6..8f1bb1a9 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -845,19 +845,34 @@ def are_childen_finished(self, name: str) -> tuple[bool, t.Any]: def run_error_handlers(self, task_name: str) -> None: """Run error handler.""" + from inspect import signature + node = self.get_task_state_info(task_name, "process") if not node or not node.exit_status: return for _, data in self.ctx._error_handlers.items(): if task_name in data["tasks"]: handler = data["handler"] + handler_sig = signature(handler) metadata = data["tasks"][task_name] if node.exit_code.status in metadata.get("exit_codes", []): self.report(f"Run error handler: {metadata}") metadata.setdefault("retry", 0) if metadata["retry"] < metadata["max_retries"]: - handler(self, task_name, **metadata.get("kwargs", {})) - metadata["retry"] += 1 + task = self.get_task(task_name) + try: + if "engine" in handler_sig.parameters: + msg = handler( + task, engine=self, **metadata.get("kwargs", {}) + ) + else: + msg = handler(task, **metadata.get("kwargs", {})) + self.update_task(task) + if msg: + self.report(msg) + metadata["retry"] += 1 + except Exception as e: + self.report(f"Error in running error handler: {e}") def is_workgraph_finished(self) -> bool: """Check if the workgraph is finished. diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 339435a0..c42fb1e9 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -457,7 +457,7 @@ def extend(self, wg: "WorkGraph", prefix: str = "") -> None: for link in wg.links: self.links.append(link) - def attach_error_handler(self, handler, name, tasks: dict = None) -> None: + def add_error_handler(self, handler, name, tasks: dict = None) -> None: """Attach an error handler to the workgraph.""" self.error_handlers[name] = {"handler": handler, "tasks": tasks} diff --git a/docs/source/howto/error_resistant.ipynb b/docs/source/howto/error_resistant.ipynb index 09bbb84b..500b25bd 100644 --- a/docs/source/howto/error_resistant.ipynb +++ b/docs/source/howto/error_resistant.ipynb @@ -212,7 +212,7 @@ "wg = WorkGraph(\"normal_graph\")\n", "wg.add_task(ArithmeticAddCalculation, name=\"add1\")\n", "# register error handler\n", - "wg.attach_error_handler(handle_negative_sum, name=\"handle_negative_sum\",\n", + "wg.add_error_handler(handle_negative_sum, name=\"handle_negative_sum\",\n", " tasks={\"add1\": {\"exit_codes\": [410],\n", " \"max_retries\": 5}\n", " })\n", @@ -303,7 +303,7 @@ "wg = WorkGraph(\"normal_graph\")\n", "wg.add_task(ArithmeticAddCalculation, name=\"add1\")\n", "# register error handler\n", - "wg.attach_error_handler(handle_negative_sum, name=\"handle_negative_sum\",\n", + "wg.add_error_handler(handle_negative_sum, name=\"handle_negative_sum\",\n", " tasks={\"add1\": {\"exit_codes\": [410],\n", " \"max_retries\": 5,\n", " \"kwargs\": {\"increment\": 1}}\n", diff --git a/tests/test_error_handler.py b/tests/test_error_handler.py index 7719fdf9..10d014df 100644 --- a/tests/test_error_handler.py +++ b/tests/test_error_handler.py @@ -1,5 +1,5 @@ import pytest -from aiida_workgraph import WorkGraph +from aiida_workgraph import WorkGraph, Task from aiida import orm from aiida.calculations.arithmetic.add import ArithmeticAddCalculation @@ -9,12 +9,10 @@ def test_error_handlers(add_code): """Test error handlers.""" from aiida.cmdline.utils.common import get_workchain_report - def handle_negative_sum(self, task_name: str, **kwargs): + def handle_negative_sum(task: Task): """Handle negative sum by resetting the task and changing the inputs. self is the WorkGraph instance, thus we can access the tasks and the context. """ - self.report("Run error handler: handle_negative_sum.") - task = self.get_task(task_name) # modify task inputs task.set( { @@ -22,11 +20,12 @@ def handle_negative_sum(self, task_name: str, **kwargs): "y": orm.Int(abs(task.inputs["y"].value)), } ) - self.update_task(task) + msg = "Run error handler: handle_negative_sum." + return msg wg = WorkGraph("restart_graph") wg.add_task(ArithmeticAddCalculation, name="add1") - wg.attach_error_handler( + wg.add_error_handler( handle_negative_sum, name="handle_negative_sum", tasks={"add1": {"exit_codes": [410], "max_retries": 5, "kwargs": {}}}, diff --git a/tests/test_python.py b/tests/test_python.py index c760666f..7a9828f1 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -1,5 +1,5 @@ import pytest -from aiida_workgraph import WorkGraph, task +from aiida_workgraph import WorkGraph, task, Task from typing import Any @@ -500,17 +500,14 @@ def add(x: array, y: array) -> array: return {"sum": sum, "exit_code": exit_code} return {"sum": sum} - def handle_negative_sum(self, task_name: str): + def handle_negative_sum(task: Task): """Handle the failure code 410 of the `add`. Simply make the inputs positive by taking the absolute value. """ - self.report("Run error handler: handle_negative_sum.") - # load the task from the WorkGraph engine - task = self.get_task(task_name) - # modify task inputs + task.set({"x": abs(task.inputs["x"].value), "y": abs(task.inputs["y"].value)}) - self.update_task(task) + return "Run error handler: handle_negative_sum." wg = WorkGraph("test_PythonJob") wg.add_task( @@ -522,7 +519,7 @@ def handle_negative_sum(self, task_name: str): code_label=python_executable_path, ) # register error handler - wg.attach_error_handler( + wg.add_error_handler( handle_negative_sum, name="handle_negative_sum", tasks={"add1": {"exit_codes": [410], "max_retries": 5}},