diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 44cee6430..f123b3ecd 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -607,9 +607,9 @@ def no_weight_decay(self) -> set: @registry.register_model("equiformer_v2_energy_head") class EquiformerV2EnergyHead(nn.Module, HeadInterface): - def __init__(self, backbone, agg_fn: str="sum"): + def __init__(self, backbone, reduce: str="sum"): super().__init__() - self.agg_fn = agg_fn + self.reduce = reduce self.avg_num_nodes = backbone.avg_num_nodes self.energy_block = FeedForwardNetwork( backbone.sphere_channels, @@ -635,10 +635,11 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]): device=node_energy.device, dtype=node_energy.dtype, ) + energy.index_add_(0, data.batch, node_energy.view(-1)) - if self.agg_fn == "sum": + if self.reduce == "sum": return {"energy": energy / self.avg_num_nodes} - elif self.agg_fn == "mean": + elif self.reduce == "mean": return {"energy": energy / data.natoms} else: raise ValueError(f"agg_fn can only be sum or mean, user provided: {self.agg_fn}")