Skip to content

Commit

Permalink
Use node graph (#3)
Browse files Browse the repository at this point in the history
- use `node-graph` package instead of `scinode`.
- add `WorkTree.load()` to load a WorkTree from the AiiDA `WorkTree` process.
- support continuing a finished worktree.
  • Loading branch information
superstar54 authored Oct 9, 2023
1 parent b0a64c5 commit bb147c7
Show file tree
Hide file tree
Showing 27 changed files with 681 additions and 124 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.8'
python-version: '3.10'

- name: Install Python dependencies
run: pip install -e .[pre-commit,tests]
Expand Down
3 changes: 2 additions & 1 deletion aiida_worktree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .decorator import node, build_node
from .worktree import WorkTree
from .node import Node
from .decorator import node, build_node

__version__ = "0.0.1"
37 changes: 27 additions & 10 deletions aiida_worktree/decorator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any
from scinode.utils.node import get_executor
from aiida_worktree.utils import get_executor
from aiida.engine.processes.functions import calcfunction, workfunction
from aiida.engine.processes.calcjobs import CalcJob
from aiida.engine.processes.workchains import WorkChain
Expand Down Expand Up @@ -27,7 +27,9 @@ def add_input_recursive(inputs, port, prefix=None):

def build_node(ndata):
"""Register a node from a AiiDA component."""
from scinode.utils.decorator import create_node
from node_graph.decorator import create_node
from aiida_worktree.node import Node
import cloudpickle as pickle

path, executor_name, = ndata.pop(
"path"
Expand All @@ -44,16 +46,26 @@ def build_node(ndata):
inputs = []
outputs = []
spec = executor.spec()
for key, port in spec.inputs.ports.items():
for _key, port in spec.inputs.ports.items():
add_input_recursive(inputs, port)
kwargs = [input[1] for input in inputs]
for key, port in spec.outputs.ports.items():
for _key, port in spec.outputs.ports.items():
outputs.append(["General", port.name])
# print("kwargs: ", kwargs)
ndata["node_class"] = Node
ndata["kwargs"] = kwargs
ndata["inputs"] = inputs
ndata["outputs"] = outputs
ndata["identifier"] = ndata.pop("identifier", ndata["executor"]["name"])
# TODO In order to reload the WorkTree from process, "is_pickle" should be True
# so I pickled the function here, but this is not necessary
# we need to update the node_graph to support the path and name of the function
executor = {
"executor": pickle.dumps(executor),
"type": "function",
"is_pickle": True,
}
ndata["executor"] = executor
node = create_node(ndata)
return node

Expand All @@ -68,7 +80,7 @@ def decorator_node(
catalog="Others",
executor_type="function",
):
"""Generate a decorator that register a function as a SciNode node.
"""Generate a decorator that register a function as a node.
Attributes:
indentifier (str): node identifier
Expand All @@ -79,13 +91,15 @@ def decorator_node(
inputs (list): node inputs
outputs (list): node outputs
"""
from aiida_worktree.node import Node

properties = properties or []
inputs = inputs or []
outputs = outputs or [["General", "result"]]

def decorator(func):
import cloudpickle as pickle
from scinode.utils.decorator import generate_input_sockets, create_node
from node_graph.decorator import generate_input_sockets, create_node

nonlocal identifier

Expand All @@ -94,10 +108,9 @@ def decorator(func):
# use cloudpickle to serialize function
executor = {
"executor": pickle.dumps(func),
"type": executor_type,
"type": "function",
"is_pickle": True,
}
#
# Get the args and kwargs of the function
args, kwargs, var_args, var_kwargs, _inputs = generate_input_sockets(
func, inputs, properties
Expand All @@ -113,6 +126,7 @@ def decorator(func):
else:
node_type = "Normal"
ndata = {
"node_class": Node,
"identifier": identifier,
"node_type": node_type,
"args": args,
Expand Down Expand Up @@ -151,13 +165,15 @@ def decorator_node_group(
inputs (list): node inputs
outputs (list): node outputs
"""
from aiida_worktree.node import Node

properties = properties or []
inputs = inputs or []
outputs = outputs or []

def decorator(func):
import cloudpickle as pickle
from scinode.utils.decorator import generate_input_sockets, create_node
from node_graph.decorator import generate_input_sockets, create_node

nonlocal identifier, inputs, outputs

Expand All @@ -179,11 +195,12 @@ def decorator(func):
# inputs = [[nt.nodes[input[0]].inputs[input[1]].identifier, input[2]] for input in group_inputs]
# outputs = [[nt.nodes[output[0]].outputs[output[1]].identifier, output[2]] for output in group_outputs]
# node_inputs = [["General", input[2]] for input in inputs]
node_outputs = [["General", output[2]] for output in outputs]
node_outputs = [["General", output[1]] for output in outputs]
# print(node_inputs, node_outputs)
#
node_type = "worktree"
ndata = {
"node_class": Node,
"identifier": identifier,
"args": args,
"kwargs": kwargs,
Expand Down
105 changes: 57 additions & 48 deletions aiida_worktree/engine/worktree.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _do_step(self) -> t.Any:
result: t.Any = None

try:
self.launch_worktree()
self.run_worktree()
except _PropagateReturn as exception:
finished, result = True, exception.exit_code
else:
Expand Down Expand Up @@ -367,7 +367,7 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None:
self.resume()

def setup(self):
from scinode.utils.nt_analysis import ConnectivityAnalysis
from node_graph.analysis import ConnectivityAnalysis
from aiida_worktree.utils import build_node_link

self.ctx.new_data = dict()
Expand All @@ -392,25 +392,23 @@ def setup(self):
self.ctx.ctrl_links = ntdata["ctrl_links"]
self.ctx.worktree = ntdata
print("init")
# init
for _name, node in self.ctx.nodes.items():
node["state"] = "CREATED"
node["process"] = None
#
nc = ConnectivityAnalysis(ntdata)
self.ctx.connectivity = nc.build_connectivity()
self.ctx.msgs = []
self.node.set_process_label(f"WorkTree: {self.ctx.worktree['name']}")
# while worktree
if self.ctx.worktree["is_while"]:
if self.ctx.worktree["worktree_type"].upper() == "WHILE":
should_run = self.check_while_conditions()
if not should_run:
self.set_node_state(self.ctx.nodes.keys(), "SKIPPED")
# for worktree
if self.ctx.worktree["is_for"]:
if self.ctx.worktree["worktree_type"].upper() == "FOR":
should_run = self.check_for_conditions()
if not should_run:
self.set_node_state(self.ctx.nodes.keys(), "SKIPPED")
# init node results
self.set_node_results()

def init_ctx(self, datas):
from aiida_worktree.utils import update_nested_dict
Expand All @@ -420,13 +418,49 @@ def init_ctx(self, datas):
key = key.replace("__", ".")
update_nested_dict(self.ctx, key, value)

def launch_worktree(self):
print("launch_worktree: ")
def set_node_results(self):
for _, node in self.ctx.nodes.items():
if node.get("process"):
if isinstance(node["process"], str):
node["process"] = orm.load_node(node["process"])
self.set_node_result(node)
self.set_node_result(node)

def set_node_result(self, node):
name = node["name"]
print(f"set node result: {name}")
if node.get("process"):
print(f"set node result: {name} process")
state = node["process"].process_state.value.upper()
if state == "FINISHED":
node["state"] = state
if node["metadata"]["node_type"] == "worktree":
# expose the outputs of nodetree
node["results"] = getattr(
node["process"].outputs, "group_outputs", None
)
# self.ctx.new_data[name] = outputs
else:
node["results"] = node["process"].outputs
# self.ctx.new_data[name] = node["results"]
self.ctx.nodes[name]["state"] = "FINISHED"
self.node_to_ctx(name)
print(f"Node: {name} finished.")
elif state == "EXCEPTED":
node["state"] = state
node["results"] = node["process"].outputs
# self.ctx.new_data[name] = node["results"]
self.ctx.nodes[name]["state"] = "FAILED"
# set child state to FAILED
self.set_node_state(self.ctx.connectivity["child_node"][name], "FAILED")
print(f"Node: {name} failed.")
else:
print(f"set node result: None")
node["results"] = None

def run_worktree(self):
print("run_worktree: ")
self.report("Lanch worktree.")
if len(self.ctx.worktree["starts"]) > 0:
self.run_nodes(self.ctx.worktree["starts"])
self.ctx.worktree["starts"] = []
return
node_to_run = []
for name, node in self.ctx.nodes.items():
# update node state
Expand Down Expand Up @@ -458,40 +492,14 @@ def is_worktree_finished(self):
]
and node["state"] == "RUNNING"
):
if node.get("process"):
state = node["process"].process_state.value.upper()
print(node["name"], state)
if state == "FINISHED":
node["state"] = state
if node["metadata"]["node_type"] == "worktree":
# expose the outputs of nodetree
node["results"] = getattr(
node["process"].outputs, "group_outputs", None
)
# self.ctx.new_data[name] = outputs
else:
node["results"] = node["process"].outputs
# self.ctx.new_data[name] = node["results"]
self.ctx.nodes[name]["state"] = "FINISHED"
self.node_to_ctx(name)
print(f"Node: {name} finished.")
elif state == "EXCEPTED":
node["state"] = state
node["results"] = node["process"].outputs
# self.ctx.new_data[name] = node["results"]
self.ctx.nodes[name]["state"] = "FAILED"
# set child state to FAILED
self.set_node_state(
self.ctx.connectivity["child_node"][name], "FAILED"
)
print(f"Node: {name} failed.")
self.set_node_result(node)
if node["state"] in ["RUNNING", "CREATED", "READY"]:
is_finished = False
if is_finished:
if self.ctx.worktree["is_while"]:
if self.ctx.worktree["worktree_type"].upper() == "WHILE":
should_run = self.check_while_conditions()
is_finished = not should_run
if self.ctx.worktree["is_for"]:
if self.ctx.worktree["worktree_type"].upper() == "FOR":
should_run = self.check_for_conditions()
is_finished = not should_run
return is_finished
Expand Down Expand Up @@ -743,6 +751,7 @@ def update_ctx_variable(self, value):
def node_to_ctx(self, name):
from aiida_worktree.utils import update_nested_dict

print("node to ctx: ", name)
items = self.ctx.nodes[name]["to_ctx"]
for item in items:
update_nested_dict(
Expand Down Expand Up @@ -835,23 +844,23 @@ def finalize(self):
from aiida_worktree.utils import get_nested_dict

# expose group outputs
print("finalize")
group_outputs = {}
print("group outputs: ", self.ctx.worktree["metadata"]["group_outputs"])
for output in self.ctx.worktree["metadata"]["group_outputs"]:
print("output: ", output)
if output[0] == "ctx":
group_outputs[output[2]] = get_nested_dict(self.ctx, output[1])
node_name, socket_name = output[0].split(".")
if node_name == "ctx":
group_outputs[output[1]] = get_nested_dict(self.ctx, socket_name)
else:
group_outputs[output[2]] = self.ctx.nodes[output[0]]["results"][
group_outputs[output[1]] = self.ctx.nodes[node_name]["results"][
output[1]
]
self.out("group_outputs", group_outputs)
self.out("new_data", self.ctx.new_data)
self.report("Finalize")
print(f"Finalize worktree {self.ctx.worktree['name']}")
for name, node in self.ctx.nodes.items():
if node["state"] == "FAILED":
print(f" Node {name} failed.")
return self.exit_codes.NODE_FAILED
print(f"Finalize worktree {self.ctx.worktree['name']}\n")
# check if all nodes are finished with nonzero exit code
29 changes: 29 additions & 0 deletions aiida_worktree/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from node_graph.node import Node as GraphNode


class Node(GraphNode):
"""Represent a Node in the AiiDA WorkTree.
The class extends from node_graph.node.Node and add new
attributes to it.
"""

socket_entry = "aiida_worktree.socket"
property_entry = "aiida_worktree.property"

def __init__(self, **kwargs):
"""
Initialize a Node instance.
"""
super().__init__(**kwargs)
self.to_ctx = []
self.wait = []
self.process = None

def to_dict(self):
ndata = super().to_dict()
ndata["to_ctx"] = self.to_ctx
ndata["wait"] = self.wait
ndata["process"] = self.process.uuid if self.process else None

return ndata
2 changes: 1 addition & 1 deletion aiida_worktree/nodes/builtin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from scinode.core.node import Node
from aiida_worktree.node import Node
from aiida_worktree.executors.builtin import GatherWorkChain


Expand Down
9 changes: 5 additions & 4 deletions aiida_worktree/nodes/qe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from scinode.core.node import Node
from aiida_worktree.node import Node
from aiida import orm


class AiiDAKpoint(Node):
Expand All @@ -10,8 +11,8 @@ class AiiDAKpoint(Node):
kwargs = ["mesh", "offset"]

def create_properties(self):
self.properties.new("IntVector", "mesh", default=[1, 1, 1], size=3)
self.properties.new("IntVector", "offset", default=[0, 0, 0], size=3)
self.properties.new("AiiDAIntVector", "mesh", default=[1, 1, 1], size=3)
self.properties.new("AiiDAIntVector", "offset", default=[0, 0, 0], size=3)

def create_sockets(self):
self.outputs.new("General", "Kpoint")
Expand Down Expand Up @@ -59,7 +60,7 @@ class AiiDAPWPseudo(Node):

def create_properties(self):
self.properties.new(
"String", "psuedo_familay", default="SSSP/1.2/PBEsol/efficiency"
"AiiDAString", "psuedo_familay", default="SSSP/1.2/PBEsol/efficiency"
)

def create_sockets(self):
Expand Down
Loading

0 comments on commit bb147c7

Please sign in to comment.