Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further improve error and warning infrastructure #2

Open
wants to merge 3 commits into
base: base-sha/b5bccd9789437feecbc80546e1020c070063bfed
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion CONTRIBUTING.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,30 @@ When you encounter a file format error while reading the file, raise a ``LoadErr
def load_one(lit: LineIterator) -> dict:
...
if something_wrong:
raise LoadError("Describe the problem in a sentence.", lit)
raise LoadError("Describe the problem that made it impossible to load the file.", lit)

The error that appears in the terminal will automatically include the file name and line number.
If your code has already read the full file and encounters an error when processing the data,
you can use ``raise LoadError("Describe problem in a sentence.", lit.filename)`` instead.
This way, no line number is included in the error message.

Sometimes, it is possible to correct errors while reading a file.
In this case, you should warn the user that the file contains (fixable) errors:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (documentation): Consider using 'correctable' instead of 'fixable'.

The term 'correctable' might be more formal and precise than 'fixable'.

Suggested change
In this case, you should warn the user that the file contains (fixable) errors:
In this case, you should warn the user that the file contains (correctable) errors:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test


.. code-block:: python

from warnings import warn

from ..utils import LoadWarning

@document_load_one(...)
def load_one(lit: LineIterator) -> dict:
...
if something_fixed:
warn(LoadWarning("Describe the issue that was fixed while loading.", lit), stacklevel=2)

Always use ``stacklevel=2`` when raising warnings.


``dump_one`` functions: writing a single IOData object to a file
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
72 changes: 51 additions & 21 deletions iodata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""Functions to be used by end users."""

import os
import warnings
from collections.abc import Iterable, Iterator
from fnmatch import fnmatch
from importlib import import_module
Expand Down Expand Up @@ -84,7 +85,7 @@ def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = Non
return format_module
else:
return FORMAT_MODULES[fmt]
raise FileFormatError(f"Could not find file format with feature {attrname} for file {filename}")
raise FileFormatError(f"Cannot find file format with feature {attrname}", filename)


def _find_input_modules():
Expand All @@ -101,11 +102,13 @@ def _find_input_modules():
INPUT_MODULES = _find_input_modules()


def _select_input_module(fmt: str) -> ModuleType:
def _select_input_module(filename: str, fmt: str) -> ModuleType:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider simplifying the code by reducing the number of parameters and centralizing error handling.

The new code introduces several complexities that could be simplified. Here are some key points:

  1. Increased Function Parameters: Several functions now take an additional filename parameter, which is only used for error messages. This increases the cognitive load on the developer, as they need to remember to pass this parameter and understand its purpose.

  2. Error Handling Changes: The error messages have been modified to include the filename, which is a good practice for debugging, but the way it has been implemented adds verbosity and complexity to the code. The original error messages were simpler and easier to read.

  3. Decorator Usage: The introduction of the _reissue_warnings decorator adds another layer of abstraction. While decorators can be useful, they can also make the code harder to follow, especially for developers who are not familiar with how the decorator works.

  4. Redundant Code: The changes introduce redundancy in error handling. For example, the filename is now passed around and included in multiple error messages, which could have been handled in a more centralized manner.

  5. Function Signature Changes: The changes to function signatures (e.g., _select_input_module and _check_required) make the code less intuitive. The original signatures were simpler and more straightforward.

Suggested Simplifications:

  1. Remove Redundant Filename Parameter: Avoid passing the filename parameter around unnecessarily. Use it directly where needed.

  2. Simplify Error Messages: Keep error messages simple and clear, while still including the filename for context.

  3. Maintain Decorator Usage: Retain the _reissue_warnings decorator but use it in a way that doesn't add unnecessary complexity.

  4. Simplify Function Signatures: Keep function signatures simple and intuitive.

Here is a simplified version of the code that maintains the original simplicity while incorporating the necessary changes:

import os
import warnings
from collections.abc import Iterable, Iterator
from fnmatch import fnmatch
from importlib import import_module
from pkgutil import iter_modules
from types import ModuleType
from typing import Callable, Optional

from .iodata import IOData
from .utils import DumpError, FileFormatError, LoadError, PrepareDumpError, WriteInputError

FORMAT_MODULES = {}  # Assuming this is defined somewhere in the original code

def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = None) -> ModuleType:
    """Select the appropriate format module based on the filename or specified format."""
    basename = os.path.basename(filename)
    if fmt is None:
        for format_module in FORMAT_MODULES.values():
            if any(fnmatch(basename, pattern) for pattern in format_module.PATTERNS) and hasattr(format_module, attrname):
                return format_module
    else:
        return FORMAT_MODULES[fmt]
    raise FileFormatError(f"Could not find file format with feature {attrname} for file {filename}")

def _find_input_modules():
    """Return all input modules found with importlib."""
    result = {}
    for module_info in iter_modules(import_module("iodata.inputs").__path__):
        if not module_info.ispkg:
            input_module = import_module("iodata.inputs." + module_info.name)
            if hasattr(input_module, "write_input"):
                result[module_info.name] = input_module
    return result

INPUT_MODULES = _find_input_modules()

def _select_input_module(fmt: str) -> ModuleType:
    """Find an input module."""
    if fmt in INPUT_MODULES:
        if not hasattr(INPUT_MODULES[fmt], "write_input"):
            raise FileFormatError(f"{fmt} input module does not have write_input.")
        return INPUT_MODULES[fmt]
    raise FileFormatError(f"Could not find input format {fmt}.")

def _reissue_warnings(func):
    """Correct stacklevel of warnings raised in functions called deeper in IOData."""
    def inner(*args, **kwargs):
        """Wrapper for func that reissues warnings."""
        warning_list = []
        try:
            with warnings.catch_warnings(record=True) as warning_list:
                result = func(*args, **kwargs)
        finally:
            for warning in warning_list:
                warnings.warn(warning.message, warning.category, stacklevel=2)
        return result
    return inner

@_reissue_warnings
def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
    """Load data from a file."""
    format_module = _select_format_module(filename, "load_one", fmt)
    with LineIterator(filename) as lit:
        try:
            return IOData(**format_module.load_one(lit, **kwargs))
        except LoadError:
            raise
        except StopIteration as exc:
            raise LoadError("File ended before all data was read.", lit) from exc
        except Exception as exc:
            raise LoadError("Uncaught exception while loading file.", lit) from exc

@_reissue_warnings
def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IOData]:
    """Load multiple IOData instances from a file."""
    format_module = _select_format_module(filename, "load_many", fmt)
    with LineIterator(filename) as lit:
        try:
            for data in format_module.load_many(lit, **kwargs):
                yield IOData(**data)
        except StopIteration:
            return
        except LoadError:
            raise
        except Exception as exc:
            raise LoadError("Uncaught exception while loading file.", lit) from exc

def _check_required(iodata: IOData, dump_func: Callable):
    """Check that required attributes are not None before dumping to a file."""
    for attr_name in dump_func.required:
        if getattr(iodata, attr_name) is None:
            raise PrepareDumpError(f"Required attribute {attr_name}, for format {dump_func.fmt}, is None.")

@_reissue_warnings
def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs):
    """Write data to a file."""
    format_module = _select_format_module(filename, "dump_one", fmt)
    try:
        _check_required(iodata, format_module.dump_one)
        if hasattr(format_module, "prepare_dump"):
            format_module.prepare_dump(iodata)
    except PrepareDumpError:
        raise
    except Exception as exc:
        raise PrepareDumpError(f"{filename}: Uncaught exception while preparing for dumping to a file") from exc
    with open(filename, "w") as f:
        try:
            format_module.dump_one(f, iodata, **kwargs)
        except DumpError:
            raise
        except Exception as exc:
            raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc

@_reissue_warnings
def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = None, **kwargs):
    """Write multiple IOData instances to a file."""
    format_module = _select_format_module(filename, "dump_many", fmt)
    iter_iodatas = iter(iodatas)
    try:
        first = next(iter_iodatas)
    except StopIteration as exc:
        raise DumpError(f"{filename}: dump_many needs at least one iodata object.") from exc
    try:
        _check_required(first, format_module.dump_many)
        if hasattr(format_module, "prepare_dump"):
            format_module.prepare_dump(first)
    except PrepareDumpError:
        raise
    except Exception as exc:
        raise PrepareDumpError(f"{filename}: Uncaught exception while preparing for dumping to a file") from exc

    def checking_iterator():
        """Iterate over all iodata items, not checking the first."""
        yield first
        for other in iter_iodatas:
            _check_required(other, format_module.dump_many)
            if hasattr(format_module, "prepare_dump"):
                format_module.prepare_dump(other)
            yield other

    with open(filename, "w") as f:
        try:
            format_module.dump_many(f, checking_iterator(), **kwargs)
        except (PrepareDumpError, DumpError):
            raise
        except Exception as exc:
            raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc

@_reissue_warnings
def write_input(iodata: IOData, filename: str, fmt: str, template: Optional[str] = None, atom_line: Optional[Callable] = None, **kwargs):
    """Write input file using an instance of IOData for the specified software format."""
    input_module = _select_input_module(fmt)
    with open(filename, "w") as fh:
        try:
            input_module.write_input(fh, iodata, template, atom_line, **kwargs)
        except Exception as exc:
            raise WriteInputError(f"{filename}: Uncaught exception while writing an input file") from exc

This approach maintains the original simplicity while incorporating the necessary changes, making the code easier to maintain and understand.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test

"""Find an input module.

Parameters
----------
filename
The file to be written to, only used for error messages.
fmt
The name of the input module to use.

Expand All @@ -121,11 +124,33 @@ def _select_input_module(fmt: str) -> ModuleType:
"""
if fmt in INPUT_MODULES:
if not hasattr(INPUT_MODULES[fmt], "write_input"):
raise FileFormatError(f"{fmt} input module does not have write_input.")
raise FileFormatError(f"{fmt} input module does not have write_input.", filename)
return INPUT_MODULES[fmt]
raise FileFormatError(f"Could not find input format {fmt}.")
raise FileFormatError(f"Cannot find input format {fmt}.", filename)


def _reissue_warnings(func):
"""Correct stacklevel of warnings raised in functions called deeper in IOData.

This function should be used as a decorator of end-user API functions.
Adapted from https://stackoverflow.com/a/71635963/494584
"""

def inner(*args, **kwargs):
"""Wrapper for func that reissues warnings."""
warning_list = []
try:
with warnings.catch_warnings(record=True) as warning_list:
result = func(*args, **kwargs)
finally:
for warning in warning_list:
warnings.warn(warning.message, warning.category, stacklevel=2)
return result

return inner


@_reissue_warnings
def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
"""Load data from a file.

Expand All @@ -151,16 +176,16 @@ def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
format_module = _select_format_module(filename, "load_one", fmt)
with LineIterator(filename) as lit:
try:
iodata = IOData(**format_module.load_one(lit, **kwargs))
return IOData(**format_module.load_one(lit, **kwargs))
except LoadError:
raise
except StopIteration as exc:
raise LoadError("File ended before all data was read.", lit) from exc
except Exception as exc:
raise LoadError("Uncaught exception while loading file.", lit) from exc
return iodata


@_reissue_warnings
def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IOData]:
"""Load multiple IOData instances from a file.

Expand Down Expand Up @@ -197,11 +222,13 @@ def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IO
raise LoadError("Uncaught exception while loading file.", lit) from exc


def _check_required(iodata: IOData, dump_func: Callable):
def _check_required(filename: str, iodata: IOData, dump_func: Callable):
"""Check that required attributes are not None before dumping to a file.

Parameters
----------
filename
The file to be dumped to, only used for error messages.
iodata
The data to be written.
dump_func
Expand All @@ -215,10 +242,11 @@ def _check_required(iodata: IOData, dump_func: Callable):
for attr_name in dump_func.required:
if getattr(iodata, attr_name) is None:
raise PrepareDumpError(
f"Required attribute {attr_name}, for format {dump_func.fmt}, is None."
f"Required attribute {attr_name}, for format {dump_func.fmt}, is None.", filename
)


@_reissue_warnings
def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs):
"""Write data to a file.

Expand Down Expand Up @@ -251,24 +279,25 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs)
"""
format_module = _select_format_module(filename, "dump_one", fmt)
try:
_check_required(iodata, format_module.dump_one)
_check_required(filename, iodata, format_module.dump_one)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(iodata)
format_module.prepare_dump(filename, iodata)
except PrepareDumpError:
raise
except Exception as exc:
raise PrepareDumpError(
f"{filename}: Uncaught exception while preparing for dumping to a file"
"Uncaught exception while preparing for dumping to a file.", filename
) from exc
with open(filename, "w") as f:
try:
format_module.dump_one(f, iodata, **kwargs)
except DumpError:
raise
except Exception as exc:
raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc
raise DumpError("Uncaught exception while dumping to a file", filename) from exc


@_reissue_warnings
def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = None, **kwargs):
"""Write multiple IOData instances to a file.

Expand Down Expand Up @@ -309,26 +338,26 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non
try:
first = next(iter_iodatas)
except StopIteration as exc:
raise DumpError(f"{filename}: dump_many needs at least one iodata object.") from exc
raise DumpError("dump_many needs at least one iodata object.", filename) from exc
try:
_check_required(first, format_module.dump_many)
_check_required(filename, first, format_module.dump_many)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(first)
format_module.prepare_dump(filename, first)
except PrepareDumpError:
raise
except Exception as exc:
raise PrepareDumpError(
f"{filename}: Uncaught exception while preparing for dumping to a file"
"Uncaught exception while preparing for dumping to a file.", filename
) from exc

def checking_iterator():
"""Iterate over all iodata items, not checking the first."""
# The first one was already checked.
yield first
for other in iter_iodatas:
_check_required(other, format_module.dump_many)
_check_required(filename, other, format_module.dump_many)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(other)
format_module.prepare_dump(filename, other)
yield other

with open(filename, "w") as f:
Expand All @@ -337,9 +366,10 @@ def checking_iterator():
except (PrepareDumpError, DumpError):
raise
except Exception as exc:
raise DumpError(f"{filename}: Uncaught exception while dumping to a file") from exc
raise DumpError("Uncaught exception while dumping to a file.", filename) from exc


@_reissue_warnings
def write_input(
iodata: IOData,
filename: str,
Expand Down Expand Up @@ -370,11 +400,11 @@ def write_input(
Keyword arguments are passed on to the input-specific write_input function.

"""
input_module = _select_input_module(fmt)
input_module = _select_input_module(filename, fmt)
with open(filename, "w") as fh:
try:
input_module.write_input(fh, iodata, template, atom_line, **kwargs)
except Exception as exc:
raise WriteInputError(
f"{filename}: Uncaught exception while writing an input file"
"Uncaught exception while writing an input file.", filename
) from exc
14 changes: 9 additions & 5 deletions iodata/formats/fchk.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,28 +542,32 @@ def _dump_real_arrays(name: str, val: NDArray[float], f: TextIO):
k = 0


def prepare_dump(data: IOData):
def prepare_dump(filename: str, data: IOData):
"""Check the compatibility of the IOData object with the FCHK format.

Parameters
----------
filename
The file to be written to, only used for error messages.
data
The IOData instance to be checked.
"""
if data.mo is not None:
if data.mo.kind == "generalized":
raise PrepareDumpError("Cannot write FCHK file with generalized orbitals.")
raise PrepareDumpError("Cannot write FCHK file with generalized orbitals.", filename)
na = int(np.round(np.sum(data.mo.occsa)))
if not ((data.mo.occsa[:na] == 1.0).all() and (data.mo.occsa[na:] == 0.0).all()):
raise PrepareDumpError(
"Cannot dump FCHK because it does not have fully occupied alpha orbitals "
"followed by fully virtual ones."
"followed by fully virtual ones.",
filename,
)
nb = int(np.round(np.sum(data.mo.occsb)))
if not ((data.mo.occsb[:nb] == 1.0).all() and (data.mo.occsb[nb:] == 0.0).all()):
raise PrepareDumpError(
"Cannot dump FCHK because it does not have fully occupied beta orbitals "
"followed by fully virtual ones."
"followed by fully virtual ones.",
filename,
)


Expand Down Expand Up @@ -643,7 +647,7 @@ def dump_one(f: TextIO, data: IOData):
elif shell.ncon == 2 and shell.angmoms == [0, 1]:
shell_types.append(-1)
else:
raise DumpError("Cannot identify type of shell!")
raise DumpError("Cannot identify type of shell!", f)

num_pure_d_shells = sum([1 for st in shell_types if st == 2])
num_pure_f_shells = sum([1 for st in shell_types if st == 3])
Expand Down
Loading