Skip to content

Commit

Permalink
Add safe_globals to resume training on PyTorch 2.6
Browse files Browse the repository at this point in the history
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: #34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh committed Nov 21, 2024
1 parent befbbf2 commit dbb3112
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,25 @@ def _get_fsdp_ckpt_kwargs():
return {}


def safe_globals():
# Starting from version 2.4 PyTorch introduces a stricter check for the objects which
# can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
# requires allowlisting of such objects.
# See: https://github.com/pytorch/pytorch/pull/137602
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
# See: https://github.com/huggingface/accelerate/pull/3036
if version.parse(torch.__version__) < version.parse("2.4.0"):
return contextlib.nullcontext()

np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core
allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype]
# numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for
# all versions of numpy
allowlist += [type(np.dtype(np.uint32))]

return torch.serialization.safe_globals(allowlist)


if TYPE_CHECKING:
import optuna

Expand Down Expand Up @@ -3052,7 +3071,8 @@ def _load_rng_state(self, checkpoint):
)
return

checkpoint_rng_state = torch.load(rng_file)
with safe_globals():
checkpoint_rng_state = torch.load(rng_file)
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
Expand Down

0 comments on commit dbb3112

Please sign in to comment.