Skip to content

Commit

Permalink
added site_energy observable in MD
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Oct 25, 2024
1 parent 11ff513 commit 1eedc8c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
15 changes: 13 additions & 2 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
class CHGNetCalculator(Calculator):
"""CHGNet Calculator for ASE applications."""

implemented_properties = ("energy", "forces", "stress", "magmoms")
implemented_properties = ("energy", "forces", "stress", "magmoms", "energies")

def __init__(
self,
Expand All @@ -61,6 +61,7 @@ def __init__(
check_cuda_mem: bool = False,
stress_weight: float | None = 1 / 160.21766208,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
return_site_energies: bool = False,
**kwargs,
) -> None:
"""Provide a CHGNet instance to calculate various atomic properties using ASE.
Expand All @@ -80,6 +81,7 @@ def __init__(
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
with isolated atoms.
Default = 'warn'
return_site_energies (bool): whether to return the energy of each atom
**kwargs: Passed to the Calculator parent class.
"""
super().__init__(**kwargs)
Expand All @@ -95,6 +97,7 @@ def __init__(
self.model = model.to(self.device)
self.model.graph_converter.set_isolated_atom_response(on_isolated_atoms)
self.stress_weight = stress_weight
self.return_site_energies = return_site_energies
print(f"CHGNet will run on {self.device}")

@classmethod
Expand Down Expand Up @@ -143,7 +146,10 @@ def calculate(
structure = AseAtomsAdaptor.get_structure(atoms)
graph = self.model.graph_converter(structure)
model_prediction = self.model.predict_graph(
graph.to(self.device), task="efsm", return_crystal_feas=True
graph.to(self.device),
task="efsm",
return_crystal_feas=True,
return_site_energies=self.return_site_energies,
)

# Convert Result
Expand All @@ -156,6 +162,8 @@ def calculate(
stress=model_prediction["s"] * self.stress_weight,
crystal_fea=model_prediction["crystal_fea"],
)
if self.return_site_energies:
self.results.update(energies=model_prediction["site_energies"])


class StructOptimizer:
Expand Down Expand Up @@ -430,6 +438,7 @@ def __init__(
crystal_feas_logfile: str | None = None,
append_trajectory: bool = False,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
return_site_energies: bool = False,
use_device: str | None = None,
) -> None:
"""Initialize the MD class.
Expand Down Expand Up @@ -494,6 +503,7 @@ def __init__(
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
with isolated atoms.
Default = 'warn'
return_site_energies (bool): whether to return the energy of each atom
use_device (str): the device for the MD run
Default = None
"""
Expand All @@ -517,6 +527,7 @@ def __init__(
model=model,
use_device=use_device,
on_isolated_atoms=on_isolated_atoms,
return_site_energies=return_site_energies,
)

if taut is None:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pytest
from ase import Atoms
from ase.io.trajectory import Trajectory
from ase.md.npt import NPT
from ase.md.nptberendsen import Inhomogeneous_NPTBerendsen
from ase.md.nvtberendsen import NVTBerendsen
Expand Down Expand Up @@ -75,6 +76,7 @@ def test_md_nvt_berendsen(
trajectory="md_out.traj",
logfile="md_out.log",
loginterval=10,
return_site_energies=True,
)
md.run(100)

Expand Down Expand Up @@ -102,6 +104,10 @@ def test_md_nvt_berendsen(
)
assert_allclose(logs, ref, rtol=2.1e-3, atol=1e-8)

traj = Trajectory("md_out.traj")
assert isinstance(traj[0].get_potential_energy(), float)
assert isinstance(traj[0].get_potential_energies(), np.ndarray)


def test_md_nve(tmp_path: Path, monkeypatch: MonkeyPatch):
monkeypatch.chdir(tmp_path) # run MD in temporary directory
Expand Down

0 comments on commit 1eedc8c

Please sign in to comment.