Skip to content

Commit

Permalink
fix issues with ddp/hydra and add tests (#796)
Browse files Browse the repository at this point in the history
* fix issues with ddp/hydra and add tests

* remove load balancing for painn tests

* remove painn parameters when using hydra

* update test configs to be consistent with each other
  • Loading branch information
misko authored Aug 8, 2024
1 parent bc1307f commit 44c71b3
Show file tree
Hide file tree
Showing 17 changed files with 415 additions and 225 deletions.
12 changes: 12 additions & 0 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class EquiformerV2(nn.Module, GraphModelMixin):
Args:
use_pbc (bool): Use periodic boundary conditions
use_pbc_single (bool): Process batch PBC graphs one at a time
regress_forces (bool): Compute forces
otf_graph (bool): Compute graph On The Fly (OTF)
max_neighbors (int): Maximum number of neighbors per atom
Expand Down Expand Up @@ -683,6 +684,12 @@ def no_weight_decay(self) -> set:

@registry.register_model("equiformer_v2_backbone")
class EquiformerV2Backbone(EquiformerV2, BackboneInterface):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO remove these once we deprecate/stop-inheriting EquiformerV2 class
self.energy_block = None
self.force_block = None

@conditional_grad(torch.enable_grad())
def forward(self, data: Batch) -> dict[str, torch.Tensor]:
self.batch_size = len(data.natoms)
Expand Down Expand Up @@ -815,6 +822,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

self.avg_num_nodes = backbone.avg_num_nodes
self.energy_block = FeedForwardNetwork(
backbone.sphere_channels,
Expand All @@ -828,6 +836,8 @@ def __init__(self, backbone):
backbone.use_grid_mlp,
backbone.use_sep_s2_act,
)
self.apply(backbone._init_weights)
self.apply(backbone._uniform_init_rad_func_linear_weights)

def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
node_energy = self.energy_block(emb["node_embedding"])
Expand Down Expand Up @@ -871,6 +881,8 @@ def __init__(self, backbone):
backbone.use_sep_s2_act,
alpha_drop=0.0,
)
self.apply(backbone._init_weights)
self.apply(backbone._uniform_init_rad_func_linear_weights)

def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
forces = self.force_block(
Expand Down
4 changes: 2 additions & 2 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
class eSCNEnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

backbone.energy_block = None
# Output blocks for energy and forces
self.energy_block = EnergyBlock(
backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act
Expand All @@ -550,7 +550,7 @@ def forward(
class eSCNForceHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

backbone.force_block = None
self.force_block = ForceBlock(
backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act
)
Expand Down
5 changes: 5 additions & 0 deletions src/fairchem/core/models/gemnet_oc/gemnet_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,6 +1431,9 @@ def __init__(
self.direct_forces = backbone.direct_forces
self.force_scaler = backbone.force_scaler

backbone.out_mlp_E = None
backbone.out_energy = None

out_mlp_E = [
Dense(
backbone.atom_emb.emb_size * (len(backbone.int_blocks) + 1),
Expand Down Expand Up @@ -1495,6 +1498,8 @@ def __init__(

emb_size_edge = backbone.edge_emb.dense.linear.out_features
if self.direct_forces:
backbone.out_mlp_F = None
backbone.out_forces = None
out_mlp_F = [
Dense(
emb_size_edge * (len(backbone.int_blocks) + 1),
Expand Down
3 changes: 2 additions & 1 deletion src/fairchem/core/models/painn/painn.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def forward(self, x, v):
class PaiNNEnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

backbone.out_energy = None
self.out_energy = nn.Sequential(
nn.Linear(backbone.hidden_channels, backbone.hidden_channels // 2),
ScaledSiLU(),
Expand All @@ -697,6 +697,7 @@ def __init__(self, backbone):
self.direct_forces = backbone.direct_forces

if self.direct_forces:
backbone.out_forces = None
self.out_forces = PaiNNOutput(backbone.hidden_channels)

def forward(
Expand Down
37 changes: 33 additions & 4 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def _run_main(


class TestSmoke:
def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False):
def smoke_test_train(
self, input_yaml, tutorial_val_src, world_size, num_workers, otf_norms=False
):
with tempfile.TemporaryDirectory() as tempdirname:
# first train a very simple model, checkpoint
train_rundir = Path(tempdirname) / "train"
Expand All @@ -221,7 +223,12 @@ def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False):
rundir=str(train_rundir),
input_yaml=input_yaml,
update_dict_with={
"optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5},
"optim": {
"max_epochs": 2,
"eval_every": 8,
"batch_size": 5,
"num_workers": num_workers,
},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
Expand All @@ -231,6 +238,7 @@ def smoke_test_train(self, input_yaml, tutorial_val_src, otf_norms=False):
},
save_checkpoint_to=checkpoint_path,
save_predictions_to=training_predictions_filename,
world_size=world_size,
)
assert "train/energy_mae" in acc.Tags()["scalars"]
assert "val/energy_mae" in acc.Tags()["scalars"]
Expand Down Expand Up @@ -313,10 +321,21 @@ def test_train_and_predict(
configs,
tutorial_val_src,
):
# test without ddp
self.smoke_test_train(
input_yaml=configs[model_name],
tutorial_val_src=tutorial_val_src,
otf_norms=otf_norms,
world_size=0,
num_workers=2,
)
# test with ddp but no wokers
self.smoke_test_train(
input_yaml=configs[model_name],
tutorial_val_src=tutorial_val_src,
otf_norms=otf_norms,
world_size=1,
num_workers=0,
)

def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic):
Expand All @@ -341,11 +360,21 @@ def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic):
@pytest.mark.parametrize(
("world_size", "ddp"),
[
pytest.param(2, True),
pytest.param(
2,
True,
),
pytest.param(0, False),
],
)
def test_ddp(self, world_size, ddp, configs, tutorial_val_src, torch_deterministic):
def test_ddp(
self,
world_size,
ddp,
configs,
tutorial_val_src,
torch_deterministic,
):
with tempfile.TemporaryDirectory() as tempdirname:
tempdir = Path(tempdirname)
extra_args = {"seed": 0}
Expand Down
41 changes: 30 additions & 11 deletions tests/core/models/test_configs/test_dpp.yml
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
trainer: forces

task:
dataset: lmdb
type: regression
metric: mae
primary_metric: forces_mae
labels:
- potential energy
grad_input: atomic forces
train_on_free_atoms: True
eval_on_free_atoms: True
prediction_dtype: float32
outputs:
energy:
shape: 1
level: system
forces:
irrep_dim: 1
level: atom
train_on_free_atoms: True
eval_on_free_atoms: True

loss_functions:
- energy:
fn: mae
coefficient: 2
- forces:
fn: l2mae
coefficient: 100

evaluation_metrics:
metrics:
energy:
- mae
forces:
- mae
- cosine_similarity
- magnitude_error
misc:
- energy_forces_within_threshold
primary_metric: forces_mae

logger:
name: tensorboard


model:
name: dimenetplusplus #_bbwheads
hidden_channels: 4
Expand Down
41 changes: 30 additions & 11 deletions tests/core/models/test_configs/test_dpp_hydra.yml
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
trainer: forces

task:
dataset: lmdb
type: regression
metric: mae
primary_metric: forces_mae
labels:
- potential energy
grad_input: atomic forces
train_on_free_atoms: True
eval_on_free_atoms: True
prediction_dtype: float32
outputs:
energy:
shape: 1
level: system
forces:
irrep_dim: 1
level: atom
train_on_free_atoms: True
eval_on_free_atoms: True

loss_functions:
- energy:
fn: mae
coefficient: 2
- forces:
fn: l2mae
coefficient: 100

evaluation_metrics:
metrics:
energy:
- mae
forces:
- mae
- cosine_similarity
- magnitude_error
misc:
- energy_forces_within_threshold
primary_metric: forces_mae

logger:
name: tensorboard


model:
name: hydra
backbone:
Expand Down
63 changes: 33 additions & 30 deletions tests/core/models/test_configs/test_equiformerv2_hydra.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,38 @@


trainer: forces

outputs:
energy:
shape: 1
level: system
forces:
irrep_dim: 1
level: atom
train_on_free_atoms: True
eval_on_free_atoms: True

loss_functions:
- energy:
fn: mae
coefficient: 2
- forces:
fn: l2mae
coefficient: 100

evaluation_metrics:
metrics:
energy:
- mae
forces:
- mae
- cosine_similarity
- magnitude_error
misc:
- energy_forces_within_threshold
primary_metric: forces_mae

logger:
name: tensorboard

model:
name: hydra
backbone:
Expand Down Expand Up @@ -53,34 +84,6 @@ model:
forces:
module: equiformer_v2_force_head

dataset:
train:
src: tutorial_dset/s2ef/train_100/
normalize_labels: True
target_mean: -0.7554450631141663
target_std: 2.887317180633545
grad_target_mean: 0.0
grad_target_std: 2.887317180633545
val:
format: lmdb
src: tutorial_dset/s2ef/val_20/

logger:
name: tensorboard

task:
dataset: lmdb
type: regression
metric: mae
primary_metric: forces_mae
labels:
- potential energy
grad_input: atomic forces
train_on_free_atoms: True
eval_on_free_atoms: True
prediction_dtype: float32


optim:
batch_size: 5
eval_batch_size: 2
Expand Down
Loading

0 comments on commit 44c71b3

Please sign in to comment.