Skip to content

Commit

Permalink
Merge pull request #29 from microsoft/hanyang/restore-deep-calc
Browse files Browse the repository at this point in the history
fix: restore DeepCalculator for compatibility
  • Loading branch information
ZeroKnighting authored Dec 2, 2024
2 parents 548c40f + 52065f2 commit e036feb
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ dependencies = [
"azure-identity",
"mp-api",
"emmet-core<0.84",
"pydantic==2.9.2"
"pydantic==2.9.2",
"deprecated"
]

[project.optional-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions src/mattersim/forcefield/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-
from .potential import MatterSimCalculator, Potential
from .potential import DeepCalculator, MatterSimCalculator, Potential

__all__ = ["MatterSimCalculator", "Potential"]
__all__ = ["MatterSimCalculator", "Potential", "DeepCalculator"]
102 changes: 102 additions & 0 deletions src/mattersim/forcefield/potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ase.calculators.calculator import Calculator
from ase.constraints import full_3x3_to_voigt_6_stress
from ase.units import GPa
from deprecated import deprecated
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch_ema import ExponentialMovingAverage
Expand Down Expand Up @@ -990,6 +991,107 @@ def batch_to_dict(graph_batch, model_type="m3gnet", device="cuda"):
return input


@deprecated(version="1.0.0", reason="Please use MatterSimCalculator instead.")
class DeepCalculator(Calculator):
"""
Deep calculator based on ase Calculator
"""

implemented_properties = ["energy", "free_energy", "forces", "stress"]

def __init__(
self,
potential: Potential,
args_dict: dict = {},
compute_stress: bool = True,
stress_weight: float = 1.0,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
**kwargs,
):
"""
Args:
potential (Potential): m3gnet.models.Potential
compute_stress (bool): whether to calculate the stress
stress_weight (float): the stress weight.
**kwargs:
"""
super().__init__(**kwargs)
self.potential = potential
self.compute_stress = compute_stress
self.stress_weight = stress_weight
self.args_dict = args_dict
self.device = device

def calculate(
self,
atoms: Optional[Atoms] = None,
properties: Optional[list] = None,
system_changes: Optional[list] = None,
):
"""
Args:
atoms (ase.Atoms): ase Atoms object
properties (list): list of properties to calculate
system_changes (list): monitor which properties of atoms were
changed for new calculation. If not, the previous calculation
results will be loaded.
Returns:
"""

all_changes = [
"positions",
"numbers",
"cell",
"pbc",
"initial_charges",
"initial_magmoms",
]

properties = properties or ["energy"]
system_changes = system_changes or all_changes
super().calculate(
atoms=atoms, properties=properties, system_changes=system_changes
)

self.args_dict["batch_size"] = 1
self.args_dict["only_inference"] = 1
dataloader = build_dataloader(
[atoms], model_type=self.potential.model_name, **self.args_dict
)
for graph_batch in dataloader:
# Resemble input dictionary
if (
self.potential.model_name == "graphormer"
or self.potential.model_name == "geomformer"
):
raise NotImplementedError
else:
graph_batch = graph_batch.to(self.device)
input = batch_to_dict(graph_batch)

result = self.potential.forward(
input, include_forces=True, include_stresses=self.compute_stress
)
if (
self.potential.model_name == "graphormer"
or self.potential.model_name == "geomformer"
):
raise NotImplementedError
else:
self.results.update(
energy=result["total_energy"].detach().cpu().numpy()[0],
free_energy=result["total_energy"].detach().cpu().numpy()[0],
forces=result["forces"].detach().cpu().numpy(),
)
if self.compute_stress:
self.results.update(
stress=self.stress_weight
* full_3x3_to_voigt_6_stress(
result["stresses"].detach().cpu().numpy()[0]
)
)


class MatterSimCalculator(Calculator):
"""
Deep calculator based on ase Calculator
Expand Down

0 comments on commit e036feb

Please sign in to comment.