diff --git a/docs/source/checkpoint.rst b/docs/source/checkpoint.rst index f7bc160fa98bd2..8559d8bd73663c 100644 --- a/docs/source/checkpoint.rst +++ b/docs/source/checkpoint.rst @@ -35,3 +35,6 @@ torch.utils.checkpoint .. autofunction:: checkpoint .. autofunction:: checkpoint_sequential .. autofunction:: set_checkpoint_debug_enabled +.. autoclass:: CheckpointPolicy +.. autoclass:: SelectiveCheckpointContext +.. autofunction:: create_selective_checkpoint_contexts diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index 14851e51895b40..274e033028451a 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -19,7 +19,11 @@ from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.two_tensor import TwoTensor -from torch.utils.checkpoint import _pt2_selective_checkpoint_context_fn_gen, checkpoint +from torch.utils.checkpoint import ( + checkpoint, + CheckpointPolicy, + create_selective_checkpoint_contexts, +) requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") requires_distributed = functools.partial( @@ -105,8 +109,11 @@ def op_count(gm): def _get_custom_policy(no_recompute_list=None): - def _custom_policy(mode, func, *args, **kwargs): - return func in no_recompute_list + def _custom_policy(ctx, func, *args, **kwargs): + if func in no_recompute_list: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE return _custom_policy @@ -530,7 +537,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -580,7 +587,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.mm.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -650,7 +657,7 @@ def _custom_policy(mode, func, *args, **kwargs): def selective_checkpointing_context_fn(): meta = {} - return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta)) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) def gn(x, y): return torch.sigmoid( @@ -698,7 +705,7 @@ def fn(x, y): ) def test_compile_selective_checkpoint_partial_ctx_fn(self): def selective_checkpointing_context_fn(no_recompute_list): - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -751,7 +758,7 @@ def selective_checkpointing_context_fn(): torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list), ) @@ -803,7 +810,7 @@ def selective_checkpointing_context_fn(): torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) @@ -854,7 +861,7 @@ def selective_checkpointing_context_fn(): no_recompute_list = [ torch.ops.aten.sigmoid.default, ] - return _pt2_selective_checkpoint_context_fn_gen( + return create_selective_checkpoint_contexts( _get_custom_policy(no_recompute_list=no_recompute_list) ) diff --git a/test/test_autograd.py b/test/test_autograd.py index ce5b4234b8291b..812d05d6303f19 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2,6 +2,7 @@ import collections import contextlib +import functools import gc import io import math @@ -79,8 +80,14 @@ ) from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils.checkpoint import checkpoint, checkpoint_sequential +from torch.utils.checkpoint import ( + checkpoint, + checkpoint_sequential, + CheckpointPolicy, + create_selective_checkpoint_contexts, +) from torch.utils.cpp_extension import load_inline +from torch.utils.flop_counter import FlopCounterMode from torch.utils.hooks import RemovableHandle # noqa: TCH001 @@ -13187,6 +13194,297 @@ def fn2(x): self.assertEqual(counter[0], 1) +class TestSelectiveActivationCheckpoint(TestCase): + @unittest.skipIf(not TEST_CUDA, "requires CUDA") + def test_flops_and_mem(self): + # From https://github.com/pytorch/pytorch/pull/126320 + def get_act_mem(f): + out = f() + out.backward() + # Why do one forward and backward? + start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + out = f() + cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] + act_mem = (cur_mem - start_mem) / (1024 * 1024) + out.backward() + return act_mem + + def get_bw_flops(f): + # Normalized so that a 512 square matmul returns 1 + f().backward() + out = f() + # NB: FlopCounterMode is pushed onto the mode stack before CachedMode, so + # it will be able to observe whether an op is cached or not. + with FlopCounterMode(display=False) as mode: + out.backward() + return mode.get_total_flops() / (512**3 * 2) + + x = torch.randn(512, 512, requires_grad=True, device="cuda") + y = torch.randn(512, 512, requires_grad=True, device="cuda") + + def fn(x, y): + return torch.mm(x.cos(), y).sin().sum() + + def fn_ac(x, y): + return checkpoint(fn, x, y, use_reentrant=False) + + def fn_sac(x, y): + context_fn = functools.partial( + create_selective_checkpoint_contexts, + [ + torch.ops.aten.mm.default, + ], + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + return out + + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.mm.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn_sac2(x, y): + context_fn = functools.partial( + create_selective_checkpoint_contexts, + policy_fn, + ) + out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) + return out + + act_mem_noac = get_act_mem(lambda: fn(x, y)) + bw_flops_noac = get_bw_flops(lambda: fn(x, y)) + + self.assertEqual(act_mem_noac, 2.0) + self.assertEqual(bw_flops_noac, 2.0) + + act_mem_ac = get_act_mem(lambda: fn_ac(x, y)) + bw_flops_ac = get_bw_flops(lambda: fn_ac(x, y)) + + self.assertEqual(act_mem_ac, 0.0) + self.assertEqual(bw_flops_ac, 3.0) + + act_mem_sac = get_act_mem(lambda: fn_sac(x, y)) + bw_flops_sac = get_bw_flops(lambda: fn_sac(x, y)) + + self.assertEqual(act_mem_sac, 1.0) + self.assertEqual(bw_flops_sac, 2.0) + + act_mem_sac2 = get_act_mem(lambda: fn_sac2(x, y)) + bw_flops_sac2 = get_bw_flops(lambda: fn_sac2(x, y)) + + self.assertEqual(act_mem_sac2, 1.0) + self.assertEqual(bw_flops_sac2, 2.0) + + def test_bad_inputs(self): + bad_op_list1 = [2] + + with self.assertRaisesRegex( + ValueError, "Expected op in `op_list` to be an OpOverload" + ): + create_selective_checkpoint_contexts(bad_op_list1) + + bad_op_list2 = [torch.ops.aten.sin] + + with self.assertRaisesRegex( + ValueError, "update the OpOverloadPacket to a specific OpOverload" + ): + create_selective_checkpoint_contexts(bad_op_list2) + + with self.assertRaisesRegex(TypeError, "either a function or a list of ops."): + create_selective_checkpoint_contexts(2) + + # Dynamo fails for various reasons: + # - some tests using custom op that does not implement Fake + # - dynamo is trying to trace into saved variable hooks unpack hook for some reason + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_policy_with_state(self): + # If I have a stateful callable, state is shared between the original + # forward and the recompute. + counters = [] + + class Policy: + def __init__(self): + self.counter = [0] + self.recompute_counter = [0] + + def __call__(self, ctx, func, *args, **kwargs): + counter = self.recompute_counter if ctx.is_recompute else self.counter + counter[0] += 1 + counters.append(counter[0]) + if counter == 1 and func is torch.ops.aten.mm.default: + return CheckpointPolicy.MUST_SAVE + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().sin().sin() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, + Policy(), + allow_cache_entry_mutation=True, + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + out.sum().backward() + # 1. counter properly reset to 0 for the recompute + # 2. due to early-stop we do not recompute the final op + self.assertEqual(counters, [1, 2, 3, 1, 2]) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_storage_lifetime(self): + # The storage object saved by SAC survives as long as the graph is alive + # graph -> the saved variable hooks -> recompute_context -> storage + # However, we make sure to eagerly free any cached objects upon use. + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.sin.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + ref = None + + # This hook fires after unpack is triggered. + def hook(x): + self.assertIsNone(ref()) + + def fn(x): + nonlocal ref + # IMPORTANT: the tensor object saved is the same exact tensor + # object as the output only when the tensor does not require grad. + # Detach, so we can conveniently access the reference here. + sin_out = x.detach().sin() + ref = weakref.ref(sin_out) + + out = x.cos().exp() + out.register_hook(hook) + return out.cos() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + self.assertIsNotNone(ref()) + out.sum().backward() + self.assertIsNone(ref()) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_version_counter(self): + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.sin.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().mul_(2).cos().exp() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + + # 1) Error because the output of sin is saved and mutated by mul_ + with self.assertRaisesRegex(RuntimeError, "has been mutated"): + out.sum().backward() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, + policy_fn, + allow_cache_entry_mutation=True, + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + + # 2) No longer should be an error because of allow_cache_entry_mutation + out.sum().backward() + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_function_with_more_than_one_output(self): + # maybe there is a more systematic way: + counter = [0] + + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.var_mean.correction: + counter[0] += 1 + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + # var_mean has two outputs + def fn(x): + a, b = torch.var_mean(x) + return a * b + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + x_grad = torch.autograd.grad(out.sum(), (x,)) + x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) + self.assertEqual(x_grad, x_grad_ref) + self.assertEqual(counter[0], 2) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_function_with_non_tensor_output(self): + # When SAC is enabled, the op is not computed a second time + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + counter = [0] + + @torch.library.custom_op("mylib::sin_with_extra", mutates_args=()) + def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]: + counter[0] += 1 + return x.sin(), 2 + + def setup_context(ctx, inputs, output) -> torch.Tensor: + (x,) = inputs + ctx.save_for_backward(x) + + def backward(ctx, grad, _unused): + (x,) = ctx.saved_tensors + return grad * x.cos() + + torch.library.register_autograd( + "mylib::sin_with_extra", backward, setup_context=setup_context + ) + + x = torch.randn(3, requires_grad=True) + + def fn(x): + return (torch.ops.mylib.sin_with_extra(x)[0] * x.sin().exp()).sin() + + ops_list = [torch.ops.mylib.sin_with_extra.default] + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, ops_list + ) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + x_grad = torch.autograd.grad(out.sum(), (x,)) + self.assertEqual(counter[0], 1) + x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) + self.assertEqual(x_grad, x_grad_ref) + + @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") + def test_can_only_trigger_recompute_once(self): + # We don't support this to avoid adding extra complexity for now. + # If there's a need, we could probably do some kind of use_count tracking. + # TODO: have a nice error message here. + def policy_fn(ctx, op, *args, **kwargs): + if op == torch.ops.aten.sin.default: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + + def fn(x): + return x.sin().cos().exp() + + x = torch.randn(3, requires_grad=True) + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) + out.sum().backward(retain_graph=True) + + with self.assertRaisesRegex(RuntimeError, "Trying to backward an extra time"): + out.sum().backward(retain_graph=True) + + class TestAutogradMultipleDispatch(TestCase): def test_autograd_multiple_dispatch_registrations(self, device): t = torch.randn(3, 3, device=device, requires_grad=True) diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index 6d83a44e752a06..e7fe553387d1c8 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -1,15 +1,17 @@ # mypy: allow-untyped-defs import inspect +import itertools import logging import torch from torch._ops import HigherOrderOperator -from torch.utils.checkpoint import checkpoint, uid +from torch.utils.checkpoint import checkpoint + import torch._dynamo.config log = logging.getLogger(__name__) - +uid = itertools.count(1) # Used for testing the HigherOrderOperator mechanism class Wrap(HigherOrderOperator): diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 5cbfd1543cf423..0ab77478764867 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -5,18 +5,8 @@ import warnings import weakref from collections import defaultdict -from itertools import count -from typing import ( - Any, - Callable, - ContextManager, - DefaultDict, - Dict, - Iterable, - List, - Optional, - Tuple, -) +from typing import * # noqa: F403 +import enum from weakref import ReferenceType import torch @@ -39,6 +29,9 @@ "set_checkpoint_early_stop", "DefaultDeviceType", "set_checkpoint_debug_enabled", + "CheckpointPolicy", + "SelectiveCheckpointContext", + "create_selective_checkpoint_contexts", ] _DEFAULT_DETERMINISM_MODE = "default" @@ -1153,149 +1146,229 @@ def _is_compiling(func, args, kwargs): return False -def _detach(x): - if isinstance(x, torch.Tensor): - return x.detach() +class _VersionWrapper: + # Check that cached tensors are not mutated. + def __init__(self, val): + self.val: Union[torch.Tensor, Any] = val + self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None + + def get_val(self, allow_cache_entry_mutation): + if self.version is not None and not allow_cache_entry_mutation: + if self.val._version != self.version: + # Can we give user a stack trace of where the mutation happened? + raise RuntimeError( + "Tensor cached during selective activation checkpoint has been mutated" + ) + return self.val + + +def _maybe_detach(x): + if isinstance(x, torch.Tensor) and x.requires_grad: + # NB: Ensure the original tensor object is saved when x does not require grad + with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): + # Ensure that view performed beneath autograd properly propagates + # version counter. TODO: Use reentrant_dispatch instead of + # manually manipulating dispatch keys. Using reentrant_dispatch + # would respect inference_mode, though that is not relevant for + # this case. + x = x.detach() return x -uid = count(1) +class SelectiveCheckpointContext: + """ + Context passed to policy function during selective checkpointing. + This class is used to pass relevant metadata to the policy function during + selective checkpointing. The metadata includes whether the current invocation + of the policy function is during recomputation or not. -# NOTE: torch.utils.checkpoint internal logic will call these two functions unknown number of times -# (i.e. there could be _CachedTorchDispatchMode calls that doesn't map to a _CachingTorchDispatchMode call), -# so we ignore these ops and just always recompute them. -_ignored_ops = { - torch.ops.prim.device.default, - torch.ops.aten.detach.default, -} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) + Example: + >>> # xdoctest: +SKIP(stub) + >>> + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> print(ctx.is_recompute) + >>> + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) + """ + def __init__(self, *, is_recompute): + self.is_recompute = is_recompute -class _CachingTorchDispatchMode(TorchDispatchMode): - r""" - A :class:`TorchDispatchMode` to implement selective activation checkpointing - that's compatible with torch.compile. Used together with _CachedTorchDispatchMode. +class CheckpointPolicy(enum.Enum): """ + Enum for specifying the policy for checkpointing during backpropagation. + + The following policies are supported: + + - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward + pass and will not be recomputed during the backward pass + - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the + forward pass and will be recomputed during the backward pass + + Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden + by other subsystems like `torch.compile`. + + .. note:: + A policy function that always returns ``PREFER_RECOMPUTE`` is + equivalent to vanilla checkpointing. + + A policy function that returns ``PREFER_SAVE`` every op is + NOT equivalent to not using checkpointing. Using such a policy would + save additional tensors not limited to ones that are actually needed for + gradient computation. + """ + MUST_SAVE = 0 + PREFER_SAVE = 1 + MUST_RECOMPUTE = 2 + PREFER_RECOMPUTE = 3 + + +class _CachingTorchDispatchMode(TorchDispatchMode): + # Used together with _CachedTorchDispatchMode to implement SAC. def __init__(self, policy_fn, storage): self.policy_fn = policy_fn self.storage = storage - def push_into_storage(self, out, func, args, kwargs): - out_detached = tree_map(_detach, out) - self.storage[func].append(out_detached) + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if func is torch.ops.aten.detach.default: + return func(*args, **kwargs) + + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), + func, *args, **kwargs) + is_compiling = _is_compiling(func, args, kwargs) - def _handle_compile_in_forward_ctx(self, should_not_recompute, func, args, kwargs): - if should_not_recompute: + if is_compiling and policy == CheckpointPolicy.MUST_SAVE: fx_traceback.current_meta["recompute"] = 0 - # NOTE: Here we just store and reuse output of all ops, since in torch.compile mode - # we decide and handle recomputation in the partitioner. + out = func(*args, **kwargs) - self.push_into_storage(out, func, args, kwargs) - return out - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if func in _ignored_ops: - return func(*args, **kwargs) - should_not_recompute = self.policy_fn("forward", func, *args, **kwargs) - if _is_compiling(func, args, kwargs): - return self._handle_compile_in_forward_ctx(should_not_recompute, func, args, kwargs) - else: - if should_not_recompute: - out = func(*args, **kwargs) - self.push_into_storage(out, func, args, kwargs) - else: - out = func(*args, **kwargs) - return out + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x)), out)) + return out class _CachedTorchDispatchMode(TorchDispatchMode): - r""" - A :class:`TorchDispatchMode` to implement selective activation checkpointing - that's compatible with torch.compile. Used together with _CachingTorchDispatchMode. - """ - def __init__(self, policy_fn, storage): + # Used together with _CachedTorchDispatchMode to implement SAC. + def __init__(self, policy_fn, storage, allow_cache_entry_mutation): self.policy_fn = policy_fn self.storage = storage - - def pop_from_storage(self, func, args, kwargs): - assert func in self.storage - out = self.storage[func].pop(0) - return out - - def _handle_compile_in_recompute_ctx(self, should_not_recompute, func, args, kwargs): - out = self.pop_from_storage(func, args, kwargs) - return out + self.allow_cache_entry_mutation = allow_cache_entry_mutation def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - if func in _ignored_ops: + if func is torch.ops.aten.detach.default: return func(*args, **kwargs) - should_not_recompute = self.policy_fn("recompute", func, *args, **kwargs) - if _is_compiling(func, args, kwargs): - return self._handle_compile_in_recompute_ctx(should_not_recompute, func, args, kwargs) + + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), + func, *args, **kwargs) + is_compiling = _is_compiling(func, args, kwargs) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + storage = self.storage.get(func) + if storage is None: + raise RuntimeError(f"{func} encountered during backward, but not found in storage") + if len(storage) == 0: + raise RuntimeError( + "Trying to backward an extra time. You are only allowed to backward once " + "on any region computed under selective activation checkpoint." + ) + out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) else: - if should_not_recompute: - out = self.pop_from_storage(func, args, kwargs) - else: - out = func(*args, **kwargs) - return out + out = func(*args, **kwargs) + return out + -def _pt2_selective_checkpoint_context_fn_gen(policy_fn): +def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): """ - A helper function that generates a pair of contexts to be later passed into - `torch.utils.checkpoint` API to implment selective checkpointing. + Helper to avoid recomputing certain ops during activation checkpointing. - .. warning:: - This is context_fn is intended for use with torch.compile only. + Use this with `torch.utils.checkpoint.checkpoint` to control which + operations are recomputed during the backward pass. Args: - policy_fn (Callable[[Callable, List[Any], Dict[str, Any]], bool]): Policy function - to decide whether a particular op should be recomputed in backward pass or not. - In eager mode: - If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed. - If policy_fn(...) returns False, the op is guaranteed to be recomputed. - In torch.compile mode: - If policy_fn(...) returns True, the op is guaranteed to NOT be recomputed. - If policy_fn(...) returns False, the op may or may not be recomputed - (it's up to the partitioner to decide). - + policy_fn_or_list (Callable or List): + - If a policy function is provided, it should accept a + :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and + kwargs to the op, and return a :class:`CheckpointPolicy` enum value + indicating whether the execution of the op should be recomputed or not. + - If a list of operations is provided, it is equivalent to a policy + returning `CheckpointPolicy.MUST_SAVE` for the specified + operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other + operations. + allow_cache_entry_mutation (bool, optional): By default, an error is + raised if any tensors cached by selective activation checkpoint are + mutated in order to ensure correctness. If set to `True`, this check + is disabled. Returns: - A pair of generated contexts. + A tuple of two context managers. Example: >>> # xdoctest: +REQUIRES(LINUX) + >>> import functools >>> - >>> def get_custom_policy(): - >>> no_recompute_list = [ - >>> torch.ops.aten.mm.default, - >>> ] - >>> def custom_policy(mode, func, *args, **kwargs): - >>> return func in no_recompute_list - >>> return custom_policy + >>> x = torch.rand(10, 10, requires_grad=True) + >>> y = torch.rand(10, 10, requires_grad=True) >>> - >>> def selective_checkpointing_context_fn(): - >>> return _pt2_selective_checkpoint_context_fn_gen(get_custom_policy()) + >>> ops_to_save = [ + >>> torch.ops.aten.mm.default, + >>> ] >>> - >>> def gn(x, y): - >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> if op in ops_to_save: + >>> return CheckpointPolicy.MUST_SAVE + >>> else: + >>> return CheckpointPolicy.PREFER_RECOMPUTE >>> - >>> def fn(x, y): - >>> return torch.utils.checkpoint.checkpoint( - >>> gn, x, y, - >>> use_reentrant=False, - >>> context_fn=selective_checkpointing_context_fn, - >>> ) + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) >>> - >>> x = torch.randn(4, 4, requires_grad=True) - >>> y = torch.randn(4, 4, requires_grad=True) + >>> # or equivalently + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) + >>> + >>> def fn(x, y): + >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y >>> - >>> compiled_fn = torch.compile(fn) + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) """ - storage: Dict[Any, List[Any]] = defaultdict(list) - return _CachingTorchDispatchMode(policy_fn, storage), _CachedTorchDispatchMode(policy_fn, storage) + # NB: If grad_mode is disabled, checkpoint would not run forward under + # context_fn anyway, so proceed as usual. + if isinstance(policy_fn_or_list, list): + for op in policy_fn_or_list: + if not isinstance(op, torch._ops.OpOverload): + _extra_msg = ( + "Please update the OpOverloadPacket to a specific OpOverload." + "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." + ) if isinstance(op, torch._ops.OpOverloadPacket) else "" + raise ValueError( + f"Expected op in `op_list` to be an OpOverload but got: {op} " + f"of type {type(op)}. {_extra_msg}" + ) + def policy_fn(ctx, op, *args, **kwargs): + if op in policy_fn_or_list: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + elif callable(policy_fn_or_list): + policy_fn = policy_fn_or_list + else: + raise TypeError("policy_fn_or_list must be either a function or a list of ops.") + + storage: Dict[Any, List[Any]] = defaultdict(list) + return ( + _CachingTorchDispatchMode(policy_fn, storage), + _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), + ) # NB: this helper wraps fn before calling checkpoint_impl. kwargs and # saving/restoring of global state is handled here.