Skip to content

Commit

Permalink
Add node_value & get_node_value to TaskSocket (#299)
Browse files Browse the repository at this point in the history
Add `node_value` & `get_node_value` to `TaskSocket`. This is basically just syntactic sugar to directly access the associated
raw Python value of a `Socket` which has a orm.Data as its value. Therefore, we can avoid, e.g., `.value.value`.
  • Loading branch information
GeigerJ2 authored and agoscinski committed Sep 19, 2024
1 parent 9a3dc11 commit f47ea1f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
19 changes: 19 additions & 0 deletions aiida_workgraph/socket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Type

from aiida import orm
from node_graph.socket import NodeSocket

from aiida_workgraph.property import TaskProperty


Expand All @@ -10,6 +13,22 @@ class TaskSocket(NodeSocket):
# to override the default NodeProperty from node_graph
node_property = TaskProperty

@property
def node_value(self):
return self.get_node_value()

def get_node_value(self):
"""Obtain the actual Python `value` of the object attached to the Socket."""
if isinstance(self.value, orm.Data):
if hasattr(self.value, "value"):
return self.value.value
else:
raise ValueError(
"Data node does not have a value attribute. We do not know how to extract the raw Python value."
)
else:
return self.value


def build_socket_from_AiiDA(DataClass: Type[Any]) -> Type[TaskSocket]:
"""Create a socket class from AiiDA DataClass."""
Expand Down
36 changes: 36 additions & 0 deletions tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,39 @@ def test(a, b=1, **kwargs):
test1 = test.node()
assert test1.inputs["kwargs"].link_limit == 1e6
assert test1.inputs["kwargs"].identifier == "workgraph.namespace"


@pytest.mark.parametrize(
"data_type, socket_value, node_value",
(
(None, None, None),
# Check that SocketAny works for int node, without providing type hint
(None, 1, 1),
(int, 1, 1),
(float, 1.0, 1.0),
(bool, True, True),
(str, "abc", "abc"),
(orm.Int, 1, 1),
(orm.Float, 1.0, 1.0),
(orm.Str, "abc", "abc"),
(orm.Bool, True, True),
),
)
def test_node_value(data_type, socket_value, node_value):

wg = WorkGraph()

def my_task(x: data_type):
pass

my_task1 = wg.add_task(my_task, name="my_task", x=socket_value)
socket = my_task1.inputs["x"]

socket_node_value = socket.get_node_value()
assert isinstance(socket_node_value, type(node_value))
assert socket_node_value == node_value

# Check that property also returns the correct results
socket_node_value = socket.node_value
assert isinstance(socket_node_value, type(node_value))
assert socket_node_value == node_value

0 comments on commit f47ea1f

Please sign in to comment.