Skip to content

Commit

Permalink
only serialize the pythonjob data when launching
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Sep 13, 2024
1 parent 10d25b6 commit 49df54e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 42 deletions.
59 changes: 17 additions & 42 deletions aiida_workgraph/tasks/pythonjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,27 @@ class PythonJob(Task):

def update_from_dict(self, data: Dict[str, Any], **kwargs) -> "PythonJob":
"""Overwrite the update_from_dict method to handle the PythonJob data."""
self.deserialize_pythonjob_data(data)
self.function_kwargs = data.get("function_kwargs", [])
self.deserialize_pythonjob_data(data)
super().update_from_dict(data)

def to_dict(self) -> Dict[str, Any]:
data = super().to_dict()
data["function_kwargs"] = self.function_kwargs
self.serialize_pythonjob_data(data)
return data

def serialize_pythonjob_data(self, tdata: Dict[str, Any]):
@classmethod
def serialize_pythonjob_data(cls, tdata: Dict[str, Any]):
"""Serialize the properties for PythonJob."""

input_kwargs = tdata.get("function_kwargs", [])
for name in input_kwargs:
tdata["inputs"][name]["property"]["value"] = self.serialize_socket_data(
tdata["inputs"][name]["property"]["value"] = cls.serialize_socket_data(
tdata["inputs"][name]
)

def deserialize_pythonjob_data(self, tdata: Dict[str, Any]) -> None:
@classmethod
def deserialize_pythonjob_data(cls, tdata: Dict[str, Any]) -> None:
"""
Process the task data dictionary for a PythonJob.
It load the orignal Python data from the AiiDA Data node for the
Expand All @@ -50,32 +51,17 @@ def deserialize_pythonjob_data(self, tdata: Dict[str, Any]) -> None:
if name in tdata["inputs"]:
tdata["inputs"][name]["property"][
"value"
] = self.deserialize_socket_data(tdata["inputs"][name])
] = cls.deserialize_socket_data(tdata["inputs"][name])

def find_input_socket(self, name):
"""Find the output with the given name."""
if name in self.inputs:
return self.inputs[name]
return None

def serialize_socket_data(self, data: Dict[str, Any]) -> Any:
name = data["name"]
@classmethod
def serialize_socket_data(cls, data: Dict[str, Any]) -> Any:
if data.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE":
if data["property"]["value"] is None:
return None
if isinstance(data["property"]["value"], dict):
serialized_result = {}
for key, value in data["property"]["value"].items():
full_name = f"{name}.{key}"
full_name_output = self.find_input_socket(full_name)
if (
full_name_output
and full_name_output.get("identifier", "Any").upper()
== "WORKGRAPH.NAMESPACE"
):
serialized_result[key] = self.serialize_socket_data(
full_name_output
)
else:
serialized_result[key] = general_serializer(value)
serialized_result[key] = general_serializer(value)
return serialized_result
else:
raise ValueError("Namespace socket should be a dictionary.")
Expand All @@ -84,27 +70,16 @@ def serialize_socket_data(self, data: Dict[str, Any]) -> Any:
return data["property"]["value"]
return general_serializer(data["property"]["value"])

def deserialize_socket_data(self, data: Dict[str, Any]) -> Any:
name = data["name"]
@classmethod
def deserialize_socket_data(cls, data: Dict[str, Any]) -> Any:
if data.get("identifier", "Any").upper() == "WORKGRAPH.NAMESPACE":
if isinstance(data["property"]["value"], dict):
deserialized_result = {}
for key, value in data["property"]["value"].items():
full_name = f"{name}.{key}"
full_name_output = self.find_input_socket(full_name)
if (
full_name_output
and full_name_output.get("identifier", "Any").upper()
== "WORKGRAPH.NAMESPACE"
):
deserialized_result[key] = self.deserialize_socket_data(
full_name_output
)
if isinstance(value, orm.Data):
deserialized_result[key] = value.value
else:
if isinstance(value, orm.Data):
deserialized_result[key] = value.value
else:
deserialized_result[key] = value
deserialized_result[key] = value
return deserialized_result
else:
raise ValueError("Namespace socket should be a dictionary.")
Expand Down
3 changes: 3 additions & 0 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,12 @@ def serialize_properties(wgdata):
So, if a function is used as input, we needt to serialize the function.
"""
from aiida_workgraph.orm.function_data import PickledLocalFunction
from aiida_workgraph.tasks.pythonjob import PythonJob
import inspect

for _, task in wgdata["tasks"].items():
if task["metadata"]["node_type"].upper() == "PYTHONJOB":
PythonJob.serialize_pythonjob_data(task)
for _, input in task["inputs"].items():
if input["property"] is None:
continue
Expand Down

0 comments on commit 49df54e

Please sign in to comment.