Skip to content

Commit

Permalink
Add node_value and get_node_value to TaskSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
GeigerJ2 committed Sep 13, 2024
1 parent ce6b75b commit 6394819
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 55 deletions.
42 changes: 13 additions & 29 deletions aiida_workgraph/socket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Type

from aiida import orm
from node_graph.socket import NodeSocket

from aiida_workgraph.property import TaskProperty


Expand All @@ -12,39 +15,20 @@ class TaskSocket(NodeSocket):

@property
def node_value(self):
if not hasattr(self, '_node_value'):
self._node_value = self.get_node_value()
return self._node_value
return self.get_node_value()

def get_node_value(self):
"Directly return or set the _node_value attribute."

# _node_value was set before already
if hasattr(self, '_node_value'):
pass

# If data associated with Socket is AiiDA ORM, return again its value
# Check for the nested case before, otherwise Socket value is matched
# first and the ORM instance returned
elif hasattr(self.value, "value"):
# TODO: One could also check for isinstance of AiiDA ORM, however, that adds another import, and not sure if
# TODO: it's really necessary here

self._node_value = self.value.value
# TODO: Possibly check here for AttributeDict, for which we return the `get_dict` result directly
# TODO: However, with a specific `SocketAiiDADict` class, we could just overwrite the method there

# If not, e.g. when Python base types are used, directly return the variable's value
elif hasattr(self, "value"):
self._node_value = self.value

# TODO: This shouldn't even be a case? Shouldn't a `Socket` _always_ have an associated value, even if it's
# TODO: `None` on instantiation
"""Obtain the actual Python `value` of the object attached to the Socket."""
if isinstance(self.value, orm.Data):
if hasattr(self.value, "value"):
return self.value.value
else:
raise ValueError(
"Data node does not have a value attribute. We do not know how to extract the raw Python value."
)
else:
msg = f"Socket <{self}> does not have an asociated `value` attribute."
raise AttributeError(msg)
return self.value

return self._node_value

def build_socket_from_AiiDA(DataClass: Type[Any]) -> Type[TaskSocket]:
"""Create a socket class from AiiDA DataClass."""
Expand Down
45 changes: 19 additions & 26 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,22 +94,20 @@ def test_numpy_array(decorated_normal_add):

@pytest.mark.parametrize(
"data_type, socket_value, node_value",
(
(None, None, None),
# Check that SocketAny works for int node, without providing type hint
(None, 1, 1),
(int, 1, 1),
(float, 1., 1.),
(bool, True, True),
(str, "abc", "abc"),
# TODO: Wanted to instantiate the AiiDA ORM classes here, but that raises an
# TODO: aiida.common.exceptions.ConfigurationError, due to profile not being loaded
# TODO: which also isn't resolved by: `@pytest.mark.usefixtures("aiida_profile")`
(orm.Int, 1, 1),
(orm.Float, 1., 1.),
(orm.Str, 'abc', 'abc'),
(orm.Bool, True, True),
))
(
(None, None, None),
# Check that SocketAny works for int node, without providing type hint
(None, 1, 1),
(int, 1, 1),
(float, 1.0, 1.0),
(bool, True, True),
(str, "abc", "abc"),
(orm.Int, 1, 1),
(orm.Float, 1.0, 1.0),
(orm.Str, "abc", "abc"),
(orm.Bool, True, True),
),
)
def test_node_value(data_type, socket_value, node_value):

wg = WorkGraph()
Expand All @@ -118,18 +116,13 @@ def my_task(x: data_type):
pass

my_task1 = wg.add_task(my_task, name="my_task", x=socket_value)
socket = my_task1.inputs['x']
socket = my_task1.inputs["x"]

# Private attribute is undefined (and shouldn't be called anyway)
with pytest.raises(AttributeError):
socket._node_value
socket_node_value = socket.get_node_value()
assert isinstance(socket_node_value, type(node_value))
assert socket_node_value == node_value

# This should call the `get_node_value` method
# Check that property also returns the correct results
socket_node_value = socket.node_value

assert isinstance(socket_node_value, type(node_value))
assert socket_node_value == node_value

# Now the private attribute should be set, such that the `get_node_value` method doesn't have to be called again
assert isinstance(socket_node_value, type(socket._node_value))
assert socket_node_value == socket._node_value

0 comments on commit 6394819

Please sign in to comment.