diff --git a/src/pydvl/influence/torch/torch_differentiable.py b/src/pydvl/influence/torch/torch_differentiable.py index 2ce7688f3..099fe9905 100644 --- a/src/pydvl/influence/torch/torch_differentiable.py +++ b/src/pydvl/influence/torch/torch_differentiable.py @@ -568,7 +568,7 @@ def solve_batch_cg( if len(training_data) == 0: raise ValueError("Training dataloader must not be empty.") - total_grad_xy = torch.empty() + total_grad_xy = torch.empty(0) total_points = 0 for x, y in maybe_progress(training_data, progress, desc="Batch Train Gradients"):