Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve: ase try to get virials from different sources #660

Open
wants to merge 5 commits into
base: devel
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions dpdata/plugins/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,21 @@
"energies": np.array([energies]),
"forces": np.array([forces]),
}
try:
stress = atoms.get_stress(False)
except PropertyNotImplementedError:
pass
else:
virials = np.array([-atoms.get_volume() * stress])

# try to get virials from different sources
virials = atoms.info.get("virial")
if virials is None:
virials = atoms.info.get("virials")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear why it has two different keys.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a fallback strategy as users may use either virial or virials.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation is necessary if users are expected to do something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I try to search in the code base but find no place to add comment for this.

I don't think we need add extra documents for this as dpdata is supposed to be able to find virial for user automatically. My patch doesn't introduce any compatibility issue, it just make the ase plugin more robust to find virial.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It finds data from non-standard keys, which needs documentation to avoid unexpected behaviors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just add a comment in docstring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not rendered correctly.
image

if virials is None:
try:
stress = atoms.get_stress(False)
except PropertyNotImplementedError:
pass

Check warning on line 105 in dpdata/plugins/ase.py

View check run for this annotation

Codecov / codecov/patch

dpdata/plugins/ase.py#L104-L105

Added lines #L104 - L105 were not covered by tests
Comment on lines +109 to +110
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handle PropertyNotImplementedError more gracefully to inform the user.

-            except PropertyNotImplementedError:
-                pass
+            except PropertyNotImplementedError as e:
+                logging.warning(f"Failed to compute stress due to: {str(e)}")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
except PropertyNotImplementedError:
pass
except PropertyNotImplementedError as e:
logging.warning(f"Failed to compute stress due to: {str(e)}")

else:
virials = np.array([-atoms.get_volume() * stress])
if virials is not None:
Comment on lines +102 to +113
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider handling the PropertyNotImplementedError for stress calculation more gracefully.

The current implementation passes silently if PropertyNotImplementedError is raised when trying to get stress (lines 104-105). This could lead to virials being None without any indication of why, which might confuse users. Consider logging a warning or providing a fallback mechanism.

-            except PropertyNotImplementedError:
-                pass
+            except PropertyNotImplementedError as e:
+                logging.warning(f"Failed to compute stress due to: {str(e)}")

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
# try to get virials from different sources
virials = atoms.info.get("virial")
if virials is None:
virials = atoms.info.get("virials")
if virials is None:
try:
stress = atoms.get_stress(False)
except PropertyNotImplementedError:
pass
else:
virials = np.array([-atoms.get_volume() * stress])
if virials is not None:
# try to get virials from different sources
virials = atoms.info.get("virial")
if virials is None:
virials = atoms.info.get("virials")
if virials is None:
try:
stress = atoms.get_stress(False)
except PropertyNotImplementedError as e:
logging.warning(f"Failed to compute stress due to: {str(e)}")
else:
virials = np.array([-atoms.get_volume() * stress])
if virials is not None:

info_dict["virials"] = virials

return info_dict

def from_multi_systems(
Expand Down Expand Up @@ -159,13 +167,18 @@
return structures

def to_labeled_system(self, data, *args, **kwargs):
"""Convert System to ASE Atoms object."""
"""Convert System to ASE Atoms object.

Note that this method will try to load virials from the following sources:
- atoms.info['virial']
- atoms.info['virials']
- converted from stress tensor
"""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the method documentation to reflect the new virial retrieval logic.

The documentation in the to_labeled_system method (lines 170-176) mentions the sources from which virials are loaded, but it does not reflect the new logic added in from_labeled_system for obtaining virials from different sources and calculating them from stress. Consider updating the documentation to maintain consistency and clarity.

-        - converted from stress tensor
+        - converted from stress tensor if other sources are unavailable

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
"""Convert System to ASE Atoms object.
Note that this method will try to load virials from the following sources:
- atoms.info['virial']
- atoms.info['virials']
- converted from stress tensor
"""
"""Convert System to ASE Atoms object.
Note that this method will try to load virials from the following sources:
- atoms.info['virial']
- atoms.info['virials']
- converted from stress tensor if other sources are unavailable
"""

from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator

structures = []
species = [data["atom_names"][tt] for tt in data["atom_types"]]

for ii in range(data["coords"].shape[0]):
structure = Atoms(
symbols=species,
Expand Down