Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issues with ddp/hydra and add tests #796

Merged
merged 5 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class EquiformerV2(nn.Module, GraphModelMixin):

Args:
use_pbc (bool): Use periodic boundary conditions
use_pbc_single (bool): Process batch PBC graphs one at a time
regress_forces (bool): Compute forces
otf_graph (bool): Compute graph On The Fly (OTF)
max_neighbors (int): Maximum number of neighbors per atom
Expand Down Expand Up @@ -683,6 +684,12 @@ def no_weight_decay(self) -> set:

@registry.register_model("equiformer_v2_backbone")
class EquiformerV2Backbone(EquiformerV2, BackboneInterface):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO remove these once we deprecate/stop-inheriting EquiformerV2 class
self.energy_block = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment to get rid of this when we stop inheriting from EquiformersV2?

self.force_block = None

@conditional_grad(torch.enable_grad())
def forward(self, data: Batch) -> dict[str, torch.Tensor]:
self.batch_size = len(data.natoms)
Expand Down Expand Up @@ -815,6 +822,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

self.avg_num_nodes = backbone.avg_num_nodes
self.energy_block = FeedForwardNetwork(
backbone.sphere_channels,
Expand All @@ -828,6 +836,8 @@ def __init__(self, backbone):
backbone.use_grid_mlp,
backbone.use_sep_s2_act,
)
self.apply(backbone._init_weights)
self.apply(backbone._uniform_init_rad_func_linear_weights)

def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
node_energy = self.energy_block(emb["node_embedding"])
Expand Down Expand Up @@ -871,6 +881,8 @@ def __init__(self, backbone):
backbone.use_sep_s2_act,
alpha_drop=0.0,
)
self.apply(backbone._init_weights)
self.apply(backbone._uniform_init_rad_func_linear_weights)

def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
forces = self.force_block(
Expand Down
4 changes: 2 additions & 2 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
class eSCNEnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

backbone.energy_block = None
# Output blocks for energy and forces
self.energy_block = EnergyBlock(
backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act
Expand All @@ -550,7 +550,7 @@ def forward(
class eSCNForceHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

backbone.force_block = None
self.force_block = ForceBlock(
backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act
)
Expand Down
5 changes: 5 additions & 0 deletions src/fairchem/core/models/gemnet_oc/gemnet_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,9 @@ def __init__(
self.direct_forces = backbone.direct_forces
self.force_scaler = backbone.force_scaler

backbone.out_mlp_E = None
backbone.out_energy = None

out_mlp_E = [
Dense(
backbone.atom_emb.emb_size * (len(backbone.int_blocks) + 1),
Expand Down Expand Up @@ -1495,6 +1498,8 @@ def __init__(

emb_size_edge = backbone.edge_emb.dense.linear.out_features
if self.direct_forces:
backbone.out_mlp_F = None
backbone.out_forces = None
out_mlp_F = [
Dense(
emb_size_edge * (len(backbone.int_blocks) + 1),
Expand Down
3 changes: 2 additions & 1 deletion src/fairchem/core/models/painn/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def forward(self, x, v):
class PaiNNEnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

backbone.out_energy = None
self.out_energy = nn.Sequential(
nn.Linear(backbone.hidden_channels, backbone.hidden_channels // 2),
ScaledSiLU(),
Expand All @@ -697,6 +697,7 @@ def __init__(self, backbone):
self.direct_forces = backbone.direct_forces

if self.direct_forces:
backbone.out_forces = None
self.out_forces = PaiNNOutput(backbone.hidden_channels)

def forward(
Expand Down
37 changes: 33 additions & 4 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def _run_main(


class TestSmoke:
def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False):
def smoke_test_train(
self, input_yaml, tutorial_val_src, world_size, num_workers, otf_norms=False
):
with tempfile.TemporaryDirectory() as tempdirname:
# first train a very simple model, checkpoint
train_rundir = Path(tempdirname) / "train"
Expand All @@ -221,7 +223,12 @@ def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False):
rundir=str(train_rundir),
input_yaml=input_yaml,
update_dict_with={
"optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5},
"optim": {
"max_epochs": 2,
"eval_every": 8,
"batch_size": 5,
"num_workers": num_workers,
},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
Expand All @@ -231,6 +238,7 @@ def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False):
},
save_checkpoint_to=checkpoint_path,
save_predictions_to=training_predictions_filename,
world_size=world_size,
)
assert "train/energy_mae" in acc.Tags()["scalars"]
assert "val/energy_mae" in acc.Tags()["scalars"]
Expand Down Expand Up @@ -313,10 +321,21 @@ def test_train_and_predict(
configs,
tutorial_val_src,
):
# test without ddp
self.smoke_test_train(
input_yaml=configs[model_name],
tutorial_val_src=tutorial_val_src,
otf_norms=otf_norms,
world_size=0,
num_workers=2,
)
# test with ddp but no wokers
self.smoke_test_train(
input_yaml=configs[model_name],
tutorial_val_src=tutorial_val_src,
otf_norms=otf_norms,
world_size=1,
num_workers=0,
)

def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic):
Expand All @@ -341,11 +360,21 @@ def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic):
@pytest.mark.parametrize(
("world_size", "ddp"),
[
pytest.param(2, True),
pytest.param(
2,
True,
),
pytest.param(0, False),
],
)
def test_ddp(self, world_size, ddp, configs, tutorial_val_src, torch_deterministic):
def test_ddp(
self,
world_size,
ddp,
configs,
tutorial_val_src,
torch_deterministic,
):
with tempfile.TemporaryDirectory() as tempdirname:
tempdir = Path(tempdirname)
extra_args = {"seed": 0}
Expand Down
41 changes: 30 additions & 11 deletions tests/core/models/test_configs/test_dpp.yml
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
trainer: forces

task:
dataset: lmdb
type: regression
metric: mae
primary_metric: forces_mae
labels:
- potential energy
grad_input: atomic forces
train_on_free_atoms: True
eval_on_free_atoms: True
prediction_dtype: float32
outputs:
energy:
shape: 1
level: system
forces:
irrep_dim: 1
level: atom
train_on_free_atoms: True
eval_on_free_atoms: True

loss_functions:
- energy:
fn: mae
coefficient: 2
- forces:
fn: l2mae
coefficient: 100

evaluation_metrics:
metrics:
energy:
- mae
forces:
- mae
- cosine_similarity
- magnitude_error
misc:
- energy_forces_within_threshold
primary_metric: forces_mae

logger:
name: tensorboard


model:
name: dimenetplusplus #_bbwheads
hidden_channels: 4
Expand Down
41 changes: 30 additions & 11 deletions tests/core/models/test_configs/test_dpp_hydra.yml
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
trainer: forces

task:
dataset: lmdb
type: regression
metric: mae
primary_metric: forces_mae
labels:
- potential energy
grad_input: atomic forces
train_on_free_atoms: True
eval_on_free_atoms: True
prediction_dtype: float32
outputs:
energy:
shape: 1
level: system
forces:
irrep_dim: 1
level: atom
train_on_free_atoms: True
eval_on_free_atoms: True

loss_functions:
- energy:
fn: mae
coefficient: 2
- forces:
fn: l2mae
coefficient: 100

evaluation_metrics:
metrics:
energy:
- mae
forces:
- mae
- cosine_similarity
- magnitude_error
misc:
- energy_forces_within_threshold
primary_metric: forces_mae

logger:
name: tensorboard


model:
name: hydra
backbone:
Expand Down
63 changes: 33 additions & 30 deletions tests/core/models/test_configs/test_equiformerv2_hydra.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@


trainer: forces

outputs:
energy:
shape: 1
level: system
forces:
irrep_dim: 1
level: atom
train_on_free_atoms: True
eval_on_free_atoms: True

loss_functions:
- energy:
fn: mae
coefficient: 2
- forces:
fn: l2mae
coefficient: 100

evaluation_metrics:
metrics:
energy:
- mae
forces:
- mae
- cosine_similarity
- magnitude_error
misc:
- energy_forces_within_threshold
primary_metric: forces_mae

logger:
name: tensorboard

model:
name: hydra
backbone:
Expand Down Expand Up @@ -53,34 +84,6 @@ model:
forces:
module: equiformer_v2_force_head

dataset:
train:
src: tutorial_dset/s2ef/train_100/
normalize_labels: True
target_mean: -0.7554450631141663
target_std: 2.887317180633545
grad_target_mean: 0.0
grad_target_std: 2.887317180633545
val:
format: lmdb
src: tutorial_dset/s2ef/val_20/

logger:
name: tensorboard

task:
dataset: lmdb
type: regression
metric: mae
primary_metric: forces_mae
labels:
- potential energy
grad_input: atomic forces
train_on_free_atoms: True
eval_on_free_atoms: True
prediction_dtype: float32


optim:
batch_size: 5
eval_batch_size: 2
Expand Down
Loading