Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overwrite to method of CgInfluence, add to method to precondito… #572

Merged
merged 3 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## Unreleased

### Fixed

- Missing move to device of `preconditioner` in `CgInfluence` implementation
[PR #572](https://github.com/aai-institute/pyDVL/pull/572)

## 0.9.1 - Bug fixes, logging improvement

### Fixed
Expand Down
9 changes: 8 additions & 1 deletion src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,9 @@ def mat_mat(x: torch.Tensor):
R = (rhs - mat_mat(X)).T
Z = R if self.pre_conditioner is None else self.pre_conditioner.solve(R)
P, _, _ = torch.linalg.svd(Z, full_matrices=False)
active_indices = torch.as_tensor(list(range(X.shape[-1])), dtype=torch.long)
active_indices = torch.as_tensor(
list(range(X.shape[-1])), dtype=torch.long, device=self.model_device
)

maxiter = self.maxiter if self.maxiter is not None else len(rhs) * 10
y_norm = torch.linalg.norm(rhs, dim=1)
Expand Down Expand Up @@ -758,6 +760,11 @@ def mat_mat(x: torch.Tensor):

return X.T

def to(self, device: torch.device):
if self.pre_conditioner is not None:
self.pre_conditioner = self.pre_conditioner.to(device)
return super().to(device)


class LissaInfluence(TorchInfluenceFunctionModel):
r"""
Expand Down
17 changes: 17 additions & 0 deletions src/pydvl/influence/torch/pre_conditioner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Optional

Expand Down Expand Up @@ -70,6 +72,11 @@ def solve(self, rhs: torch.Tensor):
def _solve(self, rhs: torch.Tensor):
pass

@abstractmethod
def to(self, device: torch.device) -> PreConditioner:
"""Implement this to move the (potentially fitted) preconditioner to a
specific device"""


class JacobiPreConditioner(PreConditioner):
r"""
Expand Down Expand Up @@ -141,6 +148,11 @@ def _solve(self, rhs: torch.Tensor):

return rhs * inv_diag.unsqueeze(-1)

def to(self, device: torch.device) -> JacobiPreConditioner:
if self._diag is not None:
self._diag = self._diag.to(device)
return self


class NystroemPreConditioner(PreConditioner):
r"""
Expand Down Expand Up @@ -233,3 +245,8 @@ def _solve(self, rhs: torch.Tensor):
result = result.squeeze()

return result

def to(self, device: torch.device) -> NystroemPreConditioner:
if self._low_rank_approx is not None:
self._low_rank_approx = self._low_rank_approx.to(device)
return self
Loading