Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trigger weights_only=True by default for all compatible objects #3036

Merged
merged 12 commits into from
Oct 10, 2024

Conversation

muellerzr
Copy link
Collaborator

What does this PR do?

Sets weights_only=True by default when using torch.save() for all compatible objects:

  • Model weights
  • Optimizer states
  • DataLoader states
  • Scheduler states

The only one that we can't do right now are the random states, since the numpy random states will raise an exception if we add the right classes for safe loading:

Traceback (most recent call last):
  File "/home/zach/work/accelerate/test.py", line 24, in <module>
    _ = load(p / "r2.pkl", weights_only=True)
  File "/home/zach/work/accelerate/src/accelerate/utils/other.py", line 245, in load
    loaded_obj = torch.load(f, map_location=map_location, **kwargs)
  File "/home/zach/miniconda3/envs/accelerate/lib/python3.10/site-packages/torch/serialization.py", line 1096, in load
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
 Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Can only build Tensor, parameter or OrderedDict objects, but got <class 'numpy.dtypes.UInt32DType'>

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@BenjaminBossan

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr muellerzr requested a review from SunMarc August 23, 2024 15:05
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tackling this.

Honestly, I'm a bit surprised that by just allowing these few numpy types, everything can be loaded fine (except for the rng). I think it's super important to ensure that this will not lead to breakage.

src/accelerate/checkpointing.py Outdated Show resolved Hide resolved
src/accelerate/utils/other.py Outdated Show resolved Hide resolved
src/accelerate/utils/other.py Outdated Show resolved Hide resolved
src/accelerate/utils/imports.py Show resolved Hide resolved
src/accelerate/checkpointing.py Outdated Show resolved Hide resolved
src/accelerate/utils/other.py Show resolved Hide resolved
Comment on lines 218 to 227
TORCH_SAFE_GLOBALS = [
# numpy arrays are just numbers, not objects, so we can reconstruct them safely
np.core.multiarray._reconstruct,
np.ndarray,
]
Copy link

@mikaylagawarecki mikaylagawarecki Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe to load numpy arrays these are the GLOBALS needed (probably only a subset of the dtypes would be needed though) https://github.com/pytorch/pytorch/pull/124763/files#diff-0a602e09e0dd231bf8cf8813744ca89b6ebbbb92ef54d356379d85e51f0113f1R119-R168,

we made an explicit decision not to allowlist globals required to rebuild numpy arrays by default in torch however

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. A bunch of tests are failing. I added a suggestion that should fix some, but I'm not sure if it'll fix all.

Generally, if the tests indicate that this switch will work fine, I'm okay with merging. Still I'm wondering if there should not at least be a fail-safe for users to be able to switch to weights_only=False, as otherwise there is no way they can circumvent this, short of downgrading accelerate.

src/accelerate/utils/other.py Show resolved Hide resolved
src/accelerate/utils/imports.py Show resolved Hide resolved
src/accelerate/utils/other.py Outdated Show resolved Hide resolved
Copy link

github-actions bot commented Oct 7, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@muellerzr muellerzr force-pushed the migrate-saving-away-from-torch-load branch from d2a7f0a to 49e6556 Compare October 10, 2024 16:53
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change LGTM, thanks.

There is a CI issue caused by a numpy import. It might be worth checking a few numpy versions (especially < 2.0 and >= 2.0) to be sure that all the numpy imports work.

src/accelerate/checkpointing.py Outdated Show resolved Hide resolved
@muellerzr muellerzr merged commit 6f79b63 into main Oct 10, 2024
27 of 28 checks passed
@muellerzr muellerzr deleted the migrate-saving-away-from-torch-load branch October 10, 2024 18:08
dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Nov 20, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionall call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Nov 21, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionall call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Nov 21, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Nov 22, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
dvrogozh added a commit to dvrogozh/transformers that referenced this pull request Nov 22, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
ArthurZucker pushed a commit to huggingface/transformers that referenced this pull request Nov 25, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: #34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036

Signed-off-by: Dmitry Rogozhkin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants