Skip to content

Commit

Permalink
Refactor: Add the _prepare_yaml method to AbstractCode (#6565)
Browse files Browse the repository at this point in the history
As specified in the docs, the `Data` class implements an `export()` method. For actually exporting `Data` nodes, subclasses should implement a `_prepare_XXX` method, where `XXX` is the desired export file format. When running `verdi data <orm-data-type> export`, the available data formats for exporting are dynamically determined based on the implemented `_prepare_XXX` methods. The `Code` classes didn't follow this specification (likely because `Code` wasn't historically a subclass of `Data`), but instead a custom implementation was used for `verdi code export` in `cmd_code.py`.

With the goal of increasing consistency, this PR moves the code of this custom implementation to the new `_prepare_yaml` method of the `AbstractCode` class (`_prepare_yml` is also added, as the previous default file extension was `.yml`, but it basically just calls `prepare_yaml`). The `export` function in `cmd_code.py` now instead calls the `data_export` function, as is done when exporting other classes derived from `Data`. Thus, exporting a `Code` via the CLI through `verdi code export` remains unchanged.

Lastly, for the `PortableCode` class, the `prepare_yaml` method is overridden to temporarily attach the `filepath_files` attribute to the instance, as this field is defined via the `pydantic` model, but not actually saved as an attribute (instead, the contents of the given folder are added to the `NodeRepository` via `put_object_from_tree`). Without temporarily attaching the attribute, this would otherwise lead to an exception in `test_data.py` which tests the exporting of the derived `Data` nodes, as the `filepath_files` field is tried to be accessed from the instance via `getattr`. In addition, the files contained in the `NodeRepository` of the `PortableCode` are also dumped to disk, such that upon exporting the code configuration to a YAML, the `PortableCode` _can_ actually be fully re-created from this file and the additional directory.
  • Loading branch information
GeigerJ2 authored Oct 7, 2024
1 parent c7c289d commit 98ffc33
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 60 deletions.
38 changes: 24 additions & 14 deletions src/aiida/cmdline/commands/cmd_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

import click

from aiida.cmdline.commands.cmd_data.cmd_export import data_export
from aiida.cmdline.commands.cmd_verdi import verdi
from aiida.cmdline.groups.dynamic import DynamicEntryPointCommandGroup
from aiida.cmdline.params import arguments, options, types
from aiida.cmdline.params.options.commands import code as options_code
from aiida.cmdline.utils import echo, echo_tabulate
from aiida.cmdline.utils.common import generate_validate_output_file
from aiida.cmdline.utils.common import validate_output_filename
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.common import exceptions

Expand Down Expand Up @@ -241,28 +242,37 @@ def show(code):
@options.SORT()
@with_dbenv()
def export(code, output_file, overwrite, sort):
"""Export code to a yaml file."""
"""Export code to a yaml file. If no output file is given, default name is created based on the code label."""

import yaml
other_args = {'sort': sort}

code_data = {}
fileformat = 'yaml'

for key in code.Model.model_fields.keys():
value = getattr(code, key).label if key == 'computer' else getattr(code, key)

# If the attribute is not set, for example ``with_mpi`` do not export it, because the YAML won't be valid for
# use in ``verdi code create`` since ``None`` is not a valid value on the CLI.
if value is not None:
code_data[key] = str(value)
if output_file is None:
output_file = pathlib.Path(f'{code.full_label}.{fileformat}')

try:
output_file = generate_validate_output_file(
output_file=output_file, entity_label=code.label, overwrite=overwrite, appendix=f'@{code_data["computer"]}'
# In principle, output file validation is also done in the `data_export` function. However, the
# `validate_output_filename` function is applied here, as well, as it is also used in the `Computer` export, and
# as `Computer` is not derived from `Data`, it cannot be exported by `data_export`, so
# `validate_output_filename` cannot be removed in favor of the validation done in `data_export`.
validate_output_filename(
output_file=output_file,
overwrite=overwrite,
)
except (FileExistsError, IsADirectoryError) as exception:
raise click.BadParameter(str(exception), param_hint='OUTPUT_FILE') from exception

output_file.write_text(yaml.dump(code_data, sort_keys=sort))
try:
data_export(
node=code,
output_fname=output_file,
fileformat=fileformat,
other_args=other_args,
overwrite=overwrite,
)
except Exception as exception:
echo.echo_critical(f'Error in the `data_export` function: {exception}')

echo.echo_success(f'Code<{code.pk}> {code.label} exported to file `{output_file}`.')

Expand Down
14 changes: 7 additions & 7 deletions src/aiida/cmdline/commands/cmd_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from aiida.cmdline.params import arguments, options
from aiida.cmdline.params.options.commands import computer as options_computer
from aiida.cmdline.utils import echo, echo_tabulate
from aiida.cmdline.utils.common import generate_validate_output_file
from aiida.cmdline.utils.common import validate_output_filename
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.common.exceptions import EntryPointError, ValidationError
from aiida.plugins.entry_point import get_entry_point_names
Expand Down Expand Up @@ -766,10 +766,10 @@ def computer_export_setup(computer, output_file, overwrite, sort):
'append_text': computer.get_append_text(),
}

if output_file is None:
output_file = pathlib.Path(f'{computer.label}-setup.yaml')
try:
output_file = generate_validate_output_file(
output_file=output_file, entity_label=computer.label, overwrite=overwrite, appendix='-setup'
)
validate_output_filename(output_file=output_file, overwrite=overwrite)
except (FileExistsError, IsADirectoryError) as exception:
raise click.BadParameter(str(exception), param_hint='OUTPUT_FILE') from exception

Expand Down Expand Up @@ -804,10 +804,10 @@ def computer_export_config(computer, output_file, user, overwrite, sort):
' because computer has not been configured yet.'
)
else:
if output_file is None:
output_file = pathlib.Path(f'{computer.label}-config.yaml')
try:
output_file = generate_validate_output_file(
output_file=output_file, entity_label=computer.label, overwrite=overwrite, appendix='-config'
)
validate_output_filename(output_file=output_file, overwrite=overwrite)
except (FileExistsError, IsADirectoryError) as exception:
raise click.BadParameter(str(exception), param_hint='OUTPUT_FILE') from exception

Expand Down
14 changes: 8 additions & 6 deletions src/aiida/cmdline/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,17 @@ def build_entries(ports):
echo.echo(style('\nExit codes that invalidate the cache are marked in bold red.\n', italic=True))


def generate_validate_output_file(
output_file: Path | None, entity_label: str, appendix: str = '', overwrite: bool = False
def validate_output_filename(
output_file: Path | str,
overwrite: bool = False,
):
"""Generate default output filename for `Code`/`Computer` export and validate."""
"""Validate output filename."""

if output_file is None:
output_file = Path(f'{entity_label}{appendix}.yml')
raise TypeError('Output filename must be passed for validation.')

if isinstance(output_file, str):
output_file = Path(output_file)

if output_file.is_dir():
raise IsADirectoryError(
Expand All @@ -501,5 +505,3 @@ def generate_validate_output_file(

if output_file.is_file() and not overwrite:
raise FileExistsError(f'File `{output_file}` already exists, use `--overwrite` to overwrite.')

return output_file
21 changes: 21 additions & 0 deletions src/aiida/orm/nodes/data/code/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,24 @@ def get_builder(self) -> 'ProcessBuilder':
builder.code = self

return builder

def _prepare_yaml(self, *args, **kwargs):
"""Export code to a YAML file."""
import yaml

code_data = {}
sort = kwargs.get('sort', False)

for key in self.Model.model_fields.keys():
value = getattr(self, key).label if key == 'computer' else getattr(self, key)

# If the attribute is not set, for example ``with_mpi`` do not export it
# so that there are no null-values in the resulting YAML file
if value is not None:
code_data[key] = str(value)

return yaml.dump(code_data, sort_keys=sort, encoding='utf-8'), {}

def _prepare_yml(self, *args, **kwargs):
"""Also allow for export as .yml"""
return self._prepare_yaml(*args, **kwargs)
35 changes: 34 additions & 1 deletion src/aiida/orm/nodes/data/code/portable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from __future__ import annotations

import contextlib
import logging
import pathlib
import typing as t

Expand All @@ -34,6 +36,7 @@
from .legacy import Code

__all__ = ('PortableCode',)
_LOGGER = logging.getLogger(__name__)


class PortableCode(Code):
Expand Down Expand Up @@ -71,7 +74,7 @@ def validate_filepath_files(cls, value: str) -> pathlib.Path:
raise ValueError(f'The filepath `{value}` is not a directory.')
return filepath

def __init__(self, filepath_executable: str, filepath_files: pathlib.Path, **kwargs):
def __init__(self, filepath_executable: str, filepath_files: pathlib.Path | str, **kwargs):
"""Construct a new instance.
.. note:: If the files necessary for this code are not all located in a single directory or the directory
Expand Down Expand Up @@ -177,3 +180,33 @@ def filepath_executable(self, value: str) -> None:
raise ValueError('The `filepath_executable` should not be absolute.')

self.base.attributes.set(self._KEY_ATTRIBUTE_FILEPATH_EXECUTABLE, value)

def _prepare_yaml(self, *args, **kwargs):
"""Export code to a YAML file."""
try:
target = pathlib.Path().cwd() / f'{self.label}'
setattr(self, 'filepath_files', str(target))
result = super()._prepare_yaml(*args, **kwargs)[0]

extra_files = {}
node_repository = self.base.repository

# Logic taken from `copy_tree` method of the `Repository` class and adapted to return
# the relative file paths and their utf-8 encoded content as `extra_files` dictionary
path = '.'
for root, dirnames, filenames in node_repository.walk():
for filename in filenames:
rel_output_file_path = root.relative_to(path) / filename
full_output_file_path = target / rel_output_file_path
full_output_file_path.parent.mkdir(exist_ok=True, parents=True)

extra_files[str(full_output_file_path)] = node_repository.get_object_content(
str(rel_output_file_path), mode='rb'
)
_LOGGER.warning(f'Repository files for PortableCode <{self.pk}> dumped to folder `{target}`.')

finally:
with contextlib.suppress(AttributeError):
delattr(self, 'filepath_files')

return result, extra_files
11 changes: 7 additions & 4 deletions tests/cmdline/commands/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def test_code_export_default_filename(run_cli_command, aiida_code_installed):
options = [str(code.pk)]
run_cli_command(cmd_code.export, options)

assert pathlib.Path('code@localhost.yml').is_file()
assert pathlib.Path('code@localhost.yaml').is_file()


@pytest.mark.parametrize('non_interactive_editor', ('vim -cwq',), indirect=True)
Expand Down Expand Up @@ -461,10 +461,13 @@ def test_code_setup_local_duplicate_full_label_interactive(run_cli_command, non_


@pytest.mark.usefixtures('aiida_profile_clean')
def test_code_setup_local_duplicate_full_label_non_interactive(run_cli_command):
def test_code_setup_local_duplicate_full_label_non_interactive(run_cli_command, tmp_path):
"""Test ``verdi code setup`` for a local code in non-interactive mode specifying an existing full label."""
label = 'some-label'
code = PortableCode(filepath_executable='bash', filepath_files=pathlib.Path('/bin/bash'))
tmp_bin_dir = tmp_path / 'bin'
tmp_bin_dir.mkdir()
(tmp_bin_dir / 'bash').touch()
code = PortableCode(filepath_executable='bash', filepath_files=tmp_bin_dir)
code.label = label
code.base.repository.put_object_from_filelike(io.BytesIO(b''), 'bash')
code.store()
Expand All @@ -477,7 +480,7 @@ def test_code_setup_local_duplicate_full_label_non_interactive(run_cli_command):
'-P',
'core.arithmetic.add',
'--store-in-db',
'--code-folder=/bin',
f'--code-folder={tmp_bin_dir}',
'--code-rel-path=bash',
'--label',
label,
Expand Down
18 changes: 9 additions & 9 deletions tests/cmdline/commands/test_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def test_computer_export_setup(self, tmp_path, file_regression, sort_option):
comp = self.comp_builder.new()
comp.store()

exported_setup_filename = tmp_path / 'computer-setup.yml'
exported_setup_filename = tmp_path / 'computer-setup.yaml'

# Successfull write behavior
result = self.cli_runner(computer_export_setup, [comp.label, exported_setup_filename, sort_option])
Expand All @@ -534,9 +534,9 @@ def test_computer_export_setup(self, tmp_path, file_regression, sort_option):

# file regresssion check
content = exported_setup_filename.read_text()
file_regression.check(content, extension='.yml')
file_regression.check(content, extension='.yaml')

# verifying correctness by comparing internal and loaded yml object
# verifying correctness by comparing internal and loaded yaml object
configure_setup_data = yaml.safe_load(exported_setup_filename.read_text())
assert configure_setup_data == self.comp_builder.get_computer_spec(
comp
Expand All @@ -550,7 +550,7 @@ def test_computer_export_setup_overwrite(self, tmp_path):
comp = self.comp_builder.new()
comp.store()

exported_setup_filename = tmp_path / 'computer-setup.yml'
exported_setup_filename = tmp_path / 'computer-setup.yaml'
# Check that export fails if the file already exists
exported_setup_filename.touch()
result = self.cli_runner(computer_export_setup, [comp.label, exported_setup_filename], raises=True)
Expand Down Expand Up @@ -581,7 +581,7 @@ def test_computer_export_setup_default_filename(self):
comp = self.comp_builder.new()
comp.store()

exported_setup_filename = f'{comp_label}-setup.yml'
exported_setup_filename = f'{comp_label}-setup.yaml'

self.cli_runner(computer_export_setup, [comp.label])
assert pathlib.Path(exported_setup_filename).is_file()
Expand All @@ -593,7 +593,7 @@ def test_computer_export_config(self, tmp_path):
comp = self.comp_builder.new()
comp.store()

exported_config_filename = tmp_path / 'computer-configure.yml'
exported_config_filename = tmp_path / 'computer-configure.yaml'

# We have not configured the computer yet so it should exit with an critical error
result = self.cli_runner(computer_export_config, [comp.label, exported_config_filename], raises=True)
Expand All @@ -613,7 +613,7 @@ def test_computer_export_config(self, tmp_path):
content = exported_config_filename.read_text()
assert content.startswith('safe_interval: 0.0')

# verifying correctness by comparing internal and loaded yml object
# verifying correctness by comparing internal and loaded yaml object
configure_config_data = yaml.safe_load(exported_config_filename.read_text())
assert (
configure_config_data == comp.get_configuration()
Expand Down Expand Up @@ -641,7 +641,7 @@ def test_computer_export_config_overwrite(self, tmp_path):
comp.store()
comp.configure(safe_interval=0.0)

exported_config_filename = tmp_path / 'computer-configure.yml'
exported_config_filename = tmp_path / 'computer-configure.yaml'

# Create directory with the same name and check that command fails
exported_config_filename.mkdir()
Expand Down Expand Up @@ -673,7 +673,7 @@ def test_computer_export_config_default_filename(self):
comp.store()
comp.configure(safe_interval=0.0)

exported_config_filename = f'{comp_label}-config.yml'
exported_config_filename = f'{comp_label}-config.yaml'

self.cli_runner(computer_export_config, [comp.label])
assert pathlib.Path(exported_config_filename).is_file()
Expand Down
31 changes: 13 additions & 18 deletions tests/cmdline/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import pytest
from aiida.cmdline.utils import common
from aiida.cmdline.utils.common import generate_validate_output_file
from aiida.cmdline.utils.common import validate_output_filename
from aiida.common import LinkType
from aiida.engine import Process, calcfunction
from aiida.orm import CalcFunctionNode, CalculationNode, WorkflowNode
Expand Down Expand Up @@ -95,35 +95,30 @@ def test_with_docstring():


@pytest.mark.usefixtures('chdir_tmp_path')
def test_generate_validate_output():
def test_validate_output_filename():
test_entity_label = 'test_code'
test_appendix = '@test_computer'
fileformat = 'yaml'

expected_output_file = Path(f'{test_entity_label}{test_appendix}.yml')
expected_output_file = Path(f'{test_entity_label}{test_appendix}.{fileformat}')

# Test default label creation
obtained_output_file = generate_validate_output_file(
output_file=None, entity_label=test_entity_label, appendix=test_appendix
)
assert expected_output_file == obtained_output_file, 'Filenames differ'
# Test failure if no actual file to be validated is passed
with pytest.raises(TypeError, match='.*passed for validation.'):
validate_output_filename(output_file=None)

# Test failure if file exists, but overwrite False
expected_output_file.touch()
with pytest.raises(FileExistsError, match='.*use `--overwrite` to overwrite.'):
generate_validate_output_file(
output_file=None, entity_label=test_entity_label, appendix=test_appendix, overwrite=False
)
validate_output_filename(output_file=expected_output_file, overwrite=False)

# Test that overwrite does the job
obtained_output_file = generate_validate_output_file(
output_file=None, entity_label=test_entity_label, appendix=test_appendix, overwrite=True
)
assert expected_output_file == obtained_output_file, 'Overwrite unsuccessful'
# Test that overwrite does the job -> No exception raised
validate_output_filename(output_file=expected_output_file, overwrite=True)
expected_output_file.unlink()

# Test failure if directory exists
expected_output_file.mkdir()
with pytest.raises(IsADirectoryError, match='A directory with the name.*'):
generate_validate_output_file(
output_file=None, entity_label=test_entity_label, appendix=test_appendix, overwrite=False
validate_output_filename(
output_file=expected_output_file,
overwrite=False,
)
Loading

0 comments on commit 98ffc33

Please sign in to comment.