Skip to content

Commit

Permalink
pyproject.toml down-pin pymatgen>=2023.10.11
Browse files Browse the repository at this point in the history
drop black for ruff-format
codespell check filenames
  • Loading branch information
janosh committed Oct 30, 2023
1 parent 7c8139d commit 3673a84
Show file tree
Hide file tree
Showing 13 changed files with 47 additions and 39 deletions.
9 changes: 3 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.1
rev: v0.1.3
hooks:
- id: ruff
args: [--fix]

- repo: https://github.com/psf/black
rev: 23.10.0
hooks:
- id: black-jupyter
- id: ruff-format

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand All @@ -31,6 +27,7 @@ repos:
hooks:
- id: codespell
stages: [commit, commit-msg]
args: [--check-filenames]

- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
Expand Down
9 changes: 6 additions & 3 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]:

return crystal_graph, targets

# Omit structures with isolated atoms. Return another randomly selected structure
# Omit structures with isolated atoms. Return another randomly selected
# structure
except Exception:
struct = self.structures[graph_id]
self.failed_graph_id[graph_id] = struct.composition.formula
Expand Down Expand Up @@ -491,7 +492,8 @@ def __init__(
Args:
data (str | dict): file path or dir name that contain all the JSONs
graph_converter (CrystalGraphConverter): Converts pymatgen.core.Structure to graph
graph_converter (CrystalGraphConverter): Converts pymatgen.core.Structure
to CrystalGraph object.
targets ("ef" | "efs" | "efm" | "efsm"): The training targets.
Default = "efsm"
energy_key (str, optional): the key of energy in the labels.
Expand Down Expand Up @@ -575,7 +577,8 @@ def __getitem__(self, idx):
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
return crystal_graph, targets

# Omit structures with isolated atoms. Return another randomly selected structure
# Omit structures with isolated atoms. Return another randomly selected
# structure
except Exception:
structure = Structure.from_dict(self.data[mp_id][graph_id]["structure"])
self.failed_graph_id[graph_id] = structure.composition.formula
Expand Down
7 changes: 4 additions & 3 deletions chgnet/graph/crystalgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(
) -> None:
"""Initialize the crystal graph.
Attention! This data class is not intended to be created manually. CrystalGraph should
be returned by a CrystalGraphConverter
Attention! This data class is not intended to be created manually. CrystalGraph
should be returned by a CrystalGraphConverter
Args:
atomic_number (Tensor): the atomic numbers of atoms in the structure
Expand Down Expand Up @@ -92,7 +92,8 @@ def __init__(
self.composition = composition
if len(directed2undirected) != 2 * len(undirected2directed):
raise ValueError(
f"{graph_id} number of directed indices != 2 * number of undirected indices!"
f"{graph_id} number of directed indices ({len(directed2undirected)}) !="
f" 2 * number of undirected indices ({2 * len(undirected2directed)})!"
)

def to(self, device: str = "cpu") -> CrystalGraph:
Expand Down
4 changes: 2 additions & 2 deletions chgnet/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def __eq__(self, other: object) -> bool:
other (DirectedEdge): another DirectedEdge to compare to
Returns:
bool: True if other is the same directed edge, or if other is the directed edge
with reverse direction of self, else False.
bool: True if other is the same directed edge, or if other is the directed
edge with reverse direction of self, else False.
"""
self_img = (self.info or {}).get("image")
other_img = (other.info or {}).get("image")
Expand Down
4 changes: 2 additions & 2 deletions chgnet/model/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def __init__(self, cutoff: float = 5, cutoff_coeff: float = 5) -> None:
Default = 5
cutoff_coeff (float): the strength of soft-Cutoff
0 will disable the cutoff, returning 1 at every r
for positive numbers > 0, the smaller cutoff_coeff is, the faster this function
decays. Default = 5.
for positive numbers > 0, the smaller cutoff_coeff is, the faster this
function decays. Default = 5.
"""
super().__init__()
self.cutoff = cutoff
Expand Down
6 changes: 4 additions & 2 deletions chgnet/model/composition_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def _get_energy(self, composition_feas: Tensor) -> Tensor:
"""Predict the energy given composition encoding.
Args:
composition_feas: batched atom feature matrix [batch_size, total_num_elements].
composition_feas: batched atom feature matrix of shape
[batch_size, total_num_elements].
Returns:
prediction associated with each composition [batchsize].
Expand Down Expand Up @@ -111,7 +112,8 @@ def _get_energy(self, composition_feas: Tensor) -> Tensor:
"""Predict the energy given composition encoding.
Args:
composition_feas: batched atom feature matrix [batch_size, total_num_elements].
composition_feas: batched atom feature matrix of shape
[batch_size, total_num_elements].
Returns:
prediction associated with each composition [batchsize].
Expand Down
12 changes: 6 additions & 6 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ def __init__(
Default = None
loginterval (int): write to log file every interval steps
Default = 1
crystal_feas_logfile (str): open this file for recording crystal features during MD
Default = None
crystal_feas_logfile (str): open this file for recording crystal features
during MD. Default = None
append_trajectory (bool): Whether to append to prev trajectory.
If false, previous trajectory gets overwritten
Default = False
Expand Down Expand Up @@ -541,8 +541,8 @@ def __init__(
bulk_modulus_au = eos.get_bulk_modulus(unit="eV/A^3")
compressibility_au = eos.get_compressibility(unit="A^3/eV")
print(
f"Done bulk modulus calculation: "
f"k = {round(bulk_modulus, 3)}GPa, {round(bulk_modulus_au, 3)}eV/A^3"
f"Completed bulk modulus calculation: "
f"k = {bulk_modulus:.3}GPa, {bulk_modulus_au:.3}eV/A^3"
)
except Exception:
bulk_modulus_au = 2 / 160.2176
Expand Down Expand Up @@ -667,8 +667,8 @@ def upper_triangular_cell(self, verbose: bool | None = False):
while ASE's canonical description is lower-triangular cell.
Args:
verbose (bool): Whether to notify user about upper-triangular cell transformation.
Default = False
verbose (bool): Whether to notify user about upper-triangular cell
transformation. Default = False
"""
if not NPT._isuppertriangular(self.atoms.get_cell()):
a, b, c, alpha, beta, gamma = self.atoms.cell.cellpar()
Expand Down
7 changes: 5 additions & 2 deletions chgnet/model/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def __init__(self, atom_feature_dim: int, max_num_elements: int = 94) -> None:
Args:
atom_feature_dim (int): dimension of atomic embedding.
max_num_elements (int): maximum number of elements in the dataset. Default = 94
max_num_elements (int): maximum number of elements in the dataset.
Default = 94
"""
super().__init__()
self.embedding = nn.Embedding(max_num_elements, atom_feature_dim)
Expand All @@ -32,7 +33,9 @@ def forward(self, atomic_numbers: Tensor) -> Tensor:


class BondEncoder(nn.Module):
"""Encode a chemical bond given the position of two atoms using Gaussian Distance."""
"""Encode a chemical bond given the positions of two atoms using Gaussian
distance.
"""

def __init__(
self,
Expand Down
10 changes: 6 additions & 4 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,8 @@ def predict_structure(
"""Predict from pymatgen.core.Structure.
Args:
structure (Structure | Sequence[Structure]): structure or a list of structures
to predict.
structure (Structure | Sequence[Structure]): structure or a list of
structures to predict.
task (str): can be 'e' 'ef', 'em', 'efs', 'efsm'
Default = "efsm"
return_site_energies (bool): whether to return per-site energies.
Expand All @@ -552,7 +552,8 @@ def predict_structure(
e (Tensor) : energy of structures float in eV/atom
f (Tensor) : force on atoms [num_atoms, 3] in eV/A
s (Tensor) : stress of structure [3, 3] in GPa
m (Tensor) : magnetic moments of sites [num_atoms, 3] in Bohr magneton mu_B
m (Tensor) : magnetic moments of sites [num_atoms, 3] in Bohr
magneton mu_B
"""
if self.graph_converter is None:
raise ValueError("graph_converter cannot be None!")
Expand Down Expand Up @@ -598,7 +599,8 @@ def predict_graph(
e (Tensor) : energy of structures float in eV/atom
f (Tensor) : force on atoms [num_atoms, 3] in eV/A
s (Tensor) : stress of structure [3, 3] in GPa
m (Tensor) : magnetic moments of sites [num_atoms, 3] in Bohr magneton mu_B
m (Tensor) : magnetic moments of sites [num_atoms, 3] in Bohr
magneton mu_B
"""
if not isinstance(graph, (CrystalGraph, Sequence)):
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(
Default = 0.1
mag_loss_ratio (float): magmom loss ratio in loss function
Default = 0.1
optimizer (str): optimizer to update model. Can be "Adam", "SGD", "AdamW", "RAdam"
Default = 'Adam'
optimizer (str): optimizer to update model. Can be "Adam", "SGD", "AdamW",
"RAdam". Default = 'Adam'
scheduler (str): learning rate scheduler. Can be "CosLR", "ExponentialLR",
"CosRestartLR". Default = 'CosLR'
criterion (str): loss function criterion. Can be "MSE", "Huber", "MAE"
Expand Down Expand Up @@ -216,7 +216,8 @@ def train(
Default = None
save_dir (str): the dir name to save the trained weights
Default = None
save_test_result (bool): whether to save the test set prediction in a json file
save_test_result (bool): Whether to save the test set prediction in a JSON
file. Default = False
train_composition_model (bool): whether to train the composition model
(AtomRef), this is suggested when the fine-tuning dataset has large
elemental energy shift from the pretrained CHGNet, which typically comes
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
"cython>=0.29.26",
"numpy>=1.21.6",
"nvidia-ml-py3>=7.352.0",
"pymatgen",
"pymatgen>=2023.10.11",
"torch>=1.11.0",
]
classifiers = [
Expand Down Expand Up @@ -49,7 +49,6 @@ find = { include = ["chgnet*"], exclude = ["tests", "tests*"] }

[tool.ruff]
target-version = "py39"
line-length = 95
include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"]
select = [
"B", # flake8-bugbear
Expand Down
5 changes: 2 additions & 3 deletions tests/test_crystal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,8 @@ def test_crystal_graph_stability_fast():
def test_crystal_graph_repr():
graph = converter_legacy(structure)
assert (
repr(graph)
== "CrystalGraph(composition='Li2 Mn2 O4', atom_graph_cutoff=5, bond_graph_cutoff=3, "
"n_atoms=8, atom_graph_len=384, bond_graph_len=744)"
repr(graph) == "CrystalGraph(composition='Li2 Mn2 O4', atom_graph_cutoff=5, "
"bond_graph_cutoff=3, n_atoms=8, atom_graph_len=384, bond_graph_len=744)"
)


Expand Down
3 changes: 2 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,6 @@ def test_structure_data_inconsistent_length():

assert (
str(exc.value)
== f"Inconsistent number of structures and labels: {len(structures)=}, {len(forces)=}"
== f"Inconsistent number of structures and labels: {len(structures)=}, "
f"{len(forces)=}"
)

0 comments on commit 3673a84

Please sign in to comment.