Skip to content

Commit

Permalink
Add safe_globals to resume training on PyTorch 2.6
Browse files Browse the repository at this point in the history
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

Fixes: #34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Nov 6, 2024
1 parent 7bbc624 commit 2cee855
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,14 @@ def _get_fsdp_ckpt_kwargs():
else:
return {}

def add_safe_globals():
# Starting from version 2.4 PyTorch introduces a stricter check for the objects which
# can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
# requires allowlisting of such objects.
# See: https://github.com/pytorch/pytorch/pull/137602
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
if version.parse(torch.__version__) >= version.parse("2.4.0"):
torch.serialization.add_safe_globals([np.core.multiarray._reconstruct, np.ndarray, np.dtype, np.dtypes.UInt32DType])

if TYPE_CHECKING:
import optuna
Expand Down Expand Up @@ -2109,6 +2117,7 @@ def train(
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

if resume_from_checkpoint is not None:
add_safe_globals()
if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint)
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
Expand Down

0 comments on commit 2cee855

Please sign in to comment.