Skip to content

Commit

Permalink
Add device move in influence_from_factors method in base class TorchI…
Browse files Browse the repository at this point in the history
…nfluenceFunctionModel
  • Loading branch information
schroedk committed May 3, 2024
1 parent a151422 commit 36ea3ba
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,13 @@ def influences_from_factors(
"""
if mode == InfluenceMode.Up:
return (
z_test_factors
z_test_factors.to(self.model_device)
@ self._loss_grad(x.to(self.model_device), y.to(self.model_device)).T
)
elif mode == InfluenceMode.Perturbation:
return torch.einsum(
"ia,j...a->ij...",
z_test_factors,
z_test_factors.to(self.model_device),
self._flat_loss_mixed_grad(
x.to(self.model_device), y.to(self.model_device)
),
Expand Down

0 comments on commit 36ea3ba

Please sign in to comment.