diff --git a/src/fairchem/core/models/painn/painn.py b/src/fairchem/core/models/painn/painn.py index 33425e8d8..935ecc5a9 100644 --- a/src/fairchem/core/models/painn/painn.py +++ b/src/fairchem/core/models/painn/painn.py @@ -671,7 +671,7 @@ def forward(self, x, v): class PaiNNEnergyHead(nn.Module, HeadInterface): def __init__(self, backbone): super().__init__() - + backbone.out_energy = None self.out_energy = nn.Sequential( nn.Linear(backbone.hidden_channels, backbone.hidden_channels // 2), ScaledSiLU(), @@ -697,6 +697,7 @@ def __init__(self, backbone): self.direct_forces = backbone.direct_forces if self.direct_forces: + backbone.out_forces = None self.out_forces = PaiNNOutput(backbone.hidden_channels) def forward(