From 40a6965d055d7c4849e4911365eb1b0b3be44737 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Sun, 23 Jun 2024 11:23:14 +0200 Subject: [PATCH 1/4] Further improve error and warning infrastructure --- CONTRIBUTING.rst | 17 ++++ iodata/api.py | 70 ++++++++++----- iodata/formats/fchk.py | 14 +-- iodata/formats/json.py | 170 +++++++++++++++++++++--------------- iodata/formats/mol2.py | 13 +-- iodata/formats/molden.py | 50 ++++++++--- iodata/formats/molekel.py | 16 ++-- iodata/formats/pdb.py | 19 +++- iodata/formats/wfn.py | 16 ++-- iodata/formats/wfx.py | 22 +++-- iodata/test/data/water.mol2 | 15 ++++ iodata/test/test_iodata.py | 4 +- iodata/test/test_mol2.py | 11 ++- iodata/utils.py | 144 ++++++++++++++++++++---------- 14 files changed, 391 insertions(+), 190 deletions(-) create mode 100644 iodata/test/data/water.mol2 diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index ae802f3c..ffeefd74 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -170,6 +170,23 @@ If your code has already read the full file and encounters an error when process you can use ``raise LoadError("Describe problem in a sentence.", lit.filename)`` instead. This way, no line number is included in the error message. +Sometimes, it is possible to correct errors while reading a file. +In this case, you should warn the user that the file contains (fixable) errors: + +.. code-block:: python + + from warnings import warn + + from ..utils import LoadWarning + + @document_load_one(...) + def load_one(lit: LineIterator) -> dict: + ... + if something_fixed: + warn(LoadWarning("Describe the problem in a sentence.", lit), stacklevel=2) + +Always use ``stacklevel=2`` when raising warnings. + ``dump_one`` functions: writing a single IOData object to a file ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/iodata/api.py b/iodata/api.py index eaa2d855..2e1eeed6 100644 --- a/iodata/api.py +++ b/iodata/api.py @@ -19,6 +19,7 @@ """Functions to be used by end users.""" import os +import warnings from collections.abc import Iterable, Iterator from fnmatch import fnmatch from importlib import import_module @@ -84,7 +85,7 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non return format_module else: return FORMAT_MODULES[fmt] - raise FileFormatError(f"Could not find file format with feature {attrname} for file {filename}") + raise FileFormatError(f"Cannot find file format with feature {attrname}", filename) def _find_input_modules(): @@ -101,11 +102,13 @@ def _find_input_modules(): INPUT_MODULES = _find_input_modules() -def _select_input_module(fmt: str) -> ModuleType: +def _select_input_module(filename: str, fmt: str) -> ModuleType: """Find an input module. Parameters ---------- + filename + The file to be written to, only used for error messages. fmt The name of the input module to use. @@ -121,11 +124,31 @@ def _select_input_module(fmt: str) -> ModuleType: """ if fmt in INPUT_MODULES: if not hasattr(INPUT_MODULES[fmt], "write_input"): - raise FileFormatError(f"{fmt} input module does not have write_input.") + raise FileFormatError(f"{fmt} input module does not have write_input.", filename) return INPUT_MODULES[fmt] - raise FileFormatError(f"Could not find input format {fmt}.") + raise FileFormatError(f"Cannot find input format {fmt}.", filename) + +def _reissue_warnings(func): + """Correct stacklevel of warnings raised in functions called deeper in IOData. + This function should be used as a decorator of end-user API functions. + Adapted from https://stackoverflow.com/a/71635963/494584 + """ + + def inner(*args, **kwargs): + try: + with warnings.catch_warnings(record=True) as warning_list: + result = func(*args, **kwargs) + finally: + for warning in warning_list: + warnings.warn(warning.message, warning.category, stacklevel=2) + return result + + return inner + + +@_reissue_warnings def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData: """Load data from a file. @@ -151,16 +174,16 @@ def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData: format_module = _select_format_module(filename, "load_one", fmt) with LineIterator(filename) as lit: try: - iodata = IOData(**format_module.load_one(lit, **kwargs)) + return IOData(**format_module.load_one(lit, **kwargs)) except LoadError: raise except StopIteration as exc: raise LoadError("File ended before all data was read.", lit) from exc except Exception as exc: raise LoadError("Uncaught exception while loading file.", lit) from exc - return iodata +@_reissue_warnings def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IOData]: """Load multiple IOData instances from a file. @@ -197,11 +220,13 @@ def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IO raise LoadError("Uncaught exception while loading file.", lit) from exc -def _check_required(iodata: IOData, dump_func: Callable): +def _check_required(filename: str, iodata: IOData, dump_func: Callable): """Check that required attributes are not None before dumping to a file. Parameters ---------- + filename + The file to be dumped to, only used for error messages. iodata The data to be written. dump_func @@ -215,10 +240,11 @@ def _check_required(iodata: IOData, dump_func: Callable): for attr_name in dump_func.required: if getattr(iodata, attr_name) is None: raise PrepareDumpError( - f"Required attribute {attr_name}, for format {dump_func.fmt}, is None." + f"Required attribute {attr_name}, for format {dump_func.fmt}, is None.", filename ) +@_reissue_warnings def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs): """Write data to a file. @@ -251,14 +277,14 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs) """ format_module = _select_format_module(filename, "dump_one", fmt) try: - _check_required(iodata, format_module.dump_one) + _check_required(filename, iodata, format_module.dump_one) if hasattr(format_module, "prepare_dump"): - format_module.prepare_dump(iodata) + format_module.prepare_dump(filename, iodata) except PrepareDumpError: raise except Exception as exc: raise PrepareDumpError( - f"{filename}: Uncaught exception while preparing for dumping to a file" + "Uncaught exception while preparing for dumping to a file.", filename ) from exc with open(filename, "w") as f: try: @@ -266,9 +292,10 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs) except DumpError: raise except Exception as exc: - raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc + raise DumpError("Uncaught exception while dumping to a file", filename) from exc +@_reissue_warnings def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = None, **kwargs): """Write multiple IOData instances to a file. @@ -309,16 +336,16 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non try: first = next(iter_iodatas) except StopIteration as exc: - raise DumpError(f"{filename}: dump_many needs at least one iodata object.") from exc + raise DumpError("dump_many needs at least one iodata object.", filename) from exc try: - _check_required(first, format_module.dump_many) + _check_required(filename, first, format_module.dump_many) if hasattr(format_module, "prepare_dump"): - format_module.prepare_dump(first) + format_module.prepare_dump(filename, first) except PrepareDumpError: raise except Exception as exc: raise PrepareDumpError( - f"{filename}: Uncaught exception while preparing for dumping to a file" + "Uncaught exception while preparing for dumping to a file.", filename ) from exc def checking_iterator(): @@ -326,9 +353,9 @@ def checking_iterator(): # The first one was already checked. yield first for other in iter_iodatas: - _check_required(other, format_module.dump_many) + _check_required(filename, other, format_module.dump_many) if hasattr(format_module, "prepare_dump"): - format_module.prepare_dump(other) + format_module.prepare_dump(filename, other) yield other with open(filename, "w") as f: @@ -337,9 +364,10 @@ def checking_iterator(): except (PrepareDumpError, DumpError): raise except Exception as exc: - raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc + raise DumpError("Uncaught exception while dumping to a file.", filename) from exc +@_reissue_warnings def write_input( iodata: IOData, filename: str, @@ -370,11 +398,11 @@ def write_input( Keyword arguments are passed on to the input-specific write_input function. """ - input_module = _select_input_module(fmt) + input_module = _select_input_module(filename, fmt) with open(filename, "w") as fh: try: input_module.write_input(fh, iodata, template, atom_line, **kwargs) except Exception as exc: raise WriteInputError( - f"{filename}: Uncaught exception while writing an input file" + "Uncaught exception while writing an input file.", filename ) from exc diff --git a/iodata/formats/fchk.py b/iodata/formats/fchk.py index eef58110..e35f91f5 100644 --- a/iodata/formats/fchk.py +++ b/iodata/formats/fchk.py @@ -542,28 +542,32 @@ def _dump_real_arrays(name: str, val: NDArray[float], f: TextIO): k = 0 -def prepare_dump(data: IOData): +def prepare_dump(filename: str, data: IOData): """Check the compatibility of the IOData object with the FCHK format. Parameters ---------- + filename + The file to be written to, only used for error messages. data The IOData instance to be checked. """ if data.mo is not None: if data.mo.kind == "generalized": - raise PrepareDumpError("Cannot write FCHK file with generalized orbitals.") + raise PrepareDumpError("Cannot write FCHK file with generalized orbitals.", filename) na = int(np.round(np.sum(data.mo.occsa))) if not ((data.mo.occsa[:na] == 1.0).all() and (data.mo.occsa[na:] == 0.0).all()): raise PrepareDumpError( "Cannot dump FCHK because it does not have fully occupied alpha orbitals " - "followed by fully virtual ones." + "followed by fully virtual ones.", + filename, ) nb = int(np.round(np.sum(data.mo.occsb))) if not ((data.mo.occsb[:nb] == 1.0).all() and (data.mo.occsb[nb:] == 0.0).all()): raise PrepareDumpError( "Cannot dump FCHK because it does not have fully occupied beta orbitals " - "followed by fully virtual ones." + "followed by fully virtual ones.", + filename, ) @@ -643,7 +647,7 @@ def dump_one(f: TextIO, data: IOData): elif shell.ncon == 2 and shell.angmoms == [0, 1]: shell_types.append(-1) else: - raise DumpError("Cannot identify type of shell!") + raise DumpError("Cannot identify type of shell!", f) num_pure_d_shells = sum([1 for st in shell_types if st == 2]) num_pure_f_shells = sum([1 for st in shell_types if st == 3]) diff --git a/iodata/formats/json.py b/iodata/formats/json.py index eca7e644..851873a8 100644 --- a/iodata/formats/json.py +++ b/iodata/formats/json.py @@ -571,7 +571,7 @@ from ..docstrings import document_dump_one, document_load_one from ..iodata import IOData from ..periodic import num2sym, sym2num -from ..utils import DumpError, LineIterator, LoadError, LoadWarning, PrepareDumpError +from ..utils import DumpError, DumpWarning, LineIterator, LoadError, LoadWarning, PrepareDumpError __all__ = [] @@ -640,9 +640,11 @@ def _parse_json(json_in: dict, lit: LineIterator) -> dict: if "schema_name" not in result: # Attempt to determine schema type, since some QCElemental files omit this warn( - f"{lit.filename}: QCSchema files should have a `schema_name` key." - "Attempting to determine schema type...", - LoadWarning, + LoadWarning( + "QCSchema files should have a `schema_name` key." + "Attempting to determine schema type...", + lit.filename, + ), stacklevel=2, ) # Geometry is required in any molecule schema @@ -662,9 +664,11 @@ def _parse_json(json_in: dict, lit: LineIterator) -> dict: raise LoadError("Could not determine `schema_name`.", lit.filename) if "schema_version" not in result: warn( - f"{lit.filename}: QCSchema files should have a `schema_version` key." - "Attempting to load without version number.", - LoadWarning, + LoadWarning( + "QCSchema files should have a `schema_version` key." + "Attempting to load without version number.", + lit.filename, + ), stacklevel=2, ) @@ -754,8 +758,7 @@ def _parse_topology_keys(mol: dict, lit: LineIterator) -> dict: for key in should_be_required_keys: if key not in mol: warn( - f"{lit.filename}: QCSchema files should have a '{key}' key.", - LoadWarning, + LoadWarning(f"QCSchema files should have a '{key}' key.", lit.filename), stacklevel=2, ) for key in topology_keys: @@ -778,10 +781,12 @@ def _parse_topology_keys(mol: dict, lit: LineIterator) -> dict: # Check for missing charge, warn that this is a required field if "molecular_charge" not in mol: warn( - "{}: Missing 'molecular_charge' key." - "Some QCSchema writers omit this key for default value 0.0," - "Ensure this value is correct.", - LoadWarning, + LoadWarning( + "Missing 'molecular_charge' key." + "Some QCSchema writers omit this key for default value 0.0," + "Ensure this value is correct.", + lit.filename, + ), stacklevel=2, ) formal_charge = 0.0 @@ -793,10 +798,12 @@ def _parse_topology_keys(mol: dict, lit: LineIterator) -> dict: # Check for missing mult, warn that this is a required field if "molecular_multiplicity" not in mol: warn( - "{}: Missing 'molecular_multiplicity' key." - "Some QCSchema writers omit this key for default value 1," - "Ensure this value is correct.", - LoadWarning, + LoadWarning( + "Missing 'molecular_multiplicity' key." + "Some QCSchema writers omit this key for default value 1," + "Ensure this value is correct.", + lit.filename, + ), stacklevel=2, ) topology_dict["spinpol"] = 0 @@ -817,9 +824,11 @@ def _parse_topology_keys(mol: dict, lit: LineIterator) -> dict: # Load atom masses to array, canonical weights assumed if masses not given if "masses" in mol and "mass_numbers" in mol: warn( - "{}: Both `masses` and `mass_numbers` given. " - "Both values will be written to `extra` dict.", - LoadWarning, + LoadWarning( + "Both `masses` and `mass_numbers` given. " + "Both values will be written to `extra` dict.", + lit.filename, + ), stacklevel=2, ) extra_dict["mass_numbers"] = np.array(mol["mass_numbers"]) @@ -930,9 +939,10 @@ def _version_check(result: dict, max_version: float, schema_name: str, lit: Line version = -1 if float(version) < 0 or float(version) > max_version: warn( - f"{lit.filename}: Unknown {schema_name} version {version}, " - "loading may produce invalid results", - LoadWarning, + LoadWarning( + f"Unknown {schema_name} version {version}, " "loading may produce invalid results", + lit.filename, + ), stacklevel=2, ) return version @@ -1073,8 +1083,7 @@ def _parse_input_keys(result: dict, lit: LineIterator) -> dict: for key in should_be_required_keys: if key not in result: warn( - f"{lit.filename}: QCSchema files should have a '{key}' key.", - LoadWarning, + LoadWarning(f"QCSchema files should have a '{key}' key.", lit.filename), stacklevel=2, ) for key in input_keys: @@ -1206,15 +1215,19 @@ def _parse_model(model: dict, lit: LineIterator) -> dict: # QCEngineRecords doesn't give an empty string for basis-free methods, omits req'd key instead if "basis" not in model: warn( - f"{lit.filename}: Model `basis` key should be given. Assuming basis-free method.", + LoadWarning( + "Model `basis` key should be given. Assuming basis-free method.", lit.filename + ), stacklevel=2, ) elif isinstance(model["basis"], str): if model["basis"] == "": warn( - f"{lit.filename}: QCSchema `basis` could not be read and will be omitted." - "Unless model is for a basis-free method, check input file.", - LoadWarning, + LoadWarning( + "QCSchema `basis` could not be read and will be omitted." + "Unless model is for a basis-free method, check input file.", + lit.filename, + ), stacklevel=2, ) else: @@ -1247,8 +1260,10 @@ def _parse_protocols(protocols: dict, lit: LineIterator) -> dict: """ if "wavefunction" not in protocols: warn( - "{}: Protocols `wavefunction` key not specified, no properties will be kept.", - LoadWarning, + LoadWarning( + "Protocols `wavefunction` key not specified, no properties will be kept.", + lit.filename, + ), stacklevel=2, ) wavefunction = "none" @@ -1256,8 +1271,7 @@ def _parse_protocols(protocols: dict, lit: LineIterator) -> dict: wavefunction = protocols["wavefunction"] if "stdout" not in protocols: warn( - "{}: Protocols `stdout` key not specified, stdout will be kept.", - LoadWarning, + LoadWarning("Protocols `stdout` key not specified, stdout will be kept.", lit.filename), stacklevel=2, ) keep_stdout = True @@ -1333,8 +1347,7 @@ def _parse_output_keys(result: dict, lit: LineIterator) -> dict: for key in should_be_required_keys: if key not in result: warn( - f"{lit.filename}: QCSchema files should have a '{key}' key.", - LoadWarning, + LoadWarning(f"QCSchema files should have a '{key}' key.", lit.filename), stacklevel=2, ) for key in output_keys: @@ -1433,20 +1446,24 @@ def _parse_provenance( return base_provenance -def prepare_dump(data: IOData): +def prepare_dump(filename: str, data: IOData): """Check the compatibility of the IOData object with QCScheme. Parameters ---------- + filename + The file to be written to, only used for error messages. data The IOData instance to be checked. """ if "schema_name" not in data.extra: - raise PrepareDumpError("Cannot write qcschema file without 'schema_name' defined.") + raise PrepareDumpError( + "Cannot write qcschema file without 'schema_name' defined.", filename + ) schema_name = data.extra["schema_name"] if schema_name == "qcschema_basis": - raise PrepareDumpError(f"{schema_name} not yet implemented in IOData.") + raise PrepareDumpError(f"{schema_name} not yet implemented in IOData.", filename) @document_dump_one( @@ -1459,27 +1476,30 @@ def dump_one(f: TextIO, data: IOData): schema_name = data.extra["schema_name"] if schema_name == "qcschema_molecule": - return_dict = _dump_qcschema_molecule(data) + return_dict = _dump_qcschema_molecule(f, data) elif schema_name == "qcschema_basis": raise NotImplementedError(f"{schema_name} not yet implemented in IOData.") - # return_dict = _dump_qcschema_basis(data) + # return_dict = _dump_qcschema_basis(f, data) elif schema_name == "qcschema_input": - return_dict = _dump_qcschema_input(data) + return_dict = _dump_qcschema_input(f, data) elif schema_name == "qcschema_output": - return_dict = _dump_qcschema_output(data) + return_dict = _dump_qcschema_output(f, data) else: raise DumpError( "'schema_name' must be one of 'qcschema_molecule', 'qcschema_basis'" - "'qcschema_input' or 'qcschema_output'." + "'qcschema_input' or 'qcschema_output'.", + f, ) json.dump(return_dict, f, indent=4) -def _dump_qcschema_molecule(data: IOData) -> dict: +def _dump_qcschema_molecule(f: TextIO, data: IOData) -> dict: """Dump relevant attributes from IOData to :ref:`qcschema_molecule `. Parameters ---------- + f + The file being written, used for error and warning messages only. data The IOData instance to dump to file. @@ -1493,16 +1513,18 @@ def _dump_qcschema_molecule(data: IOData) -> dict: # Gather required field data if data.atnums is None or data.atcoords is None: - raise DumpError("qcschema_molecule requires `atnums` and `atcoords` fields.") + raise DumpError("qcschema_molecule requires `atnums` and `atcoords` fields.", f) molecule_dict["symbols"] = [num2sym[num] for num in data.atnums] molecule_dict["geometry"] = list(data.atcoords.flatten()) # Should be required field data if data.charge is None or data.spinpol is None: warn( - "`charge` and `spinpol` should be given to write qcschema_molecule file:" - "QCSchema defaults to charge = 0 and multiplicity = 1 if no values given.", - LoadWarning, + DumpWarning( + "`charge` and `spinpol` should be given to write qcschema_molecule file:" + "QCSchema defaults to charge = 0 and multiplicity = 1 if no values given.", + f, + ), stacklevel=2, ) if data.charge is not None: @@ -1554,7 +1576,7 @@ def _dump_qcschema_molecule(data: IOData) -> dict: molecule_dict["fix_com"] = data.extra["molecule"]["fix_com"] if "fix_orientation" in data.extra["molecule"]: molecule_dict["fix_orientation"] = data.extra["molecule"]["fix_orientation"] - molecule_dict["provenance"] = _dump_provenance(data, "molecule") + molecule_dict["provenance"] = _dump_provenance(f, data, "molecule") if "id" in data.extra["molecule"]: molecule_dict["id"] = data.extra["molecule"]["id"] if "extras" in data.extra["molecule"]: @@ -1566,13 +1588,15 @@ def _dump_qcschema_molecule(data: IOData) -> dict: return molecule_dict -def _dump_provenance(data: IOData, source: str) -> Union[list[dict], dict]: +def _dump_provenance(f: TextIO, data: IOData, source: str) -> Union[list[dict], dict]: """Generate the :ref:`provenance ` information. This is used when dumping an IOData instance to QCSchema. Parameters ---------- + f + The file being written, used for error and warning messages only. data The IOData instance to dump to file. source @@ -1596,11 +1620,11 @@ def _dump_provenance(data: IOData, source: str) -> Union[list[dict], dict]: if isinstance(provenance, list): provenance.append(new_provenance) return provenance - raise DumpError("QCSchema provenance must be either a dict or list of dicts.") + raise DumpError("QCSchema provenance must be either a dict or list of dicts.", f) return new_provenance -def _dump_qcschema_input(data: IOData) -> dict: +def _dump_qcschema_input(f: TextIO, data: IOData) -> dict: """Dump relevant attributes from IOData to :ref:`qcschema_input `. Using this function requires keywords to be stored in two locations in the ``extra`` dict: @@ -1609,6 +1633,8 @@ def _dump_qcschema_input(data: IOData) -> dict: Parameters ---------- + f + The file being written, used for error and warning messages only. data The IOData instance to dump to file. @@ -1621,19 +1647,19 @@ def _dump_qcschema_input(data: IOData) -> dict: input_dict = {"schema_name": "qcschema_input", "schema_version": 2.0} # Gather required field data - input_dict["molecule"] = _dump_qcschema_molecule(data) + input_dict["molecule"] = _dump_qcschema_molecule(f, data) if "driver" not in data.extra["input"]: - raise DumpError("qcschema_input requires `driver` field in extra['input'].") + raise DumpError("qcschema_input requires `driver` field in extra['input'].", f) if data.extra["input"]["driver"] not in {"energy", "gradient", "hessian", "properties"}: raise DumpError( - "QCSchema driver must be one of `energy`, `gradient`, `hessian`, or `properties`" + "QCSchema driver must be one of `energy`, `gradient`, `hessian`, or `properties`", f ) input_dict["driver"] = data.extra["input"]["driver"] if "model" not in data.extra["input"]: - raise DumpError("qcschema_input requires `model` field in extra['input'].") + raise DumpError("qcschema_input requires `model` field in extra['input'].", f) input_dict["model"] = {} if data.lot is None: - raise DumpError("qcschema_input requires specifed `lot`.") + raise DumpError("qcschema_input requires specifed `lot`.", f) input_dict["model"]["method"] = data.lot if data.obasis_name is None and "basis" not in data.extra["input"]["model"]: input_dict["model"]["basis"] = "" @@ -1651,7 +1677,7 @@ def _dump_qcschema_input(data: IOData) -> dict: # Remove 'keep_' from protocols keys (added in IOData for readability) for keep in data.extra["input"]["protocols"]: input_dict["protocols"][keep[5:]] = data.extra["input"]["protocols"][keep] - input_dict["provenance"] = _dump_provenance(data, "input") + input_dict["provenance"] = _dump_provenance(f, data, "input") if "unparsed" in data.extra["input"]: for k in data.extra["input"]["unparsed"]: input_dict[k] = data.extra["input"]["unparsed"][k] @@ -1659,7 +1685,7 @@ def _dump_qcschema_input(data: IOData) -> dict: return input_dict -def _dump_qcschema_output(data: IOData) -> dict: +def _dump_qcschema_output(f: TextIO, data: IOData) -> dict: """Dump relevant attributes from IOData to :ref:`qcschema_output `. Using this function requires keywords to be stored in three locations in the ``extra`` dict: @@ -1668,6 +1694,8 @@ def _dump_qcschema_output(data: IOData) -> dict: Parameters ---------- + f + The file being written, used for error and warning messages only. data The IOData instance to dump to file. @@ -1681,39 +1709,41 @@ def _dump_qcschema_output(data: IOData) -> dict: # Gather required field data # Gather required field data - output_dict["molecule"] = _dump_qcschema_molecule(data) + output_dict["molecule"] = _dump_qcschema_molecule(f, data) if "driver" not in data.extra["input"]: - raise DumpError("qcschema_output requires `driver` field in extra['input'].") + raise DumpError("qcschema_output requires `driver` field in extra['input'].", f) if data.extra["input"]["driver"] not in {"energy", "gradient", "hessian", "properties"}: raise DumpError( - "QCSchema driver must be one of `energy`, `gradient`, `hessian`, or `properties`" + "QCSchema driver must be one of `energy`, `gradient`, `hessian`, or `properties`", f ) output_dict["driver"] = data.extra["input"]["driver"] if "model" not in data.extra["input"]: - raise DumpError("qcschema_output requires `model` field in extra['input'].") + raise DumpError("qcschema_output requires `model` field in extra['input'].", f) output_dict["model"] = {} if data.lot is None: - raise DumpError("qcschema_output requires specifed `lot`.") + raise DumpError("qcschema_output requires specifed `lot`.", f) output_dict["model"]["method"] = data.lot if data.obasis_name is None and "basis" not in data.extra["input"]["model"]: warn( - "No basis name given. QCSchema assumes this signifies a basis-free method; to" - "avoid this warning, specify `obasis_name` as an empty string.", - LoadWarning, + DumpWarning( + "No basis name given. QCSchema assumes this signifies a basis-free method; to" + "avoid this warning, specify `obasis_name` as an empty string.", + f, + ), stacklevel=2, ) if "basis" in data.extra["input"]["model"]: raise NotImplementedError("qcschema_basis is not yet supported in IOData.") output_dict["model"]["basis"] = data.obasis_name if "properties" not in data.extra["output"]: - raise DumpError("qcschema_output requires `properties` field in extra['output'].") + raise DumpError("qcschema_output requires `properties` field in extra['output'].", f) output_dict["properties"] = data.extra["output"]["properties"] if data.energy is not None: output_dict["properties"]["return_energy"] = data.energy if output_dict["driver"] == "energy": output_dict["return_result"] = data.energy if "return_result" not in output_dict and "return_result" not in data.extra["output"]: - raise DumpError("qcschema_output requires `return_result` field in extra['output'].") + raise DumpError("qcschema_output requires `return_result` field in extra['output'].", f) if "return_result" in data.extra["output"]: output_dict["return_result"] = data.extra["output"]["return_result"] if "keywords" in data.extra["input"]: @@ -1735,7 +1765,7 @@ def _dump_qcschema_output(data: IOData) -> dict: output_dict["stderr"] = data.extra["output"]["stdout"] if "wavefunction" in data.extra["output"]: output_dict["wavefunction"] = data.extra["output"]["wavefunction"] - output_dict["provenance"] = _dump_provenance(data, "input") + output_dict["provenance"] = _dump_provenance(f, data, "input") if "unparsed" in data.extra["input"]: for k in data.extra["input"]["unparsed"]: output_dict[k] = data.extra["input"]["unparsed"][k] diff --git a/iodata/formats/mol2.py b/iodata/formats/mol2.py index defd4d70..1dbb9cb6 100644 --- a/iodata/formats/mol2.py +++ b/iodata/formats/mol2.py @@ -24,6 +24,7 @@ from collections.abc import Iterator from typing import TextIO +from warnings import warn import numpy as np from numpy.typing import NDArray @@ -36,7 +37,7 @@ ) from ..iodata import IOData from ..periodic import bond2num, num2bond, num2sym, sym2num -from ..utils import LineIterator, LoadError, angstrom +from ..utils import LineIterator, LoadError, LoadWarning, angstrom __all__ = [] @@ -100,7 +101,7 @@ def _load_helper_atoms( atnum = sym2num.get(symbol, sym2num.get(symbol[0], None)) if atnum is None: atnum = 0 - lit.warn(f"Can not convert {words[1][:2]} to elements") + warn(LoadWarning(f"Cannot interpret element symbol {words[1][:2]}", lit), stacklevel=2) atnums[i] = atnum attypes.append(words[5]) atcoords[i] = [float(words[2]), float(words[3]), float(words[4])] @@ -131,11 +132,11 @@ def _load_helper_bonds(lit: LineIterator, nbonds: int) -> NDArray[int]: int(words[1]) - 1, int(words[2]) - 1, # convert mol2 bond type to integer - bond2num.get(words[3], bond2num["un"]), + bond2num.get(words[3]), ] - if bond is None: - bond = [0, 0, 0] - lit.warn(f"Something wrong in the bond section: {bond}") + if bond[-1] is None: + bond[-1] = bond2num["un"] + warn(LoadWarning(f"Cannot interpret bond type {words[3]}", lit), stacklevel=2) bonds[i] = bond return bonds diff --git a/iodata/formats/molden.py b/iodata/formats/molden.py index e3cc02ab..14528094 100644 --- a/iodata/formats/molden.py +++ b/iodata/formats/molden.py @@ -27,6 +27,7 @@ import copy from typing import TextIO, Union +from warnings import warn import attrs import numpy as np @@ -45,7 +46,7 @@ from ..orbitals import MolecularOrbitals from ..overlap import compute_overlap, gob_cart_normalization from ..periodic import num2sym, sym2num -from ..utils import DumpError, LineIterator, LoadError, PrepareDumpError, angstrom +from ..utils import DumpError, LineIterator, LoadError, LoadWarning, PrepareDumpError, angstrom __all__ = [] @@ -676,7 +677,10 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold # --- ORCA orca_obasis = _fix_obasis_orca(obasis) if _is_normalized_properly(orca_obasis, atcoords, coeffsa, coeffsb, norm_threshold): - lit.warn("Corrected for typical ORCA errors in Molden/MKL file.") + warn( + LoadWarning("Corrected for typical ORCA errors in Molden/MKL file.", lit.filename), + stacklevel=2, + ) result["obasis"] = orca_obasis return @@ -685,7 +689,10 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold if psi4_obasis is not None and _is_normalized_properly( psi4_obasis, atcoords, coeffsa, coeffsb, norm_threshold ): - lit.warn("Corrected for PSI4 < 1.0 errors in Molden/MKL file.") + warn( + LoadWarning("Corrected for PSI4 < 1.0 errors in Molden/MKL file.", lit.filename), + stacklevel=2, + ) result["obasis"] = psi4_obasis return @@ -694,7 +701,10 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold if turbom_obasis is not None and _is_normalized_properly( turbom_obasis, atcoords, coeffsa, coeffsb, norm_threshold ): - lit.warn("Corrected for Turbomole errors in Molden/MKL file.") + warn( + LoadWarning("Corrected for Turbomole errors in Molden/MKL file.", lit.filename), + stacklevel=2, + ) result["obasis"] = turbom_obasis return @@ -704,7 +714,10 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold coeffsa_cfour = coeffsa / cfour_coeff_correction[:, np.newaxis] coeffsb_cfour = None if coeffsb is None else coeffsb / cfour_coeff_correction[:, np.newaxis] if _is_normalized_properly(obasis, atcoords, coeffsa_cfour, coeffsb_cfour, norm_threshold): - lit.warn("Corrected for CFOUR 2.1 errors in Molden/MKL file.") + warn( + LoadWarning("Corrected for CFOUR 2.1 errors in Molden/MKL file.", lit.filename), + stacklevel=2, + ) result["obasis"] = obasis if result["mo"].kind == "restricted": result["mo"].coeffs[:] = coeffsa_cfour @@ -716,7 +729,12 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold # --- Renormalized contractions normed_obasis = _fix_obasis_normalize_contractions(obasis) if _is_normalized_properly(normed_obasis, atcoords, coeffsa, coeffsb, norm_threshold): - lit.warn("Corrected for unnormalized contractions in Molden/MKL file.") + warn( + LoadWarning( + "Corrected for unnormalized contractions in Molden/MKL file.", lit.filename + ), + stacklevel=2, + ) result["obasis"] = normed_obasis return @@ -728,7 +746,10 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold if _is_normalized_properly( normed_obasis, atcoords, coeffsa_psi4, coeffsb_psi4, norm_threshold ): - lit.warn("Corrected for PSI4 <= 1.3.2 errors in Molden/MKL file.") + warn( + LoadWarning("Corrected for PSI4 <= 1.3.2 errors in Molden/MKL file.", lit.filename), + stacklevel=2, + ) result["obasis"] = normed_obasis if result["mo"].kind == "restricted": result["mo"].coeffs[:] = coeffsa_psi4 @@ -747,22 +768,24 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold ) -def prepare_dump(data: IOData): +def prepare_dump(filename: str, data: IOData): """Check the compatibility of the IOData object with the Molden format. Parameters ---------- + filename + The file to be written to, only used for error messages. data The IOData instance to be checked. """ if data.mo is None: - raise PrepareDumpError("The Molden format requires molecular orbitals.") + raise PrepareDumpError("The Molden format requires molecular orbitals.", filename) if data.obasis is None: - raise PrepareDumpError("The Molden format requires an orbital basis set.") + raise PrepareDumpError("The Molden format requires an orbital basis set.", filename) if data.mo.occs_aminusb is not None: - raise PrepareDumpError("Cannot write Molden file when mo.occs_aminusb is set.") + raise PrepareDumpError("Cannot write Molden file when mo.occs_aminusb is set.", filename) if data.mo.kind == "generalized": - raise PrepareDumpError("Cannot write Molden file with generalized orbitals.") + raise PrepareDumpError("Cannot write Molden file with generalized orbitals.", filename) @document_dump_one("Molden", ["atcoords", "atnums", "mo", "obasis"], ["atcorenums", "title"]) @@ -800,7 +823,8 @@ def dump_one(f: TextIO, data: IOData): if kind != angmom_kinds[angmom]: raise DumpError( "Molden format does not support mixed pure+Cartesian functions for one " - "angular momentum." + "angular momentum.", + f, ) else: angmom_kinds[angmom] = kind diff --git a/iodata/formats/molekel.py b/iodata/formats/molekel.py index b1552e43..c6c3e353 100644 --- a/iodata/formats/molekel.py +++ b/iodata/formats/molekel.py @@ -261,22 +261,24 @@ def load_one(lit: LineIterator, norm_threshold: float = 1e-4) -> dict: return result -def prepare_dump(data: IOData): +def prepare_dump(filename: str, data: IOData): """Check the compatibility of the IOData object with the Molekel format. Parameters ---------- + filename + The file to be written to, only used for error messages. data The IOData instance to be checked. """ if data.mo is None: - raise PrepareDumpError("The Molekel format requires molecular orbitals.") + raise PrepareDumpError("The Molekel format requires molecular orbitals.", filename) if data.obasis is None: - raise PrepareDumpError("The Molekel format requires an orbital basis set.") + raise PrepareDumpError("The Molekel format requires an orbital basis set.", filename) if data.mo.occs_aminusb is not None: - raise PrepareDumpError("Cannot write Molekel file when mo.occs_aminusb is set.") + raise PrepareDumpError("Cannot write Molekel file when mo.occs_aminusb is set.", filename) if data.mo.kind == "generalized": - raise PrepareDumpError("Cannot write Molekel file with generalized orbitals.") + raise PrepareDumpError("Cannot write Molekel file with generalized orbitals.", filename) @document_dump_one("Molekel", ["atcoords", "atnums", "mo", "obasis"], ["atcharges"]) @@ -373,7 +375,7 @@ def _dump_helper_coeffs(f, data, spin=None): ener = data.mo.energiesb irreps = data.mo.irreps[norb:] if data.mo.irreps is not None else ["a1g"] * norb else: - raise DumpError("A spin must be specified") + raise DumpError("A spin must be specified", f) for j in range(0, norb, 5): en = " ".join([f" {e: ,.12f}" for e in ener[j : j + 5]]) @@ -399,7 +401,7 @@ def _dump_helper_occ(f, data, spin=None): norb = data.mo.norba occ = data.mo.occs else: - raise DumpError("A spin must be specified") + raise DumpError("A spin must be specified", f) for j in range(0, norb, 5): occs = " ".join([f" {o: ,.7f}" for o in occ[j : j + 5]]) diff --git a/iodata/formats/pdb.py b/iodata/formats/pdb.py index 4dd3954b..c998d2f5 100644 --- a/iodata/formats/pdb.py +++ b/iodata/formats/pdb.py @@ -25,6 +25,7 @@ from collections.abc import Iterator from typing import TextIO +from warnings import warn import numpy as np @@ -36,7 +37,7 @@ ) from ..iodata import IOData from ..periodic import bond2num, num2sym, sym2num -from ..utils import LineIterator, LoadError, angstrom +from ..utils import LineIterator, LoadError, LoadWarning, angstrom __all__ = [] @@ -101,10 +102,18 @@ def _parse_pdb_atom_line(line, lit): # If not present, guess it from position 13:16 (atom name) atname = line[12:16].strip() atnum = sym2num.get(atname, sym2num.get(atname[:2].title(), sym2num.get(atname[0], None))) - lit.warn("Using the atom name in the PDB file to guess the chemical element.") + warn( + LoadWarning("Using the atom name in the PDB file to guess the chemical element.", lit), + stacklevel=2, + ) if atnum is None: atnum = 0 - lit.warn(f"Failed to determine the atomic number. atname='{atname}' symbol='{symbol}'") + warn( + LoadWarning( + f"Failed to determine the atomic number. atname='{atname}' symbol='{symbol}'", lit + ), + stacklevel=2, + ) # atom name, residue name, chain id, & residue sequence number atname = line[12:16].strip() @@ -190,7 +199,9 @@ def load_one(lit: LineIterator) -> dict: if not molecule_found: raise LoadError("Molecule could not be read.", lit) if not end_reached: - lit.warn("The END is not found, but the parsed data is returned.") + warn( + LoadWarning("The END is not found, but the parsed data is returned.", lit), stacklevel=2 + ) # Data related to force fields atffparams = { diff --git a/iodata/formats/wfn.py b/iodata/formats/wfn.py index 9b2b5188..05fce829 100644 --- a/iodata/formats/wfn.py +++ b/iodata/formats/wfn.py @@ -496,25 +496,29 @@ def _dump_helper_section(f: TextIO, data: NDArray, fmt: str, skip: int, step: in DEFAULT_WFN_TTL = "WFN auto-generated by IOData" -def prepare_dump(data: IOData): +def prepare_dump(filename: str, data: IOData): """Check the compatibility of the IOData object with the WFN format. Parameters ---------- + filename + The file to be written to, only used for error messages. data The IOData instance to be checked. """ if data.mo is None: - raise PrepareDumpError("The WFN format requires molecular orbitals") + raise PrepareDumpError("The WFN format requires molecular orbitals", filename) if data.obasis is None: - raise PrepareDumpError("The WFN format requires an orbital basis set") + raise PrepareDumpError("The WFN format requires an orbital basis set", filename) if data.mo.kind == "generalized": - raise PrepareDumpError("Cannot write WFN file with generalized orbitals.") + raise PrepareDumpError("Cannot write WFN file with generalized orbitals.", filename) if data.mo.occs_aminusb is not None: - raise PrepareDumpError("Cannot write WFN file when mo.occs_aminusb is set.") + raise PrepareDumpError("Cannot write WFN file when mo.occs_aminusb is set.", filename) for shell in data.obasis.shells: if any(kind != "c" for kind in shell.kinds): - raise PrepareDumpError("The WFN format only supports Cartesian MolecularBasis.") + raise PrepareDumpError( + "The WFN format only supports Cartesian MolecularBasis.", filename + ) @document_dump_one( diff --git a/iodata/formats/wfx.py b/iodata/formats/wfx.py index 4fbbf67a..2a15501f 100644 --- a/iodata/formats/wfx.py +++ b/iodata/formats/wfx.py @@ -21,9 +21,9 @@ See http://aim.tkgristmill.com/wfxformat.html """ -import warnings from collections.abc import Iterator from typing import Optional, TextIO +from warnings import warn import numpy as np @@ -32,7 +32,7 @@ from ..iodata import IOData from ..orbitals import MolecularOrbitals from ..periodic import num2sym -from ..utils import LineIterator, LoadError, PrepareDumpError +from ..utils import LineIterator, LoadError, LoadWarning, PrepareDumpError from .wfn import CONVENTIONS, build_obasis, get_mocoeff_scales __all__ = [] @@ -142,7 +142,7 @@ def load_data_wfx(lit: LineIterator) -> dict: elif key in lbs_other: result[lbs_other[key]] = value else: - warnings.warn(f"Not recognized section label, skip {key}", stacklevel=2) + warn(LoadWarning(f"Not recognized section label, skip {key}", lit), stacklevel=2) # reshape some arrays result["atcoords"] = result["atcoords"].reshape(-1, 3) @@ -329,25 +329,29 @@ def load_one(lit: LineIterator) -> dict: } -def prepare_dump(data: IOData): +def prepare_dump(filename: str, data: IOData): """Check the compatibility of the IOData object with the WFX format. Parameters ---------- + filename + The file to be written to, only used for error messages. data The IOData instance to be checked. """ if data.mo is None: - raise PrepareDumpError("The WFX format requires molecular orbitals.") + raise PrepareDumpError("The WFX format requires molecular orbitals.", filename) if data.obasis is None: - raise PrepareDumpError("The WFX format requires an orbital basis set.") + raise PrepareDumpError("The WFX format requires an orbital basis set.", filename) if data.mo.kind == "generalized": - raise PrepareDumpError("Cannot write WFX file with generalized orbitals.") + raise PrepareDumpError("Cannot write WFX file with generalized orbitals.", filename) if data.mo.occs_aminusb is not None: - raise PrepareDumpError("Cannot write WFX file when mo.occs_aminusb is set.") + raise PrepareDumpError("Cannot write WFX file when mo.occs_aminusb is set.", filename) for shell in data.obasis.shells: if any(kind != "c" for kind in shell.kinds): - raise PrepareDumpError("The WFX format only supports Cartesian MolecularBasis.") + raise PrepareDumpError( + "The WFX format only supports Cartesian MolecularBasis.", filename + ) @document_dump_one( diff --git a/iodata/test/data/water.mol2 b/iodata/test/data/water.mol2 new file mode 100644 index 00000000..36a4b631 --- /dev/null +++ b/iodata/test/data/water.mol2 @@ -0,0 +1,15 @@ +@MOLECULE +water + 3 2 0 0 +SMALL +NO_CHARGES +*** +Some comments. blabla + +@ATOM + 1 H 0.7838 -0.4922 -0.0000 H 1 XXX 0.0000 + 2 O -0.0000 0.0620 -0.0000 O 1 XXX 0.0000 + 3 H -0.7838 -0.4922 -0.0000 H 1 XXX 0.0000 +@BOND + 1 1 2 123 + 2 1 3 456 diff --git a/iodata/test/test_iodata.py b/iodata/test/test_iodata.py index 6ac36036..4d6206e6 100644 --- a/iodata/test/test_iodata.py +++ b/iodata/test/test_iodata.py @@ -27,6 +27,7 @@ from ..api import IOData, load_one from ..overlap import compute_overlap +from ..utils import FileFormatError from .common import compute_1rdm @@ -57,7 +58,8 @@ def test_typecheck_raises(): def test_unknown_format(): - pytest.raises(ValueError, load_one, "foo.unknown_file_extension") + with pytest.raises(FileFormatError): + load_one("foo.unknown_file_extension") def test_dm_water_sto3g_hf(): diff --git a/iodata/test/test_mol2.py b/iodata/test/test_mol2.py index 879aa1e4..f9247ed9 100644 --- a/iodata/test/test_mol2.py +++ b/iodata/test/test_mol2.py @@ -26,7 +26,7 @@ from ..api import dump_many, dump_one, load_many, load_one from ..periodic import bond2num -from ..utils import LoadError, angstrom +from ..utils import LoadError, LoadWarning, angstrom from .common import truncated_file @@ -139,3 +139,12 @@ def test_load_dump_wrong_bond_num(tmpdir): dump_one(mol, fn_tmp) mol2 = load_one(fn_tmp) assert mol2.bonds[0][2] == bond2num["un"] + + +def test_load_water_bonds_warning(): + with ( + as_file(files("iodata.test.data").joinpath("water.mol2")) as fn_mol, + pytest.warns(LoadWarning), + ): + mol = load_one(fn_mol) + assert_equal(mol.bonds, [[0, 1, bond2num["un"]], [0, 2, bond2num["un"]]]) diff --git a/iodata/utils.py b/iodata/utils.py index c9ed1ca3..dfbc9b9d 100644 --- a/iodata/utils.py +++ b/iodata/utils.py @@ -18,7 +18,7 @@ # -- """Utility functions module.""" -import warnings +from pathlib import Path from typing import Optional, TextIO, Union import attrs @@ -30,13 +30,14 @@ from .attrutils import validate_shape __all__ = ( + "LineIterator", + "FileFormatError", "LoadError", - "LoadWarning", "DumpError", - "DumpWarning", "PrepareDumpError", "WriteInputError", - "LineIterator", + "LoadWarning", + "DumpWarning", "Cube", "set_four_index_element", "volume", @@ -105,80 +106,129 @@ def __next__(self): self.lineno += 1 return self.stack.pop() if self.stack else next(self.fh) - def warn(self, msg: str): - """Raise a warning while reading a file. - - Parameters - ---------- - msg - Message to raise alongside filename and line number. - - """ - warnings.warn(f"{self.filename}:{self.lineno} {msg}", LoadWarning, stacklevel=2) - def back(self, line): """Go back one line in the file and decrease the lineno attribute by one.""" self.stack.append(line) self.lineno -= 1 -class FileFormatError(ValueError): - """Raise when a file or input format cannot be identified.""" +def _interpret_file_lineno( + file: Optional[Union[str, Path, LineIterator, TextIO]] = None, lineno: Optional[int] = None +) -> tuple[Optional[str], Optional[int]]: + """Interpret the file and lineno arguments given to Error and Warning constructors. + Parameters + ---------- + file + Object to deduce the filename (and optionally line number) from. + lineno + Line number, if known and not (correctly) included in the file object. + + Returns + ------- + filename + The filename associated with the file object. + lineno + The line number. + """ + if isinstance(file, str): + return file, lineno + if isinstance(file, Path): + return str(file), lineno + if isinstance(file, LineIterator): + if lineno is None: + lineno = file.lineno + return file.filename, lineno + if isinstance(file, TextIO): + return file.name, lineno + if file is None: + if lineno is not None: + raise TypeError("A line number without a file is not supported.") + return None, None + raise TypeError(f"Types of file and lineno are not supported: {file}, {lineno}") + + +def _format_file_message(message: str, filename: Optional[str], lineno: Optional[int]) -> str: + """Format the message of an exception. + + Parameters + ---------- + message + The actual error or warning message, without filename or line number info. + filename + The filename to which the error or warning is related. + lineno + The line number associated with the error or warning. + + Returns + ------- + full_message + The error message formated with filename and line number info. + """ + if filename is None: + return message + if lineno is None: + return f"{message} ({filename})" + return f"{message} ({filename}:{lineno})" -class LoadError(Exception): - """Raised when an error is encountered while loading from a file.""" + +class BaseFileError(Exception): + """Base class for all errors related to loading or dumping files.""" def __init__( self, message, - file: Optional[Union[str, LineIterator, TextIO]] = None, + file: Optional[Union[str, Path, LineIterator, TextIO]] = None, lineno: Optional[int] = None, ): super().__init__(message) - # Get the extra info - self.filename = None - self.lineno = None - if isinstance(file, str): - self.filename = file - elif isinstance(file, LineIterator): - self.filename = file.filename - if lineno is None: - self.lineno = file.lineno - elif isinstance(file, TextIO): - self.filename = file.name + self.filename, self.lineno = _interpret_file_lineno(file, lineno) def __str__(self): - if self.filename is None: - location = "" - elif self.lineno is None: - location = f" ({self.filename})" - else: - location = f" ({self.filename}:{self.lineno})" - message = super().__str__() - return f"{message}{location}" + return _format_file_message(super().__str__(), self.filename, self.lineno) -class LoadWarning(Warning): - """Raised when incorrect content is encountered and fixed when loading from a file.""" +class FileFormatError(BaseFileError): + """Raise when a file or input format cannot be identified.""" -class DumpError(ValueError): - """Raised when an error is encountered while dumping to a file.""" +class LoadError(BaseFileError): + """Raised when an error is encountered while loading from a file.""" -class DumpWarning(Warning): - """Raised when an IOData object is made compatible with a format when dumping to a file.""" +class DumpError(BaseFileError): + """Raised when an error is encountered while dumping to a file.""" -class PrepareDumpError(ValueError): +class PrepareDumpError(BaseFileError): """Raised when an IOData object is incompatible with a format before dumping to a file.""" -class WriteInputError(ValueError): +class WriteInputError(BaseFileError): """Raised when an error is encountered while writing an input file.""" +class BaseFileWarning(Warning): + """Base class for all warnings related to loading or dumping files.""" + + def __init__( + self, + message, + file: Optional[Union[str, Path, LineIterator, TextIO]] = None, + lineno: Optional[int] = None, + ): + self.filename, self.lineno = _interpret_file_lineno(file, lineno) + super().__init__(_format_file_message(message, self.filename, self.lineno)) + + +class LoadWarning(BaseFileWarning): + """Raised when incorrect content is encountered and fixed when loading from a file.""" + + +class DumpWarning(BaseFileWarning): + """Raised when an IOData object is made compatible with a format when dumping to a file.""" + + @attrs.define class Cube: """The volumetric data from a cube (or similar) file.""" From a354980a3185bd224487650642a210a1f06780bb Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Sun, 23 Jun 2024 11:35:57 +0200 Subject: [PATCH 2/4] Deepsource fixes --- iodata/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/iodata/api.py b/iodata/api.py index 2e1eeed6..2669df2f 100644 --- a/iodata/api.py +++ b/iodata/api.py @@ -137,6 +137,8 @@ def _reissue_warnings(func): """ def inner(*args, **kwargs): + """Wrapper for func that reissues warnings.""" + warning_list = [] try: with warnings.catch_warnings(record=True) as warning_list: result = func(*args, **kwargs) From 8a8c50723c64d055b4734742803fa7493291fde3 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Sun, 23 Jun 2024 11:45:45 +0200 Subject: [PATCH 3/4] AI-inspired improvements --- CONTRIBUTING.rst | 4 ++-- iodata/test/test_iodata.py | 2 +- iodata/test/test_mol2.py | 2 +- iodata/utils.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index ffeefd74..94c1cf82 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -163,7 +163,7 @@ When you encounter a file format error while reading the file, raise a ``LoadErr def load_one(lit: LineIterator) -> dict: ... if something_wrong: - raise LoadError("Describe the problem in a sentence.", lit) + raise LoadError("Describe the problem that made it impossible to load the file.", lit) The error that appears in the terminal will automatically include the file name and line number. If your code has already read the full file and encounters an error when processing the data, @@ -183,7 +183,7 @@ In this case, you should warn the user that the file contains (fixable) errors: def load_one(lit: LineIterator) -> dict: ... if something_fixed: - warn(LoadWarning("Describe the problem in a sentence.", lit), stacklevel=2) + warn(LoadWarning("Describe the issue that was fixed while loading.", lit), stacklevel=2) Always use ``stacklevel=2`` when raising warnings. diff --git a/iodata/test/test_iodata.py b/iodata/test/test_iodata.py index 4d6206e6..ec637b18 100644 --- a/iodata/test/test_iodata.py +++ b/iodata/test/test_iodata.py @@ -58,7 +58,7 @@ def test_typecheck_raises(): def test_unknown_format(): - with pytest.raises(FileFormatError): + with pytest.raises(FileFormatError, match="Cannot find file format with feature"): load_one("foo.unknown_file_extension") diff --git a/iodata/test/test_mol2.py b/iodata/test/test_mol2.py index f9247ed9..0af7f16d 100644 --- a/iodata/test/test_mol2.py +++ b/iodata/test/test_mol2.py @@ -144,7 +144,7 @@ def test_load_dump_wrong_bond_num(tmpdir): def test_load_water_bonds_warning(): with ( as_file(files("iodata.test.data").joinpath("water.mol2")) as fn_mol, - pytest.warns(LoadWarning), + pytest.warns(LoadWarning, match="Cannot interpret bond type"), ): mol = load_one(fn_mol) assert_equal(mol.bonds, [[0, 1, bond2num["un"]], [0, 2, bond2num["un"]]]) diff --git a/iodata/utils.py b/iodata/utils.py index dfbc9b9d..def01638 100644 --- a/iodata/utils.py +++ b/iodata/utils.py @@ -217,8 +217,8 @@ def __init__( file: Optional[Union[str, Path, LineIterator, TextIO]] = None, lineno: Optional[int] = None, ): - self.filename, self.lineno = _interpret_file_lineno(file, lineno) - super().__init__(_format_file_message(message, self.filename, self.lineno)) + filename, lineno = _interpret_file_lineno(file, lineno) + super().__init__(_format_file_message(message, filename, lineno)) class LoadWarning(BaseFileWarning): From fd0f2ea78806b5f3517649e12e19ea07c5425660 Mon Sep 17 00:00:00 2001 From: Toon Verstraelen Date: Mon, 24 Jun 2024 07:55:33 +0200 Subject: [PATCH 4/4] Fix minor whitespace issues --- iodata/formats/json.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/iodata/formats/json.py b/iodata/formats/json.py index 851873a8..019ecd18 100644 --- a/iodata/formats/json.py +++ b/iodata/formats/json.py @@ -641,7 +641,7 @@ def _parse_json(json_in: dict, lit: LineIterator) -> dict: # Attempt to determine schema type, since some QCElemental files omit this warn( LoadWarning( - "QCSchema files should have a `schema_name` key." + "QCSchema files should have a `schema_name` key. " "Attempting to determine schema type...", lit.filename, ), @@ -665,7 +665,7 @@ def _parse_json(json_in: dict, lit: LineIterator) -> dict: if "schema_version" not in result: warn( LoadWarning( - "QCSchema files should have a `schema_version` key." + "QCSchema files should have a `schema_version` key. " "Attempting to load without version number.", lit.filename, ), @@ -782,8 +782,8 @@ def _parse_topology_keys(mol: dict, lit: LineIterator) -> dict: if "molecular_charge" not in mol: warn( LoadWarning( - "Missing 'molecular_charge' key." - "Some QCSchema writers omit this key for default value 0.0," + "Missing 'molecular_charge' key. " + "Some QCSchema writers omit this key for default value 0.0. " "Ensure this value is correct.", lit.filename, ), @@ -799,8 +799,8 @@ def _parse_topology_keys(mol: dict, lit: LineIterator) -> dict: if "molecular_multiplicity" not in mol: warn( LoadWarning( - "Missing 'molecular_multiplicity' key." - "Some QCSchema writers omit this key for default value 1," + "Missing 'molecular_multiplicity' key. " + "Some QCSchema writers omit this key for default value 1. " "Ensure this value is correct.", lit.filename, ), @@ -940,7 +940,7 @@ def _version_check(result: dict, max_version: float, schema_name: str, lit: Line if float(version) < 0 or float(version) > max_version: warn( LoadWarning( - f"Unknown {schema_name} version {version}, " "loading may produce invalid results", + f"Unknown {schema_name} version {version}, loading may produce invalid results", lit.filename, ), stacklevel=2, @@ -1224,7 +1224,7 @@ def _parse_model(model: dict, lit: LineIterator) -> dict: if model["basis"] == "": warn( LoadWarning( - "QCSchema `basis` could not be read and will be omitted." + "QCSchema `basis` could not be read and will be omitted. " "Unless model is for a basis-free method, check input file.", lit.filename, ), @@ -1486,7 +1486,7 @@ def dump_one(f: TextIO, data: IOData): return_dict = _dump_qcschema_output(f, data) else: raise DumpError( - "'schema_name' must be one of 'qcschema_molecule', 'qcschema_basis'" + "'schema_name' must be one of 'qcschema_molecule', 'qcschema_basis' " "'qcschema_input' or 'qcschema_output'.", f, ) @@ -1521,7 +1521,7 @@ def _dump_qcschema_molecule(f: TextIO, data: IOData) -> dict: if data.charge is None or data.spinpol is None: warn( DumpWarning( - "`charge` and `spinpol` should be given to write qcschema_molecule file:" + "`charge` and `spinpol` should be given to write qcschema_molecule file: " "QCSchema defaults to charge = 0 and multiplicity = 1 if no values given.", f, ), @@ -1726,8 +1726,8 @@ def _dump_qcschema_output(f: TextIO, data: IOData) -> dict: if data.obasis_name is None and "basis" not in data.extra["input"]["model"]: warn( DumpWarning( - "No basis name given. QCSchema assumes this signifies a basis-free method; to" - "avoid this warning, specify `obasis_name` as an empty string.", + "No basis name given. QCSchema assumes this signifies a basis-free method; " + "to avoid this warning, specify `obasis_name` as an empty string.", f, ), stacklevel=2,