From dce5139e849a1fdc70b47ccb3155e6f28a0117ad Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Mon, 3 Jun 2024 21:32:32 +0200 Subject: [PATCH] Implementation of block-diagonal and Gauss-Newton approximation for LissaInfluence --- src/pydvl/influence/torch/batch_operation.py | 14 +- .../torch/influence_function_model.py | 140 ++++++++---------- src/pydvl/influence/torch/operator.py | 110 ++++++++++++++ tests/influence/torch/test_influence_model.py | 9 +- 4 files changed, 186 insertions(+), 87 deletions(-) diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py index 908da3ef3..2e39ce8e6 100644 --- a/src/pydvl/influence/torch/batch_operation.py +++ b/src/pydvl/influence/torch/batch_operation.py @@ -60,7 +60,7 @@ def dtype(self): return next(self.model.parameters()).dtype @property - def input_size(self): + def input_size(self) -> int: return sum(p.numel() for p in self.params_to_restrict_to.values()) def to(self, device: torch.device): @@ -138,7 +138,7 @@ def apply(self, batch: TorchBatch, tensor: torch.Tensor): "property `input_size`." ) - if tensor.ndim == 2: + if tensor.ndim == 2 and tensor.shape[0] > 1: return self._apply_to_mat(batch.to(self.device), tensor.to(self.device)) return self._apply_to_vec(batch.to(self.device), tensor.to(self.device)) @@ -156,11 +156,14 @@ def _apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: $(N, \text{input_size})$ """ - return torch.func.vmap( + result = torch.func.vmap( lambda _x, _y, m: self._apply_to_vec(TorchBatch(_x, _y), m), in_dims=(None, None, 0), randomness="same", )(batch.x, batch.y, mat) + if result.requires_grad: + result = result.detach() + return result class HessianBatchOperation(_ModelBasedBatchOperation): @@ -196,7 +199,10 @@ def __init__( self.loss = loss def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: - return self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec) + result = self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec) + if result.requires_grad: + result = result.detach() + return result def _apply_to_dict( self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index c1c732beb..1b8167a8a 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -28,6 +28,11 @@ TorchGradientProvider, TorchOperatorGradientComposition, ) +from .batch_operation import ( + BatchOperationType, + GaussNewtonBatchOperation, + HessianBatchOperation, +) from .functional import ( LowRankProductRepresentation, create_batch_hvp_function, @@ -40,7 +45,7 @@ model_hessian_low_rank, model_hessian_nystroem_approximation, ) -from .operator import DirectSolveOperator, InverseHarmonicMeanOperator +from .operator import DirectSolveOperator, InverseHarmonicMeanOperator, LissaOperator from .pre_conditioner import PreConditioner from .util import ( BlockMode, @@ -825,7 +830,7 @@ def to(self, device: torch.device): return super().to(device) -class LissaInfluence(TorchInfluenceFunctionModel): +class LissaInfluence(TorchComposableInfluence[LissaOperator[BatchOperationType]]): r""" Uses LISSA, Linear time Stochastic Second-Order Algorithm, to iteratively approximate the inverse Hessian. More precisely, it finds x s.t. \(Hx = b\), @@ -844,12 +849,11 @@ class LissaInfluence(TorchInfluenceFunctionModel): this model's parameters. loss: A callable that takes the model's output and target as input and returns the scalar loss. - hessian_regularization: Optional regularization parameter added + regularization: Optional regularization parameter added to the Hessian-vector product for numerical stability. maxiter: Maximum number of iterations. dampen: Dampening factor, defaults to 0 for no dampening. scale: Scaling factor, defaults to 10. - h0: Initial guess for hvp. rtol: tolerance to use for early stopping progress: If True, display progress bars. warn_on_max_iteration: If True, logs a warning, if the desired tolerance is not @@ -861,104 +865,76 @@ def __init__( self, model: nn.Module, loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], - hessian_regularization: float = 0.0, + regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None, maxiter: int = 1000, dampen: float = 0.0, scale: float = 10.0, - h0: Optional[torch.Tensor] = None, rtol: float = 1e-4, progress: bool = False, warn_on_max_iteration: bool = True, + block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL, + second_order_mode: SecondOrderMode = SecondOrderMode.HESSIAN, ): - super().__init__(model, loss) - self.warn_on_max_iteration = warn_on_max_iteration + super().__init__(model, block_structure, regularization) self.maxiter = maxiter - self.hessian_regularization = hessian_regularization self.progress = progress self.rtol = rtol - self.h0 = h0 self.scale = scale self.dampen = dampen + self.loss = loss + self.second_order_mode = second_order_mode + self.warn_on_max_iteration = warn_on_max_iteration - train_dataloader: DataLoader + def with_regularization( + self, regularization: Union[float, Dict[str, Optional[float]]] + ) -> TorchComposableInfluence: + """ + Update the regularization parameter. + Args: + regularization: Either a positive float or a dictionary with the + block names as keys and the regularization values as values. - @property - def is_fitted(self): - try: - return self.train_dataloader is not None - except AttributeError: - return False + Returns: + The modified instance - @log_duration(log_level=logging.INFO) - def fit(self, data: DataLoader) -> LissaInfluence: - self.train_dataloader = data + """ + self._regularization_dict = self._build_regularization_dict(regularization) + for k, reg in self._regularization_dict.items(): + self.block_mapper.composable_block_dict[k].op.regularization = reg return self - @log_duration - def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor: - h_estimate = self.h0 if self.h0 is not None else torch.clone(rhs) - - shuffled_training_data = DataLoader( - self.train_dataloader.dataset, - self.train_dataloader.batch_size, - shuffle=True, - ) - - def lissa_step( - h: torch.Tensor, reg_hvp: Callable[[torch.Tensor], torch.Tensor] - ) -> torch.Tensor: - """Given an estimate of the hessian inverse and the regularised hessian - vector product, it computes the next estimate. - - Args: - h: An estimate of the hessian inverse. - reg_hvp: Regularised hessian vector product. - - Returns: - The next estimate of the hessian inverse. - """ - return rhs + (1 - self.dampen) * h - reg_hvp(h) / self.scale - - model_params = { - k: p.detach() for k, p in self.model.named_parameters() if p.requires_grad - } - b_hvp = torch.vmap( - create_batch_hvp_function(self.model, self.loss), - in_dims=(None, None, None, 0), - ) - for k in tqdm( - range(self.maxiter), disable=not self.progress, desc="Lissa iteration" - ): - x, y = next(iter(shuffled_training_data)) - x = x.to(self.model_device) - y = y.to(self.model_device) - reg_hvp = ( - lambda v: b_hvp(model_params, x, y, v) + self.hessian_regularization * v + def _create_block( + self, + block_params: Dict[str, torch.nn.Parameter], + data: DataLoader, + regularization: Optional[float], + ) -> TorchOperatorGradientComposition: + gp = TorchGradientProvider(self.model, self.loss, restrict_to=block_params) + batch_op: Union[GaussNewtonBatchOperation, HessianBatchOperation] + if self.second_order_mode is SecondOrderMode.GAUSS_NEWTON: + batch_op = GaussNewtonBatchOperation( + self.model, self.loss, restrict_to=block_params ) - residual = lissa_step(h_estimate, reg_hvp) - h_estimate - h_estimate += residual - if torch.isnan(h_estimate).any(): - raise RuntimeError("NaNs in h_estimate. Increase scale or dampening.") - max_residual = torch.max(torch.abs(residual / h_estimate)) - if max_residual < self.rtol: - mean_residual = torch.mean(torch.abs(residual / h_estimate)) - logger.debug( - f"Terminated Lissa after {k} iterations with " - f"{max_residual*100:.2f} % max residual and" - f" mean residual {mean_residual*100:.5f} %" - ) - break else: - mean_residual = torch.mean(torch.abs(residual / h_estimate)) - log_level = logging.WARNING if self.warn_on_max_iteration else logging.DEBUG - logger.log( - log_level, - f"Reached max number of iterations {self.maxiter} without " - f"achieving the desired tolerance {self.rtol}.\n " - f"Achieved max residual {max_residual*100:.2f} % and" - f" {mean_residual*100:.5f} % mean residual", + batch_op = HessianBatchOperation( + self.model, self.loss, restrict_to=block_params ) - return h_estimate / self.scale + lissa_op = LissaOperator( + batch_op, + data, + regularization, + maxiter=self.maxiter, + dampen=self.dampen, + scale=self.scale, + rtol=self.rtol, + progress=self.progress, + warn_on_max_iteration=self.warn_on_max_iteration, + ) + return TorchOperatorGradientComposition(lissa_op, gp) + + @property + def is_thread_safe(self) -> bool: + return False class ArnoldiInfluence(TorchInfluenceFunctionModel): diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py index 74aaec287..d0901bb30 100644 --- a/src/pydvl/influence/torch/operator.py +++ b/src/pydvl/influence/torch/operator.py @@ -1,8 +1,10 @@ +import logging from typing import Callable, Dict, Generic, Optional, Tuple import torch from torch import nn as nn from torch.utils.data import DataLoader +from tqdm import tqdm from .base import TensorDictOperator, TensorOperator, TorchBatch from .batch_operation import ( @@ -15,6 +17,8 @@ TensorAveragingType, ) +logger = logging.getLogger(__name__) + class _AveragingBatchOperator( TensorDictOperator, Generic[BatchOperationType, TensorAveragingType] @@ -306,3 +310,109 @@ def _apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: def input_size(self) -> int: result: int = self.matrix.shape[-1] return result + + +class LissaOperator(TensorOperator, Generic[BatchOperationType]): + def __init__( + self, + batch_operation: BatchOperationType, + data: DataLoader, + regularization: Optional[float] = None, + maxiter: int = 1000, + dampen: float = 0.0, + scale: float = 10.0, + rtol: float = 1e-4, + progress: bool = False, + warn_on_max_iteration: bool = True, + ): + + if regularization is not None and regularization < 0: + raise ValueError("regularization must be non-negative") + + self.data = data + self.warn_on_max_iteration = warn_on_max_iteration + self.progress = progress + self.rtol = rtol + self.scale = scale + self.dampen = dampen + self.maxiter = maxiter + self.batch_operation = batch_operation + self._regularization = regularization + + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + if value < 0: + raise ValueError("regularization must be non-negative") + self._regularization = value + + @property + def device(self): + return self.batch_operation.device + + @property + def dtype(self): + return self.batch_operation.dtype + + def to(self, device: torch.device): + self.batch_operation = self.batch_operation.to(device) + return self + + def _reg_apply(self, batch: TorchBatch, h: torch.Tensor): + result = self.batch_operation.apply(batch, h) + if self.regularization is not None: + result += self.regularization * h + return result + + def _lissa_step(self, h: torch.Tensor, rhs: torch.Tensor, batch: TorchBatch): + result = rhs + (1 - self.dampen) * h - self._reg_apply(batch, h) / self.scale + if result.requires_grad: + result = result.detach() + return result + + def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: + h_estimate = torch.clone(vec) + shuffled_training_data = DataLoader( + self.data.dataset, + self.data.batch_size, + shuffle=True, + ) + for k in tqdm( + range(self.maxiter), disable=not self.progress, desc="Lissa iteration" + ): + x, y = next(iter(shuffled_training_data)) + + residual = self._lissa_step(h_estimate, vec, TorchBatch(x, y)) - h_estimate + h_estimate += residual + if torch.isnan(h_estimate).any(): + raise RuntimeError("NaNs in h_estimate. Increase scale or dampening.") + max_residual = torch.max(torch.abs(residual / h_estimate)) + if max_residual < self.rtol: + mean_residual = torch.mean(torch.abs(residual / h_estimate)) + logger.debug( + f"Terminated Lissa after {k} iterations with " + f"{max_residual*100:.2f} % max residual and" + f" mean residual {mean_residual*100:.5f} %" + ) + break + else: + mean_residual = torch.mean(torch.abs(residual / h_estimate)) + log_level = logging.WARNING if self.warn_on_max_iteration else logging.DEBUG + logger.log( + log_level, + f"Reached max number of iterations {self.maxiter} without " + f"achieving the desired tolerance {self.rtol}.\n " + f"Achieved max residual {max_residual*100:.2f} % and" + f" {mean_residual*100:.5f} % mean residual", + ) + return h_estimate / self.scale + + def _apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + return self._apply_to_vec(mat) + + @property + def input_size(self) -> int: + return self.batch_operation.input_size diff --git a/tests/influence/torch/test_influence_model.py b/tests/influence/torch/test_influence_model.py index b8266df58..39b4a74b3 100644 --- a/tests/influence/torch/test_influence_model.py +++ b/tests/influence/torch/test_influence_model.py @@ -487,7 +487,7 @@ def upper_quantile_equivalence( lambda model, loss, train_dataLoader, hessian_reg: LissaInfluence( model, loss, - hessian_regularization=hessian_reg, + regularization=hessian_reg, maxiter=150, scale=10000, ).fit(train_dataLoader), @@ -830,6 +830,13 @@ def test_influences_cg( InverseHarmonicMeanInfluence, DirectInfluence, partial(DirectInfluence, second_order_mode=SecondOrderMode.GAUSS_NEWTON), + partial(LissaInfluence, maxiter=100, scale=10000), + partial( + LissaInfluence, + maxiter=150, + scale=10000, + second_order_mode=SecondOrderMode.GAUSS_NEWTON, + ), ]