diff --git a/aiida_workgraph/socket.py b/aiida_workgraph/socket.py index 0211c4b9..f1b201a5 100644 --- a/aiida_workgraph/socket.py +++ b/aiida_workgraph/socket.py @@ -1,5 +1,8 @@ from typing import Any, Type + +from aiida import orm from node_graph.socket import NodeSocket + from aiida_workgraph.property import TaskProperty @@ -10,6 +13,22 @@ class TaskSocket(NodeSocket): # to override the default NodeProperty from node_graph node_property = TaskProperty + @property + def node_value(self): + return self.get_node_value() + + def get_node_value(self): + """Obtain the actual Python `value` of the object attached to the Socket.""" + if isinstance(self.value, orm.Data): + if hasattr(self.value, "value"): + return self.value.value + else: + raise ValueError( + "Data node does not have a value attribute. We do not know how to extract the raw Python value." + ) + else: + return self.value + def build_socket_from_AiiDA(DataClass: Type[Any]) -> Type[TaskSocket]: """Create a socket class from AiiDA DataClass.""" diff --git a/tests/test_socket.py b/tests/test_socket.py index 43f15fd1..dfcbbd20 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -87,3 +87,39 @@ def test(a, b=1, **kwargs): test1 = test.node() assert test1.inputs["kwargs"].link_limit == 1e6 assert test1.inputs["kwargs"].identifier == "workgraph.namespace" + + +@pytest.mark.parametrize( + "data_type, socket_value, node_value", + ( + (None, None, None), + # Check that SocketAny works for int node, without providing type hint + (None, 1, 1), + (int, 1, 1), + (float, 1.0, 1.0), + (bool, True, True), + (str, "abc", "abc"), + (orm.Int, 1, 1), + (orm.Float, 1.0, 1.0), + (orm.Str, "abc", "abc"), + (orm.Bool, True, True), + ), +) +def test_node_value(data_type, socket_value, node_value): + + wg = WorkGraph() + + def my_task(x: data_type): + pass + + my_task1 = wg.add_task(my_task, name="my_task", x=socket_value) + socket = my_task1.inputs["x"] + + socket_node_value = socket.get_node_value() + assert isinstance(socket_node_value, type(node_value)) + assert socket_node_value == node_value + + # Check that property also returns the correct results + socket_node_value = socket.node_value + assert isinstance(socket_node_value, type(node_value)) + assert socket_node_value == node_value