Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed warning and added safe globals #423

Merged
merged 7 commits into from
Aug 24, 2024

Conversation

NeelKondapalli
Copy link
Contributor

@NeelKondapalli NeelKondapalli commented Jul 31, 2024

Fixes #422 by adding the weights_only = True argument to torch.load in the file io.py. This protects agains the arbitrary data warning. The types stype and StatType were added to the safe globals list.

By: Neel Kondapalli ([email protected])

Copy link
Member

@akihironitta akihironitta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for quickly working on this issue! Looks like the CI is failing. Let me know if you need any help!

@akihironitta akihironitta marked this pull request as draft July 31, 2024 17:11
@akihironitta akihironitta changed the title Issue 422: Fixed warning and added safe globals Fixed warning and added safe globals Jul 31, 2024
@NeelKondapalli
Copy link
Contributor Author

The fix works for a higher PyTorch version, will investigate a fix for v2.2. Has to do with telling the code that the custom classes are safe for depickling.

@NeelKondapalli
Copy link
Contributor Author

NeelKondapalli commented Aug 1, 2024

@akihironitta It seems that the warning in #422 is not being reproduced for Torch v2.2.0 in the CI test logs. The warning occurs because pickle has a potential issue where unsafe code may be executed upon unpickling. Torch v2.4 is the only version where I was able to reproduce this warning. Torch v2.4 also contains the solution, being adding the custom classes stype and StatType to the safe list via the torch.serialization.add_safe_globals method. This method is not in v2.2, hence why CI failed. What is the best next step? Thanks!

pytorch/pytorch#129239 is where the warning was introduced.

@akihironitta
Copy link
Member

Hey @NeelKondapalli, thanks again for this PR!

To keep supporting older PyTorch versions, we could try having a flag like this https://github.com/pyg-team/pytorch_geometric/blob/8c849a482c3cf2326c1f493e79d04169b26dfb0b/torch_geometric/typing.py#L12:

if WITH_PT24:
    torch.serialization.add_safe_globals(...)

If feasible, we'd like to enable weights_only, but do you know if it's possible? (I haven't had a look at this new PyTorch API, but I will this weekend!)

@NeelKondapalli
Copy link
Contributor Author

NeelKondapalli commented Aug 1, 2024

Hi @akihironitta

Adding a version specific flag like that is doable.

From what I can see, weights_only strictly unpickles with only "tensors, primitive types, dictionaries and any types added via torch.serialization.add_safe_globals().

Since the torch.serialization.add_safe_globals fn is not available in the older torch versions, it seems that turning on the weights_only flag will only scan for tensors, primitive types, and dictionaries, meaning stype and StatType won't be loaded during unpickling. Therefore, errors will be raised on lower torch versions as the stype and StatType are not among the restricted types and cannot be added to the "safe" list.

The only way I can think of at the moment to amend this for lower torch versions is to maybe use a custom unpickler when loading the model (instead of Torch's internal load method) and configuring it to only read stype and StatType in addition to the primitive types/tensors/dictionaries. However, this does not sound like a efficient or necessary solution.

Copy link
Member

@akihironitta akihironitta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. Having a look today.

Copy link
Member

@akihironitta akihironitta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks again :)

I feel your 4ce8e5c is the right thing to do. Whenever possible, we should set weights_only=True, but we can disable it in other cases.

@akihironitta akihironitta marked this pull request as ready for review August 24, 2024 15:56
@akihironitta akihironitta merged commit 3710d2f into pyg-team:master Aug 24, 2024
14 checks passed
@akihironitta
Copy link
Member

@NeelKondapalli This is great! Thank you again for making your first contribution to PyTorch Frame! 🚀

@neelkkondapalli
Copy link

Thank you! Awesome we could work it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use torch.load(weights_only=True) in the codebase
3 participants