From 6394819780a295a1cabde9e2e2b3cef5489fc3e3 Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Fri, 13 Sep 2024 15:28:25 +0200 Subject: [PATCH] Add `node_value` and `get_node_value` to `TaskSocket` --- aiida_workgraph/socket.py | 42 +++++++++++------------------------- tests/test_socket.py | 45 +++++++++++++++++---------------------- 2 files changed, 32 insertions(+), 55 deletions(-) diff --git a/aiida_workgraph/socket.py b/aiida_workgraph/socket.py index b594e6e6..cdf7a015 100644 --- a/aiida_workgraph/socket.py +++ b/aiida_workgraph/socket.py @@ -1,5 +1,8 @@ from typing import Any, Type + +from aiida import orm from node_graph.socket import NodeSocket + from aiida_workgraph.property import TaskProperty @@ -12,39 +15,20 @@ class TaskSocket(NodeSocket): @property def node_value(self): - if not hasattr(self, '_node_value'): - self._node_value = self.get_node_value() - return self._node_value + return self.get_node_value() def get_node_value(self): - "Directly return or set the _node_value attribute." - - # _node_value was set before already - if hasattr(self, '_node_value'): - pass - - # If data associated with Socket is AiiDA ORM, return again its value - # Check for the nested case before, otherwise Socket value is matched - # first and the ORM instance returned - elif hasattr(self.value, "value"): - # TODO: One could also check for isinstance of AiiDA ORM, however, that adds another import, and not sure if - # TODO: it's really necessary here - - self._node_value = self.value.value - # TODO: Possibly check here for AttributeDict, for which we return the `get_dict` result directly - # TODO: However, with a specific `SocketAiiDADict` class, we could just overwrite the method there - - # If not, e.g. when Python base types are used, directly return the variable's value - elif hasattr(self, "value"): - self._node_value = self.value - - # TODO: This shouldn't even be a case? Shouldn't a `Socket` _always_ have an associated value, even if it's - # TODO: `None` on instantiation + """Obtain the actual Python `value` of the object attached to the Socket.""" + if isinstance(self.value, orm.Data): + if hasattr(self.value, "value"): + return self.value.value + else: + raise ValueError( + "Data node does not have a value attribute. We do not know how to extract the raw Python value." + ) else: - msg = f"Socket <{self}> does not have an asociated `value` attribute." - raise AttributeError(msg) + return self.value - return self._node_value def build_socket_from_AiiDA(DataClass: Type[Any]) -> Type[TaskSocket]: """Create a socket class from AiiDA DataClass.""" diff --git a/tests/test_socket.py b/tests/test_socket.py index e3d39765..7abfb8d7 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -94,22 +94,20 @@ def test_numpy_array(decorated_normal_add): @pytest.mark.parametrize( "data_type, socket_value, node_value", -( - (None, None, None), - # Check that SocketAny works for int node, without providing type hint - (None, 1, 1), - (int, 1, 1), - (float, 1., 1.), - (bool, True, True), - (str, "abc", "abc"), - # TODO: Wanted to instantiate the AiiDA ORM classes here, but that raises an - # TODO: aiida.common.exceptions.ConfigurationError, due to profile not being loaded - # TODO: which also isn't resolved by: `@pytest.mark.usefixtures("aiida_profile")` - (orm.Int, 1, 1), - (orm.Float, 1., 1.), - (orm.Str, 'abc', 'abc'), - (orm.Bool, True, True), -)) + ( + (None, None, None), + # Check that SocketAny works for int node, without providing type hint + (None, 1, 1), + (int, 1, 1), + (float, 1.0, 1.0), + (bool, True, True), + (str, "abc", "abc"), + (orm.Int, 1, 1), + (orm.Float, 1.0, 1.0), + (orm.Str, "abc", "abc"), + (orm.Bool, True, True), + ), +) def test_node_value(data_type, socket_value, node_value): wg = WorkGraph() @@ -118,18 +116,13 @@ def my_task(x: data_type): pass my_task1 = wg.add_task(my_task, name="my_task", x=socket_value) - socket = my_task1.inputs['x'] + socket = my_task1.inputs["x"] - # Private attribute is undefined (and shouldn't be called anyway) - with pytest.raises(AttributeError): - socket._node_value + socket_node_value = socket.get_node_value() + assert isinstance(socket_node_value, type(node_value)) + assert socket_node_value == node_value - # This should call the `get_node_value` method + # Check that property also returns the correct results socket_node_value = socket.node_value - assert isinstance(socket_node_value, type(node_value)) assert socket_node_value == node_value - - # Now the private attribute should be set, such that the `get_node_value` method doesn't have to be called again - assert isinstance(socket_node_value, type(socket._node_value)) - assert socket_node_value == socket._node_value