From 4273a304250466d609da60447f2452dcc464fb74 Mon Sep 17 00:00:00 2001 From: Dmitry Rogozhkin Date: Wed, 6 Nov 2024 11:28:34 -0800 Subject: [PATCH] Add safe_globals to resume training on PyTorch 2.6 Starting from version 2.4 PyTorch introduces a stricter check for the objects which can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True requires allowlisting of such objects. This commit adds allowlist of some numpy objects used to load model checkpoints. Usage is restricted by context manager. User can still additionall call torch.serialization.add_safe_globals() to add other objects into the safe globals list. Fixes: #34631 See: https://github.com/pytorch/pytorch/pull/137602 See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals Signed-off-by: Dmitry Rogozhkin --- src/transformers/trainer.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 129398e374be73..78d279c6123cc5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -272,6 +272,23 @@ def _get_fsdp_ckpt_kwargs(): return {} +def safe_globals(): + # Starting from version 2.4 PyTorch introduces a stricter check for the objects which + # can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True + # requires allowlisting of such objects. + # See: https://github.com/pytorch/pytorch/pull/137602 + # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals + if version.parse(torch.__version__) < version.parse("2.4.0"): + return contextlib.nullcontext() + + allowlist = [np.core.multiarray._reconstruct, np.ndarray, np.dtype] + # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for + # all versions of numpy + allowlist += [type(np.dtype(np.uint32))] + + return torch.serialization.safe_globals(allowlist) + + if TYPE_CHECKING: import optuna @@ -3052,7 +3069,8 @@ def _load_rng_state(self, checkpoint): ) return - checkpoint_rng_state = torch.load(rng_file) + with safe_globals(): + checkpoint_rng_state = torch.load(rng_file) random.setstate(checkpoint_rng_state["python"]) np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"])