-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add safe_globals to resume training on PyTorch 2.6 #34632
Conversation
2cee855
to
fa62472
Compare
@muellerzr, @SunMarc, @ArthurZucker : can you, please, help comment on this PR? see issue #34631 on details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice ! Thanks for adding this ! Left a comment
I am getting
when running
against this PR. |
@ydshieh : this might be due to numpy version. dtypes was added in 1.25 according to https://numpy.org/doc/2.1/reference/routines.dtypes.html#module-numpy.dtypes. Locally I have 1.26.4. Which version do you have? I will work on using context manager since there is an alignment on that and also tune a list per versioning of numpy. |
On our CI runner , I get |
The numpy GLOBALs for dtypes that need to be allowlisted might need an if statement depending on whether version < 1.25 or not, there's some documentation on this here https://pytorch.org/docs/main/notes/serialization.html#troubleshooting-weights-only |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @muellerzr if you can have a look as well!
src/transformers/trainer.py
Outdated
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals | ||
if version.parse(torch.__version__) >= version.parse("2.4.0"): | ||
torch.serialization.add_safe_globals( | ||
[np.core.multiarray._reconstruct, np.ndarray, np.dtype, np.dtypes.UInt32DType] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could have a SAFE_TRANSFORMERS_GLOBAL
with these no? this way people can easily update them?
TBH I prefer the context manager but want to have the least duplication as possible!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found that calling torch.serialization.add_safe_globals()
still works to add additional safe global staff. SAFE_TRANSFORMERS_GLOBAL
can also be considered. Let me know if you see the need.
fa62472
to
276a3a0
Compare
allowlist = [np.core.multiarray._reconstruct, np.ndarray, np.dtype] | ||
# numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for | ||
# all versions of numpy | ||
allowlist += [type(np.dtype(np.uint32))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I add any other numpy dtypes in the list? As of now I spotted only np.unit32
in the Transformers list as the one needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only one I don't see from accelerate is encode
, however if things pass here without it it's accelerate specific and we don't need to worry about it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Transformer tests did pass on my side without adding encode
. This indeed seems accelerate specific.
276a3a0
to
4273a30
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just a documentation suggestion but this all looks correct
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
4273a30
to
468aa06
Compare
@muellerz : done, added a link to Accelerate PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Just a nit
468aa06
to
27c307f
Compare
@SunMarc : addressed, reused approach from accelerate on |
27c307f
to
dbb3112
Compare
src/transformers/trainer.py
Outdated
# See: https://github.com/pytorch/pytorch/pull/137602 | ||
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals | ||
# See: https://github.com/huggingface/accelerate/pull/3036 | ||
if version.parse(torch.__version__) < version.parse("2.4.0"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a nit: should it be "2.6.0"
here or it's really necessary being "2.4.0"
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switched to version < 2.6.0a0
. Indeed, on switching to context manager I overlooked that it was introduced later. Overall:
torch.serialization.add_safe_globals
appeared in pytorch 2.4torch.serialization.safe_globals
(context manager) appeared in 2.5- And pytorch 2.6 flipped default of
weights_only
intorch.load
fromFalse
toTrue
Overall, it indeed does not make sense to have this code working for versions earlier than 2.6 unless we will start calling torch.load
with explicit weights_only=True
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! A tiny question: how to get 2.6.0a0
installed. I know how to install night but it gets dev202411xx
instead of a0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway, good to use a0
here for now. Once 2.6 is released, we can change it to 2.6
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! A tiny question: how to get 2.6.0a0 installed.
I am getting this building from sources. And <2.6.0
does not work for me on my build. So, 2.6.0a0
is my best effort to get the check working for my current build. I did not know that nightly builds get dev202411xx
, I thought they also give a0
. I wonder will the check still work for nightly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked. <2.6.0a0
won't work with nightly. So, I switched to a check I ones spotted in a code by Narsil. This should handle both cases, building from sources and using 2.6 nightly (I checked - works for both on my side):
if version.parse(torch.__version__).release < version.parse("2.6").release:
dbb3112
to
0505f2c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
Starting from version 2.4 PyTorch introduces a stricter check for the objects which can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True requires allowlisting of such objects. This commit adds allowlist of some numpy objects used to load model checkpoints. Usage is restricted by context manager. User can still additionally call torch.serialization.add_safe_globals() to add other objects into the safe globals list. Accelerate library also stepped into same problem and addressed it with PR-3036. Fixes: huggingface#34631 See: pytorch/pytorch#137602 See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals See: huggingface/accelerate#3036 Signed-off-by: Dmitry Rogozhkin <[email protected]>
0505f2c
to
820ca4a
Compare
Thanks for fixing 🤗 |
Starting from version 2.4 PyTorch introduces a stricter check for the objects which can be loaded with
torch.load()
. Starting from version 2.6 loading withweights_only=True
requires allowlisting of such objects.This commit adds allowlist of some
numpy
objects used to load model checkpoints. Usage is restricted by context manager. User can still calltorch.serialization.add_safe_globals()
to add other objects into the safe globals list.Accelerate library also stepped into same problem and addressed it with PR-3036.
Fixes: #34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
CC: @muellerzr @SunMarc