Skip to content

Commit

Permalink
[checkpoint] Clean up selective activation checkpoint and make public (
Browse files Browse the repository at this point in the history
…pytorch#125795)

Related doc: https://docs.google.com/document/d/1BKyizkZPdri9mHqdDOLAUpkI7SbbKfLHRFVVpK9ZWqo/edit

Memory considerations:
- As with the existing SAC, cached values are cleared upon first use.
- We error if the user wishes to backward a second time on a region forwarded with SAC enabled.

In-place:
- We use version counting to enforce that if any cached tensor has been mutated. In-place operations not mutating cached tensors are allowed.
- `allow_cache_entry_mutation=True` can be passed to disable this check (useful in the case of auto AC where the user is cleverly also saves the output of the in-place)

Randomness, views
- Currently in this PR, we don't do anything special for randomness or views, the author of the policy function is expected to handle them properly. (Would it would be beneficial to error? - we either want to save all or recompute all random tensors)

Tensor object preservation
- We guarantee that if a tensor does not requires grad, and it is saved, then what you get out is the same tensor object. If the tensor does require grad, we must detach to avoid creating a reference cycle. This is a nice guarantee for nested tensors which care about the object identity of of the offsets tensor.

Policy function
- Enum values are `{MUST,PREFER}_{SAVE,RECOMPUTE}` (bikeshed welcome). Alternatively there was `{SAVE,RECOMPUTE}_{NON_,}OVERRIDABLE`. The former was preferred bc it seemed clearer that two `MUST` clashing should error, versus it is ambiguous whether two `NON_OVERRIDABLE` being stacked should silently ignore or error.
- The usage of Enum today. There actually is NO API to stack SAC policies today. The only thing the Enum should matter for in the near term is the compiler. The stacking SAC policy would be useful if someone wants to implement something like simple FSDP, but it is not perfect because with a policy of `PREFER_SAVE` you are actually saving more than autograd would save normally (would be fixed with AC v3).
- The number of times we call the policy_fn is something documented part of public API. We call the policy function for all ops except detach because detach is itself called a different number of times by AC between forward and recompute.
- The policy function can be a stateful object (we do NOT make separate copies of this object for forward/recompute, the user is expected to handle that via is_recompute see below).
Tensors guaranteed to be the same tensor as-is
- Policy function signature takes ctx object as its first argument. The ctx function is an object encapsulating info that may be useful to the user, it currently only holds "is_recompute". Adding this indirection gives us flexibility to add more attrs later if necessary.

"bc-breaking" for existing users of the private API:
- Existing policy functions must now change their return value to use the Enum.
- Existing calls to `_pt2_selective_checkpoint_context_fn_gen` must be renamed to `gen_selective_checkpoint_context_fn`. The way you use the API remains the same. It would've been nice to do something different (not make the user have to use functools.partial?), but this was the easiest to compile (idk if this should actually be a constraint).

Pull Request resolved: pytorch#125795
Approved by: https://github.com/Chillee, https://github.com/fmassa
  • Loading branch information
soulitzer authored and pytorchmergebot committed Jun 12, 2024
1 parent 25b7537 commit c472cec
Show file tree
Hide file tree
Showing 5 changed files with 510 additions and 127 deletions.
3 changes: 3 additions & 0 deletions docs/source/checkpoint.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ torch.utils.checkpoint
.. autofunction:: checkpoint
.. autofunction:: checkpoint_sequential
.. autofunction:: set_checkpoint_debug_enabled
.. autoclass:: CheckpointPolicy
.. autoclass:: SelectiveCheckpointContext
.. autofunction:: create_selective_checkpoint_contexts
27 changes: 17 additions & 10 deletions test/dynamo/test_activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint
from torch.utils.checkpoint import (
checkpoint,
CheckpointPolicy,
create_selective_checkpoint_contexts,
)

requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_distributed = functools.partial(
Expand Down Expand Up @@ -105,8 +109,11 @@ def op_count(gm):


def _get_custom_policy(no_recompute_list=None):
def _custom_policy(mode, func, *args, **kwargs):
return func in no_recompute_list
def _custom_policy(ctx, func, *args, **kwargs):
if func in no_recompute_list:
return CheckpointPolicy.MUST_SAVE
else:
return CheckpointPolicy.PREFER_RECOMPUTE

return _custom_policy

Expand Down Expand Up @@ -530,7 +537,7 @@ def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)

Expand Down Expand Up @@ -580,7 +587,7 @@ def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.mm.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)

Expand Down Expand Up @@ -650,7 +657,7 @@ def _custom_policy(mode, func, *args, **kwargs):

def selective_checkpointing_context_fn():
meta = {}
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))
return create_selective_checkpoint_contexts(_get_custom_policy(meta))

def gn(x, y):
return torch.sigmoid(
Expand Down Expand Up @@ -698,7 +705,7 @@ def fn(x, y):
)
def test_compile_selective_checkpoint_partial_ctx_fn(self):
def selective_checkpointing_context_fn(no_recompute_list):
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)

Expand Down Expand Up @@ -751,7 +758,7 @@ def selective_checkpointing_context_fn():
torch.ops.aten.mm.default,
torch.ops.aten.sigmoid.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list),
)

Expand Down Expand Up @@ -803,7 +810,7 @@ def selective_checkpointing_context_fn():
torch.ops.aten.mm.default,
torch.ops.aten.sigmoid.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)

Expand Down Expand Up @@ -854,7 +861,7 @@ def selective_checkpointing_context_fn():
no_recompute_list = [
torch.ops.aten.sigmoid.default,
]
return _pt2_selective_checkpoint_context_fn_gen(
return create_selective_checkpoint_contexts(
_get_custom_policy(no_recompute_list=no_recompute_list)
)

Expand Down
Loading

0 comments on commit c472cec

Please sign in to comment.