Skip to content

Commit

Permalink
Add safe_globals to resume training on PyTorch 2.6
Browse files Browse the repository at this point in the history
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionall call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Nov 19, 2024
1 parent befbbf2 commit 4273a30
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,23 @@ def _get_fsdp_ckpt_kwargs():
return {}


def safe_globals():
# Starting from version 2.4 PyTorch introduces a stricter check for the objects which
# can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
# requires allowlisting of such objects.
# See: https://github.com/pytorch/pytorch/pull/137602
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
if version.parse(torch.__version__) < version.parse("2.4.0"):
return contextlib.nullcontext()

allowlist = [np.core.multiarray._reconstruct, np.ndarray, np.dtype]
# numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for
# all versions of numpy
allowlist += [type(np.dtype(np.uint32))]

return torch.serialization.safe_globals(allowlist)


if TYPE_CHECKING:
import optuna

Expand Down Expand Up @@ -3052,7 +3069,8 @@ def _load_rng_state(self, checkpoint):
)
return

checkpoint_rng_state = torch.load(rng_file)
with safe_globals():
checkpoint_rng_state = torch.load(rng_file)
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
Expand Down

0 comments on commit 4273a30

Please sign in to comment.