Skip to content

Commit

Permalink
v0.3.1
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Oct 23, 2023
1 parent 74a6a70 commit 9599fe8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _assemble_graphs(self, graphs: list[CrystalGraph]):
assembled batch_graph that contains all information for model.
"""
composition_feas = []
for _graph_idx, graph in enumerate(graphs):
for graph in graphs:
composition_fea = torch.bincount(
graph.atomic_number - 1, minlength=self.max_num_elements
)
Expand Down Expand Up @@ -201,7 +201,7 @@ def initialize_from(self, dataset: str):
"""Initialize pre-fitted weights from a dataset."""
if dataset in ["MPtrj", "MPtrj_e"]:
self.initialize_from_MPtrj()
elif dataset in ["MPF"]:
elif dataset == "MPF":
self.initialize_from_MPF()
else:
raise NotImplementedError(f"{dataset=} not supported yet")
Expand Down
6 changes: 3 additions & 3 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
eta_min=decay_fraction * learning_rate,
)
self.scheduler_type = "cos"
elif scheduler in ["CosRestartLR"]:
elif scheduler == "CosRestartLR":
scheduler_params = kwargs.pop(
"scheduler_params", {"decay_fraction": 1e-2, "T_0": 10, "T_mult": 2}
)
Expand Down Expand Up @@ -471,7 +471,7 @@ def get_best_model(self):
if self.best_model is None:
raise RuntimeError("the model needs to be trained first")
MAE = min(self.training_history["e"]["val"])
print(f"Best model has val {MAE = :.4}")
print(f"Best model has val {MAE =:.4}")
return self.best_model

@property
Expand Down Expand Up @@ -616,7 +616,7 @@ def __init__(
self.criterion = nn.MSELoss()
elif criterion in ["MAE", "mae", "l1"]:
self.criterion = nn.L1Loss()
elif criterion in ["Huber"]:
elif criterion == "Huber":
self.criterion = nn.HuberLoss(delta=delta)
else:
raise NotImplementedError
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "chgnet"
version = "0.3.0"
version = "0.3.1"
description = "Pretrained Universal Neural Network Potential for Charge-informed Atomistic Modeling"
authors = [{ name = "Bowen Deng", email = "[email protected]" }]
requires-python = ">=3.9"
Expand Down Expand Up @@ -45,7 +45,7 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }

[tool.setuptools.package-data]
"chgnet" = ["*.json"]
"chgnet.pretrained" = ["**/*"]
"chgnet.pretrained" = ["*", "**/*"]

[tool.ruff]
target-version = "py39"
Expand Down

0 comments on commit 9599fe8

Please sign in to comment.