For a detailed description of the methodology, please refer to the paper, Studying Large Language Model Generalization with Influence Functions.
Kronfluence has been tested and is compatible with the following versions of PyTorch:
- Python: Version 3.9 or later
- PyTorch: Version 2.1 or later
Kronfluence offers support for:
- Computing influence functions on selected PyTorch modules. Currently, we support
nn.Linear
andnn.Conv2d
. - Computing influence functions with several Hessian approximation strategies, including
identity
,diagonal
,kfac
, andekfac
. - Computing pairwise and self-influence (with and without measurement) scores.
Note
If there are additional modules you would like to see supported, please submit an issue on our GitHub repository.
See UCI Regression example for the complete workflow and an interactive tutorial.
Prepare Your Model and Dataset. Before computing influence scores, you need to prepare the trained model and dataset. You can use any frameworks to train the model (e.g., Pytorch Lightning or HuggingFace Trainer); you just need to prepare the final model parameters.
...
# Get the model with the trained parameters.
model = construct_model()
# Get the training dataset.
train_dataset = prepare_train_dataset()
# Get the query dataset (e.g., validation/test dataset).
query_dataset = prepare_query_dataset()
...
Define a Task.
To compute influence scores, you need to define a Task
class.
This class contains information about the trained model and how influence scores will be computed:
(1) how to compute the training loss; (2) how to compute the measurable quantity (f(θ) in the paper; see Equation 5);
(3) which modules to use for influence function computations; and (4) whether the model used attention mask.
from typing import Any, Dict, List, Optional, Union
import torch
from torch import nn
from kronfluence.task import Task
class YourTask(Task):
def compute_train_loss(
self,
batch: Any,
model: nn.Module,
sample: bool = False,
) -> torch.Tensor:
# TODO: Complete this method.
def compute_measurement(
self,
batch: Any,
model: nn.Module,
) -> torch.Tensor:
# TODO: Complete this method.
def get_influence_tracked_modules(self) -> Optional[List[str]]:
# TODO: [Optional] Complete this method.
return None # Compute influence scores on all available modules.
def get_attention_mask(self, batch: Any) -> Optional[Union[Dict[str, torch.Tensor], torch.Tensor]]:
# TODO: [Optional] Complete this method.
return None # Attention mask not used.
Prepare Your Model for Influence Computations.
Kronfluence wraps all supported modules within the model with TrackedModule
.
This wrapper will be used for computing the factors and influence scores. Once your model is ready and the task is defined,
prepare your model with:
from kronfluence.analyzer import prepare_model
...
task = YourTask()
model = prepare_model(model=model, task=task)
...
If you have specified specific module names in Task.get_influence_tracked_modules
, TrackedModule
will only be installed for these modules.
[Optional] Create a DDP and FSDP Module.
After calling prepare_model
, you can create DistributedDataParallel (DDP) or
FullyShardedDataParallel (FSDP) module. You can also wrap your model with torch.compile
.
Set up the Analyzer and Fit Factors.
Initialize the Analyzer
and run fit_all_factors
to compute all factors that aim to approximate the Hessian
(Gauss-Newton Hessian). The computed factors will be stored on disk.
from kronfluence.analyzer import Analyzer
from kronfluence.utils.dataset import DataLoaderKwargs
...
analyzer = Analyzer(analysis_name="model_with_seed0", model=model, task=task)
# [Optional] Set up the parameters for the DataLoader.
dataloader_kwargs = DataLoaderKwargs(num_workers=4, pin_memory=True)
analyzer.set_dataloader_kwargs(dataloader_kwargs)
# Compute all factors.
analyzer.fit_all_factors(factors_name="initial_factor", dataset=train_dataset)
...
Compute Influence Scores. Once the factors have been computed, you can compute pairwise and self-influence scores. When computing the scores, you can specify the factor name you would like to use.
...
scores = analyzer.compute_pairwise_scores(
scores_name="initial_score",
factors_name="initial_factor",
query_dataset=query_dataset,
train_dataset=train_dataset,
per_device_query_batch_size=1024,
)
...
You can organize all factors and scores for the specific model with factors_name
and scores_name
.
What should I do if my model does not have any nn.Linear or nn.Conv2d modules?
Currently, the implementation does not support influence computations for modules other than nn.Linear
or nn.Conv2d
.
Try rewriting the model so that it uses supported modules (as done for the conv1d
module in the GPT-2 example).
Alternatively, you can create a subclass of TrackedModule
to compute influence scores for your custom module.
If there are specific modules you would like to see supported, please submit an issue.
How should I write task.get_influence_tracked_modules?
We recommend using all supported modules for influence computations. However, if you would like to compute influence scores
on subset of the modules (e.g., influence computations only on MLP layers for transformer or influence computation only on the last layer),
inspect model.named_modules()
to determine what modules to use. You can specify the list of module names you want to analyze.
Tip
Analyzer.get_module_summary(model)
can be helpful in figuring out what modules to include.
Note
If the embedding layer for transformers are defined with nn.Linear
, you must write your own
task.tracked_modules
to avoid influence computations embedding matrices.
How should I implement Task.compute_train_loss? Implement the loss function used to train the model. Note that the function should return the summed loss (over batches and tokens).
How should I implement Task.compute_measurement? It depends on the analysis you would like to perform. Influence functions approximate the effect of downweighting/upweighting a training data point on the query's measurable quantity. You can use the loss, margin (for classification), or conditional log-likelihood (for language modeling). Note that many influence functions implementation, by default, uses the loss.
I encounter TrackedModuleNotFoundError when using DDP or FSDP.
Make sure to call prepare_model
before wrapping your model with DDP or FSDP. Calling prepare_model
on DDP modules can
cause TrackedModuleNotFoundError
.
My model uses supported modules, but influence scores are not computed. Kronfluence uses module hooks to compute factors and influence scores. For these to be tracked and computed, the model's forward pass should directly call the module.
import torch
from torch import nn
...
self.linear = nn.Linear(8, 1, bias=True)
...
def forward(x: torch.Tensor) -> torch.Tensor:
x = self.linear(x) # This works 😊
x = self.linear.weight @ x + self.linear.bias # This does not work 😞
Warning
The default arguments assume the module is used only once during the forward pass.
If your model shares parameters (e.g., the module is used in multiple places during the forward pass), set
has_shared_parameters=True
in FactorArguments
.
Why are there so many arguments?
Kronfluence was originally developed to compute influence scores on large-scale models, which is why FactorArguments
and ScoreArguments
have many parameters to support these use cases. However, for most standard applications, the default argument values
should suffice. Feel free to use the default settings unless you have specific requirements that necessitate customization.
I get X error when fitting factors/computing scores. Please feel free to contact me by filing an issue or through email.
import torch
from kronfluence.arguments import FactorArguments
factor_args = FactorArguments(
strategy="ekfac", # Choose from "identity", "diagonal", "kfac", or "ekfac".
use_empirical_fisher=False,
amp_dtype=None,
amp_scale=2.0**16,
has_shared_parameters=False,
# Settings for covariance matrix fitting.
covariance_max_examples=100_000,
covariance_data_partitions=1,
covariance_module_partitions=1,
activation_covariance_dtype=torch.float32,
gradient_covariance_dtype=torch.float32,
# Settings for Eigendecomposition.
eigendecomposition_dtype=torch.float64,
# Settings for Lambda matrix fitting.
lambda_max_examples=100_000,
lambda_data_partitions=1,
lambda_module_partitions=1,
use_iterative_lambda_aggregation=False,
offload_activations_to_cpu=False,
per_sample_gradient_dtype=torch.float32,
lambda_dtype=torch.float32,
)
# You can pass in the arguments when fitting the factors.
analyzer.fit_all_factors(factors_name="initial_factor", dataset=train_dataset, factor_args=factor_args)
You can change:
strategy
: Selects the Hessian approximation strategy (identity
,diagonal
,kfac
, orekfac
).use_empirical_fisher
: Determines whether to use the empirical Fisher (using actual labels from batch) instead of the true Fisher (using sampled labels from model's predictions). It is recommended to beFalse
.amp_dtype
: Selects the dtype for automatic mixed precision (AMP). Disables AMP if set toNone
.amp_scale
: Sets the scale factor for automatic mixed precision (AMP).has_shared_parameters
: Specifies whether the shared parameters exist in the forward pass.
kfac
and ekfac
require computing the uncentered activation and pre-activation pseudo-gradient covariance matrices.
To fit covariance matrices, you can use analyzer.fit_covariance_matrices
.
# Fitting covariance matrices.
analyzer.fit_covariance_matrices(factors_name="initial_factor", dataset=train_dataset, factor_args=factor_args)
# Loading covariance matrices.
covariance_matrices = analyzer.load_covariance_matrices(factors_name="initial_factor")
This step corresponds to Equation 16 in the paper. You can tune:
covariance_max_examples
: Controls the maximum number of data points for fitting covariance matrices. Setting it toNone
, Kronfluence computes covariance matrices for all data points.covariance_data_partitions
: Number of data partitions to use for computing covariance matrices. For example, whencovariance_data_partitions=2
, the dataset is split into 2 chunks and covariance matrices are separately computed for each chunk. These chunked covariance matrices are later aggregated. This is useful with GPU preemption as intermediate covariance matrices will be saved in disk. It is also helpful when using low precision.covariance_module_partitions
: Number of module partitions to use for computing covariance matrices. For example, whencovariance_module_partitions=2
, the module is split into 2 chunks and covariance matrices are separately computed for each chunk. This is useful when the available GPU memory is limited (e.g., the total covariance matrices cannot fit into GPU memory). However, this will require multiple iterations over the dataset and can be slow.activation_covariance_dtype
:dtype
for computing activation covariance matrices. You can also usetorch.bfloat16
ortorch.float16
.gradient_covariance_dtype
:dtype
for computing pre-activation pseudo-gradient covariance matrices. You can also usetorch.bfloat16
ortorch.float16
.
Dealing with OOMs. Here are some steps to fix Out of Memory (OOM) errors.
- Try reducing the
per_device_batch_size
when fitting covariance matrices. - Try using lower precision for
activation_covariance_dtype
andgradient_covariance_dtype
. - Try setting
covariance_module_partitions > 1
.
After computing the covariance matrices, kfac
and ekfac
require performing eigendecomposition.
# Performing eigendecomposition.
analyzer.perform_eigendecomposition(factors_name="initial_factor", factor_args=factor_args)
# Loading eigendecomposition results (e.g., eigenvectors and eigenvalues).
eigen_factors = analyzer.load_eigendecomposition(factors_name="initial_factor")
This corresponds to Equation 18 in the paper. You can tune:
eigendecomposition_dtype
:dtype
for performing eigendecomposition. You can also usetorch.float32
, buttorch.float64
is strongly recommended.
ekfac
and diagonal
require computing the Lambda (corrected-eigenvalue) matrices for all modules.
# Fitting Lambda matrices.
analyzer.fit_lambda_matrices(factors_name="initial_factor", dataset=train_dataset, factor_args=factor_args)
# Loading Lambda matrices.
lambda_matrices = analyzer.load_lambda_matrices(factors_name="initial_factor")
This corresponds to Equation 20 in the paper. You can tune:
lambda_max_examples
: Controls the maximum number of data points for fitting Lambda matrices.lambda_data_partitions
: Number of data partitions to use for computing Lambda matrices.lambda_module_partitions
: Number of module partitions to use for computing Lambda matrices.offload_activations_to_cpu
: Computing the per-sample-gradient requires saving the intermediate activation in memory. You can setoffload_activations_to_cpu=True
to cache these activations in CPU. This is helpful for dealing with OOMs, but will make the overall computation slower.use_iterative_lambda_aggregation
: Whether to compute the Lambda matrices with for-loops instead of batched matrix multiplications. This is helpful for reducing peak GPU memory, as it avoids holding multiple copies of tensors with the same shape as the per-sample-gradient.per_sample_gradient_dtype
:dtype
for computing per-sample-gradient. You can also usetorch.bfloat16
ortorch.float16
.lambda_dtype
:dtype
for computing Lambda matrices. You can also usetorch.bfloat16
ortorch.float16
.
Dealing with OOMs. Here are some steps to fix Out of Memory (OOM) errors.
- Try reducing the
per_device_batch_size
when fitting Lambda matrices. - Try setting
use_iterative_lambda_aggregation=True
oroffload_activations_to_cpu=True
. (Try outuse_iterative_lambda_aggregation=True
first.) - Try using lower precision for
per_sample_gradient_dtype
andlambda_dtype
. - Try using
lambda_module_partitions > 1
.
I get different factors each time I run the code.
This is expected as we sample labels from the model's prediction when computing covariance and Lambda matrices.
Using use_empirical_fisher=True
could make the process more deterministic. Moreover, different hardware might compute
different eigenvectors when performing eigendecomposition.
How should I select the batch size? You can use the largest possible batch size that avoids OOM error. Typically, the batch size for fitting Lambda matrices should be smaller than that used for fitting covariance matrices. Furthermore, note that you should be getting similar results, regardless of what batch size you use (different from training neural networks).
import torch
from kronfluence.arguments import ScoreArguments
score_args = ScoreArguments(
damping_factor=1e-08,
amp_dtype=None,
offload_activations_to_cpu=False,
# More functionalities to compute influence scores.
data_partitions=1,
module_partitions=1,
compute_per_module_scores=False,
compute_per_token_scores=False,
use_measurement_for_self_influence=False,
aggregate_query_gradients=False,
aggregate_train_gradients=False,
# Configuration for query batching.
query_gradient_low_rank=None,
use_full_svd=False,
query_gradient_svd_dtype=torch.float32,
query_gradient_accumulation_steps=1,
# Configuration for dtype.
score_dtype=torch.float32,
per_sample_gradient_dtype=torch.float32,
precondition_dtype=torch.float32,
)
damping_factor
: A damping factor for the damped inverse Hessian-vector product (iHVP). Uses a heuristic based on mean eigenvalues(0.1 x mean eigenvalues)
ifNone
, as done in this paper.amp_dtype
: Selects the dtype for automatic mixed precision (AMP). Disables AMP if set toNone
.offload_activations_to_cpu
: Whether to offload cached activations to CPU.data_partitions
: Number of data partitions for computing influence scores.module_partitions
: Number of module partitions for computing influence scores.compute_per_module_scores
: Whether to return a per-module influence scores. Instead of summing over influences across all modules, this will keep track of intermediate module-wise scores.compute_per_token_scores
: Whether to return a per-token influence scores. Only applicable to transformer-based models.aggregate_query_gradients
: Whether to use the summed query gradient instead of per-sample query gradients.aggregate_train_gradients
: Whether to use the summed training gradient instead of per-sample training gradients.use_measurement_for_self_influence
: Whether to use the measurement (instead of the loss) when computing self-influence scores.query_gradient_low_rank
: The rank for the query batching (low-rank approximation to the preconditioned query gradient; see Section 3.2.2). IfNone
, no query batching will be used.query_gradient_svd_dtype
:dtype
for performing singular value decomposition (SVD) for query batch. You can also usetorch.float64
.query_gradient_accumulation_steps
: Number of query gradients to accumulate over. For example, whenquery_gradient_accumulation_steps=2
withquery_batch_size=16
, a total of 32 query gradients will be stored in memory when computing dot products with training gradients.score_dtype
:dtype
for computing influence scores. You can usetorch.bfloat16
ortorch.float16
.per_sample_gradient_dtype
:dtype
for computing per-sample-gradient. You can usetorch.bfloat16
ortorch.float16
.precondition_dtype
:dtype
for performing preconditioning. You can usetorch.bfloat16
ortorch.float16
.
To compute pairwise influence scores (Equation 5 in the paper), you can run:
# Computing pairwise influence scores.
analyzer.compute_pairwise_scores(scores_name="pairwise", factors_name="ekfac", score_args=score_args)
# Loading pairwise influence scores.
scores = analyzer.load_pairwise_scores(scores_name="pairwise")
To compute self-influence scores (see Section 5.4 from this paper), you can run:
# Computing self-influence scores.
analyzer.compute_self_scores(scores_name="self", factors_name="ekfac", score_args=score_args)
# Loading self-influence scores.
scores = analyzer.load_self_scores(scores_name="self")
By default, self-influence score computations only use the loss function for gradient calculations.
In this case, the method returns a vector of size len(train_dataset)
, where each value corresponds
to g_l^T ⋅ H^{-1} ⋅ g_l
. Here, g_l
denotes the gradient of the loss function with respect to the model parameters,
and H^{-1}
represents the inverse Hessian matrix. If you want to use the measurement function instead of the loss function
for self-influence calculations, set use_measurement_for_self_influence=True
. In this case, each value in the returned
vector will correspond to g_m^T ⋅ H^{-1} ⋅ g_l
, where g_m
is the gradient of the measurement function with respect to the model parameters.
Dealing with OOMs. Here are some steps to fix Out of Memory (OOM) errors.
- Try reducing the
per_device_query_batch_size
orper_device_train_batch_size
. - Try setting
offload_activations_to_cpu=True
. - Try using lower precision for
per_sample_gradient_dtype
andscore_dtype
. - Try using lower precision for
precondition_dtype
. - Try setting
query_gradient_low_rank > 1
. The recommended values are16
,32
,64
,128
, and256
. Note that query batching is only supported for computing pairwise influence scores, not self-influence scores. - Try setting
module_partitions > 1
.
How should I choose a damping term?
When setting the damping term, both 1e-08
and None
are reasonable choices. The optimal value may depend on your
specific workload. Another heuristic, suggested in this paper, is to use 10 * learning_rate * num_iterations
when the model
was trained using SGD with a momentum of 0.9. In practice, I have observed that the damping term does not significantly
affect the final results as long as it is not too large (e.g., 1e-01
). Feel free to experiment with different values within a
reasonable range to find what works best for your use case.
Influence scores are very large in magnitude. Ideally, influence scores need to be divided by the total number of training data points. However, the code does not normalize the scores. If you would like, you can divide the scores with the total number of data points (or tokens for language modeling) used to train the model.
- Studying Large Language Model Generalization with Influence Functions. Roger Grosse, Juhan Bae, Cem Anil, et al. Tech Report, 2023.
- If Influence Functions are the Answer, Then What is the Question?. Juhan Bae, Nathan Ng, Alston Lo, Marzyeh Ghassemi, Roger Grosse. NeurIPS, 2022.
- TRAK: Attributing Model Behavior at Scale. Sung Min Park, Kristian Georgiev, Andrew Ilyas, Guillaume Leclerc, Aleksander Madry. ICML, 2023.
- Understanding Black-box Predictions via Influence Functions. Pang Wei Koh, Percy Liang. ICML, 2017.
- Optimizing Neural Networks with Kronecker-factored Approximate Curvature. James Martens, Roger Grosse. Tech Report, 2015.
- Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis. Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent. NeurIPS, 2018.
- Training Data Attribution via Approximate Unrolled Differentiation. Juhan Bae, Wu Lin, Jonathan Lorraine, Roger Grosse. Preprint, 2024.