diff --git a/jitterbug/model/soap.py b/jitterbug/model/soap.py index fdf76ac..e0bdda4 100644 --- a/jitterbug/model/soap.py +++ b/jitterbug/model/soap.py @@ -161,7 +161,7 @@ def calculate(self, atoms=None, properties=('energy', 'forces', 'energies'), grad_outputs=torch.ones_like(pred_energy), )[0] # Derivative of the energy with respect to the descriptors for each center d_energy_d_center_d_pos = torch.einsum('ijkl,il->ijk', d_desc_d_pos, d_energy_d_desc) # Derivative for each center with respect to each atom - pred_forces = -d_energy_d_center_d_pos.sum(dim=0) * scale # Total effect on each center from each atom + pred_forces = -d_energy_d_center_d_pos.sum(dim=0) * scale # Total effect on each center from each atom # Store the results self.results['forces'] = pred_forces.detach().numpy()