Skip to content

Commit

Permalink
Rename parameter use_block_cg -> solve_simultaneously to not get conf…
Browse files Browse the repository at this point in the history
…used

with block-diagonal approximation
  • Loading branch information
schroedk committed Jun 20, 2024
1 parent 66e4ffe commit 099930e
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 30 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
[PR #598](https://github.com/aai-institute/pyDVL/pull/598)
- Renaming of parameters of `CgInfluence`,
`hessian_regularization` -> `regularization` (modify type annotation),
`pre_conditioner` -> `preconditioner`
`pre_conditioner` -> `preconditioner`,
`use_block_cg` -> `solve_simultaneously`
[PR #601](https://github.com/aai-institute/pyDVL/pull/601)
- Remove parameter `x0` from `CgInfluence`
[PR #601](https://github.com/aai-institute/pyDVL/pull/601)
Expand Down
38 changes: 23 additions & 15 deletions docs/influence/influence_function_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,45 @@ gradient method, defined in [@ji_breakdownfree_2017], which solves several
right hand sides simultaneously.

Optionally, the user can provide a pre-conditioner to improve convergence, such
as a [Jacobi pre-conditioner
][pydvl.influence.torch.pre_conditioner.JacobiPreConditioner], which
as a [Jacobi preconditioner
][pydvl.influence.torch.preconditioner.JacobiPreconditioner], which
is a simple [diagonal pre-conditioner](
https://en.wikipedia.org/wiki/Preconditioner#Jacobi_(or_diagonal)_preconditioner)
based on Hutchinson's diagonal estimator [@bekas_estimator_2007],
or a [Nyström approximation based pre-conditioner
][pydvl.influence.torch.pre_conditioner.NystroemPreConditioner],
or a [Nyström approximation based preconditioner
][pydvl.influence.torch.preconditioner.NystroemPreconditioner],
described in [@frangella_randomized_2023].

```python
from pydvl.influence.torch import CgInfluence
from pydvl.influence.torch import CgInfluence, BlockMode, SecondOrderMode
from pydvl.influence.torch.preconditioner import NystroemPreconditioner

if_model = CgInfluence(
model,
loss,
hessian_regularization=0.0,
regularization=0.0,
rtol=1e-7,
atol=1e-7,
maxiter=None,
use_block_cg=True,
pre_conditioner=NystroemPreconditioner(rank=10)
solve_simultaneously=True,
preconditioner=NystroemPreconditioner(rank=10),
block_structure=BlockMode.FULL,
second_order_mode=SecondOrderMode.HESSIAN
)
if_model.fit(train_loader)
```

The additional optional parameters `rtol`, `atol`, `maxiter`, `use_block_cg` and
`pre_conditioner` are respectively, the relative
The additional optional parameters `rtol`, `atol`, `maxiter`,
`solve_simultaneously` and `preconditioner` are respectively, the relative
tolerance, the absolute tolerance, the maximum number of iterations,
a flag indicating whether to use block variant of cg and an optional
pre-conditioner.
a flag indicating whether to use a variant of cg to
simultaneously solving the system for several right hand sides and an optional
preconditioner.

This implementation is capable of using a block-diagonal
approximation, see
[Block-diagonal approximation](#block-diagonal-approximation), and can handle
[Gauss-Newton approximation](#gauss-newton-approximation).


### Linear time Stochastic Second-Order Approximation (LiSSA)
Expand All @@ -78,7 +86,7 @@ from pydvl.influence.torch import LissaInfluence, BlockMode, SecondOrderMode
if_model = LissaInfluence(
model,
loss,
regularization=0.0
regularization=0.0,
maxiter=1000,
dampen=0.0,
scale=10.0,
Expand Down Expand Up @@ -114,7 +122,7 @@ the Hessian and \(V\) contains the corresponding eigenvectors. See also
[@schioppa_scaling_2022].

```python
from pydvl.influence.torch import ArnoldiInfluence
from pydvl.influence.torch import ArnoldiInfluence, BlockMode, SecondOrderMode
if_model = ArnoldiInfluence(
model,
loss,
Expand Down Expand Up @@ -207,7 +215,7 @@ see also [@hataya_nystrom_2023] and [@frangella_randomized_2023]. The essential
parameter is the rank of the approximation.

```python
from pydvl.influence.torch import NystroemSketchInfluence
from pydvl.influence.torch import NystroemSketchInfluence, BlockMode, SecondOrderMode
if_model = NystroemSketchInfluence(
model,
loss,
Expand Down
2 changes: 1 addition & 1 deletion notebooks/influence_wine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@
" F.cross_entropy,\n",
" regularization=0.1,\n",
" progress=True,\n",
" use_block_cg=True,\n",
" solve_simultaneously=True,\n",
" preconditioner=NystroemPreconditioner(rank=5),\n",
")\n",
"cg_influence_model = cg_influence_model.fit(training_data_loader)\n",
Expand Down
16 changes: 6 additions & 10 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,14 +458,10 @@ class CgInfluence(TorchComposableInfluence[CgOperator]):
maxiter: Maximum number of iterations. If None, defaults to 10*len(b).
progress: If True, display progress bars for computing in the non-block mode
(use_block_cg=False).
precompute_grad: If True, the full data gradient is precomputed and kept
in memory, which can speed up the hessian vector product computation.
Set this to False, if you can't afford to keep the full computation graph
in memory.
pre_conditioner: Optional pre-conditioner to improve convergence of conjugate
preconditioner: Optional preconditioner to improve convergence of conjugate
gradient method
use_block_cg: If True, use block variant of conjugate gradient method, which
solves several right hand sides simultaneously
solve_simultaneously: If True, use a variant of conjugate gradient method to
simultaneously solve for several right hand sides.
warn_on_max_iteration: If True, logs a warning, if the desired tolerance is not
achieved within `maxiter` iterations. If False, the log level for this
information is `logging.DEBUG`
Expand All @@ -485,15 +481,15 @@ def __init__(
progress: bool = False,
precompute_grad: bool = False,
preconditioner: Optional[Preconditioner] = None,
use_block_cg: bool = False,
solve_simultaneously: bool = False,
warn_on_max_iteration: bool = True,
block_structure: Union[BlockMode, OrderedDict[str, List[str]]] = BlockMode.FULL,
second_order_mode: SecondOrderMode = SecondOrderMode.HESSIAN,
):
super().__init__(model, block_structure, regularization)
self.loss = loss
self.warn_on_max_iteration = warn_on_max_iteration
self.use_block_cg = use_block_cg
self.solve_simultaneously = solve_simultaneously
self.preconditioner = preconditioner
self.precompute_grad = precompute_grad
self.progress = progress
Expand Down Expand Up @@ -547,7 +543,7 @@ def _create_block(
maxiter=self.maxiter,
progress=self.progress,
preconditioner=preconditioner,
use_block_cg=self.use_block_cg,
use_block_cg=self.solve_simultaneously,
warn_on_max_iteration=self.warn_on_max_iteration,
)
gp = TorchGradientProvider(self.model, self.loss, restrict_to=block_params)
Expand Down
2 changes: 1 addition & 1 deletion tests/influence/test_influence_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def influence_model(model_and_data, test_case, influence_factory):
model,
loss,
hessian_reg,
use_block_cg=True,
solve_simultaneously=True,
).fit(train_dataLoader),
lambda model, loss, train_dataLoader, hessian_reg: DirectInfluence(
model, loss, hessian_reg
Expand Down
4 changes: 2 additions & 2 deletions tests/influence/torch/test_influence_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def direct_influences_from_factors(
loss,
regularization=hessian_reg,
preconditioner=NystroemPreconditioner(10),
use_block_cg=True,
solve_simultaneously=True,
).fit(train_dataLoader),
1e-4,
],
Expand Down Expand Up @@ -776,7 +776,7 @@ def test_influences_cg(
test_case.hessian_reg,
maxiter=5,
preconditioner=preconditioner,
use_block_cg=use_block_cg,
solve_simultaneously=use_block_cg,
)
influence_model = influence_model.fit(train_dataloader)

Expand Down

0 comments on commit 099930e

Please sign in to comment.