diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 6eb95947a..0b17e34ea 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -537,9 +537,10 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: @registry.register_model("escn_energy_head") class eSCNEnergyHead(nn.Module, HeadInterface): - def __init__(self, backbone): + def __init__(self, backbone, reduce = "sum"): super().__init__() backbone.energy_block = None + self.reduce = reduce # Output blocks for energy and forces self.energy_block = EnergyBlock( backbone.sphere_channels_all, backbone.num_sphere_samples, backbone.act @@ -551,8 +552,13 @@ def forward( node_energy = self.energy_block(emb["sphere_values"]) energy = torch.zeros(len(data.natoms), device=data.pos.device) energy.index_add_(0, data.batch, node_energy.view(-1)) - # Scale energy to help balance numerical precision w.r.t. forces - return {"energy": energy * 0.001} + if self.reduce == "sum": + # Scale energy to help balance numerical precision w.r.t. forces + return {"energy": energy * 0.001} + elif self.reduce == "mean": + return {"energy": energy / data.natoms} + else: + raise ValueError(f"reduce can only be sum or mean, user provided: {self.reduce}") @registry.register_model("escn_force_head")