Skip to content

Commit

Permalink
Merge pull request #598 from aai-institute/feature/588-block-arnoldi
Browse files Browse the repository at this point in the history
Feature/588 block arnoldi
  • Loading branch information
schroedk authored Jun 20, 2024
2 parents d13af35 + 4f18ddd commit 361f5b5
Show file tree
Hide file tree
Showing 10 changed files with 383 additions and 382 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
- Extend `NystroemSketchInfluence` with block-diagonal and Gauss-Newton
approximation
[PR #596](https://github.com/aai-institute/pyDVL/pull/596)
- Extend `ArnoldiInfluence` with block-diagonal and Gauss-Newton
approximation
[PR #598](https://github.com/aai-institute/pyDVL/pull/598)

## Fixed
- Replace `np.float_` with `np.float64` and `np.alltrue` with `np.all`,
Expand All @@ -42,6 +45,14 @@
to `regularization` and change the type annotation to allow
for block-wise regularization parameters
[PR #596](https://github.com/aai-institute/pyDVL/pull/596)
- Renaming of parameters of `ArnoldiInfluence`,
`hessian_regularization` -> `regularization` (modify type annotation),
`rank_estimate` -> `rank`
[PR #598](https://github.com/aai-institute/pyDVL/pull/598)
- Remove functions remove obsolete functions
`lanczos_low_rank_hessian_approximation`, `model_hessian_low_rank`
from `influence.torch.functional`
[PR #598](https://github.com/aai-institute/pyDVL/pull/598)

## 0.9.2 - 🏗 Bug fixes, logging improvement

Expand Down
10 changes: 8 additions & 2 deletions docs/influence/influence_function_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,18 @@ from pydvl.influence.torch import ArnoldiInfluence
if_model = ArnoldiInfluence(
model,
loss,
hessian_regularization=0.0,
rank_estimate=10,
regularization=0.0,
rank=10,
tol=1e-6,
block_structure=BlockMode.FULL,
second_order_mode=SecondOrderMode.HESSIAN
)
if_model.fit(train_loader)
```
This implementation is capable of using a block-matrix
approximation, see
[Block-diagonal approximation](#block-diagonal-approximation), and can handle
[Gauss-Newton approximation](#gauss-newton-approximation).

### Eigenvalue Corrected K-FAC

Expand Down
66 changes: 33 additions & 33 deletions notebooks/influence_wine.ipynb

Large diffs are not rendered by default.

98 changes: 94 additions & 4 deletions src/pydvl/influence/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union, cast
from typing import (
TYPE_CHECKING,
Dict,
Generic,
Iterable,
List,
Optional,
Tuple,
TypeVar,
Union,
cast,
)

import torch
from torch.func import functional_call
Expand All @@ -12,11 +23,14 @@
from ..base_influence_function_model import ComposableInfluence
from ..types import (
Batch,
BatchType,
BilinearForm,
BlockMapper,
GradientProvider,
GradientProviderType,
Operator,
OperatorGradientComposition,
TensorType,
)
from .util import (
BlockMode,
Expand All @@ -27,6 +41,9 @@
flatten_dimensions,
)

if TYPE_CHECKING:
from .operator import LowRankOperator


@dataclass(frozen=True)
class TorchBatch(Batch):
Expand Down Expand Up @@ -244,7 +261,7 @@ def flat_mixed_grads(self, batch: TorchBatch) -> torch.Tensor:


class OperatorBilinearForm(
BilinearForm[torch.Tensor, TorchBatch, TorchGradientProvider]
BilinearForm[torch.Tensor, TorchBatch, TorchGradientProvider],
):
r"""
Base class for bilinear forms based on an instance of
Expand All @@ -257,7 +274,7 @@ class OperatorBilinearForm(

def __init__(
self,
operator: "TensorOperator",
operator: TorchOperatorType,
):
self.operator = operator

Expand Down Expand Up @@ -406,6 +423,75 @@ def _aggregate_grads(left: torch.Tensor, right: torch.Tensor):
return torch.einsum("i..., j... -> ij", left, right)


class LowRankBilinearForm(OperatorBilinearForm):
r"""
Specialized bilinear form for operators of the type
$$ \operatorname{Op}(b) = V D^{-1}V^Tb.$$
It computes the expressions
$$ \langle \operatorname{Op}(\nabla_{\theta} \ell(z, \theta)),
\nabla_{\theta} \ell(z^{\prime}, \theta) \rangle =
\langle V\nabla_{\theta} \ell(z, \theta),
D^{-1}V\nabla_{\theta} \ell(z^{\prime}, \theta) \rangle$$
in an efficient way using [torch.autograd][torch.autograd] functionality.
"""

def __init__(self, operator: "LowRankOperator"):
super().__init__(operator)

def grads_inner_prod(
self,
left: TorchBatch,
right: Optional[TorchBatch],
gradient_provider: TorchGradientProvider,
) -> torch.Tensor:
r"""
Computes the gradient inner product of two batches of data, i.e.
$$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}),
\nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle_{B}$$
where $\nabla_{\omega}\ell(\omega, \cdot, \cdot)$ is represented by the
`gradient_provider` and the expression must be understood sample-wise.
Args:
left: The first batch for gradient and inner product computation
right: The second batch for gradient and inner product computation,
optional; if not provided, the inner product will use the gradient
computed for `left` for both arguments.
gradient_provider: The gradient provider to compute the gradients.
Returns:
A tensor representing the inner products of the per-sample gradients
"""
op = cast("LowRankOperator", self.operator)

if op.exact:
return super().grads_inner_prod(left, right, gradient_provider)

V = op.low_rank_representation.projections
D = op.low_rank_representation.eigen_vals.clone()
regularization = op.regularization

if regularization is not None:
D += regularization

V_left = gradient_provider.jacobian_prod(left, V.t())
D_inv = 1.0 / D

if right is None:
V_right = V_left
else:
V_right = gradient_provider.jacobian_prod(right, V.t())

V_right = V_right * D_inv.unsqueeze(-1)

return torch.einsum("ij, ik -> jk", V_left, V_right)


OperatorBilinearFormType = TypeVar(
"OperatorBilinearFormType", bound=OperatorBilinearForm
)
Expand Down Expand Up @@ -653,7 +739,11 @@ def block_names(self) -> List[str]:

@property
def n_parameters(self):
return sum(block.op.input_size for _, block in self.block_mapper.items())
return sum(
param.numel()
for block in self.parameter_dict.values()
for param in block.values()
)

@abstractmethod
def with_regularization(
Expand Down
Loading

0 comments on commit 361f5b5

Please sign in to comment.