diff --git a/src/pydvl/influence/dask/influence_model.py b/src/pydvl/influence/dask/influence_model.py index 30eacb148..93a232d5a 100644 --- a/src/pydvl/influence/dask/influence_model.py +++ b/src/pydvl/influence/dask/influence_model.py @@ -33,7 +33,7 @@ def __init__( self.from_numpy = from_numpy self.to_numpy = to_numpy self._num_parameters = influence_model.num_parameters - self.influence_model = influence_model.prepare_for_distributed() + self.influence_model = influence_model client = self._get_client() if client is not None: self.influence_model = client.scatter(influence_model, broadcast=True) @@ -74,7 +74,7 @@ def _get_chunk_indices( return tuple(indices) - def factors(self, x: da.Array, y: da.Array) -> InverseHvpResult[da.Array]: + def factors(self, x: da.Array, y: da.Array) -> da.Array: """ Compute the expression $$ @@ -96,13 +96,9 @@ def factors(self, x: da.Array, y: da.Array) -> InverseHvpResult[da.Array]: self._validate_aligned_chunking(x, y) self._validate_un_chunked(x) self._validate_un_chunked(y) - return InverseHvpResult(self._factors_without_info(x, y), {}) - def _factors_without_info(self, x: da.Array, y: da.Array): def func(x_numpy: NDArray, y_numpy: NDArray, model: Influence): - factors, _ = model.factors( - self.from_numpy(x_numpy), self.from_numpy(y_numpy) - ) + factors = model.factors(self.from_numpy(x_numpy), self.from_numpy(y_numpy)) return self.to_numpy(factors) def block_func(x_block: da.Array, y_block: NDArray): @@ -229,7 +225,7 @@ def values( x: Optional[da.Array] = None, y: Optional[da.Array] = None, influence_type: InfluenceType = InfluenceType.Up, - ) -> InverseHvpResult: + ) -> da.Array: """ Compute approximation of $$ @@ -277,7 +273,7 @@ def func( y_numpy: NDArray, model: Influence, ): - values, _ = model.values( + values = model.values( self.from_numpy(x_test_numpy), self.from_numpy(y_test_numpy), self.from_numpy(x_numpy), @@ -287,7 +283,7 @@ def func( return self.to_numpy(values) resulting_shape = "ij" if influence_type is InfluenceType.Up else "ijk" - result = da.blockwise( + return da.blockwise( func, resulting_shape, x_test, @@ -303,7 +299,6 @@ def func( dtype=x.dtype, align_arrays=True, ) - return InverseHvpResult(result, {}) @staticmethod def _get_client() -> Optional[distributed.Client]: diff --git a/src/pydvl/influence/general.py b/src/pydvl/influence/general.py index 2d0bdb118..1dafec185 100644 --- a/src/pydvl/influence/general.py +++ b/src/pydvl/influence/general.py @@ -84,15 +84,7 @@ def factors_gen() -> Generator[TensorType, None, None]: ): yield influence.factors(x_test, y_test) - info_dict = {} - tensor_list = [] - for k, factors in enumerate(factors_gen()): - info_dict[k] = factors.info - tensor_list.append(factors.x) - - values = cat(tensor_list) - - return InverseHvpResult(values, info_dict) + return cat(list(factors_gen())) def compute_influences_up( @@ -237,7 +229,7 @@ def compute_influences( hessian_regularization: float = 0.0, progress: bool = False, **kwargs: Any, -) -> InverseHvpResult: # type: ignore # ToDO fix typing +) -> TensorType: # type: ignore # ToDO fix typing r""" Calculates the influence of each input data point on the specified test points. @@ -299,7 +291,7 @@ def values_gen() -> Generator[TensorType, None, None]: for x, y in maybe_progress( input_data, progress, desc="Batch input influence values" ): - yield influence_function(factors.x, x, y) + yield influence_function(factors, x, y) tensor_util: Type[TensorUtilities] = TensorUtilities.from_twice_differentiable( differentiable_model @@ -307,4 +299,4 @@ def values_gen() -> Generator[TensorType, None, None]: cat = tensor_util.cat values = cat(list(values_gen()), dim=1) - return InverseHvpResult(values, factors.info) + return values diff --git a/src/pydvl/influence/torch/influence_model.py b/src/pydvl/influence/torch/influence_model.py index 4f9c2dc24..2d5d8b95f 100644 --- a/src/pydvl/influence/torch/influence_model.py +++ b/src/pydvl/influence/torch/influence_model.py @@ -79,7 +79,7 @@ def values( x: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, influence_type: InfluenceType = InfluenceType.Up, - ) -> InverseHvpResult: + ) -> torch.Tensor: if x is None: @@ -116,19 +116,19 @@ def _non_symmetric_values( if influence_type is InfluenceType.Up: if x_test.shape[0] <= y.shape[0]: - factor, info = self.factors(x_test, y_test) + factor = self.factors(x_test, y_test) values = self.up_weighting(factor, x, y) else: - factor, info = self.factors(x, y) + factor = self.factors(x, y) values = self.up_weighting(factor, x_test, y_test) else: - factor, info = self.factors(x_test, y_test) + factor = self.factors(x_test, y_test) values = self.perturbation(factor, x, y) - return InverseHvpResult(values, info) + return values def _symmetric_values( self, x: torch.Tensor, y: torch.Tensor, influence_type: InfluenceType - ) -> InverseHvpResult[torch.Tensor]: + ) -> torch.Tensor: grad = self._loss_grad(x, y) fac, info = self._solve_hvp(grad) @@ -138,7 +138,7 @@ def _symmetric_values( else: values = self.perturbation(fac, x, y) - return InverseHvpResult(values, info) + return values def up_weighting( self, @@ -167,14 +167,14 @@ def perturbation( ), ) - def factors(self, x: torch.Tensor, y: torch.Tensor) -> InverseHvpResult: + def factors(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return self._solve_hvp( self._loss_grad(x.to(self.model_device), y.to(self.model_device)) ) @abstractmethod - def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult[torch.Tensor]: + def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor: pass @@ -195,7 +195,6 @@ def __init__( hessian_regularization: float, hessian: torch.Tensor = None, train_dataloader: DataLoader = None, - return_hessian_in_info: bool = False, ): if hessian is None and train_dataloader is None: raise ValueError( @@ -203,7 +202,6 @@ def __init__( ) super().__init__(model, loss) - self.return_hessian_in_info = return_hessian_in_info self.hessian_perturbation = hessian_regularization self.hessian = ( hessian @@ -211,26 +209,13 @@ def __init__( else get_hessian(model, loss, train_dataloader) ) - def prepare_for_distributed(self) -> "Influence": - if self.return_hessian_in_info: - self.return_hessian_in_info = False - logger.warning( - f"Modified parameter `return_hessian_in_info` to `False`, " - f"to prepare for distributed computing" - ) - return self - - def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult: - result = torch.linalg.solve( + def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor: + return torch.linalg.solve( self.hessian.to(self.model_device) + self.hessian_perturbation * torch.eye(self.num_parameters, device=self.model_device), rhs.T.to(self.model_device), ).T - info = {} - if self.return_hessian_in_info: - info["hessian"] = self.hessian - return InverseHvpResult(result, info) def to(self, device: torch.device): self.hessian = self.hessian.to(device) @@ -266,7 +251,7 @@ def __init__( self.hessian_regularization = hessian_regularization self.train_dataloader = train_dataloader - def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult: + def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor: if len(self.train_dataloader) == 0: raise ValueError("Training dataloader must not be empty.") @@ -287,8 +272,7 @@ def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult: maxiter=self.maxiter, ) batch_cg[idx] = batch_result - info[f"batch_{idx}"] = batch_info - return InverseHvpResult(x=batch_cg, info=info) + return batch_cg def to(self, device: torch.device): self.model = self.model.to(device) @@ -350,7 +334,7 @@ def __init__( self.dampen = dampen self.train_dataloader = train_dataloader - def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult: + def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor: h_estimate = self.h0 if self.h0 is not None else torch.clone(rhs) @@ -401,11 +385,7 @@ def lissa_step( f"Terminated Lissa with {max_residual*100:.2f} % max residual." f" Mean residual: {mean_residual*100:.5f} %" ) - info = { - "max_perc_residual": max_residual * 100, - "mean_perc_residual": mean_residual * 100, - } - return InverseHvpResult(x=h_estimate / self.scale, info=info) + return h_estimate / self.scale class ArnoldiInfluence(TorchInfluence): @@ -457,7 +437,6 @@ def __init__( tol: float = 1e-6, max_iter: Optional[int] = None, eigen_computation_on_gpu: bool = False, - return_low_rank_representation_in_info: bool = False, ): if low_rank_representation is None and train_dataloader is None: raise ValueError( @@ -479,20 +458,8 @@ def __init__( super().__init__(model, loss) self.low_rank_representation = low_rank_representation.to(self.model_device) - self.return_low_rank_representation_in_info = ( - return_low_rank_representation_in_info - ) self.hessian_regularization = hessian_regularization - def prepare_for_distributed(self) -> "Influence": - if self.return_low_rank_representation_in_info: - self.return_low_rank_representation_in_info = False - logger.warning( - f"Modified parameter `return_low_rank_representation_in_info` to `False`, " - f"to prepare for distributed computing" - ) - return self - def _non_symmetric_values( self, x_test: torch.Tensor, @@ -500,7 +467,7 @@ def _non_symmetric_values( x: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, influence_type: InfluenceType = InfluenceType.Up, - ) -> InverseHvpResult[torch.Tensor]: + ) -> torch.Tensor: if influence_type is InfluenceType.Up: mjp = matrix_jacobian_product( @@ -517,16 +484,14 @@ def _non_symmetric_values( ) values = torch.einsum("ij, ik -> jk", left, right) else: - factors, _ = self.factors(x_test, y_test) + factors = self.factors(x_test, y_test) values = self.perturbation(factors, x, y) - info = {} - if self.return_low_rank_representation_in_info: - info["low_rank_representation"] = self.low_rank_representation - return InverseHvpResult(values, info) + + return values def _symmetric_values( self, x: torch.Tensor, y: torch.Tensor, influence_type: InfluenceType - ) -> InverseHvpResult[torch.Tensor]: + ) -> torch.Tensor: if influence_type is InfluenceType.Up: left = matrix_jacobian_product( @@ -537,14 +502,12 @@ def _symmetric_values( ) values = torch.einsum("ij, ik -> jk", left, right) else: - factors, _ = self.factors(x, y) + factors = self.factors(x, y) values = self.perturbation(factors, x, y) - info = {} - if self.return_low_rank_representation_in_info: - info["low_rank_representation"] = self.low_rank_representation - return InverseHvpResult(values, info) - def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult: + return values + + def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor: rhs_device = rhs.device if hasattr(rhs, "device") else torch.device("cpu") if rhs_device.type != self.low_rank_representation.device.type: raise RuntimeError( @@ -564,15 +527,7 @@ def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult: @ (self.low_rank_representation.projections.t() @ rhs.t()) ) - if self.return_low_rank_representation_in_info: - info = { - "eigenvalues": self.low_rank_representation.eigen_vals, - "eigenvectors": self.low_rank_representation.projections, - } - else: - info = {} - - return InverseHvpResult(x=result.t(), info=info) + return result.t() def to(self, device: torch.device): return ArnoldiInfluence( @@ -592,7 +547,6 @@ def direct_factory( twice_differentiable.loss, train_dataloader=data_loader, hessian_regularization=hessian_regularization, - return_hessian_in_info=True, **kwargs, ) @@ -641,6 +595,5 @@ def arnoldi_factory( twice_differentiable.loss, train_dataloader=data_loader, hessian_regularization=hessian_regularization, - return_low_rank_representation_in_info=True, **kwargs, ) diff --git a/src/pydvl/influence/twice_differentiable.py b/src/pydvl/influence/twice_differentiable.py index 47feda91d..a510b806c 100644 --- a/src/pydvl/influence/twice_differentiable.py +++ b/src/pydvl/influence/twice_differentiable.py @@ -291,12 +291,6 @@ class Influence(Generic[TensorType], ABC): def num_parameters(self): """Number of trainable parameters of the underlying model""" - def prepare_for_distributed(self) -> "Influence": - """Overwrite this method, in case the instance has to be modified, before being distributed. - Must be called explicitly. - """ - return self - @abstractmethod def up_weighting( self, @@ -344,7 +338,7 @@ def perturbation( """ @abstractmethod - def factors(self, x: TensorType, y: TensorType) -> InverseHvpResult[TensorType]: + def factors(self, x: TensorType, y: TensorType) -> TensorType: r""" Overwrite this method to implement the approximation of $$ @@ -357,9 +351,7 @@ def factors(self, x: TensorType, y: TensorType) -> InverseHvpResult[TensorType]: y: label tensor to compute gradients Returns: - Container object of type [InverseHvpResult][pydvl.influence.twice_differentiable.InverseHvpResult] with a - tensor representing the element-wise inverse Hessian matrix vector products for the provided batch and - an optional info structure about the inversion process. + Tensor representing the element-wise inverse Hessian matrix vector products """ @@ -371,7 +363,7 @@ def values( x: Optional[TensorType] = None, y: Optional[TensorType] = None, influence_type: InfluenceType = InfluenceType.Up, - ) -> InverseHvpResult[TensorType]: + ) -> TensorType: r""" Overwrite this method to implement the approximation of $$ @@ -392,8 +384,6 @@ def values( influence_type: enum value of [InfluenceType][pydvl.influence.twice_differentiable.InfluenceType] Returns: - Container object of type [InverseHvpResult][pydvl.influence.twice_differentiable.InverseHvpResult] with a - tensor representing the element-wise scalar products for the provided batch and - an optional info structure about the inversion process. + Tensor representing the element-wise scalar products for the provided batch """ diff --git a/tests/influence/dask/test_influence_model.py b/tests/influence/dask/test_influence_model.py index 647a5ee77..ede8be9c2 100644 --- a/tests/influence/dask/test_influence_model.py +++ b/tests/influence/dask/test_influence_model.py @@ -70,8 +70,8 @@ def test_dask_influence_factors(influence_model): influence_model, lambda t: t.numpy(), lambda t: torch.from_numpy(t) ) dask_fac = dask_inf.factors(da_x, da_y) - dask_fac = dask_fac.x.compute(scheduler="processes") - torch_fac = influence_model.factors(t_x, t_y).x.numpy() + dask_fac = dask_fac.compute(scheduler="processes") + torch_fac = influence_model.factors(t_x, t_y).numpy() assert np.allclose(dask_fac, torch_fac, atol=1e-4) @@ -86,8 +86,8 @@ def test_dask_influence_values(influence_model, influence_type): dask_fac = dask_inf.values( da_x_test, da_y_test, da_x, da_y, influence_type=influence_type ) - dask_fac = dask_fac.x.compute(scheduler="processes") + dask_fac = dask_fac.compute(scheduler="processes") torch_fac = influence_model.values( t_x_test, t_y_test, t_x, t_y, influence_type=influence_type - ).x.numpy() + ).numpy() assert np.allclose(dask_fac, torch_fac, atol=1e-4) diff --git a/tests/influence/test_influences.py b/tests/influence/test_influences.py index fa5734066..6e728e19e 100644 --- a/tests/influence/test_influences.py +++ b/tests/influence/test_influences.py @@ -179,7 +179,7 @@ def model_and_data( @fixture def direct_influence(model_and_data, test_case: TestCase): model, train_dataloader, test_dataloader = model_and_data - direct_influence, _ = compute_influences( + direct_influence = compute_influences( model, training_data=train_dataloader, test_data=test_dataloader, @@ -249,7 +249,7 @@ def test_influence_linear_model( batch_size=40, ) - influence_values, _ = compute_influences( + influence_values = compute_influences( TorchTwiceDifferentiable(linear_layer, loss), training_data=train_data_loader, test_data=test_data_loader, @@ -295,7 +295,7 @@ def test_influences_nn( ): model, train_dataloader, test_dataloader = model_and_data - approx_influences, _ = compute_influences( + approx_influences = compute_influences( model, training_data=train_dataloader, test_data=test_dataloader, @@ -355,7 +355,7 @@ def test_influences_arnoldi( num_parameters = sum(p.numel() for p in model.model.parameters() if p.requires_grad) - low_rank_influence, _ = compute_influences( + low_rank_influence = compute_influences( model, training_data=train_dataloader, test_data=test_dataloader, @@ -379,7 +379,7 @@ def test_influences_arnoldi( rank_estimate=num_parameters - 1, ) - precomputed_low_rank_influence, _ = compute_influences( + precomputed_low_rank_influence = compute_influences( model, training_data=train_dataloader, test_data=test_dataloader,