diff --git a/aiida_workgraph/decorator.py b/aiida_workgraph/decorator.py index c5b9268d..750df1ba 100644 --- a/aiida_workgraph/decorator.py +++ b/aiida_workgraph/decorator.py @@ -245,8 +245,8 @@ def build_task_from_AiiDA( if not outputs else outputs ) - # get the source code of the function - tdata["executor"] = PickledFunction(executor).executor + # build executor from the function + tdata["executor"] = PickledFunction.build_executor(executor) # tdata["executor"]["type"] = tdata["task_type"] # print("kwargs: ", kwargs) # add built-in sockets @@ -495,9 +495,9 @@ def generate_tdata( "properties": properties, "inputs": _inputs, "outputs": task_outputs, - "executor": PickledFunction(func).executor, "catalog": catalog, } + tdata["executor"] = PickledFunction.build_executor(func) if additional_data: tdata.update(additional_data) return tdata diff --git a/aiida_workgraph/orm/function_data.py b/aiida_workgraph/orm/function_data.py index c63bc212..5e5da4d0 100644 --- a/aiida_workgraph/orm/function_data.py +++ b/aiida_workgraph/orm/function_data.py @@ -37,14 +37,18 @@ def metadata(self): "is_pickle": True, } - @property - def executor(self): + @classmethod + def build_executor(cls, func): """Return the executor for this node.""" - data = self.metadata - with self.base.repository.open(self.FILENAME, mode="rb") as f: - executor = f.read() - data["executor"] = executor - return data + import cloudpickle as pickle + + executor = { + "executor": pickle.dumps(func), + "type": "function", + "is_pickle": True, + } + executor.update(cls.inspect_function(func)) + return executor def set_attribute(self, value): """Set the contents of this node by pickling the provided function. @@ -52,7 +56,7 @@ def set_attribute(self, value): :param value: The Python function to pickle and store. """ # Serialize the function and extract metadata - serialized_data = self.serialize_function(value) + serialized_data = self.inspect_function(value) # Store relevant metadata self.base.attributes.set("function_name", serialized_data["function_name"]) @@ -68,7 +72,7 @@ def set_attribute(self, value): ) @classmethod - def serialize_function(cls, func: Callable) -> Dict[str, Any]: + def inspect_function(cls, func: Callable) -> Dict[str, Any]: """Serialize a function for storage or transmission.""" try: # we need save the source code explicitly, because in the case of jupyter notebook, @@ -99,7 +103,7 @@ def serialize_function(cls, func: Callable) -> Dict[str, Any]: for module, types in required_imports.items() ) except Exception as e: - print(f"Failed to serialize function {func.__name__}: {e}") + print(f"Failed to inspect function {func.__name__}: {e}") function_source_code = "" function_source_code_without_decorator = "" import_statements = ""