From f4969eda1c9b16584812f6efb6415bb2bbb9a1b9 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Mon, 27 Nov 2023 10:19:58 +0100 Subject: [PATCH] add `set_from_protocol` for node --- aiida_worktree/node.py | 9 +++++++++ aiida_worktree/utils.py | 10 ++++++++++ tests/test_protocol.py | 27 +++++++++++++++++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 tests/test_protocol.py diff --git a/aiida_worktree/node.py b/aiida_worktree/node.py index 292809a0..a8d648fd 100644 --- a/aiida_worktree/node.py +++ b/aiida_worktree/node.py @@ -27,3 +27,12 @@ def to_dict(self): ndata["process"] = self.process.uuid if self.process else None return ndata + + def set_from_protocol(self, *args, **kwargs): + """For node support protocol, set the node from protocol data.""" + from aiida_worktree.utils import get_executor, get_dict_from_builder + + executor = get_executor(self.get_executor())[0] + builder = executor.get_builder_from_protocol(*args, **kwargs) + data = get_dict_from_builder(builder) + self.set(data) diff --git a/aiida_worktree/utils.py b/aiida_worktree/utils.py index 9e23f0fd..bedc91dc 100644 --- a/aiida_worktree/utils.py +++ b/aiida_worktree/utils.py @@ -128,6 +128,16 @@ def build_node_link(ntdata): from_socket["links"].append(link) +def get_dict_from_builder(builder): + """Transform builder to pure dict.""" + from aiida.engine.processes.builder import ProcessBuilderNamespace + + if isinstance(builder, ProcessBuilderNamespace): + return {k: get_dict_from_builder(v) for k, v in builder.items()} + else: + return builder + + if __name__ == "__main__": d = { "base": { diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 00000000..809d7f7f --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,27 @@ +import aiida +import numpy as np + +aiida.load_profile() + + +def test_pw_relax_workchain(structure_si): + """Run simple calcfunction.""" + from aiida_worktree import build_node, WorkTree + from aiida import orm + + # register node + pw_relax_node = build_node( + {"path": "aiida_quantumespresso.workflows.pw.relax.PwRelaxWorkChain"} + ) + code = orm.load_code("pw-7.2@localhost") + wt = WorkTree("test_pw_relax") + pw_relax1 = wt.nodes.new(pw_relax_node, name="pw_relax1") + pw_relax1.set_from_protocol( + code, structure_si, protocol="fast", pseudo_family="SSSP/1.2/PBEsol/efficiency" + ) + wt.submit(wait=True, timeout=200) + assert wt.state == "FINISHED" + # print(wt.nodes["pw_relax1"].node.outputs.output_parameters["energy"]) + assert np.isclose( + wt.nodes["pw_relax1"].node.outputs.output_parameters["energy"], -308.46262827125 + )