Skip to content

Commit

Permalink
Implementation of block-diagonal and Gauss-Newton approximation for L…
Browse files Browse the repository at this point in the history
…issaInfluence
  • Loading branch information
schroedk committed Jun 5, 2024
1 parent dff2d40 commit dce5139
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 87 deletions.
14 changes: 10 additions & 4 deletions src/pydvl/influence/torch/batch_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def dtype(self):
return next(self.model.parameters()).dtype

@property
def input_size(self):
def input_size(self) -> int:
return sum(p.numel() for p in self.params_to_restrict_to.values())

def to(self, device: torch.device):
Expand Down Expand Up @@ -138,7 +138,7 @@ def apply(self, batch: TorchBatch, tensor: torch.Tensor):
"property `input_size`."
)

if tensor.ndim == 2:
if tensor.ndim == 2 and tensor.shape[0] > 1:
return self._apply_to_mat(batch.to(self.device), tensor.to(self.device))
return self._apply_to_vec(batch.to(self.device), tensor.to(self.device))

Expand All @@ -156,11 +156,14 @@ def _apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor:
$(N, \text{input_size})$
"""
return torch.func.vmap(
result = torch.func.vmap(
lambda _x, _y, m: self._apply_to_vec(TorchBatch(_x, _y), m),
in_dims=(None, None, 0),
randomness="same",
)(batch.x, batch.y, mat)
if result.requires_grad:
result = result.detach()
return result


class HessianBatchOperation(_ModelBasedBatchOperation):
Expand Down Expand Up @@ -196,7 +199,10 @@ def __init__(
self.loss = loss

def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor:
return self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec)
result = self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec)
if result.requires_grad:
result = result.detach()
return result

def _apply_to_dict(
self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor]
Expand Down
140 changes: 58 additions & 82 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
TorchGradientProvider,
TorchOperatorGradientComposition,
)
from .batch_operation import (
BatchOperationType,
GaussNewtonBatchOperation,
HessianBatchOperation,
)
from .functional import (
LowRankProductRepresentation,
create_batch_hvp_function,
Expand All @@ -40,7 +45,7 @@
model_hessian_low_rank,
model_hessian_nystroem_approximation,
)
from .operator import DirectSolveOperator, InverseHarmonicMeanOperator
from .operator import DirectSolveOperator, InverseHarmonicMeanOperator, LissaOperator
from .pre_conditioner import PreConditioner
from .util import (
BlockMode,
Expand Down Expand Up @@ -825,7 +830,7 @@ def to(self, device: torch.device):
return super().to(device)


class LissaInfluence(TorchInfluenceFunctionModel):
class LissaInfluence(TorchComposableInfluence[LissaOperator[BatchOperationType]]):
r"""
Uses LISSA, Linear time Stochastic Second-Order Algorithm, to iteratively
approximate the inverse Hessian. More precisely, it finds x s.t. \(Hx = b\),
Expand All @@ -844,12 +849,11 @@ class LissaInfluence(TorchInfluenceFunctionModel):
this model's parameters.
loss: A callable that takes the model's output and target as input and returns
the scalar loss.
hessian_regularization: Optional regularization parameter added
regularization: Optional regularization parameter added
to the Hessian-vector product for numerical stability.
maxiter: Maximum number of iterations.
dampen: Dampening factor, defaults to 0 for no dampening.
scale: Scaling factor, defaults to 10.
h0: Initial guess for hvp.
rtol: tolerance to use for early stopping
progress: If True, display progress bars.
warn_on_max_iteration: If True, logs a warning, if the desired tolerance is not
Expand All @@ -861,104 +865,76 @@ def __init__(
self,
model: nn.Module,
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
hessian_regularization: float = 0.0,
regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None,
maxiter: int = 1000,
dampen: float = 0.0,
scale: float = 10.0,
h0: Optional[torch.Tensor] = None,
rtol: float = 1e-4,
progress: bool = False,
warn_on_max_iteration: bool = True,
block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL,
second_order_mode: SecondOrderMode = SecondOrderMode.HESSIAN,
):
super().__init__(model, loss)
self.warn_on_max_iteration = warn_on_max_iteration
super().__init__(model, block_structure, regularization)
self.maxiter = maxiter
self.hessian_regularization = hessian_regularization
self.progress = progress
self.rtol = rtol
self.h0 = h0
self.scale = scale
self.dampen = dampen
self.loss = loss
self.second_order_mode = second_order_mode
self.warn_on_max_iteration = warn_on_max_iteration

train_dataloader: DataLoader
def with_regularization(
self, regularization: Union[float, Dict[str, Optional[float]]]
) -> TorchComposableInfluence:
"""
Update the regularization parameter.
Args:
regularization: Either a positive float or a dictionary with the
block names as keys and the regularization values as values.
@property
def is_fitted(self):
try:
return self.train_dataloader is not None
except AttributeError:
return False
Returns:
The modified instance
@log_duration(log_level=logging.INFO)
def fit(self, data: DataLoader) -> LissaInfluence:
self.train_dataloader = data
"""
self._regularization_dict = self._build_regularization_dict(regularization)
for k, reg in self._regularization_dict.items():
self.block_mapper.composable_block_dict[k].op.regularization = reg
return self

@log_duration
def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
h_estimate = self.h0 if self.h0 is not None else torch.clone(rhs)

shuffled_training_data = DataLoader(
self.train_dataloader.dataset,
self.train_dataloader.batch_size,
shuffle=True,
)

def lissa_step(
h: torch.Tensor, reg_hvp: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
"""Given an estimate of the hessian inverse and the regularised hessian
vector product, it computes the next estimate.
Args:
h: An estimate of the hessian inverse.
reg_hvp: Regularised hessian vector product.
Returns:
The next estimate of the hessian inverse.
"""
return rhs + (1 - self.dampen) * h - reg_hvp(h) / self.scale

model_params = {
k: p.detach() for k, p in self.model.named_parameters() if p.requires_grad
}
b_hvp = torch.vmap(
create_batch_hvp_function(self.model, self.loss),
in_dims=(None, None, None, 0),
)
for k in tqdm(
range(self.maxiter), disable=not self.progress, desc="Lissa iteration"
):
x, y = next(iter(shuffled_training_data))
x = x.to(self.model_device)
y = y.to(self.model_device)
reg_hvp = (
lambda v: b_hvp(model_params, x, y, v) + self.hessian_regularization * v
def _create_block(
self,
block_params: Dict[str, torch.nn.Parameter],
data: DataLoader,
regularization: Optional[float],
) -> TorchOperatorGradientComposition:
gp = TorchGradientProvider(self.model, self.loss, restrict_to=block_params)
batch_op: Union[GaussNewtonBatchOperation, HessianBatchOperation]
if self.second_order_mode is SecondOrderMode.GAUSS_NEWTON:
batch_op = GaussNewtonBatchOperation(
self.model, self.loss, restrict_to=block_params
)
residual = lissa_step(h_estimate, reg_hvp) - h_estimate
h_estimate += residual
if torch.isnan(h_estimate).any():
raise RuntimeError("NaNs in h_estimate. Increase scale or dampening.")
max_residual = torch.max(torch.abs(residual / h_estimate))
if max_residual < self.rtol:
mean_residual = torch.mean(torch.abs(residual / h_estimate))
logger.debug(
f"Terminated Lissa after {k} iterations with "
f"{max_residual*100:.2f} % max residual and"
f" mean residual {mean_residual*100:.5f} %"
)
break
else:
mean_residual = torch.mean(torch.abs(residual / h_estimate))
log_level = logging.WARNING if self.warn_on_max_iteration else logging.DEBUG
logger.log(
log_level,
f"Reached max number of iterations {self.maxiter} without "
f"achieving the desired tolerance {self.rtol}.\n "
f"Achieved max residual {max_residual*100:.2f} % and"
f" {mean_residual*100:.5f} % mean residual",
batch_op = HessianBatchOperation(
self.model, self.loss, restrict_to=block_params
)
return h_estimate / self.scale
lissa_op = LissaOperator(
batch_op,
data,
regularization,
maxiter=self.maxiter,
dampen=self.dampen,
scale=self.scale,
rtol=self.rtol,
progress=self.progress,
warn_on_max_iteration=self.warn_on_max_iteration,
)
return TorchOperatorGradientComposition(lissa_op, gp)

@property
def is_thread_safe(self) -> bool:
return False


class ArnoldiInfluence(TorchInfluenceFunctionModel):
Expand Down
110 changes: 110 additions & 0 deletions src/pydvl/influence/torch/operator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
from typing import Callable, Dict, Generic, Optional, Tuple

import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from .base import TensorDictOperator, TensorOperator, TorchBatch
from .batch_operation import (
Expand All @@ -15,6 +17,8 @@
TensorAveragingType,
)

logger = logging.getLogger(__name__)


class _AveragingBatchOperator(
TensorDictOperator, Generic[BatchOperationType, TensorAveragingType]
Expand Down Expand Up @@ -306,3 +310,109 @@ def _apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor:
def input_size(self) -> int:
result: int = self.matrix.shape[-1]
return result


class LissaOperator(TensorOperator, Generic[BatchOperationType]):
def __init__(
self,
batch_operation: BatchOperationType,
data: DataLoader,
regularization: Optional[float] = None,
maxiter: int = 1000,
dampen: float = 0.0,
scale: float = 10.0,
rtol: float = 1e-4,
progress: bool = False,
warn_on_max_iteration: bool = True,
):

if regularization is not None and regularization < 0:
raise ValueError("regularization must be non-negative")

self.data = data
self.warn_on_max_iteration = warn_on_max_iteration
self.progress = progress
self.rtol = rtol
self.scale = scale
self.dampen = dampen
self.maxiter = maxiter
self.batch_operation = batch_operation
self._regularization = regularization

@property
def regularization(self):
return self._regularization

@regularization.setter
def regularization(self, value: float):
if value < 0:
raise ValueError("regularization must be non-negative")
self._regularization = value

@property
def device(self):
return self.batch_operation.device

@property
def dtype(self):
return self.batch_operation.dtype

def to(self, device: torch.device):
self.batch_operation = self.batch_operation.to(device)
return self

def _reg_apply(self, batch: TorchBatch, h: torch.Tensor):
result = self.batch_operation.apply(batch, h)
if self.regularization is not None:
result += self.regularization * h
return result

def _lissa_step(self, h: torch.Tensor, rhs: torch.Tensor, batch: TorchBatch):
result = rhs + (1 - self.dampen) * h - self._reg_apply(batch, h) / self.scale
if result.requires_grad:
result = result.detach()
return result

def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor:
h_estimate = torch.clone(vec)
shuffled_training_data = DataLoader(
self.data.dataset,
self.data.batch_size,
shuffle=True,
)
for k in tqdm(
range(self.maxiter), disable=not self.progress, desc="Lissa iteration"
):
x, y = next(iter(shuffled_training_data))

residual = self._lissa_step(h_estimate, vec, TorchBatch(x, y)) - h_estimate
h_estimate += residual
if torch.isnan(h_estimate).any():
raise RuntimeError("NaNs in h_estimate. Increase scale or dampening.")
max_residual = torch.max(torch.abs(residual / h_estimate))
if max_residual < self.rtol:
mean_residual = torch.mean(torch.abs(residual / h_estimate))
logger.debug(
f"Terminated Lissa after {k} iterations with "
f"{max_residual*100:.2f} % max residual and"
f" mean residual {mean_residual*100:.5f} %"
)
break
else:
mean_residual = torch.mean(torch.abs(residual / h_estimate))
log_level = logging.WARNING if self.warn_on_max_iteration else logging.DEBUG
logger.log(
log_level,
f"Reached max number of iterations {self.maxiter} without "
f"achieving the desired tolerance {self.rtol}.\n "
f"Achieved max residual {max_residual*100:.2f} % and"
f" {mean_residual*100:.5f} % mean residual",
)
return h_estimate / self.scale

def _apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor:
return self._apply_to_vec(mat)

@property
def input_size(self) -> int:
return self.batch_operation.input_size
9 changes: 8 additions & 1 deletion tests/influence/torch/test_influence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def upper_quantile_equivalence(
lambda model, loss, train_dataLoader, hessian_reg: LissaInfluence(
model,
loss,
hessian_regularization=hessian_reg,
regularization=hessian_reg,
maxiter=150,
scale=10000,
).fit(train_dataLoader),
Expand Down Expand Up @@ -830,6 +830,13 @@ def test_influences_cg(
InverseHarmonicMeanInfluence,
DirectInfluence,
partial(DirectInfluence, second_order_mode=SecondOrderMode.GAUSS_NEWTON),
partial(LissaInfluence, maxiter=100, scale=10000),
partial(
LissaInfluence,
maxiter=150,
scale=10000,
second_order_mode=SecondOrderMode.GAUSS_NEWTON,
),
]


Expand Down

0 comments on commit dce5139

Please sign in to comment.