Skip to content

Commit

Permalink
Merge pull request #596 from aai-institute/feature/590-nystroem-block…
Browse files Browse the repository at this point in the history
…-diagonal

Feature/590 nystroem block diagonal
  • Loading branch information
schroedk authored Jun 14, 2024
2 parents 989f9e8 + cbc09cb commit debb822
Show file tree
Hide file tree
Showing 12 changed files with 294 additions and 102 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
[PR #591](https://github.com/aai-institute/pyDVL/pull/591)
- Extend `LissaInfluence` with block-diagonal and Gauss-Newton approximation
[PR #593](https://github.com/aai-institute/pyDVL/pull/593)
- Extend `NystroemSketchInfluence` with block-diagonal and Gauss-Newton
approximation
[PR #596](https://github.com/aai-institute/pyDVL/pull/596)

## Changed

Expand All @@ -30,6 +33,10 @@
[PR #593](https://github.com/aai-institute/pyDVL/pull/593)
- Remove parameter `h0` from init of `LissaInfluence`
[PR #593](https://github.com/aai-institute/pyDVL/pull/593)
- Rename parameter `hessian_regularization` of `NystroemSketchInfluence`
to `regularization` and change the type annotation to allow
for block-wise regularization parameters
[PR #596](https://github.com/aai-institute/pyDVL/pull/596)

## 0.9.2 - 🏗 Bug fixes, logging improvement

Expand Down
4 changes: 4 additions & 0 deletions docs/getting-started/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,7 @@ We currently implement the following methods:

- [**Nyström Influence**][pydvl.influence.torch.NystroemSketchInfluence], based
on the ideas in [@hataya_nystrom_2023] for bi-level optimization.

- [**Inverse-harmonic-mean Influence**][pydvl.influence.torch.InverseHarmonicMeanInfluence]
[@kwon_datainf_2023].

8 changes: 7 additions & 1 deletion docs/influence/influence_function_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,16 @@ if_model = NystroemSketchInfluence(
model,
loss,
rank=10,
hessian_regularization=0.0,
regularization=0.0,
block_structure=BlockMode.FULL,
second_order_mode=SecondOrderMode.HESSIAN
)
if_model.fit(train_loader)
```
This implementation is capable of using a block-matrix
approximation, see
[Block-diagonal approximation](#block-diagonal-approximation), and can handle
[Gauss-Newton approximation](#gauss-newton-approximation).

### Inverse Harmonic Mean

Expand Down
64 changes: 32 additions & 32 deletions notebooks/influence_wine.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/pydvl/influence/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def _validate_tensor_input(self, tensor: torch.Tensor) -> None:

def _apply(self, tensor: torch.Tensor) -> torch.Tensor:

if tensor.ndim == 2 and tensor.shape[0] > 1:
if tensor.ndim == 2:
return self._apply_to_mat(tensor.to(self.device))

return self._apply_to_vec(tensor.to(self.device))
Expand Down
2 changes: 1 addition & 1 deletion src/pydvl/influence/torch/batch_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def apply(self, batch: TorchBatch, tensor: torch.Tensor):
"property `input_size`."
)

if tensor.ndim == 2 and tensor.shape[0] > 1:
if tensor.ndim == 2:
return self._apply_to_mat(batch.to(self.device), tensor.to(self.device))
return self._apply_to_vec(batch.to(self.device), tensor.to(self.device))

Expand Down
51 changes: 48 additions & 3 deletions src/pydvl/influence/torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import logging
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union

import torch
from scipy.sparse.linalg import ArpackNoConvergence
Expand All @@ -37,15 +37,16 @@
from torch.utils.data import DataLoader

from .util import (
BlockMode,
ModelParameterDictBuilder,
align_structure,
align_with_model,
flatten_dimensions,
get_model_parameters,
to_model_device,
)

if TYPE_CHECKING:
from .base import TensorOperator

__all__ = [
"create_hvp_function",
"hessian",
Expand Down Expand Up @@ -1048,3 +1049,47 @@ def model_hessian_mat_mat_prod(x: torch.Tensor):
shift_func=shift_func,
mat_vec_device=device,
)


def operator_nystroem_approximation(
operator: "TensorOperator",
rank: int,
shift_func: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
):
r"""
Given an operator (representing a symmetric positive definite
matrix $A$ ), computes a random Nyström low rank approximation of
$A$ in factored form, i.e.
$$ A_{\text{nys}} = (A \Omega)(\Omega^T A \Omega)^{\dagger}(A \Omega)^T
= U \Sigma U^T $$
where $\Omega$ is a standard normal random matrix.
Args:
operator: the operator to approximate
rank: rank of the approximation
shift_func: optional function for computing the stabilizing shift in the
construction of the randomized nystroem approximation, defaults to
$$ \sqrt{\operatorname{\text{input_dim}}} \cdot
\varepsilon(\operatorname{\text{input_type}}) \cdot \|A\Omega\|_2,$$
where $\varepsilon(\operatorname{\text{input_type}})$ is the value of the
machine precision corresponding to the data type.
Returns:
object containing, $U$ and $\Sigma$
"""

def mat_mat_prod(x: torch.Tensor):
return operator.apply(x.t()).t()

return randomized_nystroem_approximation(
mat_mat_prod,
operator.input_size,
rank,
operator.dtype,
shift_func=shift_func,
mat_vec_device=operator.device,
)
89 changes: 54 additions & 35 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,16 @@
hessian,
model_hessian_low_rank,
model_hessian_nystroem_approximation,
operator_nystroem_approximation,
)
from .operator import (
DirectSolveOperator,
GaussNewtonOperator,
HessianOperator,
InverseHarmonicMeanOperator,
LissaOperator,
LowRankOperator,
)
from .operator import DirectSolveOperator, InverseHarmonicMeanOperator, LissaOperator
from .pre_conditioner import PreConditioner
from .util import (
BlockMode,
Expand Down Expand Up @@ -1679,7 +1687,7 @@ def to(self, device: torch.device):
return super().to(device)


class NystroemSketchInfluence(TorchInfluenceFunctionModel):
class NystroemSketchInfluence(TorchComposableInfluence[LowRankOperator]):
r"""
Given a model and training data, it uses a low-rank approximation of the Hessian
(derived via random projection Nyström approximation) in combination with
Expand All @@ -1703,58 +1711,69 @@ class NystroemSketchInfluence(TorchInfluenceFunctionModel):
this model's parameters.
loss: A callable that takes the model's output and target as input and returns
the scalar loss.
hessian_regularization: Optional regularization parameter added
regularization: Optional regularization parameter added
to the Hessian-vector product for numerical stability.
rank: rank of the low-rank approximation
"""

low_rank_representation: LowRankProductRepresentation

def __init__(
self,
model: torch.nn.Module,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
hessian_regularization: float,
regularization: Union[float, Dict[str, float]],
rank: int,
block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL,
second_order_mode: SecondOrderMode = SecondOrderMode.HESSIAN,
):
super().__init__(model, loss)
self.hessian_regularization = hessian_regularization
super().__init__(
model,
block_structure,
regularization=cast(
Union[float, Dict[str, Optional[float]]], regularization
),
)
self.second_order_mode = second_order_mode
self.rank = rank
self.loss = loss

def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
regularized_eigenvalues = (
self.low_rank_representation.eigen_vals + self.hessian_regularization
)
def with_regularization(
self, regularization: Union[float, Dict[str, Optional[float]]]
) -> TorchComposableInfluence:
self._regularization_dict = self._build_regularization_dict(regularization)
for k, reg in self._regularization_dict.items():
self.block_mapper.composable_block_dict[k].op.regularization = reg
return self

proj_rhs = self.low_rank_representation.projections.t() @ rhs.t()
inverse_regularized_eigenvalues = 1.0 / regularized_eigenvalues
result = self.low_rank_representation.projections @ (
proj_rhs * inverse_regularized_eigenvalues.unsqueeze(-1)
)
def _create_block(
self,
block_params: Dict[str, torch.nn.Parameter],
data: DataLoader,
regularization: Optional[float],
) -> TorchOperatorGradientComposition:

if self.hessian_regularization > 0.0:
result += (
1.0
/ self.hessian_regularization
* (rhs.t() - self.low_rank_representation.projections @ proj_rhs)
assert regularization is not None
regularization = cast(float, regularization)

op: Union[HessianOperator, GaussNewtonOperator]

if self.second_order_mode is SecondOrderMode.HESSIAN:
op = HessianOperator(self.model, self.loss, data, restrict_to=block_params)
elif self.second_order_mode is SecondOrderMode.GAUSS_NEWTON:
op = GaussNewtonOperator(
self.model, self.loss, data, restrict_to=block_params
)
else:
raise ValueError(f"Unsupported second order mode: {self.second_order_mode}")

return result.t()
low_rank_repr = operator_nystroem_approximation(op, self.rank)
low_rank_op = LowRankOperator(low_rank_repr, regularization)
gp = TorchGradientProvider(self.model, self.loss, restrict_to=block_params)
return TorchOperatorGradientComposition(low_rank_op, gp)

@property
def is_fitted(self):
try:
return self.low_rank_representation is not None
except AttributeError:
return False

@log_duration(log_level=logging.INFO)
def fit(self, data: DataLoader):
self.low_rank_representation = model_hessian_nystroem_approximation(
self.model, self.loss, data, self.rank
)
return self
def is_thread_safe(self) -> bool:
return False


class InverseHarmonicMeanInfluence(
Expand Down
94 changes: 94 additions & 0 deletions src/pydvl/influence/torch/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PointAveraging,
TensorAveragingType,
)
from .functional import LowRankProductRepresentation

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -471,3 +472,96 @@ def _apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor:
@property
def input_size(self) -> int:
return self.batch_operation.input_size


class LowRankOperator(TensorOperator):
r"""
Given a low rank representation of a matrix
$$ A = V D V^T$$
with a diagonal matrix $D$ and an optional regularization parameter $\lambda$,
computes
$$ (V D V^T+\lambda I)^{-1}b$$.
Depending on the value of the `exact` flag, the inverse action is computed exactly
using the [Sherman–Morrison–Woodbury formula]
(https://en.wikipedia.org/wiki/Woodbury_matrix_identity). If `exact` is set to
`False`, the inverse action is approximated by
$$ V^T(D+\lambda I)^{-1}Vb$$
Args:
"""

def __init__(
self,
low_rank_representation: LowRankProductRepresentation,
regularization: float,
exact: bool = True,
):

if exact and (regularization is None or regularization <= 0):
raise ValueError("regularization must be positive when exact=True")
elif regularization is not None and regularization < 0:
raise ValueError("regularization must be non-negative")

self._regularization = regularization
self._exact = exact
self._low_rank_representation = low_rank_representation

@property
def exact(self):
return self._exact

@property
def regularization(self):
return self._regularization

@regularization.setter
def regularization(self, value: float):
if value < 0:
raise ValueError("regularization must be non-negative")
self._regularization = value

@property
def device(self):
return self._low_rank_representation.device

@property
def dtype(self):
return self._low_rank_representation.dtype

def to(self, device: torch.device):
self._low_rank_representation = self._low_rank_representation.to(device)
return self

def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor:

if vec.ndim == 1:
return self._apply_to_mat(vec.unsqueeze(0)).squeeze()

return self._apply_to_mat(vec)

def _apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor:

D = self._low_rank_representation.eigen_vals.clone()
V = self._low_rank_representation.projections

if self.regularization is not None:
D += self.regularization

V_t_mat = V.t() @ mat.t()
D_inv = 1.0 / D
result = V @ (V_t_mat * D_inv.unsqueeze(-1))

if self._exact:
result += 1.0 / self.regularization * (mat.t() - V @ V_t_mat)

return result.t()

@property
def input_size(self) -> int:
result: int = self._low_rank_representation.projections.shape[0]
return result
2 changes: 1 addition & 1 deletion tests/influence/test_influence_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def influence_model(model_and_data, test_case, influence_factory):
model,
loss,
rank=5,
hessian_regularization=hessian_reg,
regularization=hessian_reg,
).fit(train_dataLoader),
],
ids=["cg", "direct", "arnoldi", "nystroem-sketch"],
Expand Down
Loading

0 comments on commit debb822

Please sign in to comment.