Skip to content

Commit

Permalink
Merge pull request #352 from tovrstra/prepare-unrestricted
Browse files Browse the repository at this point in the history
Complete prepare_dump API and apply to occs_aminusb + cleanups
  • Loading branch information
tovrstra authored Jun 28, 2024
2 parents e4efb2e + 565cfdf commit 03bf3e6
Show file tree
Hide file tree
Showing 15 changed files with 597 additions and 40 deletions.
4 changes: 2 additions & 2 deletions iodata/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def convert(infn, outfn, many, infmt, outfmt):
"""
if many:
dump_many((data for data in load_many(infn, infmt)), outfn, outfmt)
dump_many(load_many(infn, fmt=infmt), outfn, fmt=outfmt)
else:
dump_one(load_one(infn, infmt), outfn, outfmt)
dump_one(load_one(infn, fmt=infmt), outfn, fmt=outfmt)


def main():
Expand Down
56 changes: 44 additions & 12 deletions iodata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,14 @@ def _check_required(filename: str, iodata: IOData, dump_func: Callable):


@_reissue_warnings
def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs):
def dump_one(
iodata: IOData,
filename: str,
*,
fmt: Optional[str] = None,
allow_changes: bool = False,
**kwargs,
):
"""Write data to a file.
This routine uses the extension or prefix of the filename to determine
Expand All @@ -263,25 +270,35 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs)
fmt
The name of the file format module to use. When not given, it is guessed
from the filename.
allow_changes
Whether conversion of the IOData object to a compatible form is allowed or not.
**kwargs
Keyword arguments are passed on to the format-specific dump_one function.
Returns
-------
data
The given ``IOData`` object or a shallow copy with some new attributes if converted.
Raises
------
DumpError
When an error is encountered while dumping to a file.
If the output file already existed, it is (partially) overwritten.
PrepareDumpError
When the iodata object is not compatible with the file format,
e.g. due to missing attributes, and not conversion is available or allowed
When the ``IOData`` object is not compatible with the file format,
e.g. due to missing attributes, and no conversion is available or allowed
to make it compatible.
If the output file already existed, it is not overwritten.
PrepareDumpWarning
When the ``IOData`` object is not compatible with the file format,
but it was converted to fix the compatibility issue.
"""
format_module = _select_format_module(filename, "dump_one", fmt)
try:
_check_required(filename, iodata, format_module.dump_one)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(filename, iodata)
iodata = format_module.prepare_dump(iodata, allow_changes, filename)
except PrepareDumpError:
raise
except Exception as exc:
Expand All @@ -295,10 +312,18 @@ def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs)
raise
except Exception as exc:
raise DumpError("Uncaught exception while dumping to a file", filename) from exc
return iodata


@_reissue_warnings
def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = None, **kwargs):
def dump_many(
iodatas: Iterable[IOData],
filename: str,
*,
fmt: Optional[str] = None,
allow_changes: bool = False,
**kwargs,
):
"""Write multiple IOData instances to a file.
This routine uses the extension or prefix of the filename to determine
Expand All @@ -313,6 +338,8 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non
The file to write the data to.
fmt
The name of the file format module to use.
allow_changes
Whether conversion of the IOData object to a compatible form is allowed or not.
**kwargs
Keyword arguments are passed on to the format-specific dump_many function.
Expand All @@ -322,12 +349,15 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non
When an error is encountered while dumping to a file.
If the output file already existed, it (partially) overwritten.
PrepareDumpError
When the iodata object is not compatible with the file format,
e.g. due to missing attributes, and not conversion is available or allowed
When an ``IOData`` object is not compatible with the file format,
e.g. due to missing attributes, and no conversion is available or allowed
to make it compatible.
If the output file already existed, it is not overwritten when this error
is raised while processing the first IOData instance in the ``iodatas`` argument.
is raised while processing the first ``IOData`` instance in the ``iodatas`` argument.
When the exception is raised in later iterations, any existing file is overwritten.
PrepareDumpWarning
When an ``IOData`` object is not compatible with the file format,
but it was converted to fix the compatibility issue.
"""
format_module = _select_format_module(filename, "dump_many", fmt)

Expand All @@ -342,7 +372,7 @@ def dump_many(iodatas: Iterable[IOData], filename: str, fmt: Optional[str] = Non
try:
_check_required(filename, first, format_module.dump_many)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(filename, first)
first = format_module.prepare_dump(first, allow_changes, filename)
except PrepareDumpError:
raise
except Exception as exc:
Expand All @@ -356,9 +386,11 @@ def checking_iterator():
yield first
for other in iter_iodatas:
_check_required(filename, other, format_module.dump_many)
if hasattr(format_module, "prepare_dump"):
format_module.prepare_dump(filename, other)
yield other
yield (
format_module.prepare_dump(other, allow_changes, filename)
if hasattr(format_module, "prepare_dump")
else other
)

with open(filename, "w") as f:
try:
Expand Down
18 changes: 16 additions & 2 deletions iodata/formats/fchk.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,15 +542,28 @@ def _dump_real_arrays(name: str, val: NDArray[float], f: TextIO):
k = 0


def prepare_dump(filename: str, data: IOData):
def prepare_dump(data: IOData, allow_changes: bool, filename: str) -> IOData:
"""Check the compatibility of the IOData object with the FCHK format.
Parameters
----------
data
The IOData instance to be checked.
allow_changes
Whether conversion of the IOData object to a compatible form is allowed or not.
(not relevant for FCHK, present for API consistency)
filename
The file to be written to, only used for error messages.
Returns
-------
data
The IOData instance to be checked.
The given ``IOData`` object.
Raises
------
PrepareDumpError
If the given ``IOData`` instance is not compatible with the WFN format.
"""
if data.mo is not None:
if data.mo.kind == "generalized":
Expand All @@ -569,6 +582,7 @@ def prepare_dump(filename: str, data: IOData):
"followed by fully virtual ones.",
filename,
)
return data


@document_dump_one(
Expand Down
17 changes: 15 additions & 2 deletions iodata/formats/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,16 +1446,28 @@ def _parse_provenance(
return base_provenance


def prepare_dump(filename: str, data: IOData):
def prepare_dump(data: IOData, allow_changes: bool, filename: str) -> IOData:
"""Check the compatibility of the IOData object with QCScheme.
Parameters
----------
data
The IOData instance to be checked.
allow_changes
Whether conversion of the IOData object to a compatible form is allowed or not.
(not relevant for QCSchema JSON, present for API consistency)
filename
The file to be written to, only used for error messages.
Returns
-------
data
The IOData instance to be checked.
The given ``IOData`` object.
Raises
------
PrepareDumpError
If the given ``IOData`` instance is not compatible with the WFN format.
"""
if "schema_name" not in data.extra:
raise PrepareDumpError(
Expand All @@ -1464,6 +1476,7 @@ def prepare_dump(filename: str, data: IOData):
schema_name = data.extra["schema_name"]
if schema_name == "qcschema_basis":
raise PrepareDumpError(f"{schema_name} not yet implemented in IOData.", filename)
return data


@document_dump_one(
Expand Down
22 changes: 18 additions & 4 deletions iodata/formats/molden.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ..orbitals import MolecularOrbitals
from ..overlap import compute_overlap, gob_cart_normalization
from ..periodic import num2sym, sym2num
from ..prepare import prepare_unrestricted_aminusb
from ..utils import DumpError, LineIterator, LoadError, LoadWarning, PrepareDumpError, angstrom

__all__ = []
Expand Down Expand Up @@ -768,24 +769,37 @@ def _fix_molden_from_buggy_codes(result: dict, lit: LineIterator, norm_threshold
)


def prepare_dump(filename: str, data: IOData):
def prepare_dump(data: IOData, allow_changes: bool, filename: str) -> IOData:
"""Check the compatibility of the IOData object with the Molden format.
Parameters
----------
data
The IOData instance to be checked.
allow_changes
Whether conversion of the IOData object to a compatible form is allowed or not.
filename
The file to be written to, only used for error messages.
Returns
-------
data
The IOData instance to be checked.
The given ``IOData`` object or a shallow copy with some new attributes.
Raises
------
PrepareDumpError
If the given ``IOData`` instance is not compatible with the WFN format.
PrepareDumpWarning
If the a converted ``IOData`` instance is returned.
"""
if data.mo is None:
raise PrepareDumpError("The Molden format requires molecular orbitals.", filename)
if data.obasis is None:
raise PrepareDumpError("The Molden format requires an orbital basis set.", filename)
if data.mo.occs_aminusb is not None:
raise PrepareDumpError("Cannot write Molden file when mo.occs_aminusb is set.", filename)
if data.mo.kind == "generalized":
raise PrepareDumpError("Cannot write Molden file with generalized orbitals.", filename)
return prepare_unrestricted_aminusb(data, allow_changes, filename, "Molden")


@document_dump_one("Molden", ["atcoords", "atnums", "mo", "obasis"], ["atcorenums", "title"])
Expand Down
36 changes: 30 additions & 6 deletions iodata/formats/molekel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"""

from typing import TextIO
from warnings import warn

import numpy as np
from numpy.typing import NDArray
Expand All @@ -32,7 +33,8 @@
from ..docstrings import document_dump_one, document_load_one
from ..iodata import IOData
from ..orbitals import MolecularOrbitals
from ..utils import DumpError, LineIterator, LoadError, PrepareDumpError, angstrom
from ..prepare import prepare_unrestricted_aminusb
from ..utils import DumpError, LineIterator, LoadError, LoadWarning, PrepareDumpError, angstrom
from .molden import CONVENTIONS, _fix_molden_from_buggy_codes

__all__ = []
Expand Down Expand Up @@ -235,7 +237,16 @@ def load_one(lit: LineIterator, norm_threshold: float = 1e-4) -> dict:
)
nalpha = int(np.round(occsa.sum()))
nbeta = int(np.round(occsb.sum()))
assert abs(spinpol - abs(nalpha - nbeta)) < 1e-7
if abs(spinpol - abs(nalpha - nbeta)) > 1e-7:
warn(
LoadWarning(
f"The spin polarization ({spinpol}) is inconsistent with the"
f"difference between alpha and beta occupation numbers ({nalpha} - {nbeta}). "
"The spin polarization will be rederived from the occupation numbers.",
lit,
),
stacklevel=2,
)
assert nelec == nalpha + nbeta
assert coeffsa.shape == coeffsb.shape
assert energiesa.shape == energiesb.shape
Expand All @@ -261,24 +272,37 @@ def load_one(lit: LineIterator, norm_threshold: float = 1e-4) -> dict:
return result


def prepare_dump(filename: str, data: IOData):
def prepare_dump(data: IOData, allow_changes: bool, filename: str) -> IOData:
"""Check the compatibility of the IOData object with the Molekel format.
Parameters
----------
data
The IOData instance to be checked.
allow_changes
Whether conversion of the IOData object to a compatible form is allowed or not.
filename
The file to be written to, only used for error messages.
Returns
-------
data
The IOData instance to be checked.
The given ``IOData`` object or a shallow copy with some new attributes.
Raises
------
PrepareDumpError
If the given ``IOData`` instance is not compatible with the WFN format.
PrepareDumpWarning
If the a converted ``IOData`` instance is returned.
"""
if data.mo is None:
raise PrepareDumpError("The Molekel format requires molecular orbitals.", filename)
if data.obasis is None:
raise PrepareDumpError("The Molekel format requires an orbital basis set.", filename)
if data.mo.occs_aminusb is not None:
raise PrepareDumpError("Cannot write Molekel file when mo.occs_aminusb is set.", filename)
if data.mo.kind == "generalized":
raise PrepareDumpError("Cannot write Molekel file with generalized orbitals.", filename)
return prepare_unrestricted_aminusb(data, allow_changes, filename, "Molekel")


@document_dump_one("Molekel", ["atcoords", "atnums", "mo", "obasis"], ["atcharges"])
Expand Down
22 changes: 18 additions & 4 deletions iodata/formats/wfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ..orbitals import MolecularOrbitals
from ..overlap import gob_cart_normalization
from ..periodic import num2sym, sym2num
from ..prepare import prepare_unrestricted_aminusb
from ..utils import LineIterator, LoadError, PrepareDumpError

__all__ = []
Expand Down Expand Up @@ -496,29 +497,42 @@ def _dump_helper_section(f: TextIO, data: NDArray, fmt: str, skip: int, step: in
DEFAULT_WFN_TTL = "WFN auto-generated by IOData"


def prepare_dump(filename: str, data: IOData):
def prepare_dump(data: IOData, allow_changes: bool, filename: str) -> IOData:
"""Check the compatibility of the IOData object with the WFN format.
Parameters
----------
data
The IOData instance to be checked.
allow_changes
Whether conversion of the IOData object to a compatible form is allowed or not.
filename
The file to be written to, only used for error messages.
Returns
-------
data
The IOData instance to be checked.
The given ``IOData`` object or a shallow copy with some new attributes.
Raises
------
PrepareDumpError
If the given ``IOData`` instance is not compatible with the WFN format.
PrepareDumpWarning
If the a converted ``IOData`` instance is returned.
"""
if data.mo is None:
raise PrepareDumpError("The WFN format requires molecular orbitals", filename)
if data.obasis is None:
raise PrepareDumpError("The WFN format requires an orbital basis set", filename)
if data.mo.kind == "generalized":
raise PrepareDumpError("Cannot write WFN file with generalized orbitals.", filename)
if data.mo.occs_aminusb is not None:
raise PrepareDumpError("Cannot write WFN file when mo.occs_aminusb is set.", filename)
for shell in data.obasis.shells:
if any(kind != "c" for kind in shell.kinds):
raise PrepareDumpError(
"The WFN format only supports Cartesian MolecularBasis.", filename
)
return prepare_unrestricted_aminusb(data, allow_changes, filename, "WFN")


@document_dump_one(
Expand Down
Loading

0 comments on commit 03bf3e6

Please sign in to comment.