Skip to content

Commit

Permalink
Add ase_filter keyword to StructOptimizer.relax() (#102)
Browse files Browse the repository at this point in the history
* add ase_filter keyword to StructOptimizer.relax()

* test_relaxation parametrize ase_filter to test FrechetCellFilter and ExpCellFilter

* also link ExpCellFilter issue in doc str
  • Loading branch information
janosh authored Dec 7, 2023
1 parent ff636fa commit 1ead8bd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
12 changes: 9 additions & 3 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from ase import Atoms, units
from ase.calculators.calculator import Calculator, all_changes, all_properties
from ase.filters import FrechetCellFilter
from ase.filters import Filter, FrechetCellFilter
from ase.md.npt import NPT
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen, NPTBerendsen
from ase.md.nvtberendsen import NVTBerendsen
Expand Down Expand Up @@ -211,6 +211,7 @@ def relax(
fmax: float | None = 0.1,
steps: int | None = 500,
relax_cell: bool | None = True,
ase_filter: Filter = FrechetCellFilter,
save_path: str | None = None,
loginterval: int | None = 1,
crystal_feas_save_path: str | None = None,
Expand All @@ -227,6 +228,11 @@ def relax(
Default = 500
relax_cell (bool | None): Whether to relax the cell as well.
Default = True
ase_filter (ase.filters.Filter): The filter to apply to the atoms object
for relaxation. Default = FrechetCellFilter
Used to default to ExpCellFilter but was removed due to bug reported in
https://gitlab.com/ase/ase/-/issues/1321 and fixed in
https://gitlab.com/ase/ase/-/merge_requests/3024.
save_path (str | None): The path to save the trajectory.
Default = None
loginterval (int | None): Interval for logging trajectory and crystal feas
Expand Down Expand Up @@ -255,7 +261,7 @@ def relax(
cry_obs = CrystalFeasObserver(atoms)

if relax_cell:
atoms = FrechetCellFilter(atoms)
atoms = ase_filter(atoms)
optimizer = self.optimizer_class(atoms, **kwargs)
optimizer.attach(obs, interval=loginterval)

Expand All @@ -271,7 +277,7 @@ def relax(
if crystal_feas_save_path:
cry_obs.save(crystal_feas_save_path)

if isinstance(atoms, FrechetCellFilter):
if isinstance(atoms, Filter):
atoms = atoms.atoms
struct = AseAtomsAdaptor.get_structure(atoms)
for key in struct.site_properties:
Expand Down
9 changes: 6 additions & 3 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import torch
from ase.filters import ExpCellFilter, Filter, FrechetCellFilter
from pymatgen.core import Structure
from pytest import approx, mark, param

Expand All @@ -15,8 +16,10 @@
structure = Structure.from_file(f"{ROOT}/examples/mp-18767-LiMnO2.cif")


@pytest.mark.parametrize("algorithm", ["legacy", "fast"])
def test_relaxation(algorithm: Literal["legacy", "fast"]):
@pytest.mark.parametrize(
"algorithm, ase_filter", [("legacy", FrechetCellFilter), ("fast", ExpCellFilter)]
)
def test_relaxation(algorithm: Literal["legacy", "fast"], ase_filter: Filter) -> None:
chgnet = CHGNet.load()
converter = CrystalGraphConverter(
atom_graph_cutoff=6, bond_graph_cutoff=3, algorithm=algorithm
Expand All @@ -25,7 +28,7 @@ def test_relaxation(algorithm: Literal["legacy", "fast"]):

chgnet.graph_converter = converter
relaxer = StructOptimizer(model=chgnet)
result = relaxer.relax(structure, verbose=True)
result = relaxer.relax(structure, verbose=True, ase_filter=ase_filter)
assert list(result) == ["final_structure", "trajectory"]

traj = result["trajectory"]
Expand Down

0 comments on commit 1ead8bd

Please sign in to comment.