Skip to content

Commit

Permalink
Add error handlers to Task class
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Sep 13, 2024
1 parent e7b50b0 commit 89b05d4
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 47 deletions.
9 changes: 6 additions & 3 deletions aiida_workgraph/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from aiida.orm.nodes.process.workflow.workfunction import WorkFunctionNode
from aiida.engine.processes.ports import PortNamespace
from aiida_workgraph.task import Task
from aiida_workgraph.utils import build_executor
from aiida_workgraph.utils import build_callable
import inspect

task_types = {
Expand Down Expand Up @@ -248,7 +248,7 @@ def build_task_from_AiiDA(
# so I pickled the function here, but this is not necessary
# we need to update the node_graph to support the path and name of the function
tdata["identifier"] = tdata.pop("identifier", tdata["executor"].__name__)
tdata["executor"] = build_executor(executor)
tdata["executor"] = build_callable(executor)
tdata["executor"]["type"] = tdata["metadata"]["task_type"]
if tdata["metadata"]["task_type"].upper() in ["CALCFUNCTION", "WORKFUNCTION"]:
outputs = (
Expand Down Expand Up @@ -514,7 +514,7 @@ def generate_tdata(
"inputs": _inputs,
"outputs": task_outputs,
}
tdata["executor"] = build_executor(func)
tdata["executor"] = build_callable(func)
if additional_data:
tdata.update(additional_data)
return tdata
Expand All @@ -532,6 +532,7 @@ def decorator_task(
properties: Optional[List[Tuple[str, str]]] = None,
inputs: Optional[List[Tuple[str, str]]] = None,
outputs: Optional[List[Tuple[str, str]]] = None,
error_handlers: Optional[List[Dict[str, Any]]] = None,
catalog: str = "Others",
) -> Callable:
"""Generate a decorator that register a function as a task.
Expand Down Expand Up @@ -566,6 +567,7 @@ def decorator(func):
task_type,
)
task = create_task(tdata)
task._error_handlers = error_handlers
func.identifier = identifier
func.task = func.node = task
func.tdata = tdata
Expand Down Expand Up @@ -678,6 +680,7 @@ def decorator(func):
task_decorated, _ = build_pythonjob_task(func)
func.identifier = "PythonJob"
func.task = func.node = task_decorated
task_decorated._error_handlers = kwargs.get("error_handlers", [])

return func

Expand Down
4 changes: 2 additions & 2 deletions aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:
metadata = kwargs.pop("metadata", {})
metadata.update({"call_link_label": task["name"]})
# get the source code of the function
function_name = task["executor"]["function_name"]
function_name = task["executor"]["name"]
function_source_code = (
task["executor"]["import_statements"]
+ "\n"
+ task["executor"]["function_source_code_without_decorator"]
+ task["executor"]["source_code_without_decorator"]
)
# outputs
function_outputs = [
Expand Down
57 changes: 34 additions & 23 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,14 +498,13 @@ def setup(self) -> None:

def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None:
"""setup the workgraph in the context."""
import cloudpickle as pickle

self.ctx._tasks = wgdata["tasks"]
self.ctx._links = wgdata["links"]
self.ctx._connectivity = wgdata["connectivity"]
self.ctx._ctrl_links = wgdata["ctrl_links"]
self.ctx._workgraph = wgdata
self.ctx._error_handlers = pickle.loads(wgdata["error_handlers"])
self.ctx._error_handlers = wgdata["error_handlers"]

def read_wgdata_from_base(self) -> t.Dict[str, t.Any]:
"""Read workgraph data from base.extras."""
Expand All @@ -521,6 +520,7 @@ def read_wgdata_from_base(self) -> t.Dict[str, t.Any]:
if isinstance(prop["value"], PickledLocalFunction):
prop["value"] = prop["value"].value
wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"])
print("error_handlers:", wgdata["error_handlers"])
wgdata["context"] = deserialize_unsafe(wgdata["context"])
return wgdata

Expand Down Expand Up @@ -860,34 +860,45 @@ def are_childen_finished(self, name: str) -> tuple[bool, t.Any]:

def run_error_handlers(self, task_name: str) -> None:
"""Run error handler."""
from inspect import signature

node = self.get_task_state_info(task_name, "process")
if not node or not node.exit_status:
return
# error_handlers from the task
for _, data in self.ctx._tasks[task_name]["error_handlers"].items():
if node.exit_status in data.get("exit_codes", []):
handler = data["handler"]
self.run_error_handler(handler, data, task_name)
# error_handlers from the workgraph
for _, data in self.ctx._error_handlers.items():
if task_name in data["tasks"]:
if node.exit_code.status in data["tasks"].get(task_name, {}).get(
"exit_codes", []
):
handler = data["handler"]
handler_sig = signature(handler)
metadata = data["tasks"][task_name]
if node.exit_code.status in metadata.get("exit_codes", []):
self.report(f"Run error handler: {metadata}")
metadata.setdefault("retry", 0)
if metadata["retry"] < metadata["max_retries"]:
task = self.get_task(task_name)
try:
if "engine" in handler_sig.parameters:
msg = handler(
task, engine=self, **metadata.get("kwargs", {})
)
else:
msg = handler(task, **metadata.get("kwargs", {}))
self.update_task(task)
if msg:
self.report(msg)
metadata["retry"] += 1
except Exception as e:
self.report(f"Error in running error handler: {e}")
self.run_error_handler(handler, metadata, task_name)

def run_error_handler(self, handler: dict, metadata: dict, task_name: str) -> None:
from inspect import signature
from aiida_workgraph.utils import get_executor

handler, _ = get_executor(handler)
handler_sig = signature(handler)
metadata.setdefault("retry", 0)
self.report(f"Run error handler: {handler.__name__}")
if metadata["retry"] < metadata["max_retries"]:
task = self.get_task(task_name)
try:
if "engine" in handler_sig.parameters:
msg = handler(task, engine=self, **metadata.get("kwargs", {}))
else:
msg = handler(task, **metadata.get("kwargs", {}))
self.update_task(task)
if msg:
self.report(msg)
metadata["retry"] += 1
except Exception as e:
self.report(f"Error in running error handler: {e}")

def is_workgraph_finished(self) -> bool:
"""Check if the workgraph is finished.
Expand Down
26 changes: 12 additions & 14 deletions aiida_workgraph/orm/function_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ def __str__(self):
def metadata(self):
"""Return a dictionary of metadata."""
return {
"function_name": self.base.attributes.get("function_name"),
"name": self.base.attributes.get("name"),
"import_statements": self.base.attributes.get("import_statements"),
"function_source_code": self.base.attributes.get("function_source_code"),
"function_source_code_without_decorator": self.base.attributes.get(
"function_source_code_without_decorator"
"source_code": self.base.attributes.get("source_code"),
"source_code_without_decorator": self.base.attributes.get(
"source_code_without_decorator"
),
"type": "function",
"is_pickle": True,
}

@classmethod
def build_executor(cls, func):
def build_callable(cls, func):
"""Return the executor for this node."""
import cloudpickle as pickle

Expand All @@ -59,16 +59,14 @@ def set_attribute(self, value):
serialized_data = self.inspect_function(value)

# Store relevant metadata
self.base.attributes.set("function_name", serialized_data["function_name"])
self.base.attributes.set("name", serialized_data["name"])
self.base.attributes.set(
"import_statements", serialized_data["import_statements"]
)
self.base.attributes.set("source_code", serialized_data["source_code"])
self.base.attributes.set(
"function_source_code", serialized_data["function_source_code"]
)
self.base.attributes.set(
"function_source_code_without_decorator",
serialized_data["function_source_code_without_decorator"],
"source_code_without_decorator",
serialized_data["source_code_without_decorator"],
)

@classmethod
Expand Down Expand Up @@ -108,9 +106,9 @@ def inspect_function(cls, func: Callable) -> Dict[str, Any]:
function_source_code_without_decorator = ""
import_statements = ""
return {
"function_name": func.__name__,
"function_source_code": function_source_code,
"function_source_code_without_decorator": function_source_code_without_decorator,
"name": func.__name__,
"source_code": function_source_code,
"source_code_without_decorator": function_source_code_without_decorator,
"import_statements": import_statements,
}

Expand Down
24 changes: 24 additions & 0 deletions aiida_workgraph/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Task(GraphNode):
property_pool = property_pool
socket_pool = socket_pool
is_aiida_component = False
_error_handlers = None

def __init__(
self,
Expand Down Expand Up @@ -68,6 +69,7 @@ def to_dict(self) -> Dict[str, Any]:
tdata["process"] = serialize(self.process) if self.process else serialize(None)
tdata["metadata"]["pk"] = self.process.pk if self.process else None
tdata["metadata"]["is_aiida_component"] = self.is_aiida_component
tdata["error_handlers"] = self.get_error_handlers()

return tdata

Expand Down Expand Up @@ -123,13 +125,35 @@ def from_dict(cls, data: Dict[str, Any], task_pool: Optional[Any] = None) -> "Ta
if process and isinstance(process, str):
process = deserialize_unsafe(process)
task.process = process
task._error_handlers = data.get("error_handlers", [])

return task

def reset(self) -> None:
self.process = None
self.state = "PLANNED"

@property
def error_handlers(self) -> list:
return self.get_error_handlers()

def get_error_handlers(self) -> list:
"""Get the error handler function for this task."""
from aiida_workgraph.utils import build_callable

if self._error_handlers is None:
return {}

handlers = {}
if isinstance(self._error_handlers, dict):
for handler in self._error_handlers.values():
handler["handler"] = build_callable(handler["handler"])
elif isinstance(self._error_handlers, list):
for handler in self._error_handlers:
handler["handler"] = build_callable(handler["handler"])
handlers[handler["handler"]["name"]] = handler
return handlers

def _repr_mimebundle_(self, *args: Any, **kwargs: Any) -> any:

if self._widget is None:
Expand Down
6 changes: 4 additions & 2 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from aiida.engine.runners import Runner


def build_executor(obj: Callable) -> Dict[str, Any]:
def build_callable(obj: Callable) -> Dict[str, Any]:
"""
Build the executor data from the callable. This will either serialize the callable
using cloudpickle if it's a local or lambda function, or store its module and name
Expand All @@ -26,14 +26,16 @@ def build_executor(obj: Callable) -> Dict[str, Any]:
# Check if callable is nested (contains dots in __qualname__ after the first segment)
if obj.__module__ == "__main__" or "." in obj.__qualname__.split(".", 1)[-1]:
# Local or nested callable, so pickle the callable
executor = PickledFunction.build_executor(obj)
executor = PickledFunction.build_callable(obj)
else:
# Global callable (function/class), store its module and name for reference
executor = {
"module": obj.__module__,
"name": obj.__name__,
"is_pickle": False,
}
elif isinstance(obj, PickledFunction) or isinstance(obj, dict):
executor = obj
else:
raise TypeError("Provided object is not a callable function or class.")
return executor
Expand Down
10 changes: 7 additions & 3 deletions aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ def save_to_base(self, wgdata: Dict[str, Any]) -> None:
saver.save()

def to_dict(self, store_nodes=False) -> Dict[str, Any]:
import cloudpickle as pickle
from aiida_workgraph.utils import store_nodes_recursely
from aiida_workgraph.utils import store_nodes_recursely, build_callable

wgdata = super().to_dict()
# save the sequence and context
Expand All @@ -198,7 +197,12 @@ def to_dict(self, store_nodes=False) -> Dict[str, Any]:
"max_number_jobs": self.max_number_jobs,
}
)
wgdata["error_handlers"] = pickle.dumps(self.error_handlers)
# save error handlers
wgdata["error_handlers"] = {}
for name, error_handler in self.error_handlers.items():
print("error_handler:", error_handler)
error_handler["handler"] = build_callable(error_handler["handler"])
wgdata["error_handlers"][name] = error_handler
wgdata["tasks"] = wgdata.pop("nodes")
if store_nodes:
store_nodes_recursely(wgdata)
Expand Down

0 comments on commit 89b05d4

Please sign in to comment.