Skip to content

Commit

Permalink
add set_from_protocol for node
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Nov 27, 2023
1 parent fca7717 commit f4969ed
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
9 changes: 9 additions & 0 deletions aiida_worktree/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,12 @@ def to_dict(self):
ndata["process"] = self.process.uuid if self.process else None

return ndata

def set_from_protocol(self, *args, **kwargs):
"""For node support protocol, set the node from protocol data."""
from aiida_worktree.utils import get_executor, get_dict_from_builder

executor = get_executor(self.get_executor())[0]
builder = executor.get_builder_from_protocol(*args, **kwargs)
data = get_dict_from_builder(builder)
self.set(data)
10 changes: 10 additions & 0 deletions aiida_worktree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ def build_node_link(ntdata):
from_socket["links"].append(link)


def get_dict_from_builder(builder):
"""Transform builder to pure dict."""
from aiida.engine.processes.builder import ProcessBuilderNamespace

if isinstance(builder, ProcessBuilderNamespace):
return {k: get_dict_from_builder(v) for k, v in builder.items()}
else:
return builder


if __name__ == "__main__":
d = {
"base": {
Expand Down
27 changes: 27 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import aiida
import numpy as np

aiida.load_profile()


def test_pw_relax_workchain(structure_si):
"""Run simple calcfunction."""
from aiida_worktree import build_node, WorkTree
from aiida import orm

# register node
pw_relax_node = build_node(
{"path": "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"}
)
code = orm.load_code("pw-7.2@localhost")
wt = WorkTree("test_pw_relax")
pw_relax1 = wt.nodes.new(pw_relax_node, name="pw_relax1")
pw_relax1.set_from_protocol(
code, structure_si, protocol="fast", pseudo_family="SSSP/1.2/PBEsol/efficiency"
)
wt.submit(wait=True, timeout=200)
assert wt.state == "FINISHED"
# print(wt.nodes["pw_relax1"].node.outputs.output_parameters["energy"])
assert np.isclose(
wt.nodes["pw_relax1"].node.outputs.output_parameters["energy"], -308.46262827125
)

0 comments on commit f4969ed

Please sign in to comment.