From 919e73f17063aaaa00515a67ec3b3d17338f51bb Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Fri, 3 May 2024 12:16:53 +0200 Subject: [PATCH 1/2] Overwrite `to` method of `CgInfluence`, add `to` method to preconditoners, fix wrong device for indices array in block CG implementation --- .../influence/torch/influence_function_model.py | 9 ++++++++- src/pydvl/influence/torch/pre_conditioner.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index 46a5fa16e..b4ec964cc 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -706,7 +706,9 @@ def mat_mat(x: torch.Tensor): R = (rhs - mat_mat(X)).T Z = R if self.pre_conditioner is None else self.pre_conditioner.solve(R) P, _, _ = torch.linalg.svd(Z, full_matrices=False) - active_indices = torch.as_tensor(list(range(X.shape[-1])), dtype=torch.long) + active_indices = torch.as_tensor( + list(range(X.shape[-1])), dtype=torch.long, device=self.model_device + ) maxiter = self.maxiter if self.maxiter is not None else len(rhs) * 10 y_norm = torch.linalg.norm(rhs, dim=1) @@ -758,6 +760,11 @@ def mat_mat(x: torch.Tensor): return X.T + def to(self, device: torch.device): + if self.pre_conditioner is not None: + self.pre_conditioner = self.pre_conditioner.to(device) + return super().to(device) + class LissaInfluence(TorchInfluenceFunctionModel): r""" diff --git a/src/pydvl/influence/torch/pre_conditioner.py b/src/pydvl/influence/torch/pre_conditioner.py index 4497d81c2..f42852c2c 100644 --- a/src/pydvl/influence/torch/pre_conditioner.py +++ b/src/pydvl/influence/torch/pre_conditioner.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Callable, Optional @@ -70,6 +72,11 @@ def solve(self, rhs: torch.Tensor): def _solve(self, rhs: torch.Tensor): pass + @abstractmethod + def to(self, device: torch.device) -> PreConditioner: + """Implement this to move the (potentially fitted) preconditioner to a + specific device""" + class JacobiPreConditioner(PreConditioner): r""" @@ -141,6 +148,11 @@ def _solve(self, rhs: torch.Tensor): return rhs * inv_diag.unsqueeze(-1) + def to(self, device: torch.device) -> JacobiPreConditioner: + if self._diag is not None: + self._diag = self._diag.to(device) + return self + class NystroemPreConditioner(PreConditioner): r""" @@ -233,3 +245,8 @@ def _solve(self, rhs: torch.Tensor): result = result.squeeze() return result + + def to(self, device: torch.device) -> NystroemPreConditioner: + if self._low_rank_approx is not None: + self._low_rank_approx = self._low_rank_approx.to(device) + return self From 6375afe31d900698bb30ac42a69657d65fbdc31c Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Fri, 3 May 2024 12:24:26 +0200 Subject: [PATCH 2/2] Update CHANGELOG.md --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 52bc910a4..e2d4bf923 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## Unreleased + +### Fixed + +- Missing move to device of `preconditioner` in `CgInfluence` implementation + [PR #572](https://github.com/aai-institute/pyDVL/pull/572) + ## 0.9.1 - Bug fixes, logging improvement ### Fixed