From 354774fdd6792116334ec450413aeebd50da0a1b Mon Sep 17 00:00:00 2001 From: Anes Benmerzoug Date: Mon, 23 Oct 2023 16:55:38 +0200 Subject: [PATCH] Fix bug in solve_batch_cg --- src/pydvl/influence/torch/torch_differentiable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pydvl/influence/torch/torch_differentiable.py b/src/pydvl/influence/torch/torch_differentiable.py index 2ce7688f3..099fe9905 100644 --- a/src/pydvl/influence/torch/torch_differentiable.py +++ b/src/pydvl/influence/torch/torch_differentiable.py @@ -568,7 +568,7 @@ def solve_batch_cg( if len(training_data) == 0: raise ValueError("Training dataloader must not be empty.") - total_grad_xy = torch.empty() + total_grad_xy = torch.empty(0) total_points = 0 for x, y in maybe_progress(training_data, progress, desc="Batch Train Gradients"):