Skip to content

Commit

Permalink
change name to reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 22, 2024
1 parent f78397e commit 37e061e
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,9 @@ def no_weight_decay(self) -> set:

@registry.register_model("equiformer_v2_energy_head")
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone, agg_fn: str="sum"):
def __init__(self, backbone, reduce: str="sum"):
super().__init__()
self.agg_fn = agg_fn
self.reduce = reduce
self.avg_num_nodes = backbone.avg_num_nodes
self.energy_block = FeedForwardNetwork(
backbone.sphere_channels,
Expand All @@ -635,10 +635,11 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor | GraphData]):
device=node_energy.device,
dtype=node_energy.dtype,
)

energy.index_add_(0, data.batch, node_energy.view(-1))
if self.agg_fn == "sum":
if self.reduce == "sum":
return {"energy": energy / self.avg_num_nodes}
elif self.agg_fn == "mean":
elif self.reduce == "mean":
return {"energy": energy / data.natoms}
else:
raise ValueError(f"agg_fn can only be sum or mean, user provided: {self.agg_fn}")
Expand Down

0 comments on commit 37e061e

Please sign in to comment.