Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

safe_globals are needed to resume training on upcoming PyTorch 2.6 #34631

Closed
dvrogozh opened this issue Nov 6, 2024 · 4 comments · Fixed by #34632
Closed

safe_globals are needed to resume training on upcoming PyTorch 2.6 #34631

dvrogozh opened this issue Nov 6, 2024 · 4 comments · Fixed by #34632
Labels
dependencies Pull requests that update a dependency file PyTorch Anything PyTorch

Comments

@dvrogozh
Copy link
Contributor

dvrogozh commented Nov 6, 2024

With:

PyTorch 2.6 flips default on handling torch.load(wights_only=True) (done via pytorch/pytorch#137602). With this change, some tests in Huggingface Transformers start to fail. I did not test everything, but at least these are affected:

  • tests/trainer/test_trainer.py::TrainerIntegrationTest::test_auto_batch_size_with_resume_from_checkpoint
  • tests/trainer/test_trainer.py::TrainerIntegrationTest::test_can_resume_training
  • tests/trainer/test_trainer.py::TrainerIntegrationTest::test_compare_trainer_and_checkpoint_args_logging
  • tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_training_with_frozen_params
  • tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_training_with_gradient_accumulation
  • tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_training_with_safe_checkpoint
  • tests/trainer/test_trainer.py::TrainerIntegrationTest::test_resume_training_with_shard_checkpoint

What's the way to handle this case with Huggingface Transformers? Should Transformers retain internal allowed list of safe globals? And/or Transformers API should be extended to allow external safe globals specification? Or this is end user responisbility and such list should be retained on higher level scripts side?

See the log for one of the tests below. Can be reproduced on 1 card system with NVidia A10 or Intel PVC:

$ python3 -m pytest --pspec tests/trainer/test_trainer.py::TrainerIntegrationTest::test_can_resume_training
=========================================================== test session starts ===========================================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0
rootdir: /home/dvrogozh/git/huggingface/transformers
configfile: pyproject.toml
plugins: hypothesis-6.111.1, subtests-0.13.1, rich-0.1.1, dash-2.17.1, xdist-3.6.1, pspec-0.0.4, timeout-2.3.1
collected 1 item

tests/trainer/test_trainer.py
Trainer Integration Test
 ✗ can resume training
                                                                                                                                    [100%]

================================================================ FAILURES =================================================================
_____________________________________________ TrainerIntegrationTest.test_can_resume_training _____________________________________________

self = <tests.trainer.test_trainer.TrainerIntegrationTest testMethod=test_can_resume_training>

    @require_torch_up_to_2_accelerators
    def test_can_resume_training(self):
        # This test will fail for more than 2 GPUs since the batch size will get bigger and with the number of
        # save_steps, the checkpoint will resume training at epoch 2 or more (so the data seen by the model
        # won't be the same since the training dataloader is shuffled).

        with tempfile.TemporaryDirectory() as tmpdir:
            kwargs = {
                "output_dir": tmpdir,
                "train_len": 128,
                "save_steps": 5,
                "learning_rate": 0.1,
                "logging_steps": 5,
            }
            trainer = get_regression_trainer(**kwargs)
            trainer.train()
            (a, b) = trainer.model.a.item(), trainer.model.b.item()
            state = dataclasses.asdict(trainer.state)

            checkpoint = os.path.join(tmpdir, "checkpoint-5")

            # Reinitialize trainer
            trainer = get_regression_trainer(**kwargs)

>           trainer.train(resume_from_checkpoint=checkpoint)

tests/trainer/test_trainer.py:2610:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
src/transformers/trainer.py:2141: in train
    return inner_training_loop(
src/transformers/trainer.py:2470: in _inner_training_loop
    self._load_rng_state(resume_from_checkpoint)
src/transformers/trainer.py:3051: in _load_rng_state
    checkpoint_rng_state = torch.load(rng_file)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

f = '/tmp/tmpfuilg_pf/checkpoint-5/rng_state.pth', map_location = None, pickle_module = None, weights_only = True, mmap = False
pickle_load_args = {'encoding': 'utf-8'}, _get_wo_message = <function load.<locals>._get_wo_message at 0x7f2190293640>, skip_data = False
weights_only_not_set = True, true_values = ['1', 'y', 'yes', 'true'], force_weights_only_load = False

...

                        except pickle.UnpicklingError as e:
>                           raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
E                           _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
E                               (1) Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
E                               (2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
E                               WeightsUnpickler error: Unsupported global: GLOBAL numpy.core.multiarray._reconstruct was not an allowed global by default. Please use `torch.serialization.add_safe_globals([_reconstruct])` or the `torch.serialization.safe_globals([_reconstruct])` context manager to allowlist this global if you trust this class/function.
E
E                           Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

../../pytorch/pytorch/torch/serialization.py:1444: UnpicklingError
---------------------------------------------------------- Captured stdout call -----------------------------------------------------------
{'loss': 8.9328, 'grad_norm': 6.47317361831665, 'learning_rate': 0.08958333333333335, 'epoch': 0.31}
{'loss': 8.4215, 'grad_norm': 4.27539587020874, 'learning_rate': 0.07916666666666666, 'epoch': 0.62}
{'loss': 4.469, 'grad_norm': 2.845299005508423, 'learning_rate': 0.06875, 'epoch': 0.94}
{'loss': 2.398, 'grad_norm': 3.027919292449951, 'learning_rate': 0.05833333333333334, 'epoch': 1.25}
{'loss': 2.1414, 'grad_norm': 3.2882485389709473, 'learning_rate': 0.04791666666666667, 'epoch': 1.56}
{'loss': 1.237, 'grad_norm': 1.5751118659973145, 'learning_rate': 0.037500000000000006, 'epoch': 1.88}
{'loss': 0.7723, 'grad_norm': 1.4625849723815918, 'learning_rate': 0.027083333333333334, 'epoch': 2.19}
{'loss': 0.5407, 'grad_norm': 1.1405667066574097, 'learning_rate': 0.016666666666666666, 'epoch': 2.5}
{'loss': 0.3799, 'grad_norm': 1.1864064931869507, 'learning_rate': 0.00625, 'epoch': 2.81}
{'train_runtime': 0.497, 'train_samples_per_second': 772.666, 'train_steps_per_second': 96.583, 'train_loss': 3.0737542832891145, 'epoch': 3.0}
---------------------------------------------------------- Captured stderr call -----------------------------------------------------------
  0%|          | 0/48 [00:00<?, ?it/s]Could not estimate the number of tokens of the input, floating-point operations will not be computed
100%|██████████| 48/48 [00:00<00:00, 96.67it/s]
  0%|          | 0/48 [00:00<?, ?it/s]
========================================================= short test summary info =========================================================
FAILED tests/trainer/test_trainer.py::TrainerIntegrationTest::test_can_resume_training - _pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint.
============================================================ 1 failed in 3.32s ============================================================

CC: @muellerzr @SunMarc

@dvrogozh dvrogozh changed the title safe_globals are needed to resume from training on upcoming PyTorch 2.6 safe_globals are needed to resume training on upcoming PyTorch 2.6 Nov 6, 2024
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 6, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
Signed-off-by: Dmitry Rogozhkin <[email protected]>
@dvrogozh
Copy link
Contributor Author

dvrogozh commented Nov 6, 2024

See #34632 for potential fix.

dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 6, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
Signed-off-by: Dmitry Rogozhkin <[email protected]>
@dvrogozh
Copy link
Contributor Author

dvrogozh commented Nov 8, 2024

@muellerzr @SunMarc : can you, please, take a look on the issue and PR #34632?

@LysandreJik
Copy link
Member

cc @ydshieh, could you take a look as well please? Thanks!

@LysandreJik LysandreJik added PyTorch Anything PyTorch dependencies Pull requests that update a dependency file labels Nov 15, 2024
@ydshieh
Copy link
Collaborator

ydshieh commented Nov 15, 2024

Thank you @dvrogozh for opening this issue! Confirmed the issue is reproducible with torch 2.6 (nightly). Will check PR #34632

dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 16, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionall call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 19, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionall call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 20, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionall call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 21, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionall call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 21, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 22, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this issue Nov 22, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dependencies Pull requests that update a dependency file PyTorch Anything PyTorch
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants