diff --git a/CHANGELOG.md b/CHANGELOG.md index 076d4d927..dfbe91943 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## Unreleased + +### Added + +- New method `InverseHarmonicMeanInfluence`, implementation for the paper + `DataInf: Efficiently Estimating Data Influence in LoRA-tuned LLMs and + Diffusion Models` + [PR #582](https://github.com/aai-institute/pyDVL/pull/582) +- Add new backend implementations for influence computation + to account for block-diagonal approximations + [PR #582](https://github.com/aai-institute/pyDVL/pull/582) + + ## 0.9.2 - 🏗 Bug fixes, logging improvement ### Added diff --git a/docs/assets/pydvl.bib b/docs/assets/pydvl.bib index 724e75f20..ed4dc30d3 100644 --- a/docs/assets/pydvl.bib +++ b/docs/assets/pydvl.bib @@ -122,7 +122,8 @@ @inproceedings{george_fast_2018 publisher = {Curran Associates, Inc.}, url = {https://proceedings.neurips.cc/paper/2018/hash/48000647b315f6f00f913caa757a70b3-Abstract.html}, urldate = {2024-01-12}, - abstract = {Optimization algorithms that leverage gradient covariance information, such as variants of natural gradient descent (Amari, 1998), offer the prospect of yielding more effective descent directions. For models with many parameters, the covari- ance matrix they are based on becomes gigantic, making them inapplicable in their original form. This has motivated research into both simple diagonal approxima- tions and more sophisticated factored approximations such as KFAC (Heskes, 2000; Martens \& Grosse, 2015; Grosse \& Martens, 2016). In the present work we draw inspiration from both to propose a novel approximation that is provably better than KFAC and amendable to cheap partial updates. It consists in tracking a diagonal variance, not in parameter coordinates, but in a Kronecker-factored eigenbasis, in which the diagonal approximation is likely to be more effective. Experiments show improvements over KFAC in optimization speed for several deep network architectures.} + abstract = {Optimization algorithms that leverage gradient covariance information, such as variants of natural gradient descent (Amari, 1998), offer the prospect of yielding more effective descent directions. For models with many parameters, the covari- ance matrix they are based on becomes gigantic, making them inapplicable in their original form. This has motivated research into both simple diagonal approxima- tions and more sophisticated factored approximations such as KFAC (Heskes, 2000; Martens \& Grosse, 2015; Grosse \& Martens, 2016). In the present work we draw inspiration from both to propose a novel approximation that is provably better than KFAC and amendable to cheap partial updates. It consists in tracking a diagonal variance, not in parameter coordinates, but in a Kronecker-factored eigenbasis, in which the diagonal approximation is likely to be more effective. Experiments show improvements over KFAC in optimization speed for several deep network architectures.}, + keywords = {notion} } @inproceedings{ghorbani_data_2019, @@ -175,7 +176,8 @@ @inproceedings{hataya_nystrom_2023 urldate = {2024-02-26}, abstract = {The essential difficulty of gradient-based bilevel optimization using implicit differentiation is to estimate the inverse Hessian vector product with respect to neural network parameters. This paper proposes to tackle this problem by the Nyström method and the Woodbury matrix identity, exploiting the low-rankness of the Hessian. Compared to existing methods using iterative approximation, such as conjugate gradient and the Neumann series approximation, the proposed method avoids numerical instability and can be efficiently computed in matrix operations without iterations. As a result, the proposed method works stably in various tasks and is faster than iterative approximations. Throughout experiments including large-scale hyperparameter optimization and meta learning, we demonstrate that the Nyström method consistently achieves comparable or even superior performance to other approaches. The source code is available from https://github.com/moskomule/hypergrad.}, eventtitle = {International {{Conference}} on {{Artificial Intelligence}} and {{Statistics}}}, - langid = {english} + langid = {english}, + keywords = {notion} } @article{ji_breakdownfree_2017, @@ -292,6 +294,18 @@ @inproceedings{kwon_beta_2022 keywords = {notion} } +@inproceedings{kwon_datainf_2023, + title = {{{DataInf}}: {{Efficiently Estimating Data Influence}} in {{LoRA-tuned LLMs}} and {{Diffusion Models}}}, + shorttitle = {{{DataInf}}}, + author = {Kwon, Yongchan and Wu, Eric and Wu, Kevin and Zou, James}, + date = {2023-10-13}, + doi = {10.48550/arXiv.2310.00902}, + url = {https://openreview.net/forum?id=9m02ib92Wz}, + urldate = {2023-10-27}, + abstract = {Quantifying the impact of training data points is crucial for understanding the outputs of machine learning models and for improving the transparency of the AI pipeline. The influence function is a principled and popular data attribution method, but its computational cost often makes it challenging to use. This issue becomes more pronounced in the setting of large language models and text-to-image models. In this work, we propose DataInf, an efficient influence approximation method that is practical for large-scale generative AI models. Leveraging an easy-to-compute closed-form expression, DataInf outperforms existing influence computation algorithms in terms of computational and memory efficiency. Our theoretical analysis shows that DataInf is particularly well-suited for parameter-efficient fine-tuning techniques such as LoRA. Through systematic empirical evaluations, we show that DataInf accurately approximates influence scores and is orders of magnitude faster than existing methods. In applications to RoBERTa-large, Llama-2-13B-chat, and stable-diffusion-v1.5 models, DataInf effectively identifies the most influential fine-tuning examples better than other approximate influence scores. Moreover, it can help to identify which data points are mislabeled.}, + eventtitle = {The {{Twelfth International Conference}} on {{Learning Representations}}} +} + @inproceedings{kwon_dataoob_2023, title = {Data-{{OOB}}: {{Out-of-bag Estimate}} as a {{Simple}} and {{Efficient Data Value}}}, shorttitle = {Data-{{OOB}}}, @@ -303,7 +317,7 @@ @inproceedings{kwon_dataoob_2023 issn = {2640-3498}, url = {https://proceedings.mlr.press/v202/kwon23e.html}, urldate = {2023-09-06}, - abstract = {Data valuation is a powerful framework for providing statistical insights into which data are beneficial or detrimental to model training. Many Shapley-based data valuation methods have shown promising results in various downstream tasks, however, they are well known to be computationally challenging as it requires training a large number of models. As a result, it has been recognized as infeasible to apply to large datasets. To address this issue, we propose Data-OOB, a new data valuation method for a bagging model that utilizes the out-of-bag estimate. The proposed method is computationally efficient and can scale to millions of data by reusing trained weak learners. Specifically, Data-OOB takes less than 2.25 hours on a single CPU processor when there are \$10\^{}6\$ samples to evaluate and the input dimension is 100. Furthermore, Data-OOB has solid theoretical interpretations in that it identifies the same important data point as the infinitesimal jackknife influence function when two different points are compared. We conduct comprehensive experiments using 12 classification datasets, each with thousands of sample sizes. We demonstrate that the proposed method significantly outperforms existing state-of-the-art data valuation methods in identifying mislabeled data and finding a set of helpful (or harmful) data points, highlighting the potential for applying data values in real-world applications.}, + abstract = {Data valuation is a powerful framework for providing statistical insights into which data are beneficial or detrimental to model training. Many Shapley-based data valuation methods have shown promising results in various downstream tasks, however, they are well known to be computationally challenging as it requires training a large number of models. As a result, it has been recognized as infeasible to apply to large datasets. To address this issue, we propose Data-OOB, a new data valuation method for a bagging model that utilizes the out-of-bag estimate. The proposed method is computationally efficient and can scale to millions of data by reusing trained weak learners. Specifically, Data-OOB takes less than 2.25 hours on a single CPU processor when there are \$10\textasciicircum 6\$ samples to evaluate and the input dimension is 100. Furthermore, Data-OOB has solid theoretical interpretations in that it identifies the same important data point as the infinitesimal jackknife influence function when two different points are compared. We conduct comprehensive experiments using 12 classification datasets, each with thousands of sample sizes. We demonstrate that the proposed method significantly outperforms existing state-of-the-art data valuation methods in identifying mislabeled data and finding a set of helpful (or harmful) data points, highlighting the potential for applying data values in real-world applications.}, eventtitle = {International {{Conference}} on {{Machine Learning}}}, langid = {english}, keywords = {notion} diff --git a/docs/influence/influence_function_model.md b/docs/influence/influence_function_model.md index 0a424e918..131cce052 100644 --- a/docs/influence/influence_function_model.md +++ b/docs/influence/influence_function_model.md @@ -207,7 +207,132 @@ if_model = NystroemSketchInfluence( if_model.fit(train_loader) ``` +### Inverse Harmonic Mean + +This implementation replaces the inverse Hessian matrix in the influence computation +with an approximation of the inverse Gauss-Newton vector product and was +proposed in [@kwon_datainf_2023]. + +The approximation method comprises +the following steps: + +1. Replace the Hessian $H(\theta)$ with the Gauss-Newton matrix + $G(\theta)$: + + \begin{equation*} + G(\theta)=n^{-1} \sum_{i=1}^n \nabla_{\theta}\ell_i\nabla_{\theta}\ell_i^T + \end{equation*} + + which results in + + \begin{equation*} + \mathcal{I}(z_{t}, z) \approx \nabla_{\theta} \ell(z_{t}, \theta)^T + (G(\theta) + \lambda I_d)^{-1} + \nabla_{\theta} \ell(z, \theta) + \end{equation*} + +2. Simplify the problem by breaking it down into a block diagonal structure, + where each block $G_l(\theta)$ corresponds to the l-th block: + + \begin{equation*} + G_{l}(\theta) = n^{-1} \sum_{i=1}^n \nabla_{\theta_l} \ell_i + \nabla_{\theta_l} \ell_i^{T} + \lambda_l I_{d_l}, + \end{equation*} + + which leads to + + \begin{equation*} + \mathcal{I}(z_{t}, z) \approx \nabla_{\theta} \ell(z_{t}, \theta)^T + \operatorname{diag}(G_1(\theta)^{-1}, + \dots, G_L(\theta)^{-1}) + \nabla_{\theta} \ell(z, \theta) + \end{equation*} + +3. Substitute the arithmetic mean of the rank-$1$ updates in + $G_l(\theta)$, with the inverse harmonic mean $R_l(\theta)$ of the rank-1 + updates: + + \begin{align*} + G_l(\theta)^{-1} &= \left( n^{-1} \sum_{i=1}^n \nabla_{\theta_l} + \ell(z_i, \theta) \nabla_{\theta_l} + \ell(z_i, \theta)^{T} + + \lambda_l I_{d_l}\right)^{-1} \\\ + R_{l}(\theta)&= n^{-1} \sum_{i=1}^n \left( \nabla_{\theta_l} + \ell(z_i, \theta) \nabla_{\theta_l} \ell(z_i, \theta)^{T} + + \lambda_l I_{d_l} \right)^{-1} + \end{align*} + +4. Use the + + Sherman–Morrison formula + + to get an explicit representation of the inverses in the definition of + $R_l(\theta):$ + + \begin{align*} + R_l(\theta) &= n^{-1} \sum_{i=1}^n \left( \nabla_{\theta_l} \ell_i + \nabla_{\theta_l} \ell_i^{T} + + \lambda_l I_{d_l}\right)^{-1} \\\ + &= n^{-1} \sum_{i=1}^n \lambda_l^{-1} \left(I_{d_l} + - \frac{\nabla_{\theta_l} \ell_i \nabla_{\theta_l} + \ell_i^{T}}{\lambda_l + + \\|\nabla_{\theta_l} \ell_i\\|_2^2}\right) + , + \end{align*} + + which means application of $R_l(\theta)$ boils down to computing $n$ + rank-$1$ updates. + +```python +from pydvl.influence.torch import InverseHarmonicMeanInfluence, BlockMode + +if_model = InverseHarmonicMeanInfluence( + model, + loss, + regularization=1e-1, + block_structure=BlockMode.LAYER_WISE +) +if_model.fit(train_loader) +``` + +!!! Info + This implementation is capable of using a block-matrix approximation. The + blocking structure can be specified via the `block_structure` parameter. + The `block_structure` parameter can either be a + [BlockMode][pydvl.influence.torch.util.BlockMode] enum (which provides + layer-wise or parameter-wise blocking) or a custom block structure defined + by an ordered dictionary with the keys being the block identifiers (arbitrary + strings) and the values being lists of parameter names contained in the block. + ```python + block_structure = OrderedDict( + ( + ("custom_block1", ["0.weight", "1.bias"]), + ("custom_block2", ["1.weight", "0.bias"]), + ) + ) + ``` + If you would like to apply a block-specific regularization, you can provide a + dictionary with the block names as keys and the regularization values as values. + In this case, the specification must be complete, i.e. every block must have + a positive regularization value. + ```python + regularization = { + "custom_block1": 0.1, + "custom_block2": 0.2, + } + ``` + Accordingly, if you choose a layer-wise or parameter-wise structure + (by providing `BlockMode.LAYER_WISE` or `BlockMode.PARAMETER_WISE` for + `block_structure`) the keys must be the layer names or parameter names, + respectively. + You can retrieve the block-wise influence information from the methods + with suffix `_by_block`. By default, `block_structure` is set to + `BlockMode.FULL` and in this case these methods will return a dictionary + with the empty string being the only key. + These implementations represent the calculation logic on in memory tensors. To scale up to large collection of data, we map these influence function models over these collections. For a detailed discussion see the documentation page [Scaling Computation](scaling_computation.md). + + diff --git a/docs/influence/scaling_computation.md b/docs/influence/scaling_computation.md index b8ffbe98f..32a2088ee 100644 --- a/docs/influence/scaling_computation.md +++ b/docs/influence/scaling_computation.md @@ -24,8 +24,7 @@ into memory. ```python from pydvl.influence import SequentialInfluenceCalculator from pydvl.influence.torch.util import ( - NestedTorchCatAggregator, - TorchNumpyConverter, + TorchNumpyConverter, NestedTorchCatAggregator, ) from pydvl.influence.torch import CgInfluence diff --git a/src/pydvl/influence/__init__.py b/src/pydvl/influence/__init__.py index 6065b7cf9..187c98de1 100644 --- a/src/pydvl/influence/__init__.py +++ b/src/pydvl/influence/__init__.py @@ -10,9 +10,9 @@ probably change. """ -from .base_influence_function_model import InfluenceMode from .influence_calculator import ( DaskInfluenceCalculator, DisableClientSingleThreadCheck, SequentialInfluenceCalculator, ) +from .types import InfluenceMode diff --git a/src/pydvl/influence/array.py b/src/pydvl/influence/array.py index 7e71050f9..7ad9a59f0 100644 --- a/src/pydvl/influence/array.py +++ b/src/pydvl/influence/array.py @@ -28,7 +28,7 @@ from zarr.storage import StoreLike from ..utils import log_duration -from .base_influence_function_model import TensorType +from .types import TensorType class NumpyConverter(Generic[TensorType], ABC): diff --git a/src/pydvl/influence/base_influence_function_model.py b/src/pydvl/influence/base_influence_function_model.py index 541fbedf0..058ef823b 100644 --- a/src/pydvl/influence/base_influence_function_model.py +++ b/src/pydvl/influence/base_influence_function_model.py @@ -1,26 +1,13 @@ from __future__ import annotations +import logging from abc import ABC, abstractmethod -from enum import Enum -from typing import Collection, Generic, Iterable, Optional, Type, TypeVar +from collections import OrderedDict +from functools import wraps +from typing import Generic, Iterable, Optional, Type, cast -__all__ = ["InfluenceMode"] - - -class InfluenceMode(str, Enum): - """ - Enum representation for the types of influence. - - Attributes: - Up: [Approximating the influence of a point] - [approximating-the-influence-of-a-point] - Perturbation: [Perturbation definition of the influence score] - [perturbation-definition-of-the-influence-score] - - """ - - Up = "up" - Perturbation = "perturbation" +from ..utils.progress import log_duration +from .types import BatchType, BlockMapperType, DataLoaderType, InfluenceMode, TensorType class UnsupportedInfluenceModeException(ValueError): @@ -42,15 +29,13 @@ def __init__(self, object_type: Type): class NotImplementedLayerRepresentationException(ValueError): def __init__(self, module_id: str): - message = f"Only Linear layers are supported, but found module {module_id} requiring grad." + message = ( + f"Only Linear layers are supported, but found module {module_id} " + f"requiring grad." + ) super().__init__(message) -"""Type variable for tensors, i.e. sequences of numbers""" -TensorType = TypeVar("TensorType", bound=Collection) -DataLoaderType = TypeVar("DataLoaderType", bound=Iterable) - - class InfluenceFunctionModel(Generic[TensorType, DataLoaderType], ABC): """ Generic abstract base class for computing influence related quantities. @@ -86,7 +71,36 @@ def fit(self, data: DataLoaderType) -> InfluenceFunctionModel: The fitted instance """ + @staticmethod + def fit_required(method): + """Decorator to enforce the fitted check""" + + @wraps(method) + def wrapper(self, *args, **kwargs): + if not self.is_fitted: + raise NotFittedException(type(self)) + return method(self, *args, **kwargs) + + return wrapper + def influence_factors(self, x: TensorType, y: TensorType) -> TensorType: + r""" + Computes the approximation of + + \[ H^{-1}\nabla_{\theta} \ell(y, f_{\theta}(x)) \] + + where the gradient is meant to be per sample of the batch $(x, y)$. + For all input tensors it is assumed, + that the first dimension is the batch dimension. + + Args: + x: model input to use in the gradient computations + y: label tensor to compute gradients + + Returns: + Tensor representing the element-wise inverse Hessian matrix vector products + + """ if not self.is_fitted: raise NotFittedException(type(self)) return self._influence_factors(x, y) @@ -117,8 +131,51 @@ def influences( y: Optional[TensorType] = None, mode: InfluenceMode = InfluenceMode.Up, ) -> TensorType: + r""" + Computes the approximation of + + \[ \langle H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}})), + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle H^{-1}\nabla_{\theta} \ell(y_{test}, f_{\theta}(x_{test})), + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. + + Args: + x_test: model input to use in the gradient computations + of $H^{-1}\nabla_{theta} \ell(y_{test}, f_{\theta}(x_{test}))$ + y_test: label tensor to compute gradients + x: optional model input to use in the gradient computations + $\nabla_{theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{test}$ + y: optional label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Tensor representing the element-wise scalar products for the provided batch + + """ if not self.is_fitted: raise NotFittedException(type(self)) + + if x is None and y is not None: + raise ValueError( + "Providing labels y, without providing model input x " + "is not supported" + ) + + if x is not None and y is None: + raise ValueError( + "Providing model input x, without providing labels y " + "is not supported" + ) + return self._influences(x_test, y_test, x, y, mode) @abstractmethod @@ -199,3 +256,268 @@ def influences_from_factors( Tensor representing the element-wise scalar products for the provided batch """ + + +class ComposableInfluence( + InfluenceFunctionModel, + Generic[TensorType, BatchType, DataLoaderType, BlockMapperType], + ABC, +): + """ + Generic abstract base class, that allow for block-wise computation of influence + quantities. Inherit from this base class for specific influence algorithms and + tensor frameworks. + """ + + block_mapper: BlockMapperType + + @property + def is_fitted(self): + try: + return self.block_mapper is not None + except AttributeError: + return False + + @log_duration(log_level=logging.INFO) + def fit(self, data: DataLoaderType) -> InfluenceFunctionModel: + """ + Fitting to provided data, by internally creating a block mapper instance from + it. + Args: + data: iterable of tensors + + Returns: + Fitted instance + """ + self.block_mapper = self._create_block_mapper(data) + return self + + @abstractmethod + def _create_block_mapper(self, data: DataLoaderType) -> BlockMapperType: + """ + Override this method to create a block mapper instance, that can be used + to compute block-wise influence quantities. + + Args: + data: iterable of tensors + + Returns: + BlockMapper instance + """ + pass + + @InfluenceFunctionModel.fit_required + def influences_by_block( + self, + x_test: TensorType, + y_test: TensorType, + x: Optional[TensorType] = None, + y: Optional[TensorType] = None, + mode: InfluenceMode = InfluenceMode.Up, + ) -> OrderedDict[str, TensorType]: + r""" + Compute the block-wise influence values for the provided data, i.e. an + approximation of + + \[ \langle H^{-1}\nabla_{theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}})), + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle H^{-1}\nabla_{theta} \ell(y_{test}, f_{\theta}(x_{test})), + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. + + Args: + x_test: model input to use in the gradient computations + of the approximation of + $H^{-1}\nabla_{theta} \ell(y_{test}, f_{\theta}(x_{test}))$ + y_test: label tensor to compute gradients + x: optional model input to use in the gradient computations + $\nabla_{theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{test}$ + y: optional label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Ordered dictionary of tensors representing the element-wise scalar products + for the provided batch per block. + + """ + left_batch = self._create_batch(x_test, y_test) + + if x is None: + if y is not None: + raise ValueError( + "Providing labels y, without providing model input x " + "is not supported" + ) + right_batch = left_batch + else: + if y is None: + raise ValueError( + "Providing model input x, without providing labels y " + "is not supported" + ) + right_batch = self._create_batch(x, y) + + return self.block_mapper.interactions(left_batch, right_batch, mode) + + @InfluenceFunctionModel.fit_required + def influence_factors_by_block( + self, x: TensorType, y: TensorType + ) -> OrderedDict[str, TensorType]: + r""" + Compute the block-wise approximation of + + \[ H^{-1}\nabla_{\theta} \ell(y, f_{\theta}(x)) \] + + where the gradient is meant to be per sample of the batch $(x, y)$. + + Args: + x: model input to use in the gradient computations + y: label tensor to compute gradients + + Returns: + Ordered dictionary of tensors representing the element-wise + approximate inverse Hessian matrix vector products per block. + + """ + return self.block_mapper.transformed_grads(self._create_batch(x, y)) + + @InfluenceFunctionModel.fit_required + def influences_from_factors_by_block( + self, + z_test_factors: OrderedDict[str, TensorType], + x: TensorType, + y: TensorType, + mode: InfluenceMode = InfluenceMode.Up, + ) -> OrderedDict[str, TensorType]: + r""" + Block-wise computation of + + \[ \langle z_{\text{test_factors}}, + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle z_{\text{test_factors}}, + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. The gradient is meant to be per sample + of the batch $(x, y)$. + + Args: + z_test_factors: pre-computed array, approximating + $H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}}))$ + x: model input to use in the gradient computations + $\nabla_{\theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{\theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{\text{test}}$ + y: label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Ordered dictionary of tensors representing the element-wise scalar products + for the provided batch per block + + """ + return self.block_mapper.interactions_from_transformed_grads( + z_test_factors, self._create_batch(x, y), mode + ) + + def _influence_factors(self, x: TensorType, y: TensorType) -> TensorType: + transformed_grads = self.influence_factors_by_block(x, y) + transformed_grads = ( + self._flatten_trailing_dim(t) for t in transformed_grads.values() + ) + return cast(TensorType, self._concat(transformed_grads, dim=-1)) + + @abstractmethod + def _concat(self, tensors: Iterable[TensorType], dim: int): + """Implement this to concat tensors at a specified dimension""" + + @abstractmethod + def _flatten_trailing_dim(self, tensor: TensorType): + """Implement this to flatten all but the first dimension""" + + def _influences( + self, + x_test: TensorType, + y_test: TensorType, + x: Optional[TensorType] = None, + y: Optional[TensorType] = None, + mode: InfluenceMode = InfluenceMode.Up, + ) -> TensorType: + left_batch = self._create_batch(x_test, y_test) + + if x is None: + right_batch = None + elif y is None: + raise ValueError( + "Providing model input x, without providing labels y " + "is not supported" + ) + else: + right_batch = self._create_batch(x, y) + + tensors = self.block_mapper.generate_interactions(left_batch, right_batch, mode) + return cast(TensorType, sum(tensors)) + + @InfluenceFunctionModel.fit_required + def influences_from_factors( + self, + z_test_factors: TensorType, + x: TensorType, + y: TensorType, + mode: InfluenceMode = InfluenceMode.Up, + ) -> TensorType: + r""" + Computation of + + \[ \langle z_{\text{test_factors}}, + \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the case of up-weighting influence, resp. + + \[ \langle z_{\text{test_factors}}, + \nabla_{x} \nabla_{\theta} \ell(y, f_{\theta}(x)) \rangle \] + + for the perturbation type influence case. The gradient is meant to be per sample + of the batch $(x, y)$. + + Args: + z_test_factors: pre-computed array, approximating + $H^{-1}\nabla_{\theta} \ell(y_{\text{test}}, + f_{\theta}(x_{\text{test}}))$ + x: model input to use in the gradient computations + $\nabla_{\theta}\ell(y, f_{\theta}(x))$, + resp. $\nabla_{x}\nabla_{\theta}\ell(y, f_{\theta}(x))$, + if None, use $x=x_{\text{test}}$ + y: label tensor to compute gradients + mode: enum value of [InfluenceMode] + [pydvl.influence.base_influence_function_model.InfluenceMode] + + Returns: + Tensor representing the element-wise scalar products for the provided batch + + """ + tensors = self.block_mapper.generate_interactions_from_transformed_grads( + z_test_factors, + self._create_batch(x, y), + mode, + ) + return cast(TensorType, sum(tensors)) + + @staticmethod + @abstractmethod + def _create_batch(x: TensorType, y: TensorType) -> BatchType: + """Implement this method to provide the creation of a subtype of + [Batch][pydvl.influence.types.Batch] for a specific framework + """ diff --git a/src/pydvl/influence/influence_calculator.py b/src/pydvl/influence/influence_calculator.py index 7c48e8636..327a4137d 100644 --- a/src/pydvl/influence/influence_calculator.py +++ b/src/pydvl/influence/influence_calculator.py @@ -18,10 +18,9 @@ from .array import LazyChunkSequence, NestedLazyChunkSequence, NumpyConverter from .base_influence_function_model import ( InfluenceFunctionModel, - InfluenceMode, - TensorType, UnsupportedInfluenceModeException, ) +from .types import InfluenceMode, TensorType __all__ = [ "DaskInfluenceCalculator", diff --git a/src/pydvl/influence/torch/__init__.py b/src/pydvl/influence/torch/__init__.py index 3bbd9552c..9b2299d0b 100644 --- a/src/pydvl/influence/torch/__init__.py +++ b/src/pydvl/influence/torch/__init__.py @@ -3,7 +3,9 @@ CgInfluence, DirectInfluence, EkfacInfluence, + InverseHarmonicMeanInfluence, LissaInfluence, NystroemSketchInfluence, ) from .pre_conditioner import JacobiPreConditioner, NystroemPreConditioner +from .util import BlockMode diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py new file mode 100644 index 000000000..65b6d4f8b --- /dev/null +++ b/src/pydvl/influence/torch/base.py @@ -0,0 +1,717 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections import OrderedDict +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union, cast + +import torch +from torch.func import functional_call +from torch.utils.data import DataLoader + +from ..base_influence_function_model import ComposableInfluence +from ..types import ( + Batch, + BilinearForm, + BlockMapper, + GradientProvider, + Operator, + OperatorGradientComposition, + TensorType, +) +from .util import ( + BlockMode, + LossType, + ModelInfoMixin, + ModelParameterDictBuilder, + align_structure, + flatten_dimensions, +) + + +@dataclass(frozen=True) +class TorchBatch(Batch): + """ + A convenience class for handling batches of data. Validates, the alignment + of the first dimension (batch dimension) of the input and target tensor + + Attributes: + x: The input tensor that contains features or data points. + y: The target tensor that contains labels corresponding to the inputs. + + """ + + x: torch.Tensor + y: torch.Tensor + + def __iter__(self): + return iter((self.x, self.y)) + + def __post_init__(self): + if self.x.shape[0] != self.y.shape[0]: + raise ValueError( + f"The first dimension of x and y must be the same, " + f"got {self.x.shape[0]} and {self.y.shape[0]}" + ) + + def __len__(self): + return self.x.shape[0] + + def to(self, device: torch.device): + return TorchBatch(self.x.to(device), self.y.to(device)) + + +class TorchGradientProvider(GradientProvider[TorchBatch, torch.Tensor]): + r""" + Compute per-sample gradients of a function defined by + a [torch.nn.Module][torch.nn.Module] and a loss function using + [torch.func][torch.func]. + + Consider a function + + $$ \ell: \mathbb{R}^{d_1} \times \mathbb{R}^{d_2} \times \mathbb{R}^{n} \times + \mathbb{R}^{n}, \quad \ell(\omega_1, \omega_2, x, y) = + \operatorname{loss}(f(\omega_1, \omega_2; x), y) $$ + + e.g. a two layer neural network $f$ with a loss function, then this object should + compute the expressions: + + $$ \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega}\ell(\omega_1, \omega_2, x, y) \cdot v$$ + + """ + + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + restrict_to: Optional[Dict[str, torch.nn.Parameter]], + ): + self.model = model + self.loss = loss + + if restrict_to is None: + restrict_to = ModelParameterDictBuilder(model).build_from_block_mode( + BlockMode.FULL + ) + + self.params_to_restrict_to = restrict_to + + def _compute_loss( + self, params: Dict[str, torch.Tensor], x: torch.Tensor, y: torch.Tensor + ) -> torch.Tensor: + outputs = functional_call(self.model, params, (x.unsqueeze(0).to(self.device),)) + return self.loss(outputs, y.unsqueeze(0)) + + def _grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + result: Dict[str, torch.Tensor] = torch.vmap( + torch.func.grad(self._compute_loss), in_dims=(None, 0, 0) + )(self.params_to_restrict_to, batch.x, batch.y) + return result + + def _mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + result: Dict[str, torch.Tensor] = torch.vmap( + torch.func.jacrev(torch.func.grad(self._compute_loss, argnums=1)), + in_dims=(None, 0, 0), + )(self.params_to_restrict_to, batch.x, batch.y) + return result + + def _jacobian_prod( + self, + batch: TorchBatch, + g: torch.Tensor, + ) -> torch.Tensor: + def single_jvp( + _g: torch.Tensor, + ): + return torch.func.jvp( + lambda p: torch.vmap(self._compute_loss, in_dims=(None, 0, 0))( + p, *batch + ), + (self.params_to_restrict_to,), + (align_structure(self.params_to_restrict_to, _g),), + )[1] + + return torch.func.vmap(single_jvp)(g) + + def to(self, device: torch.device): + self.model = self.model.to(device) + self.params_to_restrict_to = { + k: p.detach() + for k, p in self.model.named_parameters() + if k in self.params_to_restrict_to + } + return self + + @property + def device(self): + return next(self.model.parameters()).device + + @property + def dtype(self): + return next(self.model.parameters()).dtype + + @staticmethod + def _detach_dict(tensor_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {k: g.detach() if g.requires_grad else g for k, g in tensor_dict.items()} + + def grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + r""" + Computes and returns a dictionary mapping parameter names to their respective + per-sample gradients. Given the example in the class docstring, this means + + $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), $$ + + where the first dimension of the resulting tensors is always considered to be + the batch dimension, so the shape of the resulting tensors are $(N, d_i)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute gradients. + + Returns: + A dictionary where keys are gradient identifiers and values are the + gradients computed per sample. + """ + gradient_dict = self._grads(batch.to(self.device)) + return self._detach_dict(gradient_dict) + + def mixed_grads(self, batch: TorchBatch) -> Dict[str, torch.Tensor]: + r""" + Computes and returns a dictionary mapping gradient names to their respective + per-sample mixed gradients. In this context, mixed gradients refer to computing + gradients with respect to the instance definition in addition to + compute derivatives with respect to the input batch. + Given the example in the class docstring, this means + + $$ \text{result}[\omega_i] = \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, + \omega_2, \text{batch.x}, \text{batch.y}), $$ + + where the first dimension of the resulting tensors is always considered to be + the batch dimension and the last to be the non-batch input related derivatives. + So the shape of the resulting tensors are $(N, n, d_i)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute mixed gradients. + + Returns: + A dictionary where keys are gradient identifiers and values are the + mixed gradients computed per sample. + """ + gradient_dict = self._mixed_grads(batch.to(self.device)) + return self._detach_dict(gradient_dict) + + def jacobian_prod( + self, + batch: TorchBatch, + g: torch.Tensor, + ) -> torch.Tensor: + r""" + Computes the matrix-Jacobian product for the provided batch and input tensor. + Given the example in the class docstring, this means + + $$ (\nabla_{\omega_{1}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), + \nabla_{\omega_{2}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y})) \cdot g^T$$ + + where g must be a tensor of shape $(K, d_1+d_2)$, so the resulting tensor + is of shape $(N, K)$. + + Args: + batch: The batch of data for which to compute the Jacobian. + g: The tensor to be used in the matrix-Jacobian product + calculation. + + Returns: + The resulting tensor from the matrix-Jacobian product computation. + """ + result = self._jacobian_prod(batch.to(self.device), g.to(self.device)) + if result.requires_grad: + result = result.detach() + return result + + def flat_grads(self, batch: TorchBatch) -> torch.Tensor: + return flatten_dimensions( + self.grads(batch).values(), shape=(batch.x.shape[0], -1) + ) + + def flat_mixed_grads(self, batch: TorchBatch) -> torch.Tensor: + shape = (*batch.x.shape, -1) + return flatten_dimensions(self.mixed_grads(batch).values(), shape=shape) + + +class OperatorBilinearForm( + BilinearForm[torch.Tensor, TorchBatch, TorchGradientProvider] +): + r""" + Base class for bilinear forms based on an instance of + [TorchOperator][pydvl.influence.torch.operator.base.TorchOperator]. This means it + computes weighted inner products of the form: + + $$ \langle \operatorname{Op}(x), y \rangle $$ + + """ + + def __init__( + self, + operator: "TensorOperator", + ): + self.operator = operator + + def inner_prod( + self, left: torch.Tensor, right: Optional[torch.Tensor] + ) -> torch.Tensor: + r""" + Computes the weighted inner product of two vectors, i.e. + + $$ \langle \operatorname{Op}(\text{left}), \text{right} \rangle $$ + + Args: + left: The first tensor in the inner product computation. + right: The second tensor, optional; if not provided, the inner product will + use `left` tensor for both arguments. + + Returns: + A tensor representing the inner product. + """ + if right is None: + right = left + if left.shape[0] <= right.shape[0]: + return self._inner_product(left, right) + return self._inner_product(right, left).T + + def _inner_product(self, left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: + left_result = self.operator.apply(left) + + if left_result.ndim == right.ndim and left.shape[-1] == right.shape[-1]: + return left_result @ right.T + + return torch.einsum("ia,j...a->ij...", left_result, right) + + +class DictBilinearForm(OperatorBilinearForm): + r""" + Base class for bilinear forms based on an instance of + [TorchOperator][pydvl.influence.torch.operator.base.TorchOperator]. This means it + computes weighted inner products of the form: + + $$ \langle \operatorname{Op}(x), y \rangle $$ + + """ + + def __init__( + self, + operator: "TensorDictOperator", + ): + super().__init__(operator) + + def grads_inner_prod( + self, + left: TorchBatch, + right: Optional[TorchBatch], + gradient_provider: TorchGradientProvider, + ) -> torch.Tensor: + r""" + Computes the gradient inner product of two batches of data, i.e. + + $$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}), + \nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle_{B}$$ + + where $\nabla_{\omega}\ell(\omega, \cdot, \cdot)$ is represented by the + `gradient_provider` and the expression must be understood sample-wise. + + Args: + left: The first batch for gradient and inner product computation + right: The second batch for gradient and inner product computation, + optional; if not provided, the inner product will use the gradient + computed for `left` for both arguments. + gradient_provider: The gradient provider to compute the gradients. + + Returns: + A tensor representing the inner products of the per-sample gradients + """ + operator = cast(TensorDictOperator, self.operator) + left_grads = gradient_provider.grads(left) + if right is None: + right_grads = left_grads + else: + right_grads = gradient_provider.grads(right) + + left_batch_size, right_batch_size = next( + ( + (l.shape[0], r.shape[0]) + for r, l in zip(left_grads.values(), right_grads.values()) + ) + ) + + if left_batch_size <= right_batch_size: + left_grads = operator.apply_to_dict(left_grads) + tensor_pairs = zip(left_grads.values(), right_grads.values()) + else: + right_grads = operator.apply_to_dict(right_grads) + tensor_pairs = zip(left_grads.values(), right_grads.values()) + + tensors_to_reduce = ( + self._aggregate_grads(left, right) for left, right in tensor_pairs + ) + + return cast(torch.Tensor, sum(tensors_to_reduce)) + + def mixed_grads_inner_prod( + self, + left: TorchBatch, + right: TorchBatch, + gradient_provider: TorchGradientProvider, + ) -> torch.Tensor: + r""" + Computes the mixed gradient inner product of two batches of data, i.e. + + $$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}), + \nabla_{\omega}\nabla_{x}\ell(\omega, \text{right.x}, \text{right.y}) + \rangle_{B}$$ + + where $\nabla_{\omega}\ell(\omega, \cdot)$ and + $\nabla_{\omega}\nabla_{x}\ell(\omega, \cdot)$ are represented by the + `gradient_provider`. The expression must be understood sample-wise. + + Args: + left: The first batch for gradient and inner product computation + right: The second batch for gradient and inner product computation + gradient_provider: The gradient provider to compute the gradients. + + Returns: + A tensor representing the inner products of the mixed per-sample gradients + """ + operator = cast(TensorDictOperator, self.operator) + right_grads = gradient_provider.mixed_grads(right) + left_grads = gradient_provider.grads(left) + left_grads = operator.apply_to_dict(left_grads) + left_grads_views = (t.reshape(t.shape[0], -1) for t in left_grads.values()) + right_grads_views = ( + t.reshape(*right.x.shape, -1) for t in right_grads.values() + ) + tensor_pairs = zip(left_grads_views, right_grads_views) + tensors_to_reduce = ( + self._aggregate_mixed_grads(left, right) for left, right in tensor_pairs + ) + return cast(torch.Tensor, sum(tensors_to_reduce)) + + @staticmethod + def _aggregate_mixed_grads(left: torch.Tensor, right: torch.Tensor): + return torch.einsum("ik, j...k -> ij...", left, right) + + @staticmethod + def _aggregate_grads(left: torch.Tensor, right: torch.Tensor): + return torch.einsum("i..., j... -> ij", left, right) + + +OperatorBilinearFormType = TypeVar( + "OperatorBilinearFormType", bound=OperatorBilinearForm +) + + +class TensorOperator(Operator[torch.Tensor, OperatorBilinearForm], ABC): + """ + Abstract base class for operators that can be applied to instances of + [torch.Tensor][torch.Tensor]. + """ + + @property + @abstractmethod + def device(self): + pass + + @property + @abstractmethod + def dtype(self): + pass + + @abstractmethod + def to(self, device: torch.device): + pass + + def _validate_tensor_input(self, tensor: torch.Tensor) -> None: + if not (1 <= tensor.ndim <= 2): + raise ValueError( + f"Expected a 1 or 2 dimensional tensor, got {tensor.ndim} dimensions." + ) + if tensor.shape[-1] != self.input_size: + raise ValueError( + f"Expected the last dimension to be of size {self.input_size}." + ) + + def _apply(self, tensor: torch.Tensor) -> torch.Tensor: + + if tensor.ndim == 2 and tensor.shape[0] > 1: + return self._apply_to_mat(tensor.to(self.device)) + + return self._apply_to_vec(tensor.to(self.device)) + + @abstractmethod + def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: + """ + Applies the operator to a single vector. + Args: + vec: A single vector consistent to the operator, i.e. it's length + must be equal to the property `input_size`. + + Returns: + A single vector after applying the batch operation + """ + + def _apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the operator to a matrix. + Args: + mat: A matrix to apply the operator to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ + return torch.func.vmap(self._apply_to_vec, in_dims=0, randomness="same")(mat) + + def as_bilinear_form(self) -> OperatorBilinearForm: + return OperatorBilinearForm(self) + + +class TensorDictOperator(TensorOperator, ABC): + """ + Abstract base class for operators that can be applied to instances of + [torch.Tensor][torch.Tensor] and compatible dictionaries mapping strings to tensors. + Input dictionaries must conform to the structure defined by the property + `input_dict_structure`. Useful for operators involving autograd functionality + to avoid intermediate flattening and concatenating of gradient inputs. + """ + + def apply_to_dict(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Applies the operator to a dictionary of tensors, compatible to the structure + defined by the property `input_dict_structure`. + + Args: + mat: dictionary of tensors, whose keys and shapes match the property + `input_dict_structure`. + + Returns: + A dictionary of tensors after applying the operator + """ + + if not self._validate_mat_dict(mat): + raise ValueError( + f"Incompatible input structure, expected (excluding batch" + f"dimension): \n {self.input_dict_structure}" + ) + + return self._apply_to_dict(self._dict_to_device(mat)) + + def _dict_to_device(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {k: v.to(self.device) for k, v in mat.items()} + + @property + @abstractmethod + def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: + """ + Implement this to expose the expected structure of the input tensor dict, i.e. + a dictionary of shapes (excluding the first batch dimension), in order + to validate the input tensor dicts. + """ + + @abstractmethod + def _apply_to_dict(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + pass + + def _validate_mat_dict(self, mat: Dict[str, torch.Tensor]) -> bool: + for keys, val in mat.items(): + if val.shape[1:] != self.input_dict_structure[keys]: + return False + else: + return True + + def as_bilinear_form(self) -> DictBilinearForm: + return DictBilinearForm(self) + + +TorchOperatorType = TypeVar("TorchOperatorType", bound=TensorOperator) + + +class TorchOperatorGradientComposition( + OperatorGradientComposition[ + torch.Tensor, TorchBatch, TorchOperatorType, TorchGradientProvider + ] +): + """ + Representing a composable block that integrates an [TorchOperator] + [pydvl.influence.torch.operator.base.TorchOperator] and + a [TorchPerSampleGradientProvider] + [pydvl.influence.torch.operator.gradient_provider.TorchPerSampleGradientProvider] + + This block is designed to be flexible, handling different computational modes via + an abstract operator and gradient provider. + """ + + def __init__(self, op: TorchOperatorType, gp: TorchGradientProvider): + super().__init__(op, gp) + + def to(self, device: torch.device): + self.gp = self.gp.to(device) + self.op = self.op.to(device) + return self + + +class TorchBlockMapper( + BlockMapper[ + torch.Tensor, TorchBatch, TorchOperatorGradientComposition[TorchOperatorType] + ] +): + """ + Class for mapping operations across multiple compositional blocks represented by + instances of [TorchOperatorGradientComposition] + [pydvl.influence.torch.influence_function_model.TorchOperatorGradientComposition]. + + This class takes a dictionary of compositional blocks and applies their methods to + batches or tensors, and aggregates the results. + """ + + def __init__( + self, composable_block_dict: OrderedDict[str, TorchOperatorGradientComposition] + ): + super().__init__(composable_block_dict) + + def _split_to_blocks( + self, z: torch.Tensor, dim: int = -1 + ) -> OrderedDict[str, torch.Tensor]: + block_sizes = [bi.op.input_size for bi in self.composable_block_dict.values()] + + block_dict = OrderedDict( + zip( + list(self.composable_block_dict.keys()), + torch.split(z, block_sizes, dim=dim), + ) + ) + return block_dict + + def to(self, device: torch.device): + self.composable_block_dict = OrderedDict( + [(k, bi.to(device)) for k, bi in self.composable_block_dict.items()] + ) + return self + + +class TorchComposableInfluence( + ComposableInfluence[ + torch.Tensor, TorchBatch, DataLoader, TorchBlockMapper[TorchOperatorType] + ], + ModelInfoMixin, + ABC, +): + """ + Abstract base class, that allow for block-wise computation of influence + quantities with the [torch][torch] framework. + Inherit from this base class for specific influence algorithms. + """ + + def __init__( + self, + model: torch.nn.Module, + block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL, + regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None, + ): + parameter_dict_builder = ModelParameterDictBuilder(model) + if isinstance(block_structure, BlockMode): + self.parameter_dict = parameter_dict_builder.build_from_block_mode( + block_structure + ) + else: + self.parameter_dict = parameter_dict_builder.build(block_structure) + + self._regularization_dict = self._build_regularization_dict(regularization) + + super().__init__(model) + + def _concat(self, tensors: Iterable[torch.Tensor], dim: int): + return torch.cat(list(tensors), dim=dim) + + def _flatten_trailing_dim(self, tensor: torch.Tensor): + return tensor.reshape((tensor.shape[0], -1)) + + @property + def block_names(self) -> List[str]: + return list(self.parameter_dict.keys()) + + @property + def n_parameters(self): + return sum(block.op.input_size for _, block in self.block_mapper.items()) + + @abstractmethod + def with_regularization( + self, regularization: Union[float, Dict[str, Optional[float]]] + ) -> TorchComposableInfluence: + pass + + def _build_regularization_dict( + self, regularization: Optional[Union[float, Dict[str, Optional[float]]]] + ) -> Dict[str, Optional[float]]: + if regularization is None or isinstance(regularization, float): + return { + k: self._validate_regularization(k, regularization) + for k in self.block_names + } + + if set(regularization.keys()).issubset(set(self.block_names)): + raise ValueError( + f"The regularization must be a float or the keys of the regularization" + f"dictionary must match a subset of" + f"block names: \n {self.block_names}.\n Found not in block names: \n" + f"{set(regularization.keys()).difference(set(self.block_names))}" + ) + return { + k: self._validate_regularization(k, regularization.get(k, None)) + for k in self.block_names + } + + @staticmethod + def _validate_regularization( + block_name: str, value: Optional[float] + ) -> Optional[float]: + if isinstance(value, float) and value < 0.0: + raise ValueError( + f"The regularization for block '{block_name}' must be non-negative, " + f"but found {value=}" + ) + return value + + @abstractmethod + def _create_block( + self, + block_params: Dict[str, torch.nn.Parameter], + data: DataLoader, + regularization: Optional[float], + ) -> TorchOperatorGradientComposition: + pass + + def _create_block_mapper(self, data: DataLoader) -> TorchBlockMapper: + block_influence_dict = OrderedDict() + for k, p in self.parameter_dict.items(): + reg = self._regularization_dict.get(k, None) + reg = self._validate_regularization(k, reg) + block_influence_dict[k] = self._create_block(p, data, reg).to(self.device) + + return TorchBlockMapper(block_influence_dict) + + @staticmethod + def _create_batch(x: torch.Tensor, y: torch.Tensor) -> TorchBatch: + return TorchBatch(x, y) + + def to(self, device: torch.device): + self.model = self.model.to(device) + if hasattr(self, "block_mapper") and self.block_mapper is not None: + self.block_mapper = self.block_mapper.to(device) + return self diff --git a/src/pydvl/influence/torch/batch_operation.py b/src/pydvl/influence/torch/batch_operation.py new file mode 100644 index 000000000..908da3ef3 --- /dev/null +++ b/src/pydvl/influence/torch/batch_operation.py @@ -0,0 +1,598 @@ +r""" +This module contains abstractions and implementations for operations carried out on a +batch $b$. These operations are of the form + +$$ m(b) \cdot v$$, + +where $m(b)$ is a matrix defined by the data in the batch and $v$ is a vector or matrix. +These batch operations can be used to conveniently build aggregations or recursions +over sequence of batches, e.g. an average of the form + +$$ \frac{1}{|B|} \sum_{b in B}m(b)\cdot v$$, + +which is useful in the case that keeping $B$ in memory is not feasible. + +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Callable, Dict, Generator, Generic, List, Optional, Tuple, TypeVar + +import torch + +from .base import TorchBatch, TorchGradientProvider +from .functional import create_batch_hvp_function, create_batch_loss_function, hvp +from .util import LossType + + +class _ModelBasedBatchOperation(ABC): + r""" + Abstract base class to implement operations of the form + + $$ m(\text{model}, b) \cdot v $$ + + where model is a [torch.nn.Module][torch.nn.Module]. + + """ + + def __init__( + self, + model: torch.nn.Module, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + if restrict_to is None: + restrict_to = { + k: p.detach() for k, p in model.named_parameters() if p.requires_grad + } + self.params_to_restrict_to = restrict_to + self.model = model + + @property + def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: + return {k: p.shape for k, p in self.params_to_restrict_to.items()} + + @property + def device(self): + return next(self.model.parameters()).device + + @property + def dtype(self): + return next(self.model.parameters()).dtype + + @property + def input_size(self): + return sum(p.numel() for p in self.params_to_restrict_to.values()) + + def to(self, device: torch.device): + self.model = self.model.to(device) + self.params_to_restrict_to = { + k: p.detach() + for k, p in self.model.named_parameters() + if k in self.params_to_restrict_to + } + return self + + def apply_to_dict( + self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + + if mat_dict.keys() != self.params_to_restrict_to.keys(): + raise ValueError( + "The keys of the matrix dictionary must match the keys of the " + "parameters to restrict to." + ) + + return self._apply_to_dict( + batch, {k: v.to(self.device) for k, v in mat_dict.items()} + ) + + def _has_batch_dim_dict(self, tensor_dict: Dict[str, torch.Tensor]): + batch_dim_flags = [ + tensor_dict[key].shape == val.shape + for key, val in self.params_to_restrict_to.items() + ] + if len(set(batch_dim_flags)) == 2: + raise ValueError("Existence of batch dim must be consistent") + return not all(batch_dim_flags) + + def _add_batch_dim(self, vec_dict: Dict[str, torch.Tensor]): + result = {} + for key, value in self.params_to_restrict_to.items(): + if value.shape == vec_dict[key].shape: + result[key] = vec_dict[key].unsqueeze(0) + else: + result[key] = vec_dict[key] + return result + + @abstractmethod + def _apply_to_dict( + self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + pass + + @abstractmethod + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: + pass + + def apply(self, batch: TorchBatch, tensor: torch.Tensor): + """ + Applies the batch operation to a tensor. + Args: + batch: Batch of data for computation + tensor: A tensor consistent to the operation, i.e. it must be + at most 2-dim, and it's tailing dimension must + be equal to the property `input_size`. + + Returns: + A tensor after applying the batch operation + """ + + if not tensor.ndim <= 2: + raise ValueError( + f"The input tensor must be at most 2-dimensional, got {tensor.ndim}" + ) + + if tensor.shape[-1] != self.input_size: + raise ValueError( + "The last dimension of the input tensor must be equal to the " + "property `input_size`." + ) + + if tensor.ndim == 2: + return self._apply_to_mat(batch.to(self.device), tensor.to(self.device)) + return self._apply_to_vec(batch.to(self.device), tensor.to(self.device)) + + def _apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the batch operation to a matrix. + Args: + batch: Batch of data for computation + mat: A matrix to apply the batch operation to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ + return torch.func.vmap( + lambda _x, _y, m: self._apply_to_vec(TorchBatch(_x, _y), m), + in_dims=(None, None, 0), + randomness="same", + )(batch.x, batch.y, mat) + + +class HessianBatchOperation(_ModelBasedBatchOperation): + r""" + Given a model and loss function computes the Hessian vector or matrix product + with respect to the model parameters, i.e. + + \begin{align*} + &\nabla^2_{\theta} L(b;\theta) \cdot v \\\ + &L(b;\theta) = \left( \frac{1}{|b|} \sum_{(x,y) \in b} + \text{loss}(\text{model}(x; \theta), y)\right), + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + + Args: + model: The model. + loss: The loss function. + restrict_to: The parameters to restrict the second order differentiation to, + i.e. the corresponding sub-matrix of the Hessian. If None, the full Hessian + is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + super().__init__(model, restrict_to=restrict_to) + self._batch_hvp = create_batch_hvp_function(model, loss, reverse_only=True) + self.loss = loss + + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: + return self._batch_hvp(self.params_to_restrict_to, batch.x, batch.y, vec) + + def _apply_to_dict( + self, batch: TorchBatch, mat_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + + func = self._create_seq_func(*batch) + + if self._has_batch_dim_dict(mat_dict): + func = torch.func.vmap( + func, in_dims=tuple((0 for _ in self.params_to_restrict_to)) + ) + + result: Dict[str, torch.Tensor] = func(*mat_dict.values()) + return result + + def _create_seq_func(self, x: torch.Tensor, y: torch.Tensor): + def seq_func(*vec: torch.Tensor) -> Dict[str, torch.Tensor]: + return hvp( + lambda p: create_batch_loss_function(self.model, self.loss)(p, x, y), + self.params_to_restrict_to, + dict(zip(self.params_to_restrict_to.keys(), vec)), + reverse_only=True, + ) + + return seq_func + + +class GaussNewtonBatchOperation(_ModelBasedBatchOperation): + r""" + Given a model and loss function computes the Gauss-Newton vector or matrix product + with respect to the model parameters, i.e. + + \begin{align*} + G(\text{model}, \text{loss}, b, \theta) &\cdot v, \\\ + G(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + + Args: + model: The model. + loss: The loss function. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + super().__init__(model, restrict_to=restrict_to) + self.gradient_provider = TorchGradientProvider( + model, loss, self.params_to_restrict_to + ) + + def _apply_to_dict( + self, batch: TorchBatch, vec_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + vec_values = list(self._add_batch_dim(vec_dict).values()) + grads_dict = self.gradient_provider.grads(batch) + grads_values = list(self._add_batch_dim(grads_dict).values()) + gen_result = self._generate_rank_one_mvp(grads_values, vec_values) + return dict(zip(vec_dict.keys(), gen_result)) + + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: + flat_grads = self.gradient_provider.flat_grads(batch) + return self._rank_one_mvp(flat_grads, vec) + + def _apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the batch operation to a matrix. + Args: + batch: Batch of data for computation + mat: A matrix to apply the batch operation to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ + return self._apply_to_vec(batch, mat) + + def to(self, device: torch.device): + self.gradient_provider = self.gradient_provider.to(device) + return super().to(device) + + @staticmethod + def _rank_one_mvp(x: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + r""" + Computes the matrix-vector product of xx^T and v for each row in X and V without + forming xx^T and sums the result. Here, X and V are matrices where each row + represents an individual vector. Effectively it is computing + + $$ V@( \frac{1}{N}\sum_i^N x[i]x[i]^T) $$ + + Args: + x: Matrix of vectors of size `(N, M)`. + v: Matrix of vectors of size `(B, M)` to be multiplied by the corresponding + $xx^T$. + + Returns: + A matrix of size `(B, N)` where each column is the result of xx^T v for + corresponding rows in x and v. + """ + if v.ndim == 1: + result = torch.einsum("ij,kj->ki", x, v.unsqueeze(0)) @ x + return result.squeeze() / x.shape[0] + return (torch.einsum("ij,kj->ki", x, v) @ x) / x.shape[0] + + @staticmethod + def _generate_rank_one_mvp( + x: List[torch.Tensor], v: List[torch.Tensor] + ) -> Generator[torch.Tensor, None, None]: + x_v_iterator = zip(x, v) + x_, v_ = next(x_v_iterator) + + nominator = torch.einsum("i..., k...->ki", x_, v_) + + for x_, v_ in x_v_iterator: + nominator += torch.einsum("i..., k...->ki", x_, v_) + + for x_, v_ in zip(x, v): + yield torch.einsum("ji, i... -> j...", nominator, x_) / x_.shape[0] + + +class InverseHarmonicMeanBatchOperation(_ModelBasedBatchOperation): + r""" + Given a model and loss function computes an approximation of the inverse + Gauss-Newton vector or matrix product. Viewing the damped Gauss-newton matrix + + \begin{align*} + G_{\lambda}(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + as an arithmetic mean of the rank-$1$ updates, this operation replaces it with + the harmonic mean of the rank-$1$ updates, i.e. + + $$ \tilde{G}_{\lambda}(\text{model}, \text{loss}, b, \theta) = + \left(n \sum_{(x, y) \in b} \left( \nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}\right)^{-1} + \right)^{-1}$$ + + and computes + + $$ \tilde{G}_{\lambda}^{-1}(\text{model}, \text{loss}, b, \theta) + \cdot v.$$ + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + In other words, it switches the order of summation and inversion, which resolves + to the `inverse harmonic mean` of the rank-$1$ updates. + + The inverses of the rank-$1$ updates are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + For more information, + see [Inverse Harmonic Mean][inverse-harmonic-mean]. + + Args: + model: The model. + loss: The loss function. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + + def __init__( + self, + model: torch.nn.Module, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + regularization: float, + restrict_to: Optional[Dict[str, torch.nn.Parameter]] = None, + ): + if regularization <= 0: + raise ValueError("regularization must be positive") + self.regularization = regularization + + super().__init__(model, restrict_to=restrict_to) + self.gradient_provider = TorchGradientProvider( + model, loss, self.params_to_restrict_to + ) + + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + if value <= 0: + raise ValueError("regularization must be positive") + self._regularization = value + + def _apply_to_vec(self, batch: TorchBatch, vec: torch.Tensor) -> torch.Tensor: + grads = self.gradient_provider.flat_grads(batch) + if vec.ndim == 1: + input_vec = vec.unsqueeze(0) + else: + input_vec = vec + return self._inverse_rank_one_update(grads, input_vec, self.regularization) + + def _apply_to_mat(self, batch: TorchBatch, mat: torch.Tensor) -> torch.Tensor: + """ + Applies the batch operation to a matrix. + Args: + batch: Batch of data for computation + mat: A matrix to apply the batch operation to. The last dimension is + assumed to be consistent to the operation, i.e. it must equal + to the property `input_size`. + + Returns: + A matrix of shape $(N, \text{input_size})$, given the shape of mat is + $(N, \text{input_size})$ + + """ + return self._apply_to_vec(batch, mat) + + def to(self, device: torch.device): + super().to(device) + self.gradient_provider.params_to_restrict_to = self.params_to_restrict_to + return self + + def _apply_to_dict( + self, batch: TorchBatch, vec_dict: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + vec_values = list(self._add_batch_dim(vec_dict).values()) + grads_dict = self.gradient_provider.grads(batch) + grads_values = list(self._add_batch_dim(grads_dict).values()) + gen_result = self._generate_inverse_rank_one_updates( + grads_values, vec_values, self.regularization + ) + return dict(zip(vec_dict.keys(), gen_result)) + + @staticmethod + def _inverse_rank_one_update( + x: torch.Tensor, v: torch.Tensor, regularization: float + ) -> torch.Tensor: + r""" + Performs an inverse-rank one update on x and v. More precisely, it computes + + $$ \sum_{i=1}^n \left(x[i]x[i]^t+\lambda \operatorname{I}\right)^{-1}v $$ + + where $\operatorname{I}$ is the identity matrix and $\lambda$ is positive + regularization parameter. The inverse matrices are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + Args: + x: Input matrix used for the rank one expressions. First dimension is + assumed to be the batch dimension. + v: Matrix to multiply with. First dimension is + assumed to be the batch dimension. + regularization: Regularization parameter to make the rank-one expressions + invertible, must be positive. + + Returns: + Matrix of size $(D, M)$ for x having shape $(N, D)$ and v having shape + $(M, D)$. + """ + nominator = torch.einsum("ij,kj->ki", x, v) + denominator = x.shape[0] * (regularization + torch.sum(x**2, dim=1)) + return (v - (nominator / denominator) @ x) / regularization + + @staticmethod + def _generate_inverse_rank_one_updates( + x: List[torch.Tensor], v: List[torch.Tensor], regularization: float + ) -> Generator[torch.Tensor, None, None]: + + x_v_iterator = enumerate(zip(x, v)) + index, (x_, v_) = next(x_v_iterator) + + denominator = regularization + torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) + nominator = torch.einsum("i..., k...->ki", x_, v_) + num_data_points = x_.shape[0] + + for k, (x_, v_) in x_v_iterator: + nominator += torch.einsum("i..., k...->ki", x_, v_) + denominator += torch.sum(x_.view(x_.shape[0], -1) ** 2, dim=1) + + denominator = num_data_points * denominator + + for x_, v_ in zip(x, v): + yield ( + v_ - torch.einsum("ji, i... -> j...", nominator / denominator, x_) + ) / regularization + + +BatchOperationType = TypeVar("BatchOperationType", bound=_ModelBasedBatchOperation) + + +class _TensorDictAveraging(ABC): + @abstractmethod + def __call__(self, tensor_dicts: Generator[Dict[str, torch.Tensor], None, None]): + pass + + +_TensorDictAveragingType = TypeVar( + "_TensorDictAveragingType", bound=_TensorDictAveraging +) + + +class _TensorAveraging(Generic[_TensorDictAveragingType], ABC): + @abstractmethod + def __call__(self, tensors: Generator[torch.Tensor, None, None]): + pass + + @abstractmethod + def as_dict_averaging(self) -> _TensorDictAveraging: + pass + + +TensorAveragingType = TypeVar("TensorAveragingType", bound=_TensorAveraging) + + +class _TensorDictChunkAveraging(_TensorDictAveraging): + def __call__(self, tensor_dicts: Generator[Dict[str, torch.Tensor], None, None]): + result = next(tensor_dicts) + n_chunks = 1.0 + for tensor_dict in tensor_dicts: + for key, tensor in tensor_dict.items(): + result[key] += tensor + n_chunks += 1.0 + return {k: t / n_chunks for k, t in result.items()} + + +class ChunkAveraging(_TensorAveraging[_TensorDictChunkAveraging]): + """ + Averages tensors, provided by a generator, and normalizes by the number + of tensors. + """ + + def __call__(self, tensors: Generator[torch.Tensor, None, None]): + result = next(tensors) + n_chunks = 1.0 + for tensor in tensors: + result += tensor + n_chunks += 1.0 + return result / n_chunks + + def as_dict_averaging(self) -> _TensorDictChunkAveraging: + return _TensorDictChunkAveraging() + + +class _TensorDictPointAveraging(_TensorDictAveraging): + def __init__(self, batch_dim: int = 0): + self.batch_dim = batch_dim + + def __call__(self, tensor_dicts: Generator[Dict[str, torch.Tensor], None, None]): + result = next(tensor_dicts) + n_points = next(iter(result.values())).shape[self.batch_dim] + for tensor_dict in tensor_dicts: + n_points_in_batch = next(iter(tensor_dict.values())).shape[self.batch_dim] + for key, tensor in tensor_dict.items(): + result[key] += n_points_in_batch * tensor + n_points += n_points_in_batch + return {k: t / float(n_points) for k, t in result.items()} + + +class PointAveraging(_TensorAveraging[_TensorDictPointAveraging]): + """ + Averages tensors provided by a generator. The averaging is weighted by + the number of points in each tensor and the final result is normalized by the + number of total points. + + Args: + batch_dim: Dimension to extract the number of points for the weighting. + + """ + + def __init__(self, batch_dim: int = 0): + self.batch_dim = batch_dim + + def __call__(self, tensors: Generator[torch.Tensor, None, None]): + result = next(tensors) + n_points = result.shape[self.batch_dim] + for tensor in tensors: + n_points_in_batch = tensor.shape[self.batch_dim] + result += n_points_in_batch * tensor + n_points += n_points_in_batch + return result / float(n_points) + + def as_dict_averaging(self) -> _TensorDictPointAveraging: + return _TensorDictPointAveraging(self.batch_dim) diff --git a/src/pydvl/influence/torch/functional.py b/src/pydvl/influence/torch/functional.py index 1028b6acd..ba5acdd3e 100644 --- a/src/pydvl/influence/torch/functional.py +++ b/src/pydvl/influence/torch/functional.py @@ -50,6 +50,8 @@ "LowRankProductRepresentation", "randomized_nystroem_approximation", "model_hessian_nystroem_approximation", + "create_batch_loss_function", + "hvp", ] @@ -632,6 +634,10 @@ def device(self) -> torch.device: else torch.device("cpu") ) + @property + def dtype(self) -> torch.dtype: + return self.projections.dtype + def to(self, device: torch.device): """ Move the representing tensors to a device diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index b3d608a23..35ab09501 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -8,21 +8,26 @@ import logging from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Optional, Tuple +from collections import OrderedDict +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn as nn from torch.utils.data import DataLoader from tqdm.auto import tqdm -from pydvl.utils.progress import log_duration - +from ...utils.progress import log_duration +from .. import InfluenceMode from ..base_influence_function_model import ( InfluenceFunctionModel, - InfluenceMode, NotImplementedLayerRepresentationException, UnsupportedInfluenceModeException, ) +from .base import ( + TorchComposableInfluence, + TorchGradientProvider, + TorchOperatorGradientComposition, +) from .functional import ( LowRankProductRepresentation, create_batch_hvp_function, @@ -34,9 +39,12 @@ model_hessian_low_rank, model_hessian_nystroem_approximation, ) +from .operator import InverseHarmonicMeanOperator from .pre_conditioner import PreConditioner from .util import ( + BlockMode, EkfacRepresentation, + LossType, empirical_cross_entropy_loss_fn, flatten_dimensions, safe_torch_linalg_eigh, @@ -49,6 +57,7 @@ "ArnoldiInfluence", "EkfacInfluence", "NystroemSketchInfluence", + "InverseHarmonicMeanInfluence", ] logger = logging.getLogger(__name__) @@ -986,6 +995,7 @@ class ArnoldiInfluence(TorchInfluenceFunctionModel): Set this to False, if you can't afford to keep the full computation graph in memory. """ + low_rank_representation: LowRankProductRepresentation def __init__( @@ -1791,3 +1801,161 @@ def fit(self, data: DataLoader): self.model, self.loss, data, self.rank ) return self + + +class InverseHarmonicMeanInfluence( + TorchComposableInfluence[InverseHarmonicMeanOperator] +): + r""" + This implementation replaces the inverse Hessian matrix in the influence computation + with an approximation of the inverse Gauss-Newton vector product. + + Viewing the damped Gauss-newton matrix + + \begin{align*} + G_{\lambda}(\theta) &= + \frac{1}{N}\sum_{i}^N\nabla_{\theta}\ell (x_i,y_i; \theta) + \nabla_{\theta}\ell (x_i, y_i; \theta)^t + \lambda \operatorname{I}, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + as an arithmetic mean of the rank-$1$ updates, this implementation replaces it with + the harmonic mean of the rank-$1$ updates, i.e. + + $$ \tilde{G}_{\lambda}(\theta) = + \left(N \cdot \sum_{i=1}^N \left( \nabla_{\theta}\ell (x_i,y_i; \theta) + \nabla_{\theta}\ell (x_i,y_i; \theta)^t + + \lambda \operatorname{I}\right)^{-1} + \right)^{-1}$$ + + and uses the matrix + + $$ \tilde{G}_{\lambda}^{-1}(\theta)$$ + + instead of the inverse Hessian. + + In other words, it switches the order of summation and inversion, which resolves + to the `inverse harmonic mean` of the rank-$1$ updates. The results are averaged + over the batches provided by the data loader. + + The inverses of the rank-$1$ updates are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + For more information, + see [Inverse Harmonic Mean][inverse-harmonic-mean]. + + Block-mode: + This implementation is capable of using a block-matrix approximation. The + blocking structure can be specified via the `block_structure` parameter. + The `block_structure` parameter can either be a + [BlockMode][pydvl.influence.torch.util.BlockMode] enum (which provides + layer-wise or parameter-wise blocking) or a custom block structure defined + by an ordered dictionary with the keys being the block identifiers (arbitrary + strings) and the values being lists of parameter names contained in the block. + + ```python + block_structure = OrderedDict( + ( + ("custom_block1", ["0.weight", "1.bias"]), + ("custom_block2", ["1.weight", "0.bias"]), + ) + ) + ``` + + If you would like to apply a block-specific regularization, you can provide a + dictionary with the block names as keys and the regularization values as values. + In this case, the specification must be complete, i.e. every block must have + a positive regularization value. + + ```python + regularization = { + "custom_block1": 0.1, + "custom_block2": 0.2, + } + ``` + Accordingly, if you choose a layer-wise or parameter-wise structure + (by providing `BlockMode.LAYER_WISE` or `BlockMode.PARAMETER_WISE` for + `block_structure`) the keys must be the layer names or parameter names, + respectively. + + You can retrieve the block-wise influence information from the methods + with suffix `_by_block`. By default, `block_structure` is set to + `BlockMode.FULL` and in this case these methods will return a dictionary + with the empty string being the only key. + + + Args: + model: The model. + loss: The loss function. + regularization: The regularization parameter. In case a dictionary is provided, + the keys must match the blocking structure. + block_structure: The blocking structure, either a pre-defined enum or a + custom block structure, see the information regarding block-mode. + """ + + def __init__( + self, + model: torch.nn.Module, + loss: LossType, + regularization: Union[float, Dict[str, Optional[float]]], + block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL, + ): + super().__init__(model, block_structure, regularization=regularization) + self.loss = loss + + @property + def n_parameters(self): + return sum(block.op.input_size for _, block in self.block_mapper.items()) + + @property + def is_thread_safe(self) -> bool: + return False + + @staticmethod + def _validate_regularization( + block_name: str, value: Optional[float] + ) -> Optional[float]: + if value is None or value <= 0.0: + raise ValueError( + f"The regularization for block '{block_name}' must be a positive float," + f"but found {value=}" + ) + return value + + def _create_block( + self, + block_params: Dict[str, torch.nn.Parameter], + data: DataLoader, + regularization: Optional[float], + ) -> TorchOperatorGradientComposition: + assert regularization is not None + op = InverseHarmonicMeanOperator( + self.model, + self.loss, + data, + regularization, + restrict_to=block_params, + ) + gp = TorchGradientProvider(self.model, self.loss, restrict_to=block_params) + return TorchOperatorGradientComposition(op, gp) + + def with_regularization( + self, regularization: Union[float, Dict[str, Optional[float]]] + ) -> TorchComposableInfluence: + """ + Update the regularization parameter. + Args: + regularization: Either a positive float or a dictionary with the + block names as keys and the regularization values as values. + + Returns: + The modified instance + + """ + self._regularization_dict = self._build_regularization_dict(regularization) + for k, reg in self._regularization_dict.items(): + self.block_mapper.composable_block_dict[k].op.regularization = reg + return self diff --git a/src/pydvl/influence/torch/operator.py b/src/pydvl/influence/torch/operator.py new file mode 100644 index 000000000..2396a2efb --- /dev/null +++ b/src/pydvl/influence/torch/operator.py @@ -0,0 +1,256 @@ +from typing import Callable, Dict, Generic, Optional, Tuple + +import torch +from torch import nn as nn +from torch.utils.data import DataLoader + +from .base import TensorDictOperator, TorchBatch +from .batch_operation import ( + BatchOperationType, + ChunkAveraging, + GaussNewtonBatchOperation, + HessianBatchOperation, + InverseHarmonicMeanBatchOperation, + PointAveraging, + TensorAveragingType, +) + + +class _AveragingBatchOperator( + TensorDictOperator, Generic[BatchOperationType, TensorAveragingType] +): + """ + Class for aggregating batch operations over a dataset using a provided data loader + and aggregator. + + This class facilitates the application of a batch operation across multiple batches + of data, aggregating the results using a specified sequence aggregator. + + Attributes: + batch_operation: The batch operation to apply. + dataloader: The data loader providing batches of data. + averaging: The sequence aggregator to aggregate the results of the batch + operations. + """ + + def __init__( + self, + batch_operation: BatchOperationType, + dataloader: DataLoader, + averager: TensorAveragingType, + ): + self.batch_operation = batch_operation + self.dataloader = dataloader + self.averaging = averager + + @property + def input_dict_structure(self) -> Dict[str, Tuple[int, ...]]: + return self.batch_operation.input_dict_structure + + def _apply_to_dict(self, mat: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + + tensor_dicts = ( + self.batch_operation.apply_to_dict(TorchBatch(x, y), mat) + for x, y in self.dataloader + ) + dict_averaging = self.averaging.as_dict_averaging() + result: Dict[str, torch.Tensor] = dict_averaging(tensor_dicts) + return result + + @property + def device(self): + return self.batch_operation.device + + @property + def dtype(self): + return self.batch_operation.dtype + + def to(self, device: torch.device): + self.batch_operation = self.batch_operation.to(device) + return self + + @property + def input_size(self): + return self.batch_operation.input_size + + def _apply_to_mat(self, mat: torch.Tensor) -> torch.Tensor: + return self._apply_to_vec(mat) + + def _apply_to_vec(self, vec: torch.Tensor) -> torch.Tensor: + tensors = ( + self.batch_operation.apply( + TorchBatch(x.to(self.device), y.to(self.device)), vec.to(self.device) + ) + for x, y in self.dataloader + ) + + return self.averaging(tensors) + + +class GaussNewtonOperator( + _AveragingBatchOperator[GaussNewtonBatchOperation, PointAveraging] +): + r""" + Given a model and loss function computes the Gauss-Newton vector or matrix product + with respect to the model parameters on a batch, i.e. + + \begin{align*} + G(\text{model}, \text{loss}, b, \theta) &\cdot v, \\\ + G(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix, + and average the results over the batches provided by the data loader. + + Args: + model: The model. + loss: The loss function. + dataloader: The data loader providing batches of data. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + + def __init__( + self, + model: nn.Module, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + dataloader: DataLoader, + restrict_to: Optional[Dict[str, nn.Parameter]] = None, + ): + batch_op = GaussNewtonBatchOperation( + model, + loss, + restrict_to=restrict_to, + ) + averaging = PointAveraging() + super().__init__(batch_op, dataloader, averaging) + + +class HessianOperator(_AveragingBatchOperator[HessianBatchOperation, ChunkAveraging]): + r""" + Given a model and loss function computes the Hessian vector or matrix product + with respect to the model parameters for a given batch, i.e. + + \begin{align*} + &\nabla^2_{\theta} L(b;\theta) \cdot v \\\ + &L(b;\theta) = \left( \frac{1}{|b|} \sum_{(x,y) \in b} + \text{loss}(\text{model}(x; \theta), y)\right), + \end{align*} + + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix, + and average the results over the batches provided by the data loader. + + Args: + model: The model. + loss: The loss function. + dataloader: The data loader providing batches of data. + restrict_to: The parameters to restrict the second order differentiation to, + i.e. the corresponding sub-matrix of the Hessian. If None, the full Hessian + is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + + def __init__( + self, + model: nn.Module, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + dataloader: DataLoader, + restrict_to: Optional[Dict[str, nn.Parameter]] = None, + ): + batch_op = HessianBatchOperation(model, loss, restrict_to=restrict_to) + averaging = ChunkAveraging() + super().__init__(batch_op, dataloader, averaging) + + +class InverseHarmonicMeanOperator( + _AveragingBatchOperator[InverseHarmonicMeanBatchOperation, PointAveraging] +): + r""" + Given a model and loss function computes an approximation of the inverse + Gauss-Newton vector or matrix product per batch and averages the results. + + Viewing the damped Gauss-newton matrix + + \begin{align*} + G_{\lambda}(\text{model}, \text{loss}, b, \theta) &= + \frac{1}{|b|}\sum_{(x, y) \in b}\nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}, \\\ + \ell(x,y; \theta) &= \text{loss}(\text{model}(x; \theta), y) + \end{align*} + + as an arithmetic mean of the rank-$1$ updates, this operator replaces it with + the harmonic mean of the rank-$1$ updates, i.e. + + $$ \tilde{G}_{\lambda}(\text{model}, \text{loss}, b, \theta) = + \left(n \sum_{(x, y) \in b} \left( \nabla_{\theta}\ell (x,y; \theta) + \nabla_{\theta}\ell (x,y; \theta)^t + \lambda \operatorname{I}\right)^{-1} + \right)^{-1}$$ + + and computes + + $$ \tilde{G}_{\lambda}^{-1}(\text{model}, \text{loss}, b, \theta) + \cdot v.$$ + + for any given batch $b$, + where model is a [torch.nn.Module][torch.nn.Module] and $v$ is a vector or matrix. + + In other words, it switches the order of summation and inversion, which resolves + to the `inverse harmonic mean` of the rank-$1$ updates. The results are averaged + over the batches provided by the data loader. + + The inverses of the rank-$1$ updates are not calculated explicitly, + but instead a vectorized version of the + [Sherman–Morrison formula]( + https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula) + is applied. + + For more information, + see [Inverse Harmonic Mean][inverse-harmonic-mean]. + + Args: + model: The model. + loss: The loss function. + dataloader: The data loader providing batches of data. + restrict_to: The parameters to restrict the differentiation to, + i.e. the corresponding sub-matrix of the Jacobian. If None, the full + Jacobian is used. Make sure the input matches the corrct dimension, i.e. the + last dimension must be equal to the property `input_size`. + """ + + def __init__( + self, + model: nn.Module, + loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + dataloader: DataLoader, + regularization: float, + restrict_to: Optional[Dict[str, nn.Parameter]] = None, + ): + if regularization <= 0: + raise ValueError("regularization must be positive") + + self._regularization = regularization + + batch_op = InverseHarmonicMeanBatchOperation( + model, + loss, + regularization, + restrict_to=restrict_to, + ) + averaging = PointAveraging() + super().__init__(batch_op, dataloader, averaging) + + @property + def regularization(self): + return self._regularization + + @regularization.setter + def regularization(self, value: float): + if value <= 0: + raise ValueError("regularization must be positive") + self._regularization = value + self.batch_operation.regularization = value diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index d157d5455..9700f047f 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -1,8 +1,14 @@ +from __future__ import annotations + import logging import math +import warnings +from collections import OrderedDict from dataclasses import dataclass +from enum import Enum from functools import partial from typing import ( + Callable, Collection, Dict, Iterable, @@ -42,11 +48,14 @@ "align_with_model", "flatten_dimensions", "TorchNumpyConverter", - "TorchCatAggregator", - "NestedTorchCatAggregator", "torch_dataset_to_dask_array", "EkfacRepresentation", "empirical_cross_entropy_loss_fn", + "LossType", + "ModelParameterDictBuilder", + "BlockMode", + "ModelInfoMixin", + "safe_torch_linalg_eigh", ] @@ -598,3 +607,162 @@ def __init__(self, original_exception: RuntimeError): f" Inspect the original exception message: \n{str(original_exception)}" ) super().__init__(err_msg) + + +LossType = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + +class BlockMode(Enum): + """ + Enumeration for different modes of grouping model parameters. + + Attributes: + LAYER_WISE: Groups parameters by layers of the model. + PARAMETER_WISE: Groups parameters individually. + FULL: Groups all parameters together. + """ + + LAYER_WISE: str = "layer_wise" + PARAMETER_WISE: str = "parameter_wise" + FULL: str = "full" + + +@dataclass +class ModelParameterDictBuilder: + """ + A builder class for creating ordered dictionaries of model parameters based on + specified block modes or custom blocking structures. + + Attributes: + model: The neural network model. + detach: Whether to detach the parameters from the computation graph. + """ + + model: torch.nn.Module + detach: bool = True + + def _optional_detach(self, p: torch.nn.Parameter): + if self.detach: + return p.detach() + return p + + def _extract_parameter_by_name(self, name: str) -> torch.nn.Parameter: + for k, p in self.model.named_parameters(): + if k == name: + return p + else: + raise ValueError(f"Parameter {name} not found in the model.") + + def build( + self, block_structure: OrderedDict[str, List[str]] + ) -> Dict[str, Dict[str, torch.nn.Parameter]]: + """ + Builds an ordered dictionary of model parameters based on the specified block + structure represented by an ordered dictionary, where the keys are block + identifiers and the values are lists of model parameter names contained in + this block. + + Args: + block_structure: The block structure specifying how to group the parameters. + + Returns: + An ordered dictionary of ordered dictionaries, where the outer dictionary's + keys are block identifiers and the inner dictionaries map parameter names + to parameters. + """ + parameter_dict = {} + + for block_name, parameter_names in block_structure.items(): + inner_ordered_dict = {} + for parameter_name in parameter_names: + parameter = self._extract_parameter_by_name(parameter_name) + if parameter.requires_grad: + inner_ordered_dict[parameter_name] = self._optional_detach( + parameter + ) + else: + warnings.warn( + f"The parameter {parameter_name} from the block " + f"{block_name} is mark as not trainable in the model " + f"and will be excluded from the computation." + ) + parameter_dict[block_name] = inner_ordered_dict + + return parameter_dict + + def build_from_block_mode( + self, block_mode: BlockMode + ) -> Dict[str, Dict[str, torch.nn.Parameter]]: + """ + Builds an ordered dictionary of model parameters based on the specified block + mode or custom blocking structure represented by an ordered dictionary, where + the keys are block identifiers and the values are lists of model parameter names + contained in this block. + + Args: + block_mode: The block mode specifying how to group the parameters. + + Returns: + An ordered dictionary of ordered dictionaries, where the outer dictionary's + keys are block identifiers and the inner dictionaries map parameter names + to parameters. + """ + + block_mode_mapping = { + BlockMode.FULL: self._build_full, + BlockMode.PARAMETER_WISE: self._build_parameter_wise, + BlockMode.LAYER_WISE: self._build_layer_wise, + } + + parameter_dict_func = block_mode_mapping.get(block_mode, None) + + if parameter_dict_func is None: + raise ValueError(f"Unknown block mode {block_mode}.") + + return self.build(parameter_dict_func()) + + def _build_full(self): + parameter_dict = OrderedDict() + parameter_dict[""] = [ + n for n, p in self.model.named_parameters() if p.requires_grad + ] + return parameter_dict + + def _build_parameter_wise(self): + parameter_dict = OrderedDict() + for k, v in self.model.named_parameters(): + if v.requires_grad: + parameter_dict[k] = [k] + return parameter_dict + + def _build_layer_wise(self): + parameter_dict = OrderedDict() + for name, submodule in self.model.named_children(): + layer_parameter_names = [] + for param_name, param in submodule.named_parameters(): + if param.requires_grad: + layer_parameter_names.append(f"{name}.{param_name}") + if layer_parameter_names: + parameter_dict[name] = layer_parameter_names + return parameter_dict + + +class ModelInfoMixin: + """ + A mixin class for classes that contain information about a model. + """ + + def __init__(self, model: torch.nn.Module): + self.model = model + + @property + def device(self) -> torch.device: + return next(self.model.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.model.parameters()).dtype + + @property + def n_parameters(self) -> int: + return sum(p.numel() for p in self.model.parameters() if p.requires_grad) diff --git a/src/pydvl/influence/types.py b/src/pydvl/influence/types.py new file mode 100644 index 000000000..8300768de --- /dev/null +++ b/src/pydvl/influence/types.py @@ -0,0 +1,643 @@ +""" +This module offers a set of generic types, which can be used to build modular and +flexible components for influence computation for different tensor frameworks. + + +Key components include: + +1. [GradientProvider][pydvl.influence.types.GradientProvider]: A generic + abstract base class designed to provide methods for computing per-sample + gradients and other related computations for given data batches. + +2. [BilinearForm][pydvl.influence.types.BilinearForm]: A generic abstract base class + for representing bilinear forms for computing inner products involving gradients. + +3. [Operator][pydvl.influence.types.Operator]: A generic abstract base class for + operators that can apply transformations to vectors and matrices and can be + represented as bilinear forms. + +4. [OperatorGradientComposition][pydvl.influence.types.OperatorGradientComposition]: A + generic abstract composition class that integrates an operator with a gradient + provider to compute interactions between batches of data. + +5. [BlockMapper][pydvl.influence.types.BlockMapper]: A generic abstract base class + for mapping operations across multiple compositional blocks, given by objects + of type + [OperatorGradientComposition][pydvl.influence.types.OperatorGradientComposition], + and aggregating the results. + +To see the usage of these types, see the implementation +[ComposableInfluence][pydvl.influence.base_influence_function_model.ComposableInfluence] +. Using these components allows the straightforward implementation of various +combinations of approximations of inverse Hessian applications +(or Gauss-Newton approximations), different blocking strategies +(e.g. layer-wise or block-wise) and different ways to +compute gradients. + +For the usage with a specific tensor framework, these types must be subclassed. An +example for [torch][torch] is provided in the module +[pydvl.influence.torch.base][pydvl.influence.torch.base] and the base class +[TorchComposableInfluence][pydvl.influence.torch.base.TorchComposableInfluence]. +""" + +from __future__ import annotations + +import collections +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import ( + Collection, + Dict, + Generator, + Generic, + Iterable, + Optional, + OrderedDict, + TypeVar, + Union, + cast, +) + + +class InfluenceMode(str, Enum): + """ + Enum representation for the types of influence. + + Attributes: + Up: [Approximating the influence of a point] + [approximating-the-influence-of-a-point] + Perturbation: [Perturbation definition of the influence score] + [perturbation-definition-of-the-influence-score] + + """ + + Up = "up" + Perturbation = "perturbation" + + +"""Type variable for tensors, i.e. sequences of numbers""" +TensorType = TypeVar("TensorType", bound=Collection) +DataLoaderType = TypeVar("DataLoaderType", bound=Iterable) + + +@dataclass(frozen=True) +class Batch(Generic[TensorType]): + """ + Represents a batch of data containing features and labels. + + Attributes: + x: Represents the input features of the batch. + y: Represents the labels or targets associated with the input features. + """ + + x: TensorType + y: TensorType + + +BatchType = TypeVar("BatchType", bound=Batch) + + +class GradientProvider(Generic[BatchType, TensorType], ABC): + r""" + Provides an interface for calculating per-sample gradients and other related + computations for a given batch of data. + + This class must be subclassed with implementations for its abstract methods tailored + to specific gradient computation needs, e.g. using an autograd engine for + a model loss function. Consider a function + + $$ \ell: \mathbb{R}^{d_1} \times \mathbb{R}^{d_2} \times \mathbb{R}^{n} \times + \mathbb{R}^{n}, \quad \ell(\omega_1, \omega_2, x, y) = + \operatorname{loss}(f(\omega_1, \omega_2; x), y) $$ + + e.g. a two layer neural network $f$ with a loss function, then this object should + compute the expressions: + + $$ \nabla_{\omega_{i}}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega_{i}}\nabla_{x}\ell(\omega_1, \omega_2, x, y), + \nabla_{\omega}\ell(\omega_1, \omega_2, x, y) \cdot v$$ + + """ + + @abstractmethod + def jacobian_prod( + self, + batch: BatchType, + g: TensorType, + ) -> TensorType: + r""" + Computes the matrix-Jacobian product for the provided batch and input tensor. + Given the example in the class docstring, this means + + $$ (\nabla_{\omega_{1}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), + \nabla_{\omega_{2}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y})) \cdot g^T$$ + + where g must be a tensor of shape $(K, d_1+d_2)$, so the resulting tensor + is of shape $(N, K)$. + + Args: + batch: The batch of data for which to compute the Jacobian. + g: The tensor to be used in the matrix-Jacobian product + calculation. + + Returns: + The resulting tensor from the matrix-Jacobian product computation. + """ + + @abstractmethod + def flat_grads(self, batch: BatchType) -> TensorType: + r""" + Computes and returns the flat per-sample gradients for the provided batch. + Given the example in the class docstring, this means + + $$ (\nabla_{\omega_{1}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}), + \nabla_{\omega_{2}}\ell(\omega_1, \omega_2, + \text{batch.x}, \text{batch.y}))$$ + + where the first dimension of the resulting tensor is always considered to be + the batch dimension, so the shape of the resulting tensor is $(N, d_1+d_2)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute the gradients. + + Returns: + A tensor containing the flat gradients computed per sample. + """ + + @abstractmethod + def flat_mixed_grads(self, batch: BatchType) -> TensorType: + r""" + Computes and returns the flat per-sample mixed gradients for the provided batch. + Given the example in the class docstring, this means + + $$ (\nabla_{\omega_1}\nabla_{x}\ell(\omega_1, + \omega_2, \text{batch.x}, \text{batch.y}), + \nabla_{\omega_1}\nabla_{x}\ell(\omega_1, + \omega_2, \text{batch.x}, \text{batch.y} ))$$ + + where the first dimension of the resulting tensor is always considered to be + the batch dimension and the last to be the non-batch input related derivatives. + So the shape of the resulting tensor is $(N, n, d_1 + d_2)$, + where $N$ is the number of samples in the batch. + + Args: + batch: The batch of data for which to compute the flat mixed gradients. + + Returns: + A tensor containing the flat mixed gradients computed per sample. + """ + + +GradientProviderType = TypeVar("GradientProviderType", bound=GradientProvider) + + +class BilinearForm(Generic[TensorType, BatchType, GradientProviderType], ABC): + """ + Abstract base class for bilinear forms, which facilitates the computation of inner + products involving gradients of batches of data. + """ + + @abstractmethod + def inner_prod(self, left: TensorType, right: Optional[TensorType]) -> TensorType: + r""" + Computes the inner product of two vectors, i.e. + + $$ \langle x, y \rangle_{B}$$ + + if we denote the bilinear-form by $\langle \cdot, \cdot \rangle_{B}$. + The implementations must take care of according vectorization to make + it applicable to the case, where `left` and `right` are not one-dimensional. + In this case, the trailing dimension of the `left` and `right` tensors are + considered for the computation of the inner product. For example, + if `left` is a tensor of shape $(N, D)$ and, `right` is of shape $(M,..., D)$, + then the result is of shape $(N,..., M)$ + + Args: + left: The first tensor in the inner product computation. + right: The second tensor, optional; if not provided, the inner product will + use `left` tensor for both arguments. + + Returns: + A tensor representing the inner product. + """ + + def grads_inner_prod( + self, + left: BatchType, + right: Optional[BatchType], + gradient_provider: GradientProviderType, + ) -> TensorType: + r""" + Computes the gradient inner product of two batches of data, i.e. + + $$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}), + \nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle_{B}$$ + + where $\nabla_{\omega}\ell(\omega, \cdot, \cdot)$ is represented by the + `gradient_provider` and the expression must be understood sample-wise. + + Args: + left: The first batch for gradient and inner product computation + right: The second batch for gradient and inner product computation, + optional; if not provided, the inner product will use the gradient + computed for `left` for both arguments. + gradient_provider: The gradient provider to compute the gradients. + + Returns: + A tensor representing the inner products of the per-sample gradients + """ + left_grad = gradient_provider.flat_grads(left) + if right is None: + right_grad = left_grad + else: + right_grad = gradient_provider.flat_grads(right) + return self.inner_prod(left_grad, right_grad) + + def mixed_grads_inner_prod( + self, left: BatchType, right: BatchType, gradient_provider: GradientProviderType + ) -> TensorType: + r""" + Computes the mixed gradient inner product of two batches of data, i.e. + + $$ \langle \nabla_{\omega}\ell(\omega, \text{left.x}, \text{left.y}), + \nabla_{\omega}\nabla_{x}\ell(\omega, \text{right.x}, \text{right.y}) + \rangle_{B}$$ + + where $\nabla_{\omega}\ell(\omega, \cdot)$ and + $\nabla_{\omega}\nabla_{x}\ell(\omega, \cdot)$ are represented by the + `gradient_provider`. The expression must be understood sample-wise. + + Args: + left: The first batch for gradient and inner product computation + right: The second batch for gradient and inner product computation + gradient_provider: The gradient provider to compute the gradients. + + Returns: + A tensor representing the inner products of the mixed per-sample gradients + """ + left_grad = gradient_provider.flat_grads(left) + right_mixed_grad = gradient_provider.flat_mixed_grads(right) + return self.inner_prod(left_grad, right_mixed_grad) + + +BilinearFormType = TypeVar("BilinearFormType", bound=BilinearForm) + + +class Operator(Generic[TensorType, BilinearFormType], ABC): + """ + Abstract base class for operators, capable of applying transformations to + vectors and matrices, and can be represented as a bilinear form. + """ + + @property + @abstractmethod + def input_size(self) -> int: + """ + Abstract property to get the needed size for inputs to the operator + instance + + Returns: + An integer representing the input size. + """ + + @abstractmethod + def _validate_tensor_input(self, tensor: TensorType) -> None: + """ + Validates the input tensor for the operator. + + Args: + tensor: A tensor to validate. + + Raises: + ValueError: If the tensor is invalid for the operator. + """ + + def apply(self, tensor: TensorType) -> TensorType: + """ + Applies the operator to a tensor. + + Args: + tensor: A tensor, whose tailing dimension must conform to the + operator's input size + + Returns: + A tensor representing the result of the operator application. + """ + self._validate_tensor_input(tensor) + return self._apply(tensor) + + @abstractmethod + def _apply(self, tensor: TensorType) -> TensorType: + """ + Applies the operator to a tensor. Implement this to handle + batched input. + + Args: + tensor: A tensor, whose tailing dimension must conform to the + operator's input size + + Returns: + A tensor representing the result of the operator application. + """ + + @abstractmethod + def as_bilinear_form(self) -> BilinearFormType: + r""" + Represents the operator as a bilinear form, i.e. the weighted inner product + + $$ \langle \operatorname{Op}(x), y \rangle$$ + + Returns: + An instance of type [BilinearForm][pydvl.influence.types.BilinearForm] + representing this operator. + """ + + +OperatorType = TypeVar("OperatorType", bound=Operator) + + +class OperatorGradientComposition( + Generic[TensorType, BatchType, OperatorType, GradientProviderType] +): + """ + Generic base class representing a composable block that integrates an operator and + a gradient provider to compute interactions between batches of data. + + This block is designed to be flexible, handling different computational modes via + an abstract operator and gradient provider. + + Attributes: + op: The operator used for transformations and influence computations. + gp: The gradient provider used for obtaining necessary gradients. + """ + + def __init__(self, op: OperatorType, gp: GradientProviderType): + self.gp = gp + self.op = op + + def interactions( + self, + left_batch: BatchType, + right_batch: Optional[BatchType], + mode: InfluenceMode, + ): + r""" + Computes the interaction between the gradients on two batches of data based on + the specified mode weighted by the operator action, + i.e. + + $$ \langle \operatorname{Op}(\nabla_{\omega}\ell(\omega, \text{left.x}, + \text{left.y})), + \nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle$$ + + for the case `InfluenceMode.Up` and + + $$ \langle \operatorname{Op}(\nabla_{\omega}\ell(\omega, \text{left.x}, + \text{left.y})), + \nabla_{\omega}\nabla_{x}\ell(\omega, \text{right.x}, \text{right.y}) \rangle $$ + + for the case `InfluenceMode.Perturbation`. + + Args: + left_batch: The left data batch for gradient computation. + right_batch: The right data batch for gradient computation. + mode: An instance of InfluenceMode determining the type of influence + computation. + + Returns: + The result of the influence computation as dictated by the mode. + """ + bilinear_form = self.op.as_bilinear_form() + if mode is InfluenceMode.Up: + return bilinear_form.grads_inner_prod(left_batch, right_batch, self.gp) + return bilinear_form.mixed_grads_inner_prod(left_batch, right_batch, self.gp) + + def transformed_grads(self, batch: BatchType): + r""" + Computes the gradients of a data batch, transformed by the operator application + , i.e. the expressions + + $$ \operatorname{Op}(\nabla_{\omega}\ell(\omega, \text{batch.x}, + \text{batch.y})) $$ + + Args: + batch: The data batch for gradient computation. + + Returns: + A tensor representing the application of the operator to the gradients. + + """ + grads = self.gp.flat_grads(batch) + return self.op.apply(grads) + + def interactions_from_transformed_grads( + self, left_factors: TensorType, right_batch: BatchType, mode: InfluenceMode + ): + r""" + Computes the interaction between the transformed gradients on two batches of + data using pre-computed factors and a batch of data, + based on the specified mode. This means + + $$ \langle \text{left_factors}, + \nabla_{\omega}\ell(\omega, \text{right.x}, \text{right.y}) \rangle$$ + + for the case `InfluenceMode.Up` and + + $$ \langle \text{left_factors}, + \nabla_{\omega}\nabla_{x}\ell(\omega, \text{right.x}, \text{right.y}) \rangle $$ + + for the case `InfluenceMode.Perturbation`. + + Args: + left_factors: Pre-computed tensor factors from a left batch. + right_batch: The right data batch for influence computation. + mode: An instance of InfluenceMode determining the type of influence + computation. + + Returns: + The result of the interaction computation using the provided factors and + batch gradients. + """ + if mode is InfluenceMode.Up: + right_grads = self.gp.flat_grads(right_batch) + else: + right_grads = self.gp.flat_mixed_grads(right_batch) + return self.op.as_bilinear_form().inner_prod(left_factors, right_grads) + + +OperatorGradientCompositionType = TypeVar( + "OperatorGradientCompositionType", bound=OperatorGradientComposition +) + + +class BlockMapper(Generic[TensorType, BatchType, OperatorGradientCompositionType], ABC): + """ + Abstract base class for mapping operations across multiple compositional blocks. + + This class takes a dictionary of compositional blocks and applies their methods to + batches or tensors, and aggregates the results. + + Attributes: + composable_block_dict: A dictionary mapping string identifiers to + composable blocks which define operations like transformations and + interactions. + """ + + def __init__( + self, composable_block_dict: OrderedDict[str, OperatorGradientCompositionType] + ): + self.composable_block_dict = composable_block_dict + + def __getitem__(self, item: str): + return self.composable_block_dict[item] + + def items(self): + return self.composable_block_dict.items() + + def _to_ordered_dict( + self, tensor_generator: Generator[TensorType, None, None] + ) -> OrderedDict[str, TensorType]: + tensor_dict = collections.OrderedDict() + for k, t in zip(self.composable_block_dict.keys(), tensor_generator): + tensor_dict[k] = t + return tensor_dict + + @abstractmethod + def _split_to_blocks( + self, z: TensorType, dim: int = -1 + ) -> OrderedDict[str, TensorType]: + """Must be implemented in a way to preserve the ordering defined by the + `composable_block_dict` attribute""" + + def transformed_grads( + self, + batch: BatchType, + ) -> OrderedDict[str, TensorType]: + """ + Computes and returns the transformed gradients for a batch in dictionary + with the keys defined by the block names. + + Args: + batch: The batch of data for which to compute transformed gradients. + + Returns: + An ordered dictionary of transformed gradients by block. + """ + tensor_gen = self.generate_transformed_grads(batch) + return self._to_ordered_dict(tensor_gen) + + def interactions( + self, left_batch: BatchType, right_batch: BatchType, mode: InfluenceMode + ) -> OrderedDict[str, TensorType]: + """ + Computes interactions between two batches, aggregated by block, + based on a specified mode. + + Args: + left_batch: The left batch for interaction computation. + right_batch: The right batch for interaction computation. + mode: The mode determining the type of interactions. + + Returns: + An ordered dictionary of gradient interactions by block. + """ + tensor_gen = self.generate_interactions(left_batch, right_batch, mode) + return self._to_ordered_dict(tensor_gen) + + def interactions_from_transformed_grads( + self, + left_factors: OrderedDict[str, TensorType], + right_batch: BatchType, + mode: InfluenceMode, + ) -> OrderedDict[str, TensorType]: + """ + Computes interactions from transformed gradients and a right batch, + aggregated by block and based on a mode. + + Args: + left_factors: Pre-computed factors as a tensor or an ordered dictionary of + tensors by block. If the input is a tensor, it is split into blocks + according to the ordering in the `composable_block_dict` attribute. + right_batch: The right batch for interaction computation. + mode: The mode determining the type of interactions. + + Returns: + An ordered dictionary of interactions from transformed gradients by block. + """ + tensor_gen = self.generate_interactions_from_transformed_grads( + left_factors, right_batch, mode + ) + return self._to_ordered_dict(tensor_gen) + + def generate_transformed_grads( + self, batch: BatchType + ) -> Generator[TensorType, None, None]: + """ + Generator that yields transformed gradients for a given batch, + processed by each block. + + Args: + batch: The batch of data for which to generate transformed gradients. + + Yields: + Transformed gradients for each block. + """ + for comp_block in self.composable_block_dict.values(): + yield comp_block.transformed_grads(batch) + + def generate_interactions( + self, + left_batch: BatchType, + right_batch: Optional[BatchType], + mode: InfluenceMode, + ) -> Generator[TensorType, None, None]: + """ + Generator that yields gradient interactions between two batches, processed by + each block based on a mode. + + Args: + left_batch: The left batch for interaction computation. + right_batch: The right batch for interaction computation. + mode: The mode determining the type of interactions. + + Yields: + TensorType: Gradient interactions for each block. + """ + for comp_block in self.composable_block_dict.values(): + yield comp_block.interactions(left_batch, right_batch, mode) + + def generate_interactions_from_transformed_grads( + self, + left_factors: Union[TensorType, OrderedDict[str, TensorType]], + right_batch: BatchType, + mode: InfluenceMode, + ) -> Generator[TensorType, None, None]: + """ + Generator that yields interactions computed from pre-computed factors and a + right batch, processed by each block based on a mode. + + Args: + left_factors: Pre-computed factors as a tensor or an ordered dictionary of + tensors by block. + right_batch: The right batch for interaction computation. + mode: The mode determining the type of interactions. + + Yields: + TensorType: Interactions for each block. + """ + if not isinstance(left_factors, dict): + left_factors_dict = self._split_to_blocks(left_factors) + else: + left_factors_dict = cast(OrderedDict[str, TensorType], left_factors) + for k, comp_block in self.composable_block_dict.items(): + yield comp_block.interactions_from_transformed_grads( + left_factors_dict[k], right_batch, mode + ) + + +BlockMapperType = TypeVar("BlockMapperType", bound=BlockMapper) diff --git a/tests/influence/test_influence_calculator.py b/tests/influence/test_influence_calculator.py index 854321f8f..70a29bf1a 100644 --- a/tests/influence/test_influence_calculator.py +++ b/tests/influence/test_influence_calculator.py @@ -1,5 +1,3 @@ -import uuid - import dask.array as da import numpy as np import pytest @@ -28,10 +26,6 @@ EkfacInfluence, ) from pydvl.influence.torch.influence_function_model import NystroemSketchInfluence -from pydvl.influence.torch.pre_conditioner import ( - JacobiPreConditioner, - NystroemPreConditioner, -) from pydvl.influence.torch.util import ( NestedTorchCatAggregator, TorchCatAggregator, diff --git a/tests/influence/torch/test_batch_operation.py b/tests/influence/torch/test_batch_operation.py new file mode 100644 index 000000000..b04f9b19b --- /dev/null +++ b/tests/influence/torch/test_batch_operation.py @@ -0,0 +1,391 @@ +from dataclasses import astuple + +import pytest +import torch + +from pydvl.influence.torch.base import TorchBatch +from pydvl.influence.torch.batch_operation import ( + GaussNewtonBatchOperation, + HessianBatchOperation, + InverseHarmonicMeanBatchOperation, +) +from pydvl.influence.torch.util import align_structure, flatten_dimensions + +from .test_util import model_data, test_parameters, torch + + +@pytest.mark.torch +class TestHessianBatchOperation: + @pytest.fixture(autouse=True) + def setup(self, model_data): + self.torch_model, self.x, self.y, self.vec, self.h_analytical = model_data + self.params = {k: p.detach() for k, p in self.torch_model.named_parameters()} + self.hessian_op = HessianBatchOperation( + self.torch_model, torch.nn.functional.mse_loss, restrict_to=self.params + ) + + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-5) for tp in test_parameters], + indirect=["model_data"], + ) + def test_analytical_comparison(self, model_data, tol, pytorch_seed): + hvp_autograd = self.hessian_op.apply(TorchBatch(self.x, self.y), self.vec) + hvp_autograd_dict = self.hessian_op.apply_to_dict( + TorchBatch(self.x, self.y), align_structure(self.params, self.vec) + ) + hvp_autograd_dict_flat = flatten_dimensions(hvp_autograd_dict.values()) + + assert torch.allclose(hvp_autograd, self.h_analytical @ self.vec, rtol=tol) + assert torch.allclose( + hvp_autograd_dict_flat, self.h_analytical @ self.vec, rtol=tol + ) + + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-5) for tp in test_parameters], + indirect=["model_data"], + ) + def test_flattening_commutation(self, model_data, tol, pytorch_seed): + batch_size = 10 + rand_mat_dict = { + k: torch.randn(batch_size, *t.shape) for k, t in self.params.items() + } + flat_rand_mat = flatten_dimensions( + rand_mat_dict.values(), shape=(batch_size, -1) + ) + hvp_autograd_mat_dict = self.hessian_op.apply_to_dict( + TorchBatch(self.x, self.y), rand_mat_dict + ) + op_then_flat = flatten_dimensions( + hvp_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", self.h_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-5, + rtol=tol, + ) + assert torch.allclose( + self.hessian_op._apply_to_mat(TorchBatch(self.x, self.y), flat_rand_mat), + op_then_flat, + ) + + +@pytest.mark.torch +class TestGaussNewtonBatchOperation: + @pytest.fixture(autouse=True) + def setup(self, model_data): + self.torch_model, self.x, self.y, self.vec, _ = model_data + self.params = dict(self.torch_model.named_parameters()) + self.gn_op = GaussNewtonBatchOperation( + self.torch_model, torch.nn.functional.mse_loss, restrict_to=self.params + ) + self.out_features = self.torch_model(self.x).shape[1] + self.grad_analytical = self.compute_grad_analytical() + self.gn_mat_analytical = self.compute_gn_mat_analytical() + + def compute_grad_analytical(self): + y_pred = self.torch_model(self.x) + dl_dw = torch.vmap( + lambda r, s, t: 2 + / float(self.out_features) + * (s - t).view(-1, 1) + @ r.view(1, -1) + )(self.x, y_pred, self.y) + dl_db = torch.vmap(lambda s, t: 2 / float(self.out_features) * (s - t))( + y_pred, self.y + ) + return torch.cat([dl_dw.reshape(self.x.shape[0], -1), dl_db], dim=-1) + + def compute_gn_mat_analytical(self): + return ( + torch.sum( + torch.func.vmap(lambda t: t.unsqueeze(-1) * t.unsqueeze(-1).t())( + self.grad_analytical + ), + dim=0, + ) + / self.x.shape[0] + ) + + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], + ) + def test_analytical_comparison(self, model_data, tol): + gn_autograd = self.gn_op.apply(TorchBatch(self.x, self.y), self.vec) + gn_autograd_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), align_structure(self.params, self.vec) + ) + gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) + analytical_vec = self.gn_mat_analytical @ self.vec + + assert torch.allclose(gn_autograd, analytical_vec, atol=1e-4, rtol=tol) + assert torch.allclose( + gn_autograd_dict_flat, analytical_vec, atol=1e-4, rtol=tol + ) + + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], + ) + def test_flattening_commutation(self, model_data, tol): + batch_size = 10 + rand_mat_dict = { + k: torch.randn(batch_size, *t.shape) for k, t in self.params.items() + } + flat_rand_mat = flatten_dimensions( + rand_mat_dict.values(), shape=(batch_size, -1) + ) + gn_autograd_mat_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), rand_mat_dict + ) + + op_then_flat = flatten_dimensions( + gn_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op = self.gn_op._apply_to_mat( + TorchBatch(self.x, self.y), flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op, + atol=1e-4, + rtol=tol, + ) + + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", self.gn_mat_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-4, + rtol=1e-2, + ) + + +@pytest.mark.torch +class TestInverseHarmonicMeanBatchOperation: + @pytest.fixture(autouse=True) + def setup(self, model_data, reg): + self.torch_model, self.x, self.y, self.vec, _ = model_data + self.reg = reg + self.params = { + k: p.detach() + for k, p in self.torch_model.named_parameters() + if p.requires_grad + } + self.grad_analytical = self.compute_grad_analytical() + self.ihm_mat_analytical = self.compute_ihm_mat_analytical() + self.gn_op = InverseHarmonicMeanBatchOperation( + self.torch_model, + torch.nn.functional.mse_loss, + self.reg, + restrict_to=self.params, + ) + + def compute_grad_analytical(self): + y_pred = self.torch_model(self.x) + out_features = y_pred.shape[1] + dl_dw = torch.vmap( + lambda r, s, t: 2 + / float(out_features) + * (s - t).view(-1, 1) + @ r.view(1, -1) + )(self.x, y_pred, self.y) + dl_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))( + y_pred, self.y + ) + return torch.cat([dl_dw.reshape(self.x.shape[0], -1), dl_db], dim=-1) + + def compute_ihm_mat_analytical(self): + return ( + torch.sum( + torch.func.vmap( + lambda z: torch.linalg.inv( + z.unsqueeze(-1) * z.unsqueeze(-1).t() + + self.reg * torch.eye(len(z)) + ) + )(self.grad_analytical), + dim=0, + ) + / self.x.shape[0] + ) + + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], + ) + @pytest.mark.parametrize("reg", [1.0, 10, 100]) + def test_analytical_comparison(self, model_data, tol, reg): + gn_autograd = self.gn_op.apply(TorchBatch(self.x, self.y), self.vec) + gn_autograd_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), align_structure(self.params, self.vec) + ) + gn_autograd_dict_flat = flatten_dimensions(gn_autograd_dict.values()) + analytical = self.ihm_mat_analytical @ self.vec + + assert torch.allclose(gn_autograd, analytical, atol=1e-4, rtol=tol) + assert torch.allclose(gn_autograd_dict_flat, analytical, atol=1e-4, rtol=tol) + + @pytest.mark.parametrize( + "model_data, tol", + [(astuple(tp.model_params), 1e-3) for tp in test_parameters], + indirect=["model_data"], + ) + @pytest.mark.parametrize("reg", [1.0, 10, 100]) + def test_flattening_commutation(self, model_data, tol, reg): + batch_size = 10 + rand_mat_dict = { + k: torch.randn(batch_size, *t.shape) for k, t in self.params.items() + } + flat_rand_mat = flatten_dimensions( + rand_mat_dict.values(), shape=(batch_size, -1) + ) + gn_autograd_mat_dict = self.gn_op.apply_to_dict( + TorchBatch(self.x, self.y), rand_mat_dict + ) + + op_then_flat = flatten_dimensions( + gn_autograd_mat_dict.values(), shape=(batch_size, -1) + ) + flat_then_op = self.gn_op._apply_to_mat( + TorchBatch(self.x, self.y), flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op, + atol=1e-4, + rtol=tol, + ) + + flat_then_op_analytical = torch.einsum( + "ik, jk -> ji", self.ihm_mat_analytical, flat_rand_mat + ) + + assert torch.allclose( + op_then_flat, + flat_then_op_analytical, + atol=1e-4, + rtol=1e-2, + ) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_0, x_dim_1, v_dim_0", + [(10, 1, 12), (3, 2, 5), (4, 5, 30), (6, 6, 6), (1, 7, 7)], +) +def test_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): + X = torch.randn(x_dim_0, x_dim_1) + V = torch.randn(v_dim_0, x_dim_1) + + expected = ( + (torch.vmap(lambda x: x.unsqueeze(-1) * x.unsqueeze(-1).t())(X) @ V.t()) + .sum(dim=0) + .t() + ) / x_dim_0 + + result = GaussNewtonBatchOperation._rank_one_mvp(X, V) + + assert result.shape == V.shape + assert torch.allclose(result, expected, atol=1e-5, rtol=1e-4) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_1", + [ + [(4, 2, 3), (5, 7), (5,)], + [(3, 6, 8, 9), (1, 2)], + [(1,)], + ], +) +@pytest.mark.parametrize( + "x_dim_0, v_dim_0", + [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], +) +def test_generate_rank_one_mvp(x_dim_0, x_dim_1, v_dim_0): + x_list = [torch.randn(x_dim_0, *d) for d in x_dim_1] + v_list = [torch.randn(v_dim_0, *d) for d in x_dim_1] + + x = flatten_dimensions(x_list, shape=(x_dim_0, -1)) + v = flatten_dimensions(v_list, shape=(v_dim_0, -1)) + result = GaussNewtonBatchOperation._rank_one_mvp(x, v) + + inverse_result = flatten_dimensions( + GaussNewtonBatchOperation._generate_rank_one_mvp(x_list, v_list), + shape=(v_dim_0, -1), + ) + + assert torch.allclose(result, inverse_result, atol=1e-5, rtol=1e-3) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_0, x_dim_1, v_dim_0", + [(10, 1, 12), (3, 2, 5), (4, 5, 10), (6, 6, 6), (1, 7, 7)], +) +@pytest.mark.parametrize("reg", [0.1, 100, 1.0, 10]) +def test_inverse_rank_one_update(x_dim_0, x_dim_1, v_dim_0, reg): + X = torch.randn(x_dim_0, x_dim_1) + V = torch.randn(v_dim_0, x_dim_1) + + inverse_result = torch.zeros_like(V) + + for x in X: + rank_one_matrix = x.unsqueeze(-1) * x.unsqueeze(-1).t() + inverse_result += torch.linalg.solve( + rank_one_matrix + reg * torch.eye(rank_one_matrix.shape[0]), V, left=False + ) + + inverse_result /= X.shape[0] + result = InverseHarmonicMeanBatchOperation._inverse_rank_one_update(X, V, reg) + + assert torch.allclose(result, inverse_result, atol=1e-5) + + +@pytest.mark.torch +@pytest.mark.parametrize( + "x_dim_1", + [ + [(4, 2, 3), (5, 7), (5,)], + [(3, 6, 8, 9), (1, 2)], + [(1,)], + ], +) +@pytest.mark.parametrize( + "x_dim_0, v_dim_0", + [(10, 12), (3, 5), (4, 10), (6, 6), (1, 7)], +) +@pytest.mark.parametrize("reg", [0.5, 100, 1.0, 10]) +def test_generate_inverse_rank_one_updates( + x_dim_0, x_dim_1, v_dim_0, reg, pytorch_seed +): + x_list = [torch.randn(x_dim_0, *d) for d in x_dim_1] + v_list = [torch.randn(v_dim_0, *d) for d in x_dim_1] + + x = flatten_dimensions(x_list, shape=(x_dim_0, -1)) + v = flatten_dimensions(v_list, shape=(v_dim_0, -1)) + result = InverseHarmonicMeanBatchOperation._inverse_rank_one_update(x, v, reg) + + inverse_result = flatten_dimensions( + InverseHarmonicMeanBatchOperation._generate_inverse_rank_one_updates( + x_list, v_list, reg + ), + shape=(v_dim_0, -1), + ) + + assert torch.allclose(result, inverse_result, atol=1e-5, rtol=1e-3) diff --git a/tests/influence/torch/test_gradient_provider.py b/tests/influence/torch/test_gradient_provider.py new file mode 100644 index 000000000..ebee2923d --- /dev/null +++ b/tests/influence/torch/test_gradient_provider.py @@ -0,0 +1,113 @@ +import numpy as np +import pytest +import torch + +from pydvl.influence.torch.base import TorchBatch, TorchGradientProvider + +from ..conftest import linear_mixed_second_derivative_analytical, linear_model +from .conftest import DATA_OUTPUT_NOISE, linear_mvp_model + + +class TestTorchPerSampleAutograd: + @pytest.mark.torch + @pytest.mark.parametrize( + "in_features, out_features, batch_size", + [(46, 6, 632), (50, 3, 120), (100, 5, 120), (25, 10, 550)], + ) + def test_per_sample_gradient(self, in_features, out_features, batch_size): + model = torch.nn.Linear(in_features, out_features) + loss = torch.nn.functional.mse_loss + + x = torch.randn(batch_size, in_features, requires_grad=True) + y = torch.randn(batch_size, out_features) + params = {k: p.detach() for k, p in model.named_parameters() if p.requires_grad} + + gp = TorchGradientProvider(model, loss, restrict_to=params) + gradients = gp.grads(TorchBatch(x, y)) + flat_gradients = gp.flat_grads(TorchBatch(x, y)) + + # Compute analytical gradients + y_pred = model(x) + dL_dw = torch.vmap( + lambda r, s, t: 2 + / float(out_features) + * (s - t).view(-1, 1) + @ r.view(1, -1) + )(x, y_pred, y) + dL_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) + + # Assert the gradient values for equality with analytical gradients + assert torch.allclose(gradients["weight"], dL_dw, atol=1e-5) + assert torch.allclose(gradients["bias"], dL_db, atol=1e-5) + assert torch.allclose( + flat_gradients, + torch.cat([dL_dw.reshape(batch_size, -1), dL_db], dim=-1), + atol=1e-5, + ) + + @pytest.mark.torch + @pytest.mark.parametrize( + "in_features, out_features, train_set_size", + [(46, 1, 1000), (50, 3, 100), (100, 5, 512), (25, 10, 734)], + ) + def test_mixed_derivatives(self, in_features, out_features, train_set_size): + A, b = linear_model((out_features, in_features), 5) + loss = torch.nn.functional.mse_loss + model = linear_mvp_model(A, b) + + data_model = lambda x: np.random.normal(x @ A.T + b, DATA_OUTPUT_NOISE) + train_x = np.random.uniform(size=[train_set_size, in_features]) + train_y = data_model(train_x) + + params = {k: p for k, p in model.named_parameters() if p.requires_grad} + + test_derivative = linear_mixed_second_derivative_analytical( + (A, b), + train_x, + train_y, + ) + + torch_train_x = torch.as_tensor(train_x) + torch_train_y = torch.as_tensor(train_y) + gp = TorchGradientProvider(model, loss, restrict_to=params) + flat_functorch_mixed_derivatives = gp.flat_mixed_grads( + TorchBatch(torch_train_x, torch_train_y) + ) + assert torch.allclose( + torch.as_tensor(test_derivative), + flat_functorch_mixed_derivatives.transpose(2, 1), + ) + + @pytest.mark.torch + @pytest.mark.parametrize( + "in_features, out_features, batch_size", + [(46, 1, 632), (50, 3, 120), (100, 5, 110), (25, 10, 500)], + ) + def test_matrix_jacobian_product( + self, in_features, out_features, batch_size, pytorch_seed + ): + model = torch.nn.Linear(in_features, out_features) + params = {k: p for k, p in model.named_parameters() if p.requires_grad} + + x = torch.randn(batch_size, in_features, requires_grad=True) + y = torch.randn(batch_size, out_features, requires_grad=True) + y_pred = model(x) + + gp = TorchGradientProvider( + model, torch.nn.functional.mse_loss, restrict_to=params + ) + + G = torch.randn((10, out_features * (in_features + 1))) + mjp = gp.jacobian_prod(TorchBatch(x, y), G) + + dL_dw = torch.vmap( + lambda r, s, t: 2 + / float(out_features) + * (s - t).view(-1, 1) + @ r.view(1, -1) + )(x, y_pred, y) + dL_db = torch.vmap(lambda s, t: 2 / float(out_features) * (s - t))(y_pred, y) + analytic_grads = torch.cat([dL_dw.reshape(dL_dw.shape[0], -1), dL_db], dim=1) + analytical_mjp = G @ analytic_grads.T + + assert torch.allclose(analytical_mjp, mjp, atol=1e-5, rtol=1e-3) diff --git a/tests/influence/torch/test_influence_model.py b/tests/influence/torch/test_influence_model.py index d2203a84e..6cc2ad0de 100644 --- a/tests/influence/torch/test_influence_model.py +++ b/tests/influence/torch/test_influence_model.py @@ -4,6 +4,7 @@ import numpy as np import pytest from numpy.typing import NDArray +from scipy.stats import pearsonr, spearmanr from pydvl.influence.base_influence_function_model import ( NotFittedException, @@ -14,6 +15,7 @@ CgInfluence, DirectInfluence, EkfacInfluence, + InverseHarmonicMeanInfluence, LissaInfluence, NystroemSketchInfluence, ) @@ -22,6 +24,7 @@ NystroemPreConditioner, PreConditioner, ) +from pydvl.influence.torch.util import BlockMode from tests.influence.torch.conftest import minimal_training torch = pytest.importorskip("torch") @@ -754,3 +757,55 @@ def test_influences_cg( .numpy() ) assert np.allclose(single_influence, direct_factors[0], atol=1e-6, rtol=1e-4) + + +composable_influence_factories = [InverseHarmonicMeanInfluence] + + +@pytest.mark.parametrize("composable_influence_factory", composable_influence_factories) +@pytest.mark.parametrize("block_mode", [mode for mode in BlockMode]) +@pytest.mark.torch +def test_composable_influence( + test_case: TestCase, + model_and_data: Tuple[ + torch.nn.Module, + Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + direct_influences, + direct_sym_influences, + device: torch.device, + block_mode, + composable_influence_factory, +): + model, loss, x_train, y_train, x_test, y_test = model_and_data + + train_dataloader = DataLoader( + TensorDataset(x_train, y_train), batch_size=test_case.batch_size + ) + + harmonic_mean_influence = composable_influence_factory( + model, loss, test_case.hessian_reg, block_structure=block_mode + ).to(device) + harmonic_mean_influence = harmonic_mean_influence.fit(train_dataloader) + harmonic_mean_influence_values = ( + harmonic_mean_influence.influences( + x_test, y_test, x_train, y_train, mode=test_case.mode + ) + .cpu() + .numpy() + ) + + threshold = 0.999 + flat_direct_influences = direct_influences.reshape(-1) + flat_harmonic_influences = harmonic_mean_influence_values.reshape(-1) + assert np.all( + pearsonr(flat_direct_influences, flat_harmonic_influences).statistic > threshold + ) + assert np.all( + spearmanr(flat_direct_influences, flat_harmonic_influences).statistic + > threshold + ) diff --git a/tests/influence/torch/test_util.py b/tests/influence/torch/test_util.py index c63a34253..a1b782a8c 100644 --- a/tests/influence/torch/test_util.py +++ b/tests/influence/torch/test_util.py @@ -17,6 +17,8 @@ lanzcos_low_rank_hessian_approx, ) from pydvl.influence.torch.util import ( + BlockMode, + ModelParameterDictBuilder, TorchLinalgEighException, TorchTensorContainerType, align_structure, @@ -318,3 +320,41 @@ def test_safe_torch_linalg_eigh(): def test_safe_torch_linalg_eigh_exception(): with pytest.raises(TorchLinalgEighException): safe_torch_linalg_eigh(torch.randn([53000, 53000])) + + +class TestModelParameterDictBuilder: + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(5, 10) + self.fc2 = torch.nn.Linear(10, 5) + self.fc1.weight.requires_grad = False + + @pytest.fixture + def model(self): + return TestModelParameterDictBuilder.SimpleModel() + + @pytest.mark.parametrize("block_mode", [mode for mode in BlockMode]) + def test_build(self, block_mode, model): + builder = ModelParameterDictBuilder( + model=model, + detach=True, + ) + param_dict = builder.build_from_block_mode(block_mode) + + if block_mode is BlockMode.FULL: + assert "" in param_dict + assert "fc1.weight" not in param_dict[""] + elif block_mode is BlockMode.PARAMETER_WISE: + assert "fc2.bias" in param_dict + assert len(param_dict["fc2.bias"]) > 0 + assert "fc1.weight" not in param_dict + elif block_mode is BlockMode.LAYER_WISE: + assert "fc2" in param_dict + assert "fc2.bias" in param_dict["fc2"] + assert "fc1.weight" not in param_dict["fc1"] + assert "fc1.bias" in param_dict["fc1"] + + assert all( + (not p.requires_grad for q in param_dict.values() for p in q.values()) + )