From 78b3a407e5a70c01dde939bf9ee799c3eb5462f9 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Fri, 21 Jun 2024 07:01:57 +0200 Subject: [PATCH 1/3] Clean up exceptions and split FileFormatError in LoadError and DumpError --- iodata/api.py | 91 ++++++++++++++++++++++++++----- iodata/formats/fchk.py | 6 +-- iodata/formats/gamess.py | 2 +- iodata/formats/gaussianinput.py | 2 +- iodata/formats/gromacs.py | 7 +-- iodata/formats/json.py | 96 ++++++++++++++++----------------- iodata/formats/mol2.py | 6 +-- iodata/formats/molden.py | 6 +-- iodata/formats/molekel.py | 6 +-- iodata/formats/mwfn.py | 18 +++---- iodata/formats/pdb.py | 8 +-- iodata/test/common.py | 6 +-- iodata/test/test_api.py | 4 +- iodata/test/test_cp2klog.py | 11 ++-- iodata/test/test_inputs.py | 4 +- iodata/test/test_json.py | 32 +++++------ iodata/test/test_mol2.py | 4 +- iodata/test/test_molden.py | 30 +++++------ iodata/test/test_pdb.py | 10 ++-- iodata/test/test_sdf.py | 6 +-- iodata/test/test_wfx.py | 8 +-- iodata/utils.py | 37 +++++++++---- 22 files changed, 239 insertions(+), 161 deletions(-) diff --git a/iodata/api.py b/iodata/api.py index 29529159..ad4bf840 100644 --- a/iodata/api.py +++ b/iodata/api.py @@ -27,7 +27,14 @@ from typing import Callable, Optional from .iodata import IOData -from .utils import FileFormatError, LineIterator, PrepareDumpError +from .utils import ( + DumpError, + FileFormatError, + LineIterator, + LoadError, + PrepareDumpError, + WriteInputError, +) __all__ = ["load_one", "load_many", "dump_one", "dump_many", "write_input"] @@ -54,7 +61,7 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non filename The file to load or dump. attrname - The required atrtibute of the file format module. + The required attribute of the file format module. fmt The name of the file format module to use. When not given, it is guessed from the filename. @@ -63,6 +70,10 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non ------- The module implementing the required file format. + Raises + ------ + FileFormatError + When no file format module can be found that has a member named ``attrname``. """ basename = os.path.basename(filename) if fmt is None: @@ -73,7 +84,7 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non return format_module else: return FORMAT_MODULES[fmt] - raise ValueError(f"Could not find file format with feature {attrname} for file {filename}") + raise FileFormatError(f"Could not find file format with feature {attrname} for file {filename}") def _find_input_modules(): @@ -102,12 +113,17 @@ def _select_input_module(fmt: str) -> ModuleType: ------- The module implementing the required input format. + + Raises + ------ + FileFormatError + When the format ``fmt`` does not exist. """ if fmt in INPUT_MODULES: if not hasattr(INPUT_MODULES[fmt], "write_input"): - raise ValueError(f"{fmt} input module does not have write_input!") + raise FileFormatError(f"{fmt} input module does not have write_input.") return INPUT_MODULES[fmt] - raise ValueError(f"Could not find input format {fmt}!") + raise FileFormatError(f"Could not find input format {fmt}.") def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData: @@ -136,8 +152,12 @@ def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData: with LineIterator(filename) as lit: try: iodata = IOData(**format_module.load_one(lit, **kwargs)) + except LoadError: + raise except StopIteration: lit.error("File ended before all data was read.") + except Exception as exc: + raise LoadError(f"{filename}: Uncaught exception while loading file.") from exc return iodata @@ -171,6 +191,10 @@ def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IO yield IOData(**data) except StopIteration: return + except LoadError: + raise + except Exception as exc: + raise LoadError(f"{filename}: Uncaught exception while loading file.") from exc def _check_required(iodata: IOData, dump_func: Callable): @@ -216,17 +240,33 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs) Raises ------ + DumpError + When an error is encountered while dumping to a file. + If the output file already existed, it (partially) overwritten. PrepareDumpError When the iodata object is not compatible with the file format, e.g. due to missing attributes, and not conversion is available or allowed to make it compatible. + If the output file already existed, it is not overwritten. """ format_module = _select_format_module(filename, "dump_one", fmt) - _check_required(iodata, format_module.dump_one) - if hasattr(format_module, "prepare_dump"): - format_module.prepare_dump(iodata) + try: + _check_required(iodata, format_module.dump_one) + if hasattr(format_module, "prepare_dump"): + format_module.prepare_dump(iodata) + except PrepareDumpError: + raise + except Exception as exc: + raise PrepareDumpError( + f"{filename}: Uncaught exception while preparing for dumping to a file" + ) from exc with open(filename, "w") as f: - format_module.dump_one(f, iodata, **kwargs) + try: + format_module.dump_one(f, iodata, **kwargs) + except DumpError: + raise + except Exception as exc: + raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = None, **kwargs): @@ -249,10 +289,16 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non Raises ------ + DumpError + When an error is encountered while dumping to a file. + If the output file already existed, it (partially) overwritten. PrepareDumpError When the iodata object is not compatible with the file format, e.g. due to missing attributes, and not conversion is available or allowed to make it compatible. + If the output file already existed, it is not overwritten when this error + is raised while processing the first IOData instance in the ``iodatas`` argument. + When the exception is raised in later iterations, any existing file is overwritten. """ format_module = _select_format_module(filename, "dump_many", fmt) @@ -262,9 +308,18 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non iter_iodatas = iter(iodatas) try: first = next(iter_iodatas) - _check_required(first, format_module.dump_many) except StopIteration as exc: - raise FileFormatError("dump_many needs at least one iodata object.") from exc + raise DumpError(f"{filename}: dump_many needs at least one iodata object.") from exc + try: + _check_required(first, format_module.dump_many) + if hasattr(format_module, "prepare_dump"): + format_module.prepare_dump(first) + except PrepareDumpError: + raise + except Exception as exc: + raise PrepareDumpError( + f"{filename}: Uncaught exception while preparing for dumping to a file" + ) from exc def checking_iterator(): """Iterate over all iodata items, not checking the first.""" @@ -277,7 +332,12 @@ def checking_iterator(): yield other with open(filename, "w") as f: - format_module.dump_many(f, checking_iterator(), **kwargs) + try: + format_module.dump_many(f, checking_iterator(), **kwargs) + except (PrepareDumpError, DumpError): + raise + except Exception as exc: + raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc def write_input( @@ -312,4 +372,9 @@ def write_input( """ input_module = _select_input_module(fmt) with open(filename, "w") as fh: - input_module.write_input(fh, iodata, template, atom_line, **kwargs) + try: + input_module.write_input(fh, iodata, template, atom_line, **kwargs) + except Exception as exc: + raise WriteInputError( + f"{filename}: Uncaught exception while writing an input file" + ) from exc diff --git a/iodata/formats/fchk.py b/iodata/formats/fchk.py index 96baa53d..63c70338 100644 --- a/iodata/formats/fchk.py +++ b/iodata/formats/fchk.py @@ -29,7 +29,7 @@ from ..docstrings import document_dump_one, document_load_many, document_load_one from ..iodata import IOData from ..orbitals import MolecularOrbitals -from ..utils import LineIterator, PrepareDumpError, amu +from ..utils import DumpError, LineIterator, PrepareDumpError, amu __all__ = [] @@ -221,7 +221,7 @@ def load_one(lit: LineIterator) -> dict: if nalpha < 0 or nbeta < 0 or nalpha + nbeta <= 0: lit.error("The number of electrons is not positive.") if nalpha < nbeta: - raise ValueError(f"n_alpha={nalpha} < n_beta={nbeta} is not valid!") + lit.error(f"n_alpha={nalpha} < n_beta={nbeta} is not valid!") norba = fchk["Alpha Orbital Energies"].shape[0] mo_coeffs = np.copy(fchk["Alpha MO coefficients"].reshape(norba, nbasis).T) @@ -643,7 +643,7 @@ def dump_one(f: TextIO, data: IOData): elif shell.ncon == 2 and shell.angmoms == [0, 1]: shell_types.append(-1) else: - raise ValueError("Cannot identify type of shell!") + raise DumpError("Cannot identify type of shell!") num_pure_d_shells = sum([1 for st in shell_types if st == 2]) num_pure_f_shells = sum([1 for st in shell_types if st == 3]) diff --git a/iodata/formats/gamess.py b/iodata/formats/gamess.py index d87b0d41..820da5e7 100644 --- a/iodata/formats/gamess.py +++ b/iodata/formats/gamess.py @@ -37,7 +37,7 @@ def _read_data(lit: LineIterator) -> tuple[str, str, list[str]]: # The dat file only contains symmetry-unique atoms, so we would be incapable of # supporting non-C1 symmetry without significant additional coding. if symmetry != "C1": - raise NotImplementedError(f"Only C1 symmetry is supported. Got {symmetry}") + lit.error(f"Only C1 symmetry is supported. Got {symmetry}") symbols = [] line = True while line != " $END \n": diff --git a/iodata/formats/gaussianinput.py b/iodata/formats/gaussianinput.py index 29cc7a26..e2200e6a 100644 --- a/iodata/formats/gaussianinput.py +++ b/iodata/formats/gaussianinput.py @@ -68,7 +68,7 @@ def load_one(lit: LineIterator): if not contents: break if len(contents) != 4: - raise ValueError("No Cartesian Structure is detected") + lit.error("No Cartesian Structure is detected") numbers.append(sym2num[contents[0]]) coor = list(map(float, contents[1:])) coordinates.append(coor) diff --git a/iodata/formats/gromacs.py b/iodata/formats/gromacs.py index f41905a2..d7080bf4 100644 --- a/iodata/formats/gromacs.py +++ b/iodata/formats/gromacs.py @@ -42,10 +42,7 @@ def load_one(lit: LineIterator) -> dict: """Do not edit this docstring. It will be overwritten.""" while True: - try: - data = _helper_read_frame(lit) - except StopIteration: - break + data = _helper_read_frame(lit) title = data[0] time = data[1] resnums = np.array(data[2]) @@ -75,7 +72,7 @@ def load_many(lit: LineIterator) -> Iterator[dict]: try: while True: yield load_one(lit) - except OSError: + except StopIteration: return diff --git a/iodata/formats/json.py b/iodata/formats/json.py index 55539bb9..c3b03447 100644 --- a/iodata/formats/json.py +++ b/iodata/formats/json.py @@ -571,7 +571,7 @@ from ..docstrings import document_dump_one, document_load_one from ..iodata import IOData from ..periodic import num2sym, sym2num -from ..utils import FileFormatError, FileFormatWarning, LineIterator, PrepareDumpError +from ..utils import DumpError, LineIterator, LoadError, LoadWarning, PrepareDumpError __all__ = [] @@ -642,7 +642,7 @@ def _parse_json(json_in: dict, lit: LineIterator) -> dict: warn( f"{lit.filename}: QCSchema files should have a `schema_name` key." "Attempting to determine schema type...", - FileFormatWarning, + LoadWarning, stacklevel=2, ) # Geometry is required in any molecule schema @@ -650,7 +650,7 @@ def _parse_json(json_in: dict, lit: LineIterator) -> dict: schema_name = "qcschema_molecule" # Check if BSE file, which is too different elif "molssi_bse_schema" in result: - raise FileFormatError( + raise LoadError( f"{lit.filename}: IOData does not currently support MolSSI BSE Basis JSON." ) # Center_data is required in any basis schema @@ -659,12 +659,12 @@ def _parse_json(json_in: dict, lit: LineIterator) -> dict: elif "driver" in result: schema_name = "qcschema_output" if "return_result" in result else "qcschema_input" else: - raise FileFormatError(f"{lit.filename}: Could not determine `schema_name`.") + raise LoadError(f"{lit.filename}: Could not determine `schema_name`.") if "schema_version" not in result: warn( f"{lit.filename}: QCSchema files should have a `schema_version` key." "Attempting to load without version number.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) @@ -676,9 +676,9 @@ def _parse_json(json_in: dict, lit: LineIterator) -> dict: return _load_qcschema_input(result, lit) if schema_name == "qcschema_output": return _load_qcschema_output(result, lit) - raise FileFormatError( - "{}: Invalid QCSchema type {}, should be one of `qcschema_molecule`, `qcschema_basis`," - "`qcschema_input`, or `qcschema_output".format(lit.filename, result["schema_name"]) + raise LoadError( + f"{lit.filename}: Invalid QCSchema type {result['schema_name']}, should be one of " + "`qcschema_molecule`, `qcschema_basis`, `qcschema_input`, or `qcschema_output`." ) @@ -754,12 +754,12 @@ def _parse_topology_keys(mol: dict, lit: LineIterator) -> dict: if key not in mol: warn( f"{lit.filename}: QCSchema files should have a '{key}' key.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) for key in topology_keys: if key not in mol: - raise FileFormatError(f"{lit.filename}: QCSchema topology requires '{key}' key") + raise LoadError(f"{lit.filename}: QCSchema topology requires '{key}' key") topology_dict = {} extra_dict = {} @@ -780,7 +780,7 @@ def _parse_topology_keys(mol: dict, lit: LineIterator) -> dict: "{}: Missing 'molecular_charge' key." "Some QCSchema writers omit this key for default value 0.0," "Ensure this value is correct.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) formal_charge = 0.0 @@ -795,7 +795,7 @@ def _parse_topology_keys(mol: dict, lit: LineIterator) -> dict: "{}: Missing 'molecular_multiplicity' key." "Some QCSchema writers omit this key for default value 1," "Ensure this value is correct.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) topology_dict["spinpol"] = 0 @@ -818,7 +818,7 @@ def _parse_topology_keys(mol: dict, lit: LineIterator) -> dict: warn( "{}: Both `masses` and `mass_numbers` given. " "Both values will be written to `extra` dict.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) extra_dict["mass_numbers"] = np.array(mol["mass_numbers"]) @@ -931,7 +931,7 @@ def _version_check(result: dict, max_version: float, schema_name: str, lit: Line warn( f"{lit.filename}: Unknown {schema_name} version {version}, " "loading may produce invalid results", - FileFormatWarning, + LoadWarning, stacklevel=2, ) return version @@ -1038,7 +1038,7 @@ def _load_qcschema_input(result: dict, lit: LineIterator) -> dict: extra_dict["input"] = input_dict["extra"] if "molecule" not in result: - raise FileFormatError(f"{lit.filename}: QCSchema Input requires 'molecule' key") + raise LoadError(f"{lit.filename}: QCSchema Input requires 'molecule' key") molecule_dict = _parse_topology_keys(result["molecule"], lit) input_dict.update(molecule_dict) extra_dict["molecule"] = molecule_dict["extra"] @@ -1073,14 +1073,12 @@ def _parse_input_keys(result: dict, lit: LineIterator) -> dict: if key not in result: warn( f"{lit.filename}: QCSchema files should have a '{key}' key.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) for key in input_keys: if key not in result: - raise FileFormatError( - f"{lit.filename}: QCSchema `qcschema_input` file requires '{key}' key" - ) + raise LoadError(f"{lit.filename}: QCSchema `qcschema_input` file requires '{key}' key") # Store all extra keys in extra_dict and gather at end input_dict = {} extra_dict = {} @@ -1164,7 +1162,7 @@ def _parse_driver(driver: str, lit: LineIterator) -> str: Raises ------ - FileFormatError + LoadError If driver is not one of {"energy", "gradient", "hessian", "properties"}. Notes @@ -1174,7 +1172,7 @@ def _parse_driver(driver: str, lit: LineIterator) -> str: """ if driver not in ["energy", "gradient", "hessian", "properties"]: - raise FileFormatError( + raise LoadError( f"{lit.filename}: QCSchema driver must be one of `energy`, `gradient`, `hessian`, " "or `properties`" ) @@ -1202,7 +1200,7 @@ def _parse_model(model: dict, lit: LineIterator) -> dict: extra_dict = {} if "method" not in model: - raise FileFormatError(f"{lit.filename}: QCSchema `model` requires a `method`") + raise LoadError(f"{lit.filename}: QCSchema `model` requires a `method`") model_dict["lot"] = model["method"] # QCEngineRecords doesn't give an empty string for basis-free methods, omits req'd key instead if "basis" not in model: @@ -1215,7 +1213,7 @@ def _parse_model(model: dict, lit: LineIterator) -> dict: warn( f"{lit.filename}: QCSchema `basis` could not be read and will be omitted." "Unless model is for a basis-free method, check input file.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) else: @@ -1249,7 +1247,7 @@ def _parse_protocols(protocols: dict, lit: LineIterator) -> dict: if "wavefunction" not in protocols: warn( "{}: Protocols `wavefunction` key not specified, no properties will be kept.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) wavefunction = "none" @@ -1258,7 +1256,7 @@ def _parse_protocols(protocols: dict, lit: LineIterator) -> dict: if "stdout" not in protocols: warn( "{}: Protocols `stdout` key not specified, stdout will be kept.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) keep_stdout = True @@ -1266,10 +1264,10 @@ def _parse_protocols(protocols: dict, lit: LineIterator) -> dict: keep_stdout = protocols["stdout"] protocols_dict = {} if wavefunction not in {"all", "orbitals_and_eigenvalues", "return_results", "none"}: - raise FileFormatError(f"{lit.filename}: Invalid `protocols` `wavefunction` keyword.") + raise LoadError(f"{lit.filename}: Invalid `protocols` `wavefunction` keyword.") protocols_dict["keep_wavefunction"] = wavefunction if not isinstance(keep_stdout, bool): - raise FileFormatError("{}: `protocols` `stdout` option must be a boolean.") + raise LoadError("{}: `protocols` `stdout` option must be a boolean.") protocols_dict["keep_stdout"] = keep_stdout return protocols_dict @@ -1298,7 +1296,7 @@ def _load_qcschema_output(result: dict, lit: LineIterator) -> dict: extra_dict["output"] = output_dict["extra"] if "molecule" not in result: - raise FileFormatError(f"{lit.filename}: QCSchema Input requires 'molecule' key") + raise LoadError(f"{lit.filename}: QCSchema Input requires 'molecule' key") molecule_dict = _parse_topology_keys(result["molecule"], lit) output_dict.update(molecule_dict) extra_dict["molecule"] = molecule_dict["extra"] @@ -1335,14 +1333,12 @@ def _parse_output_keys(result: dict, lit: LineIterator) -> dict: if key not in result: warn( f"{lit.filename}: QCSchema files should have a '{key}' key.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) for key in output_keys: if key not in result: - raise FileFormatError( - f"{lit.filename}: QCSchema `qcschema_output` file requires '{key}' key" - ) + raise LoadError(f"{lit.filename}: QCSchema `qcschema_output` file requires '{key}' key") # Store all extra keys in extra_dict and gather at end output_dict = {} @@ -1417,7 +1413,7 @@ def _parse_provenance( """ if isinstance(provenance, dict): if "creator" not in provenance: - raise FileFormatError(f"{lit.filename}: `{source}` provenance requires `creator` key") + raise LoadError(f"{lit.filename}: `{source}` provenance requires `creator` key") if append: base_provenance = [provenance] else: @@ -1425,10 +1421,10 @@ def _parse_provenance( elif isinstance(provenance, list): for prov in provenance: if "creator" not in prov: - raise FileFormatError("{}: `{}` provenance requires `creator` key") + raise LoadError("{}: `{}` provenance requires `creator` key") base_provenance = provenance else: - raise FileFormatError(f"{lit.filename}: Invalid `{source}` provenance type") + raise LoadError(f"{lit.filename}: Invalid `{source}` provenance type") if append: base_provenance.append( {"creator": "IOData", "version": __version__, "routine": "iodata.formats.json.load_one"} @@ -1471,7 +1467,7 @@ def dump_one(f: TextIO, data: IOData): elif schema_name == "qcschema_output": return_dict = _dump_qcschema_output(data) else: - raise FileFormatError( + raise DumpError( "'schema_name' must be one of 'qcschema_molecule', 'qcschema_basis'" "'qcschema_input' or 'qcschema_output'." ) @@ -1496,7 +1492,7 @@ def _dump_qcschema_molecule(data: IOData) -> dict: # Gather required field data if data.atnums is None or data.atcoords is None: - raise FileFormatError("qcschema_molecule requires `atnums` and `atcoords` fields.") + raise DumpError("qcschema_molecule requires `atnums` and `atcoords` fields.") molecule_dict["symbols"] = [num2sym[num] for num in data.atnums] molecule_dict["geometry"] = list(data.atcoords.flatten()) @@ -1505,7 +1501,7 @@ def _dump_qcschema_molecule(data: IOData) -> dict: warn( "`charge` and `spinpol` should be given to write qcschema_molecule file:" "QCSchema defaults to charge = 0 and multiplicity = 1 if no values given.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) if data.charge is not None: @@ -1599,7 +1595,7 @@ def _dump_provenance(data: IOData, source: str) -> Union[list[dict], dict]: if isinstance(provenance, list): provenance.append(new_provenance) return provenance - raise FileFormatError("QCSchema provenance must be either a dict or list of dicts.") + raise DumpError("QCSchema provenance must be either a dict or list of dicts.") return new_provenance @@ -1626,17 +1622,17 @@ def _dump_qcschema_input(data: IOData) -> dict: # Gather required field data input_dict["molecule"] = _dump_qcschema_molecule(data) if "driver" not in data.extra["input"]: - raise FileFormatError("qcschema_input requires `driver` field in extra['input'].") + raise DumpError("qcschema_input requires `driver` field in extra['input'].") if data.extra["input"]["driver"] not in {"energy", "gradient", "hessian", "properties"}: - raise FileFormatError( + raise DumpError( "QCSchema driver must be one of `energy`, `gradient`, `hessian`, or `properties`" ) input_dict["driver"] = data.extra["input"]["driver"] if "model" not in data.extra["input"]: - raise FileFormatError("qcschema_input requires `model` field in extra['input'].") + raise DumpError("qcschema_input requires `model` field in extra['input'].") input_dict["model"] = {} if data.lot is None: - raise FileFormatError("qcschema_input requires specifed `lot`.") + raise DumpError("qcschema_input requires specifed `lot`.") input_dict["model"]["method"] = data.lot if data.obasis_name is None and "basis" not in data.extra["input"]["model"]: input_dict["model"]["basis"] = "" @@ -1686,37 +1682,37 @@ def _dump_qcschema_output(data: IOData) -> dict: # Gather required field data output_dict["molecule"] = _dump_qcschema_molecule(data) if "driver" not in data.extra["input"]: - raise FileFormatError("qcschema_output requires `driver` field in extra['input'].") + raise DumpError("qcschema_output requires `driver` field in extra['input'].") if data.extra["input"]["driver"] not in {"energy", "gradient", "hessian", "properties"}: - raise FileFormatError( + raise DumpError( "QCSchema driver must be one of `energy`, `gradient`, `hessian`, or `properties`" ) output_dict["driver"] = data.extra["input"]["driver"] if "model" not in data.extra["input"]: - raise FileFormatError("qcschema_output requires `model` field in extra['input'].") + raise DumpError("qcschema_output requires `model` field in extra['input'].") output_dict["model"] = {} if data.lot is None: - raise FileFormatError("qcschema_output requires specifed `lot`.") + raise DumpError("qcschema_output requires specifed `lot`.") output_dict["model"]["method"] = data.lot if data.obasis_name is None and "basis" not in data.extra["input"]["model"]: warn( "No basis name given. QCSchema assumes this signifies a basis-free method; to" "avoid this warning, specify `obasis_name` as an empty string.", - FileFormatWarning, + LoadWarning, stacklevel=2, ) if "basis" in data.extra["input"]["model"]: raise NotImplementedError("qcschema_basis is not yet supported in IOData.") output_dict["model"]["basis"] = data.obasis_name if "properties" not in data.extra["output"]: - raise FileFormatError("qcschema_output requires `properties` field in extra['output'].") + raise DumpError("qcschema_output requires `properties` field in extra['output'].") output_dict["properties"] = data.extra["output"]["properties"] if data.energy is not None: output_dict["properties"]["return_energy"] = data.energy if output_dict["driver"] == "energy": output_dict["return_result"] = data.energy if "return_result" not in output_dict and "return_result" not in data.extra["output"]: - raise FileFormatError("qcschema_output requires `return_result` field in extra['output'].") + raise DumpError("qcschema_output requires `return_result` field in extra['output'].") if "return_result" in data.extra["output"]: output_dict["return_result"] = data.extra["output"]["return_result"] if "keywords" in data.extra["input"]: diff --git a/iodata/formats/mol2.py b/iodata/formats/mol2.py index 6e2b7317..814e7a85 100644 --- a/iodata/formats/mol2.py +++ b/iodata/formats/mol2.py @@ -36,7 +36,7 @@ ) from ..iodata import IOData from ..periodic import bond2num, num2bond, num2sym, sym2num -from ..utils import LineIterator, angstrom +from ..utils import LineIterator, LoadError, angstrom __all__ = [] @@ -80,7 +80,7 @@ def load_one(lit: LineIterator) -> dict: bonds = _load_helper_bonds(lit, nbonds) result["bonds"] = bonds if not molecule_found: - raise lit.error("Molecule could not be read") + lit.error("Molecule could not be read") return result @@ -148,7 +148,7 @@ def load_many(lit: LineIterator) -> Iterator[dict]: try: while True: yield load_one(lit) - except OSError: + except (StopIteration, LoadError): return diff --git a/iodata/formats/molden.py b/iodata/formats/molden.py index 440ea63b..2a794ee0 100644 --- a/iodata/formats/molden.py +++ b/iodata/formats/molden.py @@ -45,7 +45,7 @@ from ..orbitals import MolecularOrbitals from ..overlap import compute_overlap, gob_cart_normalization from ..periodic import num2sym, sym2num -from ..utils import LineIterator, PrepareDumpError, angstrom +from ..utils import DumpError, LineIterator, PrepareDumpError, angstrom __all__ = [] @@ -663,7 +663,7 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold coeffsa = result["mo"].coeffsa coeffsb = result["mo"].coeffsb else: - raise ValueError("Molecular orbital kind={} not recognized".format(result["mo"].kind)) + lit.error("Molecular orbital kind={} not recognized".format(result["mo"].kind)) if _is_normalized_properly(obasis, atcoords, coeffsa, coeffsb, norm_threshold): # The file is good. No need to change obasis. @@ -793,7 +793,7 @@ def dump_one(f: TextIO, data: IOData): for angmom, kind in zip(shell.angmoms, shell.kinds): if angmom in angmom_kinds: if kind != angmom_kinds[angmom]: - raise OSError( + raise DumpError( "Molden format does not support mixed pure+Cartesian functions for one " "angular momentum." ) diff --git a/iodata/formats/molekel.py b/iodata/formats/molekel.py index 1a2b919a..5ddce78a 100644 --- a/iodata/formats/molekel.py +++ b/iodata/formats/molekel.py @@ -32,7 +32,7 @@ from ..docstrings import document_dump_one, document_load_one from ..iodata import IOData from ..orbitals import MolecularOrbitals -from ..utils import LineIterator, PrepareDumpError, angstrom +from ..utils import DumpError, LineIterator, PrepareDumpError, angstrom from .molden import CONVENTIONS, _fix_molden_from_buggy_codes __all__ = [] @@ -370,7 +370,7 @@ def _dump_helper_coeffs(f, data, spin=None): ener = data.mo.energiesb irreps = data.mo.irreps[norb:] if data.mo.irreps is not None else ["a1g"] * norb else: - raise ValueError("A spin must be specified") + raise DumpError("A spin must be specified") for j in range(0, norb, 5): en = " ".join([f" {e: ,.12f}" for e in ener[j : j + 5]]) @@ -396,7 +396,7 @@ def _dump_helper_occ(f, data, spin=None): norb = data.mo.norba occ = data.mo.occs else: - raise ValueError("A spin must be specified") + raise DumpError("A spin must be specified") for j in range(0, norb, 5): occs = " ".join([f" {o: ,.7f}" for o in occ[j : j + 5]]) diff --git a/iodata/formats/mwfn.py b/iodata/formats/mwfn.py index 6ede364f..7b276859 100644 --- a/iodata/formats/mwfn.py +++ b/iodata/formats/mwfn.py @@ -24,7 +24,7 @@ from ..basis import HORTON2_CONVENTIONS, MolecularBasis, Shell from ..docstrings import document_load_one from ..orbitals import MolecularOrbitals -from ..utils import LineIterator, angstrom +from ..utils import LineIterator, LoadError, angstrom __all__ = [] @@ -301,8 +301,8 @@ def load_one(lit: LineIterator) -> dict: obasis = MolecularBasis(shells, CONVENTIONS, "L2") # check number of basis functions if obasis.nbasis != inp["Nbasis"]: - raise ValueError( - f"Number of basis in MolecularBasis is not equal to the 'Nbasis'. " + raise LoadError( + f"{lit.filename}: Number of basis in MolecularBasis is not equal to the 'Nbasis'. " f"{obasis.nbasis} != {inp['Nbasis']}" ) @@ -324,18 +324,18 @@ def load_one(lit: LineIterator) -> dict: ) # check number of electrons if mo.nelec != inp["Naelec"] + inp["Nbelec"]: - raise ValueError( - f"Number of electrons in MolecularOrbitals is not equal to the sum of " + raise LoadError( + f"{lit.filename}: Number of electrons in MolecularOrbitals is not equal to the sum of " f"'Naelec' and 'Nbelec'. {mo.nelec} != {inp['Naelec']} + {inp['Nbelec']}" ) if mo.occsa.sum() != inp["Naelec"]: - raise ValueError( - f"Number of alpha electrons in MolecularOrbitals is not equal to the " + raise LoadError( + f"{lit.filename}: Number of alpha electrons in MolecularOrbitals is not equal to the " f"'Naelec'. {mo.occsa.sum()} != {inp['Naelec']}" ) if mo.occsb.sum() != inp["Nbelec"]: - raise ValueError( - f"Number of beta electrons in MolecularOrbitals is not equal to the " + raise LoadError( + f"{lit.filename}: Number of beta electrons in MolecularOrbitals is not equal to the " f"'Nbelec'. {mo.occsb.sum()} != {inp['Nbelec']}" ) diff --git a/iodata/formats/pdb.py b/iodata/formats/pdb.py index 96611de9..aec22ef1 100644 --- a/iodata/formats/pdb.py +++ b/iodata/formats/pdb.py @@ -36,7 +36,7 @@ ) from ..iodata import IOData from ..periodic import bond2num, num2sym, sym2num -from ..utils import LineIterator, angstrom +from ..utils import LineIterator, LoadError, angstrom __all__ = [] @@ -188,9 +188,9 @@ def load_one(lit: LineIterator) -> dict: end_reached = True break if not molecule_found: - lit.error("Molecule could not be read!") + lit.error("Molecule could not be read.") if not end_reached: - lit.warn("The END is not found, but the parsed data is returned!") + lit.warn("The END is not found, but the parsed data is returned.") # Data related to force fields atffparams = { @@ -240,7 +240,7 @@ def load_many(lit: LineIterator) -> Iterator[dict]: try: while True: yield load_one(lit) - except OSError: + except (StopIteration, LoadError): return diff --git a/iodata/test/common.py b/iodata/test/common.py index 3a7a1c54..17a0b242 100644 --- a/iodata/test/common.py +++ b/iodata/test/common.py @@ -33,7 +33,7 @@ from ..iodata import IOData from ..orbitals import MolecularOrbitals from ..overlap import compute_overlap -from ..utils import FileFormatWarning +from ..utils import LoadWarning __all__ = [ "compute_mulliken_charges", @@ -173,7 +173,7 @@ def check_orthonormal(mo_coeffs: NDArray[float], ao_overlap: NDArray[float], ato def load_one_warning( filename: str, fmt: Optional[str] = None, match: Optional[str] = None, **kwargs ) -> IOData: - """Call load_one, catching expected FileFormatWarning. + """Call load_one, catching expected LoadWarning. Parameters ---------- @@ -196,7 +196,7 @@ def load_one_warning( with as_file(files("iodata.test.data").joinpath(filename)) as fn: if match is None: return load_one(str(fn), fmt, **kwargs) - with pytest.warns(FileFormatWarning, match=match): + with pytest.warns(LoadWarning, match=match): return load_one(str(fn), fmt, **kwargs) diff --git a/iodata/test/test_api.py b/iodata/test/test_api.py index 18de1124..0003a855 100644 --- a/iodata/test/test_api.py +++ b/iodata/test/test_api.py @@ -29,12 +29,12 @@ from ..api import dump_many, dump_one, load_many from ..iodata import IOData -from ..utils import FileFormatError, PrepareDumpError +from ..utils import DumpError, PrepareDumpError def test_empty_dump_many_no_file(tmpdir): path_xyz = os.path.join(tmpdir, "empty.xyz") - with pytest.raises(FileFormatError): + with pytest.raises(DumpError): dump_many([], path_xyz) assert not os.path.isfile(path_xyz) diff --git a/iodata/test/test_cp2klog.py b/iodata/test/test_cp2klog.py index 2f5ee6e1..043c43df 100644 --- a/iodata/test/test_cp2klog.py +++ b/iodata/test/test_cp2klog.py @@ -25,6 +25,7 @@ from ..api import load_one from ..overlap import compute_overlap +from ..utils import LoadError from .common import check_orthonormal, truncated_file @@ -207,18 +208,18 @@ def test_carbon_sc_pp_uncontracted(): def test_errors(tmpdir): source = files("iodata.test.data").joinpath("carbon_sc_pp_uncontracted.cp2k.out") with as_file(source) as fn_test: - with truncated_file(fn_test, 0, 0, tmpdir) as fn, pytest.raises(IOError): + with truncated_file(fn_test, 0, 0, tmpdir) as fn, pytest.raises(LoadError): load_one(fn) - with truncated_file(fn_test, 107, 10, tmpdir) as fn, pytest.raises(IOError): + with truncated_file(fn_test, 107, 10, tmpdir) as fn, pytest.raises(LoadError): load_one(fn) - with truncated_file(fn_test, 357, 10, tmpdir) as fn, pytest.raises(IOError): + with truncated_file(fn_test, 357, 10, tmpdir) as fn, pytest.raises(LoadError): load_one(fn) - with truncated_file(fn_test, 405, 10, tmpdir) as fn, pytest.raises(IOError): + with truncated_file(fn_test, 405, 10, tmpdir) as fn, pytest.raises(LoadError): load_one(fn) source = files("iodata.test.data").joinpath("carbon_gs_pp_uncontracted.cp2k.out") with ( as_file(source) as fn_test, truncated_file(fn_test, 456, 10, tmpdir) as fn, - pytest.raises(IOError), + pytest.raises(LoadError), ): load_one(fn) diff --git a/iodata/test/test_inputs.py b/iodata/test/test_inputs.py index 0366be5b..a8ccd1ab 100644 --- a/iodata/test/test_inputs.py +++ b/iodata/test/test_inputs.py @@ -27,7 +27,7 @@ from ..api import load_one, write_input from ..iodata import IOData from ..periodic import num2sym -from ..utils import FileFormatWarning, angstrom +from ..utils import LoadWarning, angstrom def check_load_input_and_compare(fname: str, fname_expected: str): @@ -205,7 +205,7 @@ def test_input_orca_from_molden(tmpdir): # load orca molden with ( as_file(files("iodata.test.data").joinpath("nh3_orca.molden")) as fn, - pytest.warns(FileFormatWarning), + pytest.warns(LoadWarning), ): mol = load_one(fn) # write input in a temporary file diff --git a/iodata/test/test_json.py b/iodata/test/test_json.py index 16c5b414..7745ffd2 100644 --- a/iodata/test/test_json.py +++ b/iodata/test/test_json.py @@ -26,7 +26,7 @@ import pytest from ..api import dump_one, load_one -from ..utils import FileFormatError, FileFormatWarning +from ..utils import LoadError, LoadWarning # Tests for qcschema_molecule # GEOMS: dict of str: NDArray(N, 3) @@ -65,7 +65,7 @@ def test_qcschema_molecule(filename, atnums, charge, spinpol, geometry, nwarn): if nwarn == 0: mol = load_one(str(qcschema_molecule)) else: - with pytest.warns(FileFormatWarning) as record: + with pytest.warns() as record: mol = load_one(str(qcschema_molecule)) assert len(record) == nwarn @@ -107,7 +107,7 @@ def test_molssi_qcschema_molecule(filename, atnums, charge, spinpol, nwarn): """Test qcschema_molecule parsing using MolSSI-sourced files.""" with ( as_file(files("iodata.test.data").joinpath(filename)) as qcschema_molecule, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(str(qcschema_molecule)) @@ -137,7 +137,7 @@ def test_passthrough_qcschema_molecule(filename, unparsed_dict): """Test qcschema_molecule parsing for passthrough of unparsed keys.""" with ( as_file(files("iodata.test.data").joinpath(filename)) as qcschema_molecule, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(str(qcschema_molecule)) @@ -174,7 +174,7 @@ def test_inout_qcschema_molecule(tmpdir, filename, nwarn): if nwarn == 0: mol = load_one(str(qcschema_molecule)) else: - with pytest.warns(FileFormatWarning) as record: + with pytest.warns(LoadWarning) as record: mol = load_one(str(qcschema_molecule)) assert len(record) == nwarn mol1 = json.loads(qcschema_molecule.read_bytes()) @@ -204,7 +204,7 @@ def test_inout_qcschema_molecule(tmpdir, filename, nwarn): def test_inout_molssi_qcschema_molecule(tmpdir, filename): """Test that loading and dumping qcschema_molecule files retains all relevant data.""" with as_file(files("iodata.test.data").joinpath(filename)) as qcschema_molecule: - with pytest.warns(FileFormatWarning) as record: + with pytest.warns(LoadWarning) as record: mol = load_one(str(qcschema_molecule)) mol1_preproc = json.loads(qcschema_molecule.read_bytes()) assert len(record) == 1 @@ -266,7 +266,7 @@ def test_ghost(tmpdir): def test_qcschema_input(filename, explicit_basis, lot, obasis_name, run_type, geometry): with as_file(files("iodata.test.data").joinpath(filename)) as qcschema_input: try: - with pytest.warns(FileFormatWarning): + with pytest.warns(LoadWarning): mol = load_one(str(qcschema_input)) assert mol.lot == lot if obasis_name: @@ -275,7 +275,7 @@ def test_qcschema_input(filename, explicit_basis, lot, obasis_name, run_type, ge assert mol.run_type == run_type np.testing.assert_allclose(mol.atcoords, geometry) # This will change if QCSchema Basis gets supported - except NotImplementedError: + except LoadError: assert explicit_basis @@ -293,7 +293,7 @@ def test_passthrough_qcschema_input(filename, unparsed_dict, location): """Test qcschema_molecule parsing for passthrough of unparsed keys.""" with ( as_file(files("iodata.test.data").joinpath(filename)) as qcschema_input, - pytest.warns(FileFormatWarning), + pytest.warns(LoadWarning), ): mol = load_one(str(qcschema_input)) @@ -315,10 +315,10 @@ def test_inout_qcschema_input(tmpdir, filename, nwarn): """Test that loading and dumping qcschema_molecule files retains all data.""" with as_file(files("iodata.test.data").joinpath(filename)) as qcschema_input: if nwarn == 0: - with pytest.warns(FileFormatWarning): + with pytest.warns(LoadWarning): mol = load_one(str(qcschema_input)) else: - with pytest.warns(FileFormatWarning) as record: + with pytest.warns(LoadWarning) as record: mol = load_one(str(qcschema_input)) assert len(record) == nwarn mol1 = json.loads(qcschema_input.read_bytes()) @@ -354,10 +354,10 @@ def test_inout_qcschema_input(tmpdir, filename, nwarn): def test_qcschema_output(filename, lot, obasis_name, run_type, nwarn): with as_file(files("iodata.test.data").joinpath(filename)) as qcschema_output: if nwarn == 0: - with pytest.warns(FileFormatWarning): + with pytest.warns(LoadWarning): mol = load_one(str(qcschema_output)) else: - with pytest.warns(FileFormatWarning) as record: + with pytest.warns(LoadWarning) as record: mol = load_one(str(qcschema_output)) assert len(record) == nwarn @@ -370,8 +370,8 @@ def test_qcschema_output(filename, lot, obasis_name, run_type, nwarn): # Some of these files have been manually validated, as reflected in the provenance # bad_mol_files: (filename, error) BAD_OUTPUT_FILES = [ - ("turbomole_water_energy_hf_output.json", FileFormatError), - ("turbomole_water_gradient_rimp2_output.json", FileFormatError), + ("turbomole_water_energy_hf_output.json", LoadError), + ("turbomole_water_gradient_rimp2_output.json", LoadError), ] @@ -395,7 +395,7 @@ def test_bad_qcschema_files(filename, error): def test_inout_qcschema_output(tmpdir, filename): """Test that loading and dumping qcschema_molecule files retains all data.""" with as_file(files("iodata.test.data").joinpath(filename)) as qcschema_input: - with pytest.warns(FileFormatWarning): + with pytest.warns(LoadWarning): mol = load_one(str(qcschema_input)) mol1 = json.loads(qcschema_input.read_bytes()) diff --git a/iodata/test/test_mol2.py b/iodata/test/test_mol2.py index ec1eb585..879aa1e4 100644 --- a/iodata/test/test_mol2.py +++ b/iodata/test/test_mol2.py @@ -26,7 +26,7 @@ from ..api import dump_many, dump_one, load_many, load_one from ..periodic import bond2num -from ..utils import angstrom +from ..utils import LoadError, angstrom from .common import truncated_file @@ -42,7 +42,7 @@ def test_mol2_formaterror(tmpdir): with ( as_file(files("iodata.test.data").joinpath("caffeine.mol2")) as fn_test, truncated_file(fn_test, 2, 0, tmpdir) as fn, - pytest.raises(IOError), + pytest.raises(LoadError), ): load_one(str(fn)) diff --git a/iodata/test/test_molden.py b/iodata/test/test_molden.py index 4a2ee816..3190c98d 100644 --- a/iodata/test/test_molden.py +++ b/iodata/test/test_molden.py @@ -32,7 +32,7 @@ from ..basis import convert_conventions from ..formats.molden import _load_low from ..overlap import OVERLAP_CONVENTIONS, compute_overlap -from ..utils import FileFormatWarning, LineIterator, PrepareDumpError, angstrom +from ..utils import LineIterator, LoadWarning, PrepareDumpError, angstrom from .common import check_orthonormal, compare_mols, compute_mulliken_charges, create_generalized @@ -40,7 +40,7 @@ def test_load_molden_li2_orca(): with ( as_file(files("iodata.test.data").joinpath("li2.molden.input")) as fn_molden, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(str(fn_molden)) assert len(record) == 1 @@ -82,7 +82,7 @@ def test_load_molden_li2_orca_huge_threshold(): def test_load_molden_h2o_orca(): with ( as_file(files("iodata.test.data").joinpath("h2o.molden.input")) as fn_molden, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(str(fn_molden)) assert len(record) == 1 @@ -285,7 +285,7 @@ def test_load_molden_cfour(path, should_warn): with ExitStack() as stack: fn_molden = stack.enter_context(as_file(files("iodata.test.data").joinpath(path))) if should_warn: - stack.enter_context(pytest.warns(FileFormatWarning)) + stack.enter_context(pytest.warns(LoadWarning)) mol = load_one(str(fn_molden)) # Check normalization olp = compute_overlap(mol.obasis, mol.atcoords) @@ -298,7 +298,7 @@ def test_load_molden_nh3_orca(): # properly by altering normalization and sign conventions. with ( as_file(files("iodata.test.data").joinpath("nh3_orca.molden")) as fn_molden, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(str(fn_molden)) assert len(record) == 1 @@ -320,7 +320,7 @@ def test_load_molden_nh3_psi4(): # properly by altering normalization conventions. with ( as_file(files("iodata.test.data").joinpath("nh3_psi4.molden")) as fn_molden, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(str(fn_molden)) assert len(record) == 1 @@ -342,7 +342,7 @@ def test_load_molden_nh3_psi4_1(): # It should be read in properly by renormalizing the contractions. with ( as_file(files("iodata.test.data").joinpath("nh3_psi4_1.0.molden")) as fn_molden, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(str(fn_molden)) assert len(record) == 1 @@ -366,7 +366,7 @@ def test_load_molden_high_am_psi4(case): # This is a special case because it contains higher angular momenta than # officially supported by the Molden format. Most virtual orbitals were removed. source = files("iodata.test.data").joinpath(f"psi4_{case}_cc_pvqz_pure.molden") - with as_file(source) as fn_molden, pytest.warns(FileFormatWarning) as record: + with as_file(source) as fn_molden, pytest.warns(LoadWarning) as record: mol = load_one(str(fn_molden)) assert len(record) == 1 assert "unnormalized" in record[0].message.args[0] @@ -388,7 +388,7 @@ def test_load_molden_high_am_orca(case): # This is a special case because it contains higher angular momenta than # officially supported by the Molden format. Most virtual orbitals were removed. source = files("iodata.test.data").joinpath(f"orca_{case}_cc_pvqz_pure.molden") - with as_file(source) as fn_molden, pytest.warns(FileFormatWarning) as record: + with as_file(source) as fn_molden, pytest.warns(LoadWarning) as record: mol = load_one(str(fn_molden)) assert len(record) == 1 assert "ORCA" in record[0].message.args[0] @@ -419,7 +419,7 @@ def test_load_molden_h2o_6_31g_d_cart_psi4(): # The file tested here is created with PSI4 1.3.2. It should be read in # properly after fixing for errors in AO normalization conventions. source = files("iodata.test.data").joinpath("h2o_psi4_1.3.2_6-31G_d_cart.molden") - with as_file(source) as fn_molden, pytest.warns(FileFormatWarning) as record: + with as_file(source) as fn_molden, pytest.warns(LoadWarning) as record: mol = load_one(str(fn_molden)) assert len(record) == 1 assert "PSI4 <= 1.3.2" in record[0].message.args[0] @@ -440,7 +440,7 @@ def test_load_molden_nh3_aug_cc_pvqz_cart_psi4(): # The file tested here is created with PSI4 1.3.2. It should be read in # properly after fixing for errors in AO normalization conventions. source = files("iodata.test.data").joinpath("nh3_psi4_1.3.2_aug_cc_pvqz_cart.molden") - with as_file(source) as fn_molden, pytest.warns(FileFormatWarning) as record: + with as_file(source) as fn_molden, pytest.warns(LoadWarning) as record: mol = load_one(str(fn_molden)) assert len(record) == 1 assert "PSI4 <= 1.3.2" in record[0].message.args[0] @@ -506,7 +506,7 @@ def test_load_molden_nh3_molpro2012(): def test_load_molden_neon_turbomole(): # The file tested here is created with Turbomole 7.1. source = files("iodata.test.data").joinpath("neon_turbomole_def2-qzvp.molden") - with as_file(source) as fn_molden, pytest.warns(FileFormatWarning) as record: + with as_file(source) as fn_molden, pytest.warns(LoadWarning) as record: mol = load_one(str(fn_molden)) assert len(record) == 1 assert "Turbomole" in record[0].message.args[0] @@ -524,7 +524,7 @@ def test_load_molden_nh3_turbomole(): # The file tested here is created with Turbomole 7.1 with ( as_file(files("iodata.test.data").joinpath("nh3_turbomole.molden")) as fn_molden, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(str(fn_molden)) assert len(record) == 1 @@ -546,7 +546,7 @@ def test_load_molden_nh3_turbomole(): def test_load_molden_f(): with ( as_file(files("iodata.test.data").joinpath("F.molden")) as fn_molden, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(str(fn_molden)) assert len(record) == 1 @@ -583,7 +583,7 @@ def test_load_dump_consistency(tmpdir, fn, match): if match is None: mol1 = load_one(str(file_name)) else: - with pytest.warns(FileFormatWarning, match=match): + with pytest.warns(LoadWarning, match=match): mol1 = load_one(str(file_name)) fn_tmp = os.path.join(tmpdir, "foo.bar") dump_one(mol1, fn_tmp, fmt="molden") diff --git a/iodata/test/test_pdb.py b/iodata/test/test_pdb.py index 2420f945..8ecf1cd2 100644 --- a/iodata/test/test_pdb.py +++ b/iodata/test/test_pdb.py @@ -26,7 +26,7 @@ from numpy.testing import assert_allclose, assert_equal from ..api import dump_many, dump_one, load_many, load_one -from ..utils import FileFormatWarning, angstrom +from ..utils import LoadWarning, angstrom @pytest.mark.parametrize("case", ["single", "single_model"]) @@ -41,7 +41,7 @@ def test_load_water_no_end(): # test pdb of water with ( as_file(files("iodata.test.data").joinpath("water_single_no_end.pdb")) as fn_pdb, - pytest.warns(FileFormatWarning, match="The END is not found"), + pytest.warns(LoadWarning, match="The END is not found"), ): mol = load_one(str(fn_pdb)) check_water(mol) @@ -77,7 +77,7 @@ def check_water(mol): def test_load_dump_consistency(fn_base, should_warn, tmpdir): with as_file(files("iodata.test.data").joinpath(fn_base)) as fn_pdb: if should_warn: - with pytest.warns(FileFormatWarning) as record: + with pytest.warns(LoadWarning) as record: mol0 = load_one(str(fn_pdb)) assert len(record) > 1 else: @@ -134,7 +134,7 @@ def test_load_peptide_2luv(): # test pdb of small peptide with ( as_file(files("iodata.test.data").joinpath("2luv.pdb")) as fn_pdb, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(str(fn_pdb)) assert len(record) == 271 @@ -235,7 +235,7 @@ def test_load_ch5plus_bonds(): def test_indomethacin_dimer(): with ( as_file(files("iodata.test.data").joinpath("indomethacin-dimer.pdb")) as fn_pdb, - pytest.warns(FileFormatWarning) as record, + pytest.warns(LoadWarning) as record, ): mol = load_one(fn_pdb) assert len(record) == 82 diff --git a/iodata/test/test_sdf.py b/iodata/test/test_sdf.py index 9b4e590a..ed974fc1 100644 --- a/iodata/test/test_sdf.py +++ b/iodata/test/test_sdf.py @@ -25,7 +25,7 @@ from numpy.testing import assert_allclose, assert_equal from ..api import dump_many, dump_one, load_many, load_one -from ..utils import FileFormatError, angstrom +from ..utils import LoadError, angstrom from .common import truncated_file @@ -52,7 +52,7 @@ def test_sdf_formaterror(tmpdir): with ( as_file(files("iodata.test.data").joinpath("example.sdf")) as fn_test, truncated_file(fn_test, 36, 0, tmpdir) as fn, - pytest.raises(IOError), + pytest.raises(LoadError), ): load_one(str(fn)) @@ -127,6 +127,6 @@ def test_load_dump_many_consistency(tmpdir): def test_v2000_check(): with ( as_file(files("iodata.test.data").joinpath("molv3000.sdf")) as fn_sdf, - pytest.raises(FileFormatError), + pytest.raises(LoadError), ): load_one(fn_sdf) diff --git a/iodata/test/test_wfx.py b/iodata/test/test_wfx.py index b8298ca5..bc52dab6 100644 --- a/iodata/test/test_wfx.py +++ b/iodata/test/test_wfx.py @@ -29,7 +29,7 @@ from ..api import dump_one, load_one from ..formats.wfx import load_data_wfx, parse_wfx from ..overlap import compute_overlap -from ..utils import LineIterator, PrepareDumpError +from ..utils import LineIterator, LoadError, PrepareDumpError from .common import ( check_orthonormal, compare_mols, @@ -582,7 +582,7 @@ def test_parse_wfx_missing_tag_h2o(): with ( as_file(files("iodata.test.data").joinpath("water_sto3g_hf.wfx")) as fn_wfx, LineIterator(fn_wfx) as lit, - pytest.raises(IOError) as error, + pytest.raises(LoadError) as error, ): parse_wfx(lit, required_tags=[""]) assert str(error.value).endswith("Section is missing from loaded WFX data.") @@ -592,7 +592,7 @@ def test_load_data_wfx_h2o_error(): """Check that sections without a closing tag result in an exception.""" with ( as_file(files("iodata.test.data").joinpath("h2o_error.wfx")) as fn_wfx, - pytest.raises(IOError) as error, + pytest.raises(LoadError) as error, ): load_one(str(fn_wfx)) assert str(error.value).endswith( @@ -605,7 +605,7 @@ def test_load_truncated_h2o(tmpdir): with ( as_file(files("iodata.test.data").joinpath("water_sto3g_hf.wfx")) as fn_wfx, truncated_file(str(fn_wfx), 152, 0, tmpdir) as fn_truncated, - pytest.raises(IOError) as error, + pytest.raises(LoadError) as error, ): load_one(str(fn_truncated)) assert str(error.value).endswith( diff --git a/iodata/utils.py b/iodata/utils.py index a7984a4c..bfdaf2e2 100644 --- a/iodata/utils.py +++ b/iodata/utils.py @@ -29,9 +29,12 @@ from .attrutils import validate_shape __all__ = ( - "FileFormatError", - "FileFormatWarning", + "LoadError", + "LoadWarning", + "DumpError", + "DumpWarning", "PrepareDumpError", + "WriteInputError", "LineIterator", "Cube", "set_four_index_element", @@ -59,16 +62,32 @@ kjmol: float = 1e3 / spc.value("Avogadro constant") / spc.value("Hartree energy") -class FileFormatError(IOError): - """Raised when incorrect content is encountered when loading files.""" +class FileFormatError(ValueError): + """Raise when a file or input format cannot be identified.""" -class FileFormatWarning(Warning): - """Raised when incorrect content is encountered and fixed when loading files.""" +class LoadError(ValueError): + """Raised when an error is encountered while loading from a file.""" + + +class LoadWarning(Warning): + """Raised when incorrect content is encountered and fixed when loading from a file.""" + + +class DumpError(ValueError): + """Raised when an error is encountered while dumping to a file.""" + + +class DumpWarning(Warning): + """Raised when an IOData object is made compatible with a format when dumping to a file.""" class PrepareDumpError(ValueError): - """Raised when an iodata object is not compatible with an output file format.""" + """Raised when an IOData object is incompatible with a format before dumping to a file.""" + + +class WriteInputError(ValueError): + """Raised when an error is encountered while writing an input file.""" class LineIterator: @@ -122,7 +141,7 @@ def error(self, msg: str): Message to raise alongside filename and line number. """ - raise FileFormatError(f"{self.filename}:{self.lineno} {msg}") + raise LoadError(f"{self.filename}:{self.lineno} {msg}") def warn(self, msg: str): """Raise a warning while reading a file. @@ -133,7 +152,7 @@ def warn(self, msg: str): Message to raise alongside filename and line number. """ - warnings.warn(f"{self.filename}:{self.lineno} {msg}", FileFormatWarning, stacklevel=2) + warnings.warn(f"{self.filename}:{self.lineno} {msg}", LoadWarning, stacklevel=2) def back(self, line): """Go back one line in the file and decrease the lineno attribute by one.""" From 96b296ac0b711dba295c9c74abe71cadd28c27ad Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Fri, 21 Jun 2024 07:16:22 +0200 Subject: [PATCH 2/3] Fix f-string --- iodata/formats/molden.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iodata/formats/molden.py b/iodata/formats/molden.py index 2a794ee0..2cfc48bd 100644 --- a/iodata/formats/molden.py +++ b/iodata/formats/molden.py @@ -663,7 +663,7 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold coeffsa = result["mo"].coeffsa coeffsb = result["mo"].coeffsb else: - lit.error("Molecular orbital kind={} not recognized".format(result["mo"].kind)) + lit.error(f"Molecular orbital kind={result['mo'].kind} not recognized") if _is_normalized_properly(obasis, atcoords, coeffsa, coeffsb, norm_threshold): # The file is good. No need to change obasis. From cfbbceb5c8f5c29638f727d4479624a0bde8dd00 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Fri, 21 Jun 2024 07:19:22 +0200 Subject: [PATCH 3/3] AI suggestions --- iodata/api.py | 2 +- iodata/test/test_json.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/iodata/api.py b/iodata/api.py index ad4bf840..26f3c22d 100644 --- a/iodata/api.py +++ b/iodata/api.py @@ -242,7 +242,7 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs) ------ DumpError When an error is encountered while dumping to a file. - If the output file already existed, it (partially) overwritten. + If the output file already existed, it is (partially) overwritten. PrepareDumpError When the iodata object is not compatible with the file format, e.g. due to missing attributes, and not conversion is available or allowed diff --git a/iodata/test/test_json.py b/iodata/test/test_json.py index 7745ffd2..3a86cb74 100644 --- a/iodata/test/test_json.py +++ b/iodata/test/test_json.py @@ -65,7 +65,7 @@ def test_qcschema_molecule(filename, atnums, charge, spinpol, geometry, nwarn): if nwarn == 0: mol = load_one(str(qcschema_molecule)) else: - with pytest.warns() as record: + with pytest.warns(LoadWarning) as record: mol = load_one(str(qcschema_molecule)) assert len(record) == nwarn