Skip to content

Commit

Permalink
Change the error handler's signature
Browse files Browse the repository at this point in the history
1) pass task as args, and engine as kwargs
2) return msg to report
3) update the task in the engine
  • Loading branch information
superstar54 committed Sep 11, 2024
1 parent 3b6d616 commit df4e625
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 19 deletions.
19 changes: 17 additions & 2 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,19 +845,34 @@ def are_childen_finished(self, name: str) -> tuple[bool, t.Any]:

def run_error_handlers(self, task_name: str) -> None:
"""Run error handler."""
from inspect import signature

node = self.get_task_state_info(task_name, "process")
if not node or not node.exit_status:
return
for _, data in self.ctx._error_handlers.items():
if task_name in data["tasks"]:
handler = data["handler"]
handler_sig = signature(handler)
metadata = data["tasks"][task_name]
if node.exit_code.status in metadata.get("exit_codes", []):
self.report(f"Run error handler: {metadata}")
metadata.setdefault("retry", 0)
if metadata["retry"] < metadata["max_retries"]:
handler(self, task_name, **metadata.get("kwargs", {}))
metadata["retry"] += 1
task = self.get_task(task_name)
try:
if "engine" in handler_sig.parameters:
msg = handler(
task, engine=self, **metadata.get("kwargs", {})
)
else:
msg = handler(task, **metadata.get("kwargs", {}))
self.update_task(task)
if msg:
self.report(msg)
metadata["retry"] += 1
except Exception as e:
self.report(f"Error in running error handler: {e}")

def is_workgraph_finished(self) -> bool:
"""Check if the workgraph is finished.
Expand Down
2 changes: 1 addition & 1 deletion aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def extend(self, wg: "WorkGraph", prefix: str = "") -> None:
for link in wg.links:
self.links.append(link)

def attach_error_handler(self, handler, name, tasks: dict = None) -> None:
def add_error_handler(self, handler, name, tasks: dict = None) -> None:
"""Attach an error handler to the workgraph."""
self.error_handlers[name] = {"handler": handler, "tasks": tasks}

Expand Down
4 changes: 2 additions & 2 deletions docs/source/howto/error_resistant.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@
"wg = WorkGraph(\"normal_graph\")\n",
"wg.add_task(ArithmeticAddCalculation, name=\"add1\")\n",
"# register error handler\n",
"wg.attach_error_handler(handle_negative_sum, name=\"handle_negative_sum\",\n",
"wg.add_error_handler(handle_negative_sum, name=\"handle_negative_sum\",\n",
" tasks={\"add1\": {\"exit_codes\": [410],\n",
" \"max_retries\": 5}\n",
" })\n",
Expand Down Expand Up @@ -303,7 +303,7 @@
"wg = WorkGraph(\"normal_graph\")\n",
"wg.add_task(ArithmeticAddCalculation, name=\"add1\")\n",
"# register error handler\n",
"wg.attach_error_handler(handle_negative_sum, name=\"handle_negative_sum\",\n",
"wg.add_error_handler(handle_negative_sum, name=\"handle_negative_sum\",\n",
" tasks={\"add1\": {\"exit_codes\": [410],\n",
" \"max_retries\": 5,\n",
" \"kwargs\": {\"increment\": 1}}\n",
Expand Down
11 changes: 5 additions & 6 deletions tests/test_error_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from aiida_workgraph import WorkGraph
from aiida_workgraph import WorkGraph, Task
from aiida import orm
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation

Expand All @@ -9,24 +9,23 @@ def test_error_handlers(add_code):
"""Test error handlers."""
from aiida.cmdline.utils.common import get_workchain_report

def handle_negative_sum(self, task_name: str, **kwargs):
def handle_negative_sum(task: Task):
"""Handle negative sum by resetting the task and changing the inputs.
self is the WorkGraph instance, thus we can access the tasks and the context.
"""
self.report("Run error handler: handle_negative_sum.")
task = self.get_task(task_name)
# modify task inputs
task.set(
{
"x": orm.Int(abs(task.inputs["x"].value)),
"y": orm.Int(abs(task.inputs["y"].value)),
}
)
self.update_task(task)
msg = "Run error handler: handle_negative_sum."
return msg

wg = WorkGraph("restart_graph")
wg.add_task(ArithmeticAddCalculation, name="add1")
wg.attach_error_handler(
wg.add_error_handler(
handle_negative_sum,
name="handle_negative_sum",
tasks={"add1": {"exit_codes": [410], "max_retries": 5, "kwargs": {}}},
Expand Down
13 changes: 5 additions & 8 deletions tests/test_python.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from aiida_workgraph import WorkGraph, task
from aiida_workgraph import WorkGraph, task, Task
from typing import Any


Expand Down Expand Up @@ -500,17 +500,14 @@ def add(x: array, y: array) -> array:
return {"sum": sum, "exit_code": exit_code}
return {"sum": sum}

def handle_negative_sum(self, task_name: str):
def handle_negative_sum(task: Task):
"""Handle the failure code 410 of the `add`.
Simply make the inputs positive by taking the absolute value.
"""
self.report("Run error handler: handle_negative_sum.")
# load the task from the WorkGraph engine
task = self.get_task(task_name)
# modify task inputs

task.set({"x": abs(task.inputs["x"].value), "y": abs(task.inputs["y"].value)})

self.update_task(task)
return "Run error handler: handle_negative_sum."

wg = WorkGraph("test_PythonJob")
wg.add_task(
Expand All @@ -522,7 +519,7 @@ def handle_negative_sum(self, task_name: str):
code_label=python_executable_path,
)
# register error handler
wg.attach_error_handler(
wg.add_error_handler(
handle_negative_sum,
name="handle_negative_sum",
tasks={"add1": {"exit_codes": [410], "max_retries": 5}},
Expand Down

0 comments on commit df4e625

Please sign in to comment.