diff --git a/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py b/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py index 407e62916..7ae29d2ff 100644 --- a/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py +++ b/src/aiida_quantumespresso/workflows/functions/get_xspectra_structures.py @@ -12,6 +12,8 @@ import numpy as np import spglib +from aiida_quantumespresso.utils.hubbard import HubbardStructureData, HubbardUtils + @calcfunction def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-statements @@ -45,6 +47,11 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st a molecule and not a periodic solid system. Required in order to instruct the CF to use Pymatgen rather than spglib to determine the symmetry. The CF will assume the structure to be a periodic solid if no input is given. + - use_element_types: a Bool object to indicate that symmetry analysis should consider all + sites of the same element to be equal and ignore any special Kind names + from the parent structure. For instance, use_element_types = True would + consider sites for Kinds 'Si' and 'Si1' to be equivalent if both are sites + containing silicon. Defaults to True. - spglib_settings: an optional Dict object containing overrides for the symmetry tolerance parameters used by spglib (symmprec, angle_tolerance). - pymatgen_settings: an optional Dict object containing overrides for the symmetry @@ -89,7 +96,6 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st else: standardize_structure = True if 'absorbing_elements_list' in unwrapped_kwargs.keys(): - elements_defined = True abs_elements_list = unwrapped_kwargs['absorbing_elements_list'].get_list() # confirm that the elements requested are actually in the input structure for req_element in abs_elements_list: @@ -99,8 +105,11 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st f' {elements_present}' ) else: - elements_defined = False - abs_elements_list = [Kind.symbol for Kind in structure.kinds] + abs_elements_list = [] + for kind in structure.kinds: + if kind.symbol not in abs_elements_list: + abs_elements_list.append(kind.symbol) + if 'is_molecule_input' in unwrapped_kwargs.keys(): is_molecule_input = unwrapped_kwargs['is_molecule_input'].value # If we are working with a molecule, check for pymatgen_settings @@ -118,6 +127,21 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st else: spglib_kwargs = {} + if 'use_element_types' in unwrapped_kwargs.keys(): + use_element_types = unwrapped_kwargs['use_element_types'].value + else: + use_element_types = True + + if isinstance(structure, HubbardStructureData): + is_hubbard_structure = True + if standardize_structure: + raise ValidationError( + 'Incoming structure set to be standardized, but hubbard data has been found. ' + 'Please set ``standardize_structure`` to false in ``**kwargs`` to preserve the hubbard data.' + ) + else: + is_hubbard_structure = False + output_params = {} result = {} @@ -203,56 +227,100 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st # Process a periodic system else: incoming_structure_tuple = structure_to_spglib_tuple(structure) - - symmetry_dataset = spglib.get_symmetry_dataset(incoming_structure_tuple[0], **spglib_kwargs) + spglib_tuple = incoming_structure_tuple[0] + types_order = spglib_tuple[-1] + kinds_information = incoming_structure_tuple[1] + kinds_list = incoming_structure_tuple[2] + + # We need a way to reliably convert type number into element, so we + # first create a mapping of assigned number to kind name then a mapping + # of assigned number to ``Kind`` + + type_name_mapping = {str(value): key for key, value in kinds_information.items()} + type_mapping_dict = {} + + for key, value in type_name_mapping.items(): + for kind in kinds_list: + if value == kind.name: + type_mapping_dict[key] = kind + + # By default, `structure_to_spglib_tuple` gives different + # ``Kinds`` of the same element a distinct atomic number by + # multiplying the normal atomic number by 1000, then adding + # 100 for each distinct duplicate. if we want to treat all sites + # of the same element as equal, then we must therefore briefly + # operate on a "cleaned" version of the structure tuple where this + # new label is reduced to its normal element number. + if use_element_types: + cleaned_structure_tuple = (spglib_tuple[0], spglib_tuple[1], []) + for i in spglib_tuple[2]: + if i >= 1000: + new_i = int(np.trunc(i / 1000)) + else: + new_i = i + cleaned_structure_tuple[2].append(new_i) + symmetry_dataset = spglib.get_symmetry_dataset(cleaned_structure_tuple, **spglib_kwargs) + else: + symmetry_dataset = spglib.get_symmetry_dataset(spglib_tuple, **spglib_kwargs) # if there is no symmetry to exploit, or no standardization is desired, then we just use # the input structure in the following steps. This is done to account for the case where # the user has submitted an improper crystal for calculation work and doesn't want it to # be changed. if symmetry_dataset['number'] in [1, 2] or not standardize_structure: - standardized_structure_node = spglib_tuple_to_structure(incoming_structure_tuple[0]) + standardized_structure_node = spglib_tuple_to_structure(spglib_tuple, kinds_information, kinds_list) structure_is_standardized = False else: # otherwise, we proceed with the standardized structure. - standardized_structure_tuple = spglib.standardize_cell(incoming_structure_tuple[0], **spglib_kwargs) - standardized_structure_node = spglib_tuple_to_structure(standardized_structure_tuple) + standardized_structure_tuple = spglib.standardize_cell(spglib_tuple, **spglib_kwargs) + standardized_structure_node = spglib_tuple_to_structure( + standardized_structure_tuple, kinds_information, kinds_list + ) # if we are standardizing the structure, then we need to update the symmetry # information for the standardized structure symmetry_dataset = spglib.get_symmetry_dataset(standardized_structure_tuple, **spglib_kwargs) structure_is_standardized = True equivalent_atoms_array = symmetry_dataset['equivalent_atoms'] - element_types = symmetry_dataset['std_types'] + + if structure_is_standardized: + element_types = symmetry_dataset['std_types'] + else: # convert the 'std_types' from standardized to primitive cell + # we generate the type-specific data on-the-fly since we need to + # know which type (and thus kind) *should* be at each site + # even if we "cleaned" the structure previously + non_cleaned_dataset = spglib.get_symmetry_dataset(spglib_tuple, **spglib_kwargs) + spglib_std_types = non_cleaned_dataset['std_types'] + spglib_map_to_prim = non_cleaned_dataset['mapping_to_primitive'] + spglib_std_map_to_prim = non_cleaned_dataset['std_mapping_to_primitive'] + + map_std_pos_to_type = {} + for position, atom_type in zip(spglib_std_map_to_prim, spglib_std_types): + map_std_pos_to_type[str(position)] = atom_type + primitive_types = [] + for position in spglib_map_to_prim: + atom_type = map_std_pos_to_type[str(position)] + primitive_types.append(atom_type) + element_types = primitive_types equivalency_dict = {} - index_counter = 0 - for symmetry_value, element_type in zip(equivalent_atoms_array, element_types): - if elements_defined: # only process the elements given in the list - if f'site_{symmetry_value}' in equivalency_dict: - equivalency_dict[f'site_{symmetry_value}']['equivalent_sites_list'].append(index_counter) - elif elements[element_type]['symbol'] not in abs_elements_list: - pass - else: - equivalency_dict[f'site_{symmetry_value}'] = { - 'symbol': elements[element_type]['symbol'], - 'site_index': symmetry_value, - 'equivalent_sites_list': [symmetry_value] - } - else: # process everything in the system + for index, symmetry_types in enumerate(zip(equivalent_atoms_array, element_types)): + symmetry_value, element_type = symmetry_types + if type_mapping_dict[str(element_type)].symbol in abs_elements_list: if f'site_{symmetry_value}' in equivalency_dict: - equivalency_dict[f'site_{symmetry_value}']['equivalent_sites_list'].append(index_counter) + equivalency_dict[f'site_{symmetry_value}']['equivalent_sites_list'].append(index) else: equivalency_dict[f'site_{symmetry_value}'] = { - 'symbol': elements[element_type]['symbol'], + 'kind_name': type_mapping_dict[str(element_type)].name, + 'symbol': type_mapping_dict[str(element_type)].symbol, 'site_index': symmetry_value, 'equivalent_sites_list': [symmetry_value] } - index_counter += 1 for value in equivalency_dict.values(): value['multiplicity'] = len(value['equivalent_sites_list']) output_params['equivalent_sites_data'] = equivalency_dict + output_params['spacegroup_number'] = symmetry_dataset['number'] output_params['international_symbol'] = symmetry_dataset['international'] @@ -274,9 +342,42 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st ase_structure = standardized_structure_node.get_ase() ase_supercell = ase_structure * multiples - new_supercell = StructureData(ase=ase_supercell) - result['supercell'] = new_supercell + # if there are hubbard data to apply, we re-construct + # the supercell site-by-site to keep the correct ordering + if is_hubbard_structure: + blank_supercell = StructureData(ase=ase_supercell) + new_supercell = StructureData() + new_supercell.set_cell(blank_supercell.cell) + num_extensions = np.product(multiples) + supercell_types_order = [] + # For each supercell extension, loop over each site. + # This way, the original pattern-ordering of sites is + # preserved. + for i in range(0, num_extensions): # pylint: disable=unused-variable + for type_number in types_order: + supercell_types_order.append(type_number) + + for site, type_number in zip(blank_supercell.sites, supercell_types_order): + kind_present = type_mapping_dict[str(type_number)] + if kind_present.name not in [kind.name for kind in new_supercell.kinds]: + new_supercell.append_kind(kind_present) + new_site = Site(kind_name=kind_present.name, position=site.position) + new_supercell.append_site(new_site) + else: # otherwise, simply re-construct the supercell with ASE + new_supercell = StructureData(ase=ase_supercell) + + if is_hubbard_structure: # Scale up the hubbard parameters to match and return the HubbardStructureData + # we can exploit the fact that `get_hubbard_for_supercell` will return a HubbardStructureData node + # with the same hubbard parameters in the case where the input structure and the supercell are the + # same (i.e. multiples == [1, 1, 1]) + hubbard_manip = HubbardUtils(structure) + new_hubbard_supercell = hubbard_manip.get_hubbard_for_supercell(new_supercell) + new_supercell = new_hubbard_supercell + supercell_hubbard_params = new_supercell.hubbard + result['supercell'] = new_supercell + else: + result['supercell'] = new_supercell output_params['supercell_factors'] = multiples output_params['supercell_num_sites'] = len(new_supercell.sites) output_params['supercell_cell_matrix'] = new_supercell.cell @@ -285,22 +386,28 @@ def get_xspectra_structures(structure, **kwargs): # pylint: disable=too-many-st for value in equivalency_dict.values(): target_site = value['site_index'] marked_structure = StructureData() - supercell_kinds = {kind.name: kind for kind in new_supercell.kinds} marked_structure.set_cell(new_supercell.cell) + new_kind_names = [kind.name for kind in new_supercell.kinds] for index, site in enumerate(new_supercell.sites): + kind_at_position = new_supercell.kinds[new_kind_names.index(site.kind_name)] if index == target_site: - absorbing_kind = Kind(name=abs_atom_marker, symbols=site.kind_name) + absorbing_kind = Kind(name=abs_atom_marker, symbols=kind_at_position.symbol) absorbing_site = Site(kind_name=absorbing_kind.name, position=site.position) marked_structure.append_kind(absorbing_kind) marked_structure.append_site(absorbing_site) else: - if site.kind_name not in [kind.name for kind in marked_structure.kinds]: - marked_structure.append_kind(supercell_kinds[site.kind_name]) + if kind_at_position.name not in [kind.name for kind in marked_structure.kinds]: + marked_structure.append_kind(kind_at_position) new_site = Site(kind_name=site.kind_name, position=site.position) marked_structure.append_site(new_site) - result[f'site_{target_site}_{value["symbol"]}'] = marked_structure + if is_hubbard_structure: + marked_hubbard_structure = HubbardStructureData.from_structure(marked_structure) + marked_hubbard_structure.hubbard = supercell_hubbard_params + result[f'site_{target_site}_{value["symbol"]}'] = marked_hubbard_structure + else: + result[f'site_{target_site}_{value["symbol"]}'] = marked_structure output_params['is_molecule_input'] = is_molecule_input result['output_parameters'] = orm.Dict(dict=output_params) diff --git a/src/aiida_quantumespresso/workflows/xspectra/core.py b/src/aiida_quantumespresso/workflows/xspectra/core.py index b21e5d24a..377ea378c 100644 --- a/src/aiida_quantumespresso/workflows/xspectra/core.py +++ b/src/aiida_quantumespresso/workflows/xspectra/core.py @@ -10,7 +10,7 @@ from aiida.common import AttributeDict from aiida.engine import ToContext, WorkChain, append_, if_ from aiida.orm.nodes.data.base import to_aiida_type -from aiida.plugins import CalculationFactory, WorkflowFactory +from aiida.plugins import CalculationFactory, DataFactory, WorkflowFactory import yaml from aiida_quantumespresso.calculations.functions.xspectra.get_powder_spectrum import get_powder_spectrum @@ -21,6 +21,7 @@ PwCalculation = CalculationFactory('quantumespresso.pw') PwBaseWorkChain = WorkflowFactory('quantumespresso.pw.base') XspectraBaseWorkChain = WorkflowFactory('quantumespresso.xspectra.base') +HubbardStructureData = DataFactory('quantumespresso.hubbard_structure') class XspectraCoreWorkChain(ProtocolMixin, WorkChain): @@ -101,7 +102,7 @@ def define(cls, spec): spec.inputs.validator = cls.validate_inputs spec.input( 'structure', - valid_type=orm.StructureData, + valid_type=(orm.StructureData, HubbardStructureData), help=( 'Structure to be used for calculation, with at least one site containing the `abs_atom_marker` ' 'as the kind label.' @@ -348,8 +349,8 @@ def get_builder_from_protocol( ) pw_inputs['pw']['parameters'] = pw_params - pw_args = (pw_code, structure, protocol) + scf = PwBaseWorkChain.get_builder_from_protocol(*pw_args, overrides=pw_inputs, options=options, **kwargs) scf.pop('clean_workdir', None) @@ -368,12 +369,16 @@ def get_builder_from_protocol( abs_atom_marker = inputs['abs_atom_marker'] xs_prod_parameters['INPUT_XSPECTRA']['xiabs'] = kinds_present.index(abs_atom_marker) + 1 if core_hole_pseudos: + abs_element_kinds = [] for kind in structure.kinds: if kind.name == abs_atom_marker: abs_element = kind.symbol - + for kind in structure.kinds: # run a second pass to check for multiple kinds of the same absorbing element + if kind.symbol == abs_element and kind.name != abs_atom_marker: + abs_element_kinds.append(kind.name) builder.scf.pw.pseudos[abs_atom_marker] = core_hole_pseudos[abs_atom_marker] - builder.scf.pw.pseudos[abs_element] = core_hole_pseudos[abs_element] + for kind_name in abs_element_kinds: + builder.scf.pw.pseudos[kind_name] = core_hole_pseudos[abs_element] builder.xs_prod.xspectra.code = xs_code builder.xs_prod.xspectra.parameters = orm.Dict(xs_prod_parameters) diff --git a/src/aiida_quantumespresso/workflows/xspectra/crystal.py b/src/aiida_quantumespresso/workflows/xspectra/crystal.py index 789d6ba77..1d56a50f9 100644 --- a/src/aiida_quantumespresso/workflows/xspectra/crystal.py +++ b/src/aiida_quantumespresso/workflows/xspectra/crystal.py @@ -11,6 +11,7 @@ from aiida_pseudo.data.pseudo import UpfData as aiida_pseudo_upf from aiida_quantumespresso.calculations.functions.xspectra.get_spectra_by_element import get_spectra_by_element +from aiida_quantumespresso.utils.hubbard import HubbardStructureData from aiida_quantumespresso.utils.mapping import prepare_process_inputs from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin, recursive_merge @@ -481,6 +482,10 @@ def get_xspectra_structures(self): for key, node in optional_cell_prep.items(): inputs[key] = node + if isinstance(self.inputs.structure, HubbardStructureData): + # This must be False in the case of HubbardStructureData, otherwise get_xspectra_structures will except + inputs['standardize_structure'] = orm.Bool(False) + if 'spglib_settings' in self.inputs: inputs['spglib_settings'] = self.inputs.spglib_settings @@ -527,8 +532,16 @@ def run_upf2plotcore(self): shell_inputs['code'] = self.inputs.upf2plotcore_code shell_inputs['nodes'] = {'upf': upf} - shell_inputs['arguments'] = ['upf'] - shell_inputs['metadata'] = {'call_link_label': f'upf2plotcore_{element}'} + shell_inputs['metadata'] = { + 'call_link_label': f'upf2plotcore_{element}', + 'options' : { + 'filename_stdin' : upf.filename, + 'resources' : { + 'num_machines' : 1, + 'num_mpiprocs_per_machine' : 1 + } + } + } future_shelljob = self.submit(ShellJob, **shell_inputs) self.report(f'Launching upf2plotcore.sh for {element}<{future_shelljob.pk}>') @@ -553,7 +566,7 @@ def inspect_upf2plotcore(self): if num_core_states == 0: return self.exit_codes.ERROR_NO_GIPAW_INFO_FOUND - def run_all_xspectra_core(self): + def run_all_xspectra_core(self): # pylint: disable=too-many-statements """Call all XspectraCoreWorkChains required to compute all requested spectra.""" structures_to_process = self.ctx.structures_to_process @@ -566,6 +579,7 @@ def run_all_xspectra_core(self): structure = structures_to_process[site] inputs.structure = structure abs_element = equivalent_sites_data[site]['symbol'] + abs_atom_kind = equivalent_sites_data[site]['kind_name'] if 'core_hole_treatments' in self.inputs: ch_treatments = self.inputs.core_hole_treatments.get_dict() @@ -585,7 +599,7 @@ def run_all_xspectra_core(self): scf_inputs = inputs.scf.pw scf_params = scf_inputs.parameters.get_dict() ch_inputs = XspectraCoreWorkChain.get_treatment_inputs(treatment=ch_treatment) - new_scf_params = recursive_merge(left=scf_params, right=ch_inputs) + new_scf_params = recursive_merge(left=ch_inputs, right=scf_params) # Set the absorbing species index (`xiabs`) for the xspectra.x input. new_xs_params = inputs.xs_prod.xspectra.parameters.get_dict() @@ -593,21 +607,45 @@ def run_all_xspectra_core(self): abs_species_index = kinds_present.index(abs_atom_marker) + 1 new_xs_params['INPUT_XSPECTRA']['xiabs'] = abs_species_index - # Set `starting_magnetization` if we are using an XCH approximation, using the - # absorbing species as a reasonable place for the unpaired electron. - # As a future note, we need to re-visit the core-hole treatment settings, in order - # to avoid the need for fudges like these. - if ch_treatment == 'xch_smear': - new_scf_params['SYSTEM'][f'starting_magnetization({abs_species_index})'] = 1 + # Set `starting_magnetization` if we are using an XCH approximation, using + # the absorbing species as a reasonable place for the unpaired electron. + # Alternatively, ensure the starting magnetic moment is a reasonable guess + # given the input parameters. (e.g. it conforms to an existing magnetic + # structure already defined for the system) + + # TODO: we need to re-visit the core-hole treatment settings, + # in order to avoid the need for fudges like these and set these at + # submission rather than inside the WorkChain itself. + if 'starting_magnetization' in new_scf_params['SYSTEM']: + inherited_mag = new_scf_params['SYSTEM']['starting_magnetization'][abs_atom_kind] + if ch_treatment not in ['xch_smear', 'xch_fixed']: + new_scf_params['SYSTEM']['starting_magnetization'][abs_atom_marker] = inherited_mag + else: # if there is meant to be an unpaired electron, give it to the absorbing atom. + if inherited_mag == 0: # set it to 1, if it would be neutral in the ground-state. + new_scf_params['SYSTEM']['starting_magnetization'][abs_atom_marker] = 1 + else: # assume that it takes the same magnetic configuration as the kind that it replaces. + new_scf_params['SYSTEM']['starting_magnetization'][abs_atom_marker] = inherited_mag + elif ch_treatment in ['xch_smear', 'xch_fixed']: + new_scf_params['SYSTEM']['starting_magnetization'] = {abs_atom_marker : 1} + + # remove any duplicates created from the "core_hole_treatments.yaml" defaults + for key in new_scf_params['SYSTEM'].keys(): + if 'starting_magnetization(' in key: + new_scf_params['SYSTEM'].pop(key, None) core_hole_pseudo = self.inputs.core_hole_pseudos[abs_element] + gipaw_pseudo = self.inputs.gipaw_pseudos[abs_element] inputs.scf.pw.pseudos[abs_atom_marker] = core_hole_pseudo - # In the case where the absorbing atom is the only one of its element in the - # structure, we avoid setting the GIPAW pseudo for it and remove the one . - if abs_element in kinds_present: - gipaw_pseudo = self.inputs.gipaw_pseudos[abs_element] - inputs.scf.pw.pseudos[abs_element] = gipaw_pseudo - else: + # Check how many instances of the absorbing element are present and assign + # each the GIPAW pseudo if they are not the absorbing atom itself. + abs_element_kinds = [] + for kind in structure.kinds: + if kind.symbol == abs_element and kind.name != abs_atom_marker: + abs_element_kinds.append(kind.name) + if len(abs_element_kinds) > 0: + for kind_name in abs_element_kinds: + scf_inputs['pseudos'][kind_name] = gipaw_pseudo + else: # if there is only one atom of the absorbing element, pop the GIPAW pseudo to avoid a crash scf_inputs['pseudos'].pop(abs_element, None) scf_inputs.parameters = orm.Dict(new_scf_params) @@ -620,7 +658,7 @@ def run_all_xspectra_core(self): xspectra_core_workchains[site] = future self.report(f'launched XspectraCoreWorkChain for {site}<{future.pk}>') - return ToContext(**xspectra_core_workchains) + return ToContext(**xspectra_core_workchains) # pylint: enable=too-many-statements def inspect_all_xspectra_core(self): """Check that all the XspectraCoreWorkChain sub-processes finished sucessfully.""" diff --git a/tests/workflows/functions/test_get_xspectra_structures.py b/tests/workflows/functions/test_get_xspectra_structures.py new file mode 100644 index 000000000..e192bbc44 --- /dev/null +++ b/tests/workflows/functions/test_get_xspectra_structures.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +"""Tests for the `get_marked_structure` class.""" +from aiida.orm import Bool, Dict +import pytest + +from aiida_quantumespresso.utils.hubbard import HubbardStructureData, HubbardUtils +from aiida_quantumespresso.workflows.functions.get_xspectra_structures import get_xspectra_structures + + +@pytest.fixture +def generate_hubbard(): + """Return a `Hubbard` instance.""" + + def _generate_hubbard(): + from aiida_quantumespresso.common.hubbard import Hubbard + return Hubbard.from_list([(0, '1s', 0, '1s', 5.0, (0, 0, 0), 'Ueff')]) + + return _generate_hubbard + + +@pytest.fixture +def generate_hubbard_structure(generate_structure): + """Return a `HubbardStructureData` instance.""" + + def _generate_hubbard_structure(): + from aiida_quantumespresso.common.hubbard import Hubbard + structure = generate_structure('silicon-kinds') + hp_list = [(0, '1s', 0, '1s', 5.0, (0, 0, 0), 'Ueff')] + hubbard = Hubbard.from_list(hp_list) + return HubbardStructureData.from_structure(structure=structure, hubbard=hubbard) + + return _generate_hubbard_structure + + +def test_base(generate_structure): + """Test the basic operation of get_xspectra_structures.""" + + c_si = generate_structure('silicon') + spglib_options = Dict({'symprec': 1.0e-3}) + inputs = {'structure': c_si, 'spglib_options': spglib_options} + result = get_xspectra_structures(**inputs) + assert len(result) == 4 + out_params = result['output_parameters'].get_dict() + assert out_params['spacegroup_number'] == 227 + assert out_params['supercell_num_sites'] == 64 + assert len(out_params['equivalent_sites_data']) == 1 + + +def test_use_element_types(generate_structure): + """Test the CF's `use_element_types` flag.""" + + c_si = generate_structure('silicon') + c_si_kinds = generate_structure('silicon-kinds') + spglib_options = Dict({'symprec': 1.0e-3}) + inputs_bare = {'structure': c_si, 'spglib_options': spglib_options} + inputs_kinds = { + 'structure': c_si_kinds, + 'use_element_types': Bool(False), + 'spglib_options': spglib_options, + 'standardize_structure': Bool(False) + } + + result_bare = get_xspectra_structures(**inputs_bare) + result_kinds = get_xspectra_structures(**inputs_kinds) + + inputs_element_types = { + 'structure': c_si_kinds, + 'spglib_options': spglib_options, + 'standardize_structure': Bool(False), + } + result_element_types = get_xspectra_structures(**inputs_element_types) + + assert 'site_1_Si' in result_kinds + assert 'site_1_Si' not in result_element_types + assert 'site_1_Si' not in result_bare + + +def test_hubbard(generate_structure): + """Test that the CalcFunction will pass Hubbard parameters to the output structures. + + The intent here is to confirm that simply using the `initialize_` methods to get + hubbard parameters will propogate to the resulting supercells and (crucially) + generate the correct hubbard card. + """ + + c_si = generate_structure() + c_si_hub = HubbardStructureData.from_structure(c_si) + c_si_hub.initialize_onsites_hubbard('Si', '1s', 0.0, 'Ueff', False) + + inputs = {'structure': c_si_hub, 'standardize_structure': Bool(False)} + result = get_xspectra_structures(**inputs) + + marked = result['site_0_Si'] + utils_marked = HubbardUtils(marked) + hub_card_lines = [i.strip() for i in utils_marked.get_hubbard_card().splitlines()] + out_params = result['output_parameters'].get_dict() + + assert out_params['supercell_num_sites'] == 54 + assert 'U\tX-1s\t0.0' in hub_card_lines + assert 'U\tSi-1s\t0.0' in hub_card_lines