diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index 1ad8dfd0..a0b72f07 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -6,7 +6,7 @@ from aiida.orm.nodes.process.workflow.workfunction import WorkFunctionNode from aiida.engine.processes.ports import PortNamespace from aiida_workgraph.task import Task -from aiida_workgraph.utils import build_executor +from aiida_workgraph.utils import build_callable import inspect task_types = { @@ -248,7 +248,7 @@ def build_task_from_AiiDA( # so I pickled the function here, but this is not necessary # we need to update the node_graph to support the path and name of the function tdata["identifier"] = tdata.pop("identifier", tdata["executor"].__name__) - tdata["executor"] = build_executor(executor) + tdata["executor"] = build_callable(executor) tdata["executor"]["type"] = tdata["metadata"]["task_type"] if tdata["metadata"]["task_type"].upper() in ["CALCFUNCTION", "WORKFUNCTION"]: outputs = ( @@ -514,7 +514,7 @@ def generate_tdata( "inputs": _inputs, "outputs": task_outputs, } - tdata["executor"] = build_executor(func) + tdata["executor"] = build_callable(func) if additional_data: tdata.update(additional_data) return tdata @@ -532,6 +532,7 @@ def decorator_task( properties: Optional[List[Tuple[str, str]]] = None, inputs: Optional[List[Tuple[str, str]]] = None, outputs: Optional[List[Tuple[str, str]]] = None, + error_handlers: Optional[List[Dict[str, Any]]] = None, catalog: str = "Others", ) -> Callable: """Generate a decorator that register a function as a task. @@ -566,6 +567,7 @@ def decorator(func): task_type, ) task = create_task(tdata) + task._error_handlers = error_handlers func.identifier = identifier func.task = func.node = task func.tdata = tdata @@ -678,6 +680,7 @@ def decorator(func): task_decorated, _ = build_pythonjob_task(func) func.identifier = "PythonJob" func.task = func.node = task_decorated + task_decorated._error_handlers = kwargs.get("error_handlers", []) return func diff --git a/aiida_workgraph/engine/utils.py b/aiida_workgraph/engine/utils.py index d2c60723..34a6f7c4 100644 --- a/aiida_workgraph/engine/utils.py +++ b/aiida_workgraph/engine/utils.py @@ -92,11 +92,11 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict: metadata = kwargs.pop("metadata", {}) metadata.update({"call_link_label": task["name"]}) # get the source code of the function - function_name = task["executor"]["function_name"] + function_name = task["executor"]["name"] function_source_code = ( task["executor"]["import_statements"] + "\n" - + task["executor"]["function_source_code_without_decorator"] + + task["executor"]["source_code_without_decorator"] ) # outputs function_outputs = [ diff --git a/aiida_workgraph/engine/workgraph.py b/aiida_workgraph/engine/workgraph.py index 188d52b0..9c5728d4 100644 --- a/aiida_workgraph/engine/workgraph.py +++ b/aiida_workgraph/engine/workgraph.py @@ -498,14 +498,13 @@ def setup(self) -> None: def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None: """setup the workgraph in the context.""" - import cloudpickle as pickle self.ctx._tasks = wgdata["tasks"] self.ctx._links = wgdata["links"] self.ctx._connectivity = wgdata["connectivity"] self.ctx._ctrl_links = wgdata["ctrl_links"] self.ctx._workgraph = wgdata - self.ctx._error_handlers = pickle.loads(wgdata["error_handlers"]) + self.ctx._error_handlers = wgdata["error_handlers"] def read_wgdata_from_base(self) -> t.Dict[str, t.Any]: """Read workgraph data from base.extras.""" @@ -521,6 +520,7 @@ def read_wgdata_from_base(self) -> t.Dict[str, t.Any]: if isinstance(prop["value"], PickledLocalFunction): prop["value"] = prop["value"].value wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) + print("error_handlers:", wgdata["error_handlers"]) wgdata["context"] = deserialize_unsafe(wgdata["context"]) return wgdata @@ -860,34 +860,45 @@ def are_childen_finished(self, name: str) -> tuple[bool, t.Any]: def run_error_handlers(self, task_name: str) -> None: """Run error handler.""" - from inspect import signature node = self.get_task_state_info(task_name, "process") if not node or not node.exit_status: return + # error_handlers from the task + for _, data in self.ctx._tasks[task_name]["error_handlers"].items(): + if node.exit_status in data.get("exit_codes", []): + handler = data["handler"] + self.run_error_handler(handler, data, task_name) + # error_handlers from the workgraph for _, data in self.ctx._error_handlers.items(): - if task_name in data["tasks"]: + if node.exit_code.status in data["tasks"].get(task_name, {}).get( + "exit_codes", [] + ): handler = data["handler"] - handler_sig = signature(handler) metadata = data["tasks"][task_name] - if node.exit_code.status in metadata.get("exit_codes", []): - self.report(f"Run error handler: {metadata}") - metadata.setdefault("retry", 0) - if metadata["retry"] < metadata["max_retries"]: - task = self.get_task(task_name) - try: - if "engine" in handler_sig.parameters: - msg = handler( - task, engine=self, **metadata.get("kwargs", {}) - ) - else: - msg = handler(task, **metadata.get("kwargs", {})) - self.update_task(task) - if msg: - self.report(msg) - metadata["retry"] += 1 - except Exception as e: - self.report(f"Error in running error handler: {e}") + self.run_error_handler(handler, metadata, task_name) + + def run_error_handler(self, handler: dict, metadata: dict, task_name: str) -> None: + from inspect import signature + from aiida_workgraph.utils import get_executor + + handler, _ = get_executor(handler) + handler_sig = signature(handler) + metadata.setdefault("retry", 0) + self.report(f"Run error handler: {handler.__name__}") + if metadata["retry"] < metadata["max_retries"]: + task = self.get_task(task_name) + try: + if "engine" in handler_sig.parameters: + msg = handler(task, engine=self, **metadata.get("kwargs", {})) + else: + msg = handler(task, **metadata.get("kwargs", {})) + self.update_task(task) + if msg: + self.report(msg) + metadata["retry"] += 1 + except Exception as e: + self.report(f"Error in running error handler: {e}") def is_workgraph_finished(self) -> bool: """Check if the workgraph is finished. diff --git a/aiida_workgraph/orm/function_data.py b/aiida_workgraph/orm/function_data.py index 5e5da4d0..397e52f7 100644 --- a/aiida_workgraph/orm/function_data.py +++ b/aiida_workgraph/orm/function_data.py @@ -27,18 +27,18 @@ def __str__(self): def metadata(self): """Return a dictionary of metadata.""" return { - "function_name": self.base.attributes.get("function_name"), + "name": self.base.attributes.get("name"), "import_statements": self.base.attributes.get("import_statements"), - "function_source_code": self.base.attributes.get("function_source_code"), - "function_source_code_without_decorator": self.base.attributes.get( - "function_source_code_without_decorator" + "source_code": self.base.attributes.get("source_code"), + "source_code_without_decorator": self.base.attributes.get( + "source_code_without_decorator" ), "type": "function", "is_pickle": True, } @classmethod - def build_executor(cls, func): + def build_callable(cls, func): """Return the executor for this node.""" import cloudpickle as pickle @@ -59,16 +59,14 @@ def set_attribute(self, value): serialized_data = self.inspect_function(value) # Store relevant metadata - self.base.attributes.set("function_name", serialized_data["function_name"]) + self.base.attributes.set("name", serialized_data["name"]) self.base.attributes.set( "import_statements", serialized_data["import_statements"] ) + self.base.attributes.set("source_code", serialized_data["source_code"]) self.base.attributes.set( - "function_source_code", serialized_data["function_source_code"] - ) - self.base.attributes.set( - "function_source_code_without_decorator", - serialized_data["function_source_code_without_decorator"], + "source_code_without_decorator", + serialized_data["source_code_without_decorator"], ) @classmethod @@ -108,9 +106,9 @@ def inspect_function(cls, func: Callable) -> Dict[str, Any]: function_source_code_without_decorator = "" import_statements = "" return { - "function_name": func.__name__, - "function_source_code": function_source_code, - "function_source_code_without_decorator": function_source_code_without_decorator, + "name": func.__name__, + "source_code": function_source_code, + "source_code_without_decorator": function_source_code_without_decorator, "import_statements": import_statements, } diff --git a/aiida_workgraph/task.py b/aiida_workgraph/task.py index 5b8b761e..aa316da3 100644 --- a/aiida_workgraph/task.py +++ b/aiida_workgraph/task.py @@ -25,6 +25,7 @@ class Task(GraphNode): property_pool = property_pool socket_pool = socket_pool is_aiida_component = False + _error_handlers = None def __init__( self, @@ -68,6 +69,7 @@ def to_dict(self) -> Dict[str, Any]: tdata["process"] = serialize(self.process) if self.process else serialize(None) tdata["metadata"]["pk"] = self.process.pk if self.process else None tdata["metadata"]["is_aiida_component"] = self.is_aiida_component + tdata["error_handlers"] = self.get_error_handlers() return tdata @@ -123,6 +125,7 @@ def from_dict(cls, data: Dict[str, Any], task_pool: Optional[Any] = None) -> "Ta if process and isinstance(process, str): process = deserialize_unsafe(process) task.process = process + task._error_handlers = data.get("error_handlers", []) return task @@ -130,6 +133,27 @@ def reset(self) -> None: self.process = None self.state = "PLANNED" + @property + def error_handlers(self) -> list: + return self.get_error_handlers() + + def get_error_handlers(self) -> list: + """Get the error handler function for this task.""" + from aiida_workgraph.utils import build_callable + + if self._error_handlers is None: + return {} + + handlers = {} + if isinstance(self._error_handlers, dict): + for handler in self._error_handlers.values(): + handler["handler"] = build_callable(handler["handler"]) + elif isinstance(self._error_handlers, list): + for handler in self._error_handlers: + handler["handler"] = build_callable(handler["handler"]) + handlers[handler["handler"]["name"]] = handler + return handlers + def _repr_mimebundle_(self, *args: Any, **kwargs: Any) -> any: if self._widget is None: diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index 4de090f8..6bc09370 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -5,7 +5,7 @@ from aiida.engine.runners import Runner -def build_executor(obj: Callable) -> Dict[str, Any]: +def build_callable(obj: Callable) -> Dict[str, Any]: """ Build the executor data from the callable. This will either serialize the callable using cloudpickle if it's a local or lambda function, or store its module and name @@ -26,7 +26,7 @@ def build_executor(obj: Callable) -> Dict[str, Any]: # Check if callable is nested (contains dots in __qualname__ after the first segment) if obj.__module__ == "__main__" or "." in obj.__qualname__.split(".", 1)[-1]: # Local or nested callable, so pickle the callable - executor = PickledFunction.build_executor(obj) + executor = PickledFunction.build_callable(obj) else: # Global callable (function/class), store its module and name for reference executor = { @@ -34,6 +34,8 @@ def build_executor(obj: Callable) -> Dict[str, Any]: "name": obj.__name__, "is_pickle": False, } + elif isinstance(obj, PickledFunction) or isinstance(obj, dict): + executor = obj else: raise TypeError("Provided object is not a callable function or class.") return executor diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index e1620d4b..942eb61b 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -176,8 +176,7 @@ def save_to_base(self, wgdata: Dict[str, Any]) -> None: saver.save() def to_dict(self, store_nodes=False) -> Dict[str, Any]: - import cloudpickle as pickle - from aiida_workgraph.utils import store_nodes_recursely + from aiida_workgraph.utils import store_nodes_recursely, build_callable wgdata = super().to_dict() # save the sequence and context @@ -198,7 +197,12 @@ def to_dict(self, store_nodes=False) -> Dict[str, Any]: "max_number_jobs": self.max_number_jobs, } ) - wgdata["error_handlers"] = pickle.dumps(self.error_handlers) + # save error handlers + wgdata["error_handlers"] = {} + for name, error_handler in self.error_handlers.items(): + print("error_handler:", error_handler) + error_handler["handler"] = build_callable(error_handler["handler"]) + wgdata["error_handlers"][name] = error_handler wgdata["tasks"] = wgdata.pop("nodes") if store_nodes: store_nodes_recursely(wgdata)