diff --git a/devel/CHANGELOG/index.html b/devel/CHANGELOG/index.html index 9126e88d5..c47fa0d36 100644 --- a/devel/CHANGELOG/index.html +++ b/devel/CHANGELOG/index.html @@ -2725,6 +2725,8 @@
data_names
in ValuationResult.zeros()
+ PR #443AntitheticPermutationSampler
@@ -2979,11 +2981,11 @@ def compute_influence_factors(
model: TwiceDifferentiable,
training_data: DataLoaderType,
test_data: DataLoaderType,
@@ -3007,26 +3008,27 @@ ) # type:ignore
try:
- # if provided input_data implements __len__, pre-allocate the result tensor to reduce memory consumption
- resulting_shape = (len(test_data), model.num_params) # type:ignore
- rhs = cat_gen(
- test_grads(), resulting_shape, model # type:ignore
- ) # type:ignore
- except Exception as e:
- logger.warning(
- f"Failed to pre-allocate result tensor: {e}\n"
- f"Evaluate all resulting tensor and concatenate"
- )
- rhs = cat(list(test_grads()))
-
- return solve_hvp(
- inversion_method,
- model,
- training_data,
- rhs,
- hessian_perturbation=hessian_perturbation,
- **kwargs,
- )
+ # in case input_data is a torch DataLoader created from a Dataset,
+ # we can pre-allocate the result tensor to reduce memory consumption
+ resulting_shape = (len(test_data.dataset), model.num_params) # type:ignore
+ rhs = cat_gen(
+ test_grads(), resulting_shape, model # type:ignore
+ ) # type:ignore
+ except Exception as e:
+ logger.warning(
+ f"Failed to pre-allocate result tensor: {e}\n"
+ f"Evaluate all resulting tensor and concatenate"
+ )
+ rhs = cat(list(test_grads()))
+
+ return solve_hvp(
+ inversion_method,
+ model,
+ training_data,
+ rhs,
+ hessian_perturbation=hessian_perturbation,
+ **kwargs,
+ )
src/pydvl/influence/general.py