Skip to content

Commit

Permalink
remove equiv2_backbone_and_heads
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Aug 15, 2024
1 parent 57763bc commit 7401629
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 119 deletions.
14 changes: 0 additions & 14 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,17 +673,3 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
if gp_utils.initialized():
forces = gp_utils.gather_from_model_parallel_region(forces, dim=0)
return {"forces": forces}


@registry.register_model("equiformer_v2_backbone_and_heads")
class EquiformerV2BackboneAndHeads(nn.Module):
def __init__(self, **kwargs):
super().__init__()
kwargs["model"] = "equiformer_v2_backbone"
heads = {"energy": {"module": "equiformer_v2_energy_head"}}
if "regress_forces" in kwargs and kwargs["regress_forces"]:
heads["forces"] = {"module": "equiformer_v2_force_head"}
self.model = HydraModel(backbone=kwargs, heads=heads)

def forward(self, data: Batch):
return self.model(data)
11 changes: 0 additions & 11 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ def configs():
"painn": Path("tests/core/models/test_configs/test_painn.yml"),
"painn_hydra": Path("tests/core/models/test_configs/test_painn_hydra.yml"),
"equiformer_v2": Path("tests/core/models/test_configs/test_equiformerv2.yml"),
"equiformer_v2_backbone_and_heads": Path(
"tests/core/models/test_configs/test_equiformerv2_backbone_and_heads.yml"
),
"equiformer_v2_hydra": Path(
"tests/core/models/test_configs/test_equiformerv2_hydra.yml"
),
Expand Down Expand Up @@ -170,8 +167,6 @@ def smoke_test_train(
("equiformer_v2", True),
("equiformer_v2_hydra", False),
("equiformer_v2_hydra", True),
("equiformer_v2_backbone_and_heads", False),
("equiformer_v2_backbone_and_heads", True),
],
)
def test_train_and_predict(
Expand Down Expand Up @@ -372,12 +367,6 @@ class TestSmallDatasetOptim:
pytest.param("gemnet_oc", 0.41, 0.06, id="gemnet_oc"),
pytest.param("escn", 0.41, 0.06, id="escn"),
pytest.param("equiformer_v2", 0.41, 0.06, id="equiformer_v2"),
pytest.param(
"equiformer_v2_backbone_and_heads",
0.41,
0.06,
id="equiformer_v2_backbone_and_heads",
),
],
)
def test_train_optimization(
Expand Down

This file was deleted.

0 comments on commit 7401629

Please sign in to comment.