Skip to content

Commit

Permalink
Refactor Influence class:
Browse files Browse the repository at this point in the history
* return tensors instead of InverseHvpResult
* remove prepare_for_distributed from interface
  • Loading branch information
schroedk committed Nov 23, 2023
1 parent 9b60a38 commit f546192
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 118 deletions.
17 changes: 6 additions & 11 deletions src/pydvl/influence/dask/influence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
self.from_numpy = from_numpy
self.to_numpy = to_numpy
self._num_parameters = influence_model.num_parameters
self.influence_model = influence_model.prepare_for_distributed()
self.influence_model = influence_model
client = self._get_client()
if client is not None:
self.influence_model = client.scatter(influence_model, broadcast=True)
Expand Down Expand Up @@ -74,7 +74,7 @@ def _get_chunk_indices(

return tuple(indices)

def factors(self, x: da.Array, y: da.Array) -> InverseHvpResult[da.Array]:
def factors(self, x: da.Array, y: da.Array) -> da.Array:
"""
Compute the expression
$$
Expand All @@ -96,13 +96,9 @@ def factors(self, x: da.Array, y: da.Array) -> InverseHvpResult[da.Array]:
self._validate_aligned_chunking(x, y)
self._validate_un_chunked(x)
self._validate_un_chunked(y)
return InverseHvpResult(self._factors_without_info(x, y), {})

def _factors_without_info(self, x: da.Array, y: da.Array):
def func(x_numpy: NDArray, y_numpy: NDArray, model: Influence):
factors, _ = model.factors(
self.from_numpy(x_numpy), self.from_numpy(y_numpy)
)
factors = model.factors(self.from_numpy(x_numpy), self.from_numpy(y_numpy))
return self.to_numpy(factors)

def block_func(x_block: da.Array, y_block: NDArray):
Expand Down Expand Up @@ -229,7 +225,7 @@ def values(
x: Optional[da.Array] = None,
y: Optional[da.Array] = None,
influence_type: InfluenceType = InfluenceType.Up,
) -> InverseHvpResult:
) -> da.Array:
"""
Compute approximation of
$$
Expand Down Expand Up @@ -277,7 +273,7 @@ def func(
y_numpy: NDArray,
model: Influence,
):
values, _ = model.values(
values = model.values(
self.from_numpy(x_test_numpy),
self.from_numpy(y_test_numpy),
self.from_numpy(x_numpy),
Expand All @@ -287,7 +283,7 @@ def func(
return self.to_numpy(values)

resulting_shape = "ij" if influence_type is InfluenceType.Up else "ijk"
result = da.blockwise(
return da.blockwise(
func,
resulting_shape,
x_test,
Expand All @@ -303,7 +299,6 @@ def func(
dtype=x.dtype,
align_arrays=True,
)
return InverseHvpResult(result, {})

@staticmethod
def _get_client() -> Optional[distributed.Client]:
Expand Down
16 changes: 4 additions & 12 deletions src/pydvl/influence/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,7 @@ def factors_gen() -> Generator[TensorType, None, None]:
):
yield influence.factors(x_test, y_test)

info_dict = {}
tensor_list = []
for k, factors in enumerate(factors_gen()):
info_dict[k] = factors.info
tensor_list.append(factors.x)

values = cat(tensor_list)

return InverseHvpResult(values, info_dict)
return cat(list(factors_gen()))


def compute_influences_up(
Expand Down Expand Up @@ -237,7 +229,7 @@ def compute_influences(
hessian_regularization: float = 0.0,
progress: bool = False,
**kwargs: Any,
) -> InverseHvpResult: # type: ignore # ToDO fix typing
) -> TensorType: # type: ignore # ToDO fix typing
r"""
Calculates the influence of each input data point on the specified test points.
Expand Down Expand Up @@ -299,12 +291,12 @@ def values_gen() -> Generator[TensorType, None, None]:
for x, y in maybe_progress(
input_data, progress, desc="Batch input influence values"
):
yield influence_function(factors.x, x, y)
yield influence_function(factors, x, y)

tensor_util: Type[TensorUtilities] = TensorUtilities.from_twice_differentiable(
differentiable_model
)
cat = tensor_util.cat
values = cat(list(values_gen()), dim=1)

return InverseHvpResult(values, factors.info)
return values
97 changes: 25 additions & 72 deletions src/pydvl/influence/torch/influence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def values(
x: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
influence_type: InfluenceType = InfluenceType.Up,
) -> InverseHvpResult:
) -> torch.Tensor:

if x is None:

Expand Down Expand Up @@ -116,19 +116,19 @@ def _non_symmetric_values(

if influence_type is InfluenceType.Up:
if x_test.shape[0] <= y.shape[0]:
factor, info = self.factors(x_test, y_test)
factor = self.factors(x_test, y_test)
values = self.up_weighting(factor, x, y)
else:
factor, info = self.factors(x, y)
factor = self.factors(x, y)
values = self.up_weighting(factor, x_test, y_test)
else:
factor, info = self.factors(x_test, y_test)
factor = self.factors(x_test, y_test)
values = self.perturbation(factor, x, y)
return InverseHvpResult(values, info)
return values

def _symmetric_values(
self, x: torch.Tensor, y: torch.Tensor, influence_type: InfluenceType
) -> InverseHvpResult[torch.Tensor]:
) -> torch.Tensor:

grad = self._loss_grad(x, y)
fac, info = self._solve_hvp(grad)
Expand All @@ -138,7 +138,7 @@ def _symmetric_values(
else:
values = self.perturbation(fac, x, y)

return InverseHvpResult(values, info)
return values

def up_weighting(
self,
Expand Down Expand Up @@ -167,14 +167,14 @@ def perturbation(
),
)

def factors(self, x: torch.Tensor, y: torch.Tensor) -> InverseHvpResult:
def factors(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

return self._solve_hvp(
self._loss_grad(x.to(self.model_device), y.to(self.model_device))
)

@abstractmethod
def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult[torch.Tensor]:
def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
pass


Expand All @@ -195,42 +195,27 @@ def __init__(
hessian_regularization: float,
hessian: torch.Tensor = None,
train_dataloader: DataLoader = None,
return_hessian_in_info: bool = False,
):
if hessian is None and train_dataloader is None:
raise ValueError(
f"Either provide a pre-computed hessian or a data_loader to compute the hessian"
)

super().__init__(model, loss)
self.return_hessian_in_info = return_hessian_in_info
self.hessian_perturbation = hessian_regularization
self.hessian = (
hessian
if hessian is not None
else get_hessian(model, loss, train_dataloader)
)

def prepare_for_distributed(self) -> "Influence":
if self.return_hessian_in_info:
self.return_hessian_in_info = False
logger.warning(
f"Modified parameter `return_hessian_in_info` to `False`, "
f"to prepare for distributed computing"
)
return self

def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult:
result = torch.linalg.solve(
def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
return torch.linalg.solve(
self.hessian.to(self.model_device)
+ self.hessian_perturbation
* torch.eye(self.num_parameters, device=self.model_device),
rhs.T.to(self.model_device),
).T
info = {}
if self.return_hessian_in_info:
info["hessian"] = self.hessian
return InverseHvpResult(result, info)

def to(self, device: torch.device):
self.hessian = self.hessian.to(device)
Expand Down Expand Up @@ -266,7 +251,7 @@ def __init__(
self.hessian_regularization = hessian_regularization
self.train_dataloader = train_dataloader

def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult:
def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
if len(self.train_dataloader) == 0:
raise ValueError("Training dataloader must not be empty.")

Expand All @@ -287,8 +272,7 @@ def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult:
maxiter=self.maxiter,
)
batch_cg[idx] = batch_result
info[f"batch_{idx}"] = batch_info
return InverseHvpResult(x=batch_cg, info=info)
return batch_cg

def to(self, device: torch.device):
self.model = self.model.to(device)
Expand Down Expand Up @@ -350,7 +334,7 @@ def __init__(
self.dampen = dampen
self.train_dataloader = train_dataloader

def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult:
def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:

h_estimate = self.h0 if self.h0 is not None else torch.clone(rhs)

Expand Down Expand Up @@ -401,11 +385,7 @@ def lissa_step(
f"Terminated Lissa with {max_residual*100:.2f} % max residual."
f" Mean residual: {mean_residual*100:.5f} %"
)
info = {
"max_perc_residual": max_residual * 100,
"mean_perc_residual": mean_residual * 100,
}
return InverseHvpResult(x=h_estimate / self.scale, info=info)
return h_estimate / self.scale


class ArnoldiInfluence(TorchInfluence):
Expand Down Expand Up @@ -457,7 +437,6 @@ def __init__(
tol: float = 1e-6,
max_iter: Optional[int] = None,
eigen_computation_on_gpu: bool = False,
return_low_rank_representation_in_info: bool = False,
):
if low_rank_representation is None and train_dataloader is None:
raise ValueError(
Expand All @@ -479,28 +458,16 @@ def __init__(

super().__init__(model, loss)
self.low_rank_representation = low_rank_representation.to(self.model_device)
self.return_low_rank_representation_in_info = (
return_low_rank_representation_in_info
)
self.hessian_regularization = hessian_regularization

def prepare_for_distributed(self) -> "Influence":
if self.return_low_rank_representation_in_info:
self.return_low_rank_representation_in_info = False
logger.warning(
f"Modified parameter `return_low_rank_representation_in_info` to `False`, "
f"to prepare for distributed computing"
)
return self

def _non_symmetric_values(
self,
x_test: torch.Tensor,
y_test: Optional[torch.Tensor] = None,
x: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None,
influence_type: InfluenceType = InfluenceType.Up,
) -> InverseHvpResult[torch.Tensor]:
) -> torch.Tensor:

if influence_type is InfluenceType.Up:
mjp = matrix_jacobian_product(
Expand All @@ -517,16 +484,14 @@ def _non_symmetric_values(
)
values = torch.einsum("ij, ik -> jk", left, right)
else:
factors, _ = self.factors(x_test, y_test)
factors = self.factors(x_test, y_test)
values = self.perturbation(factors, x, y)
info = {}
if self.return_low_rank_representation_in_info:
info["low_rank_representation"] = self.low_rank_representation
return InverseHvpResult(values, info)

return values

def _symmetric_values(
self, x: torch.Tensor, y: torch.Tensor, influence_type: InfluenceType
) -> InverseHvpResult[torch.Tensor]:
) -> torch.Tensor:

if influence_type is InfluenceType.Up:
left = matrix_jacobian_product(
Expand All @@ -537,14 +502,12 @@ def _symmetric_values(
)
values = torch.einsum("ij, ik -> jk", left, right)
else:
factors, _ = self.factors(x, y)
factors = self.factors(x, y)
values = self.perturbation(factors, x, y)
info = {}
if self.return_low_rank_representation_in_info:
info["low_rank_representation"] = self.low_rank_representation
return InverseHvpResult(values, info)

def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult:
return values

def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
rhs_device = rhs.device if hasattr(rhs, "device") else torch.device("cpu")
if rhs_device.type != self.low_rank_representation.device.type:
raise RuntimeError(
Expand All @@ -564,15 +527,7 @@ def _solve_hvp(self, rhs: torch.Tensor) -> InverseHvpResult:
@ (self.low_rank_representation.projections.t() @ rhs.t())
)

if self.return_low_rank_representation_in_info:
info = {
"eigenvalues": self.low_rank_representation.eigen_vals,
"eigenvectors": self.low_rank_representation.projections,
}
else:
info = {}

return InverseHvpResult(x=result.t(), info=info)
return result.t()

def to(self, device: torch.device):
return ArnoldiInfluence(
Expand All @@ -592,7 +547,6 @@ def direct_factory(
twice_differentiable.loss,
train_dataloader=data_loader,
hessian_regularization=hessian_regularization,
return_hessian_in_info=True,
**kwargs,
)

Expand Down Expand Up @@ -641,6 +595,5 @@ def arnoldi_factory(
twice_differentiable.loss,
train_dataloader=data_loader,
hessian_regularization=hessian_regularization,
return_low_rank_representation_in_info=True,
**kwargs,
)
Loading

0 comments on commit f546192

Please sign in to comment.