Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into fix/571-missing-mo…
Browse files Browse the repository at this point in the history
…ve-preconditioner-cg
  • Loading branch information
schroedk committed May 3, 2024
2 parents 6375afe + efa56e3 commit ae41cbf
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 54 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@

## Unreleased

### Added

- Add a device fixture for `pytest`, which depending on the availability and
user input (`pytest --with-cuda`) resolves to cuda device
[PR #574](https://github.com/aai-institute/pyDVL/pull/574)

### Fixed

- Fixed missing move of tensors to model device in `EkfacInfluence`
implementation [PR #570](https://github.com/aai-institute/pyDVL/pull/570)
- Missing move to device of `preconditioner` in `CgInfluence` implementation
[PR #572](https://github.com/aai-institute/pyDVL/pull/572)

Expand Down
7 changes: 7 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ There are a few important arguments:
- `--slow-tests` enables running slow tests. See below for a description
of slow tests.

- `--with-cuda` sets the device fixture in [tests/influence/torch/conftest.py](
tests/influence/torch/conftest.py) to `cuda` if it is available.
Using this fixture within tests, you can run parts of your tests on a `cuda`
device. Be aware, that you still have to take care of the usage of the device
manually in a specific test. Setting this flag does not result in
running all tests on a GPU.

### Markers

We use a few different markers to differentiate between tests and runs
Expand Down
17 changes: 10 additions & 7 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,13 @@ def influences_from_factors(
"""
if mode == InfluenceMode.Up:
return (
z_test_factors
z_test_factors.to(self.model_device)
@ self._loss_grad(x.to(self.model_device), y.to(self.model_device)).T
)
elif mode == InfluenceMode.Perturbation:
return torch.einsum(
"ia,j...a->ij...",
z_test_factors,
z_test_factors.to(self.model_device),
self._flat_loss_mixed_grad(
x.to(self.model_device), y.to(self.model_device)
),
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def _get_kfac_blocks(
data, disable=not self.progress, desc="K-FAC blocks - batch progress"
):
data_len += x.shape[0]
pred_y = self.model(x)
pred_y = self.model(x.to(self.model_device))
loss = empirical_cross_entropy_loss_fn(pred_y)
loss.backward()

Expand Down Expand Up @@ -1326,7 +1326,7 @@ def _update_diag(
data, disable=not self.progress, desc="Update Diagonal - batch progress"
):
data_len += x.shape[0]
pred_y = self.model(x)
pred_y = self.model(x.to(self.model_device))
loss = empirical_cross_entropy_loss_fn(pred_y)
loss.backward()

Expand Down Expand Up @@ -1533,7 +1533,10 @@ def influences_from_factors_by_layer(
influences = {}
for layer_id, layer_z_test in z_test_factors.items():
end_idx = start_idx + layer_z_test.shape[1]
influences[layer_id] = layer_z_test @ total_grad[:, start_idx:end_idx].T
influences[layer_id] = (
layer_z_test.to(self.model_device)
@ total_grad[:, start_idx:end_idx].T
)
start_idx = end_idx
return influences
elif mode == InfluenceMode.Perturbation:
Expand All @@ -1546,7 +1549,7 @@ def influences_from_factors_by_layer(
end_idx = start_idx + layer_z_test.shape[1]
influences[layer_id] = torch.einsum(
"ia,j...a->ij...",
layer_z_test,
layer_z_test.to(self.model_device),
total_mixed_grad[:, start_idx:end_idx],
)
start_idx = end_idx
Expand Down Expand Up @@ -1633,7 +1636,7 @@ def explore_hessian_regularization(
being dictionaries containing the influences for each layer of the model,
with the layer name as key.
"""
grad = self._loss_grad(x, y)
grad = self._loss_grad(x.to(self.model_device), y.to(self.model_device))
influences_by_reg_value = {}
for reg_value in regularization_values:
reg_factors = self._solve_hvp_by_layer(
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def pytest_addoption(parser):
default=False,
help="Disable reporting. Verbose mode takes precedence.",
)
parser.addoption(
"--with-cuda",
action="store_true",
default=False,
help="Set device fixture to 'cuda' if available",
)


@pytest.fixture
Expand Down
12 changes: 12 additions & 0 deletions tests/influence/torch/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple

import pytest
import torch
from numpy.typing import NDArray
from torch.optim import LBFGS
Expand Down Expand Up @@ -59,3 +60,14 @@ def closure():
def torch_linear_model_to_numpy(model: torch.nn.Linear) -> Tuple[NDArray, NDArray]:
model.eval()
return model.weight.data.numpy(), model.bias.data.numpy()


@pytest.fixture(scope="session")
def device(request):
import torch

use_cuda = request.config.getoption("--with-cuda")
if use_cuda and torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
Loading

0 comments on commit ae41cbf

Please sign in to comment.