Skip to content

Commit

Permalink
Merge pull request #349 from tovrstra/informative-warnings-errors
Browse files Browse the repository at this point in the history
Further improve error and warning infrastructure
  • Loading branch information
tovrstra authored Jun 24, 2024
2 parents b98bb35 + fd0f2ea commit e4efb2e
Show file tree
Hide file tree
Showing 14 changed files with 395 additions and 192 deletions.
19 changes: 18 additions & 1 deletion CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,30 @@ When you encounter a file format error while reading the file, raise a ``LoadErr
def load_one(lit: LineIterator) -> dict:
...
if something_wrong:
raise LoadError("Describe the problem in a sentence.", lit)
raise LoadError("Describe the problem that made it impossible to load the file.", lit)
The error that appears in the terminal will automatically include the file name and line number.
If your code has already read the full file and encounters an error when processing the data,
you can use ``raise LoadError("Describe problem in a sentence.", lit.filename)`` instead.
This way, no line number is included in the error message.

Sometimes, it is possible to correct errors while reading a file.
In this case, you should warn the user that the file contains (fixable) errors:

.. code-block:: python
from warnings import warn
from ..utils import LoadWarning
@document_load_one(...)
def load_one(lit: LineIterator) -> dict:
...
if something_fixed:
warn(LoadWarning("Describe the issue that was fixed while loading.", lit), stacklevel=2)
Always use ``stacklevel=2`` when raising warnings.


``dump_one`` functions: writing a single IOData object to a file
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
72 changes: 51 additions & 21 deletions iodata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""Functions to be used by end users."""

import os
import warnings
from collections.abc import Iterable, Iterator
from fnmatch import fnmatch
from importlib import import_module
Expand Down Expand Up @@ -84,7 +85,7 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non
return format_module
else:
return FORMAT_MODULES[fmt]
raise FileFormatError(f"Could not find file format with feature {attrname} for file {filename}")
raise FileFormatError(f"Cannot find file format with feature {attrname}", filename)


def _find_input_modules():
Expand All @@ -101,11 +102,13 @@ def _find_input_modules():
INPUT_MODULES = _find_input_modules()


def _select_input_module(fmt: str) -> ModuleType:
def _select_input_module(filename: str, fmt: str) -> ModuleType:
"""Find an input module.
Parameters
----------
filename
The file to be written to, only used for error messages.
fmt
The name of the input module to use.
Expand All @@ -121,11 +124,33 @@ def _select_input_module(fmt: str) -> ModuleType:
"""
if fmt in INPUT_MODULES:
if not hasattr(INPUT_MODULES[fmt], "write_input"):
raise FileFormatError(f"{fmt} input module does not have write_input.")
raise FileFormatError(f"{fmt} input module does not have write_input.", filename)
return INPUT_MODULES[fmt]
raise FileFormatError(f"Could not find input format {fmt}.")
raise FileFormatError(f"Cannot find input format {fmt}.", filename)


def _reissue_warnings(func):
"""Correct stacklevel of warnings raised in functions called deeper in IOData.
This function should be used as a decorator of end-user API functions.
Adapted from https://stackoverflow.com/a/71635963/494584
"""

def inner(*args, **kwargs):
"""Wrapper for func that reissues warnings."""
warning_list = []
try:
with warnings.catch_warnings(record=True) as warning_list:
result = func(*args, **kwargs)
finally:
for warning in warning_list:
warnings.warn(warning.message, warning.category, stacklevel=2)
return result

return inner


@_reissue_warnings
def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
"""Load data from a file.
Expand All @@ -151,16 +176,16 @@ def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
format_module = _select_format_module(filename, "load_one", fmt)
with LineIterator(filename) as lit:
try:
iodata = IOData(**format_module.load_one(lit, **kwargs))
return IOData(**format_module.load_one(lit, **kwargs))
except LoadError:
raise
except StopIteration as exc:
raise LoadError("File ended before all data was read.", lit) from exc
except Exception as exc:
raise LoadError("Uncaught exception while loading file.", lit) from exc
return iodata


@_reissue_warnings
def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IOData]:
"""Load multiple IOData instances from a file.
Expand Down Expand Up @@ -197,11 +222,13 @@ def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IO
raise LoadError("Uncaught exception while loading file.", lit) from exc


def _check_required(iodata: IOData, dump_func: Callable):
def _check_required(filename: str, iodata: IOData, dump_func: Callable):
"""Check that required attributes are not None before dumping to a file.
Parameters
----------
filename
The file to be dumped to, only used for error messages.
iodata
The data to be written.
dump_func
Expand All @@ -215,10 +242,11 @@ def _check_required(iodata: IOData, dump_func: Callable):
for attr_name in dump_func.required:
if getattr(iodata, attr_name) is None:
raise PrepareDumpError(
f"Required attribute {attr_name}, for format {dump_func.fmt}, is None."
f"Required attribute {attr_name}, for format {dump_func.fmt}, is None.", filename
)


@_reissue_warnings
def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs):
"""Write data to a file.
Expand Down Expand Up @@ -251,24 +279,25 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs)
"""
format_module = _select_format_module(filename, "dump_one", fmt)
try:
_check_required(iodata, format_module.dump_one)
_check_required(filename, iodata, format_module.dump_one)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(iodata)
format_module.prepare_dump(filename, iodata)
except PrepareDumpError:
raise
except Exception as exc:
raise PrepareDumpError(
f"{filename}: Uncaught exception while preparing for dumping to a file"
"Uncaught exception while preparing for dumping to a file.", filename
) from exc
with open(filename, "w") as f:
try:
format_module.dump_one(f, iodata, **kwargs)
except DumpError:
raise
except Exception as exc:
raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc
raise DumpError("Uncaught exception while dumping to a file", filename) from exc


@_reissue_warnings
def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = None, **kwargs):
"""Write multiple IOData instances to a file.
Expand Down Expand Up @@ -309,26 +338,26 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non
try:
first = next(iter_iodatas)
except StopIteration as exc:
raise DumpError(f"{filename}: dump_many needs at least one iodata object.") from exc
raise DumpError("dump_many needs at least one iodata object.", filename) from exc
try:
_check_required(first, format_module.dump_many)
_check_required(filename, first, format_module.dump_many)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(first)
format_module.prepare_dump(filename, first)
except PrepareDumpError:
raise
except Exception as exc:
raise PrepareDumpError(
f"{filename}: Uncaught exception while preparing for dumping to a file"
"Uncaught exception while preparing for dumping to a file.", filename
) from exc

def checking_iterator():
"""Iterate over all iodata items, not checking the first."""
# The first one was already checked.
yield first
for other in iter_iodatas:
_check_required(other, format_module.dump_many)
_check_required(filename, other, format_module.dump_many)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(other)
format_module.prepare_dump(filename, other)
yield other

with open(filename, "w") as f:
Expand All @@ -337,9 +366,10 @@ def checking_iterator():
except (PrepareDumpError, DumpError):
raise
except Exception as exc:
raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc
raise DumpError("Uncaught exception while dumping to a file.", filename) from exc


@_reissue_warnings
def write_input(
iodata: IOData,
filename: str,
Expand Down Expand Up @@ -370,11 +400,11 @@ def write_input(
Keyword arguments are passed on to the input-specific write_input function.
"""
input_module = _select_input_module(fmt)
input_module = _select_input_module(filename, fmt)
with open(filename, "w") as fh:
try:
input_module.write_input(fh, iodata, template, atom_line, **kwargs)
except Exception as exc:
raise WriteInputError(
f"{filename}: Uncaught exception while writing an input file"
"Uncaught exception while writing an input file.", filename
) from exc
14 changes: 9 additions & 5 deletions iodata/formats/fchk.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,28 +542,32 @@ def _dump_real_arrays(name: str, val: NDArray[float], f: TextIO):
k = 0


def prepare_dump(data: IOData):
def prepare_dump(filename: str, data: IOData):
"""Check the compatibility of the IOData object with the FCHK format.
Parameters
----------
filename
The file to be written to, only used for error messages.
data
The IOData instance to be checked.
"""
if data.mo is not None:
if data.mo.kind == "generalized":
raise PrepareDumpError("Cannot write FCHK file with generalized orbitals.")
raise PrepareDumpError("Cannot write FCHK file with generalized orbitals.", filename)
na = int(np.round(np.sum(data.mo.occsa)))
if not ((data.mo.occsa[:na] == 1.0).all() and (data.mo.occsa[na:] == 0.0).all()):
raise PrepareDumpError(
"Cannot dump FCHK because it does not have fully occupied alpha orbitals "
"followed by fully virtual ones."
"followed by fully virtual ones.",
filename,
)
nb = int(np.round(np.sum(data.mo.occsb)))
if not ((data.mo.occsb[:nb] == 1.0).all() and (data.mo.occsb[nb:] == 0.0).all()):
raise PrepareDumpError(
"Cannot dump FCHK because it does not have fully occupied beta orbitals "
"followed by fully virtual ones."
"followed by fully virtual ones.",
filename,
)


Expand Down Expand Up @@ -643,7 +647,7 @@ def dump_one(f: TextIO, data: IOData):
elif shell.ncon == 2 and shell.angmoms == [0, 1]:
shell_types.append(-1)
else:
raise DumpError("Cannot identify type of shell!")
raise DumpError("Cannot identify type of shell!", f)

num_pure_d_shells = sum([1 for st in shell_types if st == 2])
num_pure_f_shells = sum([1 for st in shell_types if st == 3])
Expand Down
Loading

0 comments on commit e4efb2e

Please sign in to comment.