From 28d10f50ee69018a29528c0a9142a20eeb221350 Mon Sep 17 00:00:00 2001 From: Peter N O Gillespie Date: Wed, 10 Jul 2024 14:59:01 +0000 Subject: [PATCH] `XspectraCrystalWorkChain`: Correct Error Handling Changes requested for PR #1028 (2) Corrects error handling behaviour for `validate_inputs` to use `return` rather than `raise`. Also fixes minor formatting errors in `symmetry_data` `help` entry and removes unnecessary uses of `+`. --- .../workflows/xspectra/crystal.py | 70 +++++++++---------- 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/src/aiida_quantumespresso/workflows/xspectra/crystal.py b/src/aiida_quantumespresso/workflows/xspectra/crystal.py index d440afff..76c866fc 100644 --- a/src/aiida_quantumespresso/workflows/xspectra/crystal.py +++ b/src/aiida_quantumespresso/workflows/xspectra/crystal.py @@ -4,7 +4,7 @@ Uses QuantumESPRESSO pw.x and xspectra.x. """ from aiida import orm -from aiida.common import AttributeDict, ValidationError +from aiida.common import AttributeDict from aiida.engine import ToContext, WorkChain, if_ from aiida.orm import UpfData as aiida_core_upf from aiida.plugins import CalculationFactory, DataFactory, WorkflowFactory @@ -180,10 +180,10 @@ def define(cls, spec): required=False, help=( 'Input namespace to define equivalent sites and spacegroup number for the system. If defined, will ' - + 'skip symmetry analysis and structure standardization. Use *only* if symmetry data are known' - + 'for certain. Requires ``spacegroup_number`` (Int) and ``equivalent_sites_data`` (Dict) to be' - + 'defined separately. All keys in `equivalent_sites_data` must be formatted as "site_".' - + 'See docstring of `get_xspectra_structures` for more information about inputs.' + 'skip symmetry analysis and structure standardization. Use *only* if symmetry data are known ' + 'for certain. Requires ``spacegroup_number`` (Int) and ``equivalent_sites_data`` (Dict) to be ' + 'defined separately. All keys in `equivalent_sites_data` must be formatted as "site_". ' + 'See docstring of `get_xspectra_structures` for more information about inputs.' ) ) spec.inputs.validator = cls.validate_inputs @@ -382,9 +382,8 @@ def get_builder_from_protocol( # pylint: disable=too-many-statements return builder - # pylint: disable=too-many-statements @staticmethod - def validate_inputs(inputs, _): + def validate_inputs(inputs, _): # pylint: disable=too-many-return-statements """Validate the inputs before launching the WorkChain.""" structure = inputs['structure'] kinds_present = [kind.name for kind in structure.kinds] @@ -396,58 +395,57 @@ def validate_inputs(inputs, _): if element not in elements_present: extra_elements.append(element) if len(extra_elements) > 0: - raise ValidationError( + return ( f'Some elements in ``elements_list`` {extra_elements} do not exist in the' f' structure provided {elements_present}.' ) abs_atom_marker = inputs['abs_atom_marker'].value if abs_atom_marker in kinds_present: - raise ValidationError( + return ( f'The marker given for the absorbing atom ("{abs_atom_marker}") matches an existing Kind in the ' f'input structure ({kinds_present}).' ) if not inputs['core']['get_powder_spectrum'].value: - raise ValidationError( + return ( 'The ``get_powder_spectrum`` input for the XspectraCoreWorkChain namespace must be ``True``.' ) if 'upf2plotcore_code' not in inputs and 'core_wfc_data' not in inputs: - raise ValidationError( + return ( 'Neither a ``Code`` node for upf2plotcore.sh or a set of ``core_wfc_data`` were provided.' ) if 'core_wfc_data' in inputs: core_wfc_data_list = sorted(inputs['core_wfc_data'].keys()) if core_wfc_data_list != absorbing_elements_list: - raise ValidationError( + return ( f'The ``core_wfc_data`` provided ({core_wfc_data_list}) does not match the list of' f' absorbing elements ({absorbing_elements_list})' ) - else: - empty_core_wfc_data = [] - for key, value in inputs['core_wfc_data'].items(): - header_line = value.get_content()[:40] - try: - num_core_states = int(header_line.split(' ')[5]) - except Exception as exc: - raise ValidationError( - 'The core wavefunction data file is not of the correct format' - ) from exc - if num_core_states == 0: - empty_core_wfc_data.append(key) - if len(empty_core_wfc_data) > 0: - raise ValidationError( - f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain ' - 'any wavefunction data.' - ) + empty_core_wfc_data = [] + for key, value in inputs['core_wfc_data'].items(): + header_line = value.get_content()[:40] + try: + num_core_states = int(header_line.split(' ')[5]) + except: # pylint: disable=bare-except + return ( + 'The core wavefunction data file is not of the correct format' + ) # pylint: enable=bare-except + if num_core_states == 0: + empty_core_wfc_data.append(key) + if len(empty_core_wfc_data) > 0: + return ( + f'The ``core_wfc_data`` provided for elements {empty_core_wfc_data} do not contain ' + 'any wavefunction data.' + ) if 'symmetry_data' in inputs: spacegroup_number = inputs['symmetry_data']['spacegroup_number'].value equivalent_sites_data = inputs['symmetry_data']['equivalent_sites_data'].get_dict() if spacegroup_number <= 0 or spacegroup_number >= 231: - raise ValidationError( + return ( f'Input spacegroup number ({spacegroup_number}) outside of valid range (1-230).' ) @@ -466,25 +464,23 @@ def validate_inputs(inputs, _): elif value['symbol'] not in input_elements: input_elements.append(value['symbol']) if value['site_index'] < 0 or value['site_index'] >= len(structure.sites): - raise ValidationError( + return ( f'The site index for {site_label} ({value["site_index"]}) is outside the range of ' + f'sites within the structure (0-{len(structure.sites) -1}).' ) if len(invalid_entries) != 0: - raise ValidationError( + return ( f'The required keys ({required_keys}) were not found in the following entries: {invalid_entries}' ) sorted_input_elements = sorted(input_elements) if sorted_input_elements != absorbing_elements_list: - raise ValidationError( - f'Elements defined for sites in `equivalent_sites_data` ({sorted_input_elements}) do not match the' - + f' list of absorbing elements ({absorbing_elements_list})' - ) + return (f'Elements defined for sites in `equivalent_sites_data` ({sorted_input_elements}) ' + f'do not match the list of absorbing elements ({absorbing_elements_list})') - # pylint: enable=too-many-statements + # pylint: enable=too-many-return-statements def setup(self): """Set required context variables.""" if 'core_wfc_data' in self.inputs.keys():