Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into sequential-progres…
Browse files Browse the repository at this point in the history
…s-bar
  • Loading branch information
schroedk committed May 3, 2024
2 parents 0c81e53 + 0fc0553 commit 63da411
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 55 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# Changelog

## Unreleased

### Added

- Add a device fixture for `pytest`, which depending on the availability and
user input (`pytest --with-cuda`) resolves to cuda device
[PR #574](https://github.com/aai-institute/pyDVL/pull/574)

### Fixed

- Fixed missing move of tensors to model device in `EkfacInfluence`
implementation [PR #570](https://github.com/aai-institute/pyDVL/pull/570)
- Missing move to device of `preconditioner` in `CgInfluence` implementation
[PR #572](https://github.com/aai-institute/pyDVL/pull/572)

## 0.9.1 - Bug fixes, logging improvement

### Fixed
Expand Down
7 changes: 7 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ There are a few important arguments:
- `--slow-tests` enables running slow tests. See below for a description
of slow tests.

- `--with-cuda` sets the device fixture in [tests/influence/torch/conftest.py](
tests/influence/torch/conftest.py) to `cuda` if it is available.
Using this fixture within tests, you can run parts of your tests on a `cuda`
device. Be aware, that you still have to take care of the usage of the device
manually in a specific test. Setting this flag does not result in
running all tests on a GPU.

### Markers

We use a few different markers to differentiate between tests and runs
Expand Down
26 changes: 18 additions & 8 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,13 +312,13 @@ def influences_from_factors(
"""
if mode == InfluenceMode.Up:
return (
z_test_factors
z_test_factors.to(self.model_device)
@ self._loss_grad(x.to(self.model_device), y.to(self.model_device)).T
)
elif mode == InfluenceMode.Perturbation:
return torch.einsum(
"ia,j...a->ij...",
z_test_factors,
z_test_factors.to(self.model_device),
self._flat_loss_mixed_grad(
x.to(self.model_device), y.to(self.model_device)
),
Expand Down Expand Up @@ -715,7 +715,9 @@ def mat_mat(x: torch.Tensor):
R = (rhs - mat_mat(X)).T
Z = R if self.pre_conditioner is None else self.pre_conditioner.solve(R)
P, _, _ = torch.linalg.svd(Z, full_matrices=False)
active_indices = torch.as_tensor(list(range(X.shape[-1])), dtype=torch.long)
active_indices = torch.as_tensor(
list(range(X.shape[-1])), dtype=torch.long, device=self.model_device
)

maxiter = self.maxiter if self.maxiter is not None else len(rhs) * 10
y_norm = torch.linalg.norm(rhs, dim=1)
Expand Down Expand Up @@ -767,6 +769,11 @@ def mat_mat(x: torch.Tensor):

return X.T

def to(self, device: torch.device):
if self.pre_conditioner is not None:
self.pre_conditioner = self.pre_conditioner.to(device)
return super().to(device)


class LissaInfluence(TorchInfluenceFunctionModel):
r"""
Expand Down Expand Up @@ -1204,7 +1211,7 @@ def _get_kfac_blocks(
data, disable=not self.progress, desc="K-FAC blocks - batch progress"
):
data_len += x.shape[0]
pred_y = self.model(x)
pred_y = self.model(x.to(self.model_device))
loss = empirical_cross_entropy_loss_fn(pred_y)
loss.backward()

Expand Down Expand Up @@ -1328,7 +1335,7 @@ def _update_diag(
data, disable=not self.progress, desc="Update Diagonal - batch progress"
):
data_len += x.shape[0]
pred_y = self.model(x)
pred_y = self.model(x.to(self.model_device))
loss = empirical_cross_entropy_loss_fn(pred_y)
loss.backward()

Expand Down Expand Up @@ -1535,7 +1542,10 @@ def influences_from_factors_by_layer(
influences = {}
for layer_id, layer_z_test in z_test_factors.items():
end_idx = start_idx + layer_z_test.shape[1]
influences[layer_id] = layer_z_test @ total_grad[:, start_idx:end_idx].T
influences[layer_id] = (
layer_z_test.to(self.model_device)
@ total_grad[:, start_idx:end_idx].T
)
start_idx = end_idx
return influences
elif mode == InfluenceMode.Perturbation:
Expand All @@ -1548,7 +1558,7 @@ def influences_from_factors_by_layer(
end_idx = start_idx + layer_z_test.shape[1]
influences[layer_id] = torch.einsum(
"ia,j...a->ij...",
layer_z_test,
layer_z_test.to(self.model_device),
total_mixed_grad[:, start_idx:end_idx],
)
start_idx = end_idx
Expand Down Expand Up @@ -1635,7 +1645,7 @@ def explore_hessian_regularization(
being dictionaries containing the influences for each layer of the model,
with the layer name as key.
"""
grad = self._loss_grad(x, y)
grad = self._loss_grad(x.to(self.model_device), y.to(self.model_device))
influences_by_reg_value = {}
for reg_value in regularization_values:
reg_factors = self._solve_hvp_by_layer(
Expand Down
17 changes: 17 additions & 0 deletions src/pydvl/influence/torch/pre_conditioner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Optional

Expand Down Expand Up @@ -70,6 +72,11 @@ def solve(self, rhs: torch.Tensor):
def _solve(self, rhs: torch.Tensor):
pass

@abstractmethod
def to(self, device: torch.device) -> PreConditioner:
"""Implement this to move the (potentially fitted) preconditioner to a
specific device"""


class JacobiPreConditioner(PreConditioner):
r"""
Expand Down Expand Up @@ -141,6 +148,11 @@ def _solve(self, rhs: torch.Tensor):

return rhs * inv_diag.unsqueeze(-1)

def to(self, device: torch.device) -> JacobiPreConditioner:
if self._diag is not None:
self._diag = self._diag.to(device)
return self


class NystroemPreConditioner(PreConditioner):
r"""
Expand Down Expand Up @@ -233,3 +245,8 @@ def _solve(self, rhs: torch.Tensor):
result = result.squeeze()

return result

def to(self, device: torch.device) -> NystroemPreConditioner:
if self._low_rank_approx is not None:
self._low_rank_approx = self._low_rank_approx.to(device)
return self
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def pytest_addoption(parser):
default=False,
help="Disable reporting. Verbose mode takes precedence.",
)
parser.addoption(
"--with-cuda",
action="store_true",
default=False,
help="Set device fixture to 'cuda' if available",
)


@pytest.fixture
Expand Down
12 changes: 12 additions & 0 deletions tests/influence/torch/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

import pytest
import torch
from numpy.typing import NDArray
from torch.optim import LBFGS
Expand Down Expand Up @@ -59,3 +60,14 @@ def closure():
def torch_linear_model_to_numpy(model: torch.nn.Linear) -> Tuple[NDArray, NDArray]:
model.eval()
return model.weight.data.numpy(), model.bias.data.numpy()


@pytest.fixture(scope="session")
def device(request):
import torch

use_cuda = request.config.getoption("--with-cuda")
if use_cuda and torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
Loading

0 comments on commit 63da411

Please sign in to comment.