Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CHGNet.version property #86

Merged
merged 8 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.292
rev: v0.1.1
hooks:
- id: ruff
args: [--fix]

- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.10.0
hooks:
- id: black-jupyter

Expand Down Expand Up @@ -49,7 +49,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v8.51.0
rev: v8.52.0
hooks:
- id: eslint
types: [file]
Expand Down
4 changes: 2 additions & 2 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def relax(
A dictionary with 'final_structure' and 'trajectory'.
"""
if isinstance(atoms, Structure):
atoms = AseAtomsAdaptor.get_atoms(atoms)
atoms = atoms.to_ase_atoms()

atoms.calc = self.calculator # assign model used to predict forces

Expand Down Expand Up @@ -432,7 +432,7 @@ def __init__(
self.ensemble = ensemble
self.thermostat = thermostat
if isinstance(atoms, (Structure, Molecule)):
atoms = AseAtomsAdaptor.get_atoms(atoms)
atoms = atoms.to_ase_atoms()

self.atoms = atoms
if isinstance(model, CHGNetCalculator):
Expand Down
77 changes: 46 additions & 31 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
if TYPE_CHECKING:
from chgnet import PredTask

module_dir = os.path.dirname(os.path.abspath(__file__))


class CHGNet(nn.Module):
"""Crystal Hamiltonian Graph neural Network
Expand All @@ -38,8 +40,8 @@ def __init__(
bond_fea_dim: int = 64,
angle_fea_dim: int = 64,
composition_model: str | nn.Module = "MPtrj",
num_radial: int = 9,
num_angular: int = 9,
num_radial: int = 31,
num_angular: int = 31,
n_conv: int = 4,
atom_conv_hidden_dim: Sequence[int] | int = 64,
update_bond: bool = True,
Expand All @@ -48,19 +50,22 @@ def __init__(
angle_layer_hidden_dim: Sequence[int] | int = 0,
conv_dropout: float = 0,
read_out: str = "ave",
mlp_hidden_dims: Sequence[int] | int = (64, 64),
mlp_hidden_dims: Sequence[int] | int = (64, 64, 64),
mlp_dropout: float = 0,
mlp_first: bool = True,
is_intensive: bool = True,
non_linearity: Literal["silu", "relu", "tanh", "gelu"] = "silu",
atom_graph_cutoff: float = 5,
atom_graph_cutoff: float = 6,
bond_graph_cutoff: float = 3,
graph_converter_algorithm: Literal["legacy", "fast"] = "fast",
cutoff_coeff: int = 5,
cutoff_coeff: int = 8,
learnable_rbf: bool = True,
gMLP_norm: str | None = "layer",
readout_norm: str | None = "layer",
version: str | None = None,
**kwargs,
) -> None:
"""Initialize the CHGNet.
"""Initialize CHGNet.

Args:
atom_fea_dim (int): atom feature vector embedding dimension.
Expand Down Expand Up @@ -135,6 +140,11 @@ def __init__(
learnable_rbf (bool): whether to set the frequencies in rbf and Fourier
basis functions learnable.
Default = True
gMLP_norm (str): normalization layer to use in gate-MLP
Default = 'layer'
readout_norm (str): normalization layer to use before readout layer
Default = 'layer'
version (str): Pretrained checkpoint version.
**kwargs: Additional keyword arguments
"""
# Store model args for reconstruction
Expand All @@ -144,6 +154,8 @@ def __init__(
if k not in ["self", "__class__", "kwargs"]
}
self.model_args.update(kwargs)
if version:
self.model_args["version"] = version

super().__init__()
self.atom_fea_dim = atom_fea_dim
Expand Down Expand Up @@ -200,7 +212,6 @@ def __init__(

# Define convolutional layers
conv_norm = kwargs.pop("conv_norm", None)
gMLP_norm = kwargs.pop("gMLP_norm", None)
mlp_out_bias = kwargs.pop("mlp_out_bias", False)
atom_graph_layers = [
AtomConv(
Expand Down Expand Up @@ -261,9 +272,7 @@ def __init__(

# Define readout layer
self.site_wise = nn.Linear(atom_fea_dim, 1)
self.readout_norm = find_normalization(
name=kwargs.pop("readout_norm", None), dim=atom_fea_dim
)
self.readout_norm = find_normalization(readout_norm, dim=atom_fea_dim)
self.mlp_first = mlp_first
if mlp_first:
self.read_out_type = "sum"
Expand Down Expand Up @@ -306,19 +315,23 @@ def __init__(
f"parameters"
)

@property
def version(self) -> str | None:
"""Return the version of the loaded checkpoint."""
return self.model_args.get("version")

def forward(
self,
graphs: Sequence[CrystalGraph],
task: PredTask = "e",
return_site_energies: bool = False,
return_atom_feas: bool = False,
return_crystal_feas: bool = False,
) -> dict:
) -> dict[str, Tensor]:
"""Get prediction associated with input graphs
Args:
graphs (List): a list of CrystalGraphs
task (str): the prediction task
eg: 'e', 'em', 'ef', 'efs', 'efsm'
task (str): the prediction task. One of 'e', 'em', 'ef', 'efs', 'efsm'.
Default = 'e'
return_site_energies (bool): whether to return per-site energies,
only available if self.mlp_first == True
Expand Down Expand Up @@ -651,26 +664,28 @@ def from_file(cls, path, **kwargs):

@classmethod
def load(cls, model_name="0.3.0"):
"""Load pretrained CHGNet."""
current_dir = os.path.dirname(os.path.abspath(__file__))
if model_name == "0.3.0":
return cls.from_file(
os.path.join(
current_dir,
"../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar",
)
)
elif model_name == "0.2.0": # noqa: RET505
return cls.from_file(
os.path.join(
current_dir,
"../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar",
),
mlp_out_bias=True,
)
else:
"""Load pretrained CHGNet model.

Args:
model_name (str, optional): Defaults to "0.3.0".

Raises:
ValueError: On unknown model_name.
"""
checkpoint_path = {
"0.3.0": "../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar",
"0.2.0": "../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar",
}.get(model_name)

if checkpoint_path is None:
raise ValueError(f"Unknown {model_name=}")

return cls.from_file(
os.path.join(module_dir, checkpoint_path),
mlp_out_bias=model_name == "0.2.0",
version=model_name,
)


@dataclass
class BatchedGraph:
Expand Down
2 changes: 1 addition & 1 deletion chgnet/pretrained/0.2.0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ trainer = Trainer(
| partition | Energy (meV/atom) | Force (meV/A) | stress (GPa) | magmom (muB) |
| ---------- | ----------------- | ------------- | ------------ | ------------ |
| Train | 22 | 59 | 0.246 | 0.030 |
| Validation | 20 | 75 | 0.350 | 0.033 |
| Validation | 30 | 75 | 0.350 | 0.033 |
| Test | 30 | 77 | 0.348 | 0.032 |
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }

[tool.setuptools.package-data]
"chgnet" = ["*.json"]
"chgnet.pretrained" = ["*.tar"]
"chgnet.pretrained" = ["**/*"]

[tool.ruff]
target-version = "py39"
Expand Down
18 changes: 18 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,21 @@ def test_as_to_from_dict() -> None:

model_3 = CHGNet(**to_dict["model_args"])
assert model_3.todict() == to_dict


def test_model_load(capsys: pytest.CaptureFixture) -> None:
model = CHGNet.load()
assert model.version == "0.3.0"
stdout, stderr = capsys.readouterr()
assert stdout == "CHGNet initialized with 412,525 parameters\n"
assert stderr == ""

model = CHGNet.load(model_name="0.2.0")
assert model.version == "0.2.0"
stdout, stderr = capsys.readouterr()
assert stdout == "CHGNet initialized with 400,438 parameters\n"
assert stderr == ""

model_name = "0.1.0" # invalid
with pytest.raises(ValueError, match=f"Unknown {model_name=}"):
CHGNet.load(model_name=model_name)