Skip to content

Commit

Permalink
Fail torch.load(weights=True) gracefully (#448)
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta authored Sep 9, 2024
1 parent 546f1a2 commit a3b73c4
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
22 changes: 21 additions & 1 deletion test/utils/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import shutil
import tempfile

import pytest

import torch_frame
from torch_frame import load, save
from torch_frame import TensorFrame, load, save
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.config.text_tokenizer import TextTokenizerConfig
from torch_frame.datasets import FakeDataset
Expand Down Expand Up @@ -114,3 +116,21 @@ def test_save_load_tensor_frame():
tf, col_stats = load(path)
assert dataset.col_stats == col_stats
assert dataset.tensor_frame == tf


class UntrustedClass:
pass


@pytest.mark.skipif(
not torch_frame.typing.WITH_PT24,
reason='Requres PyTorch 2.4',
)
def test_load_weights_only_gracefully(tmpdir):
save(
tensor_frame=TensorFrame({}, {}),
col_stats={'a': UntrustedClass()},
path=tmpdir.join('tf.pt'),
)
with pytest.warns(UserWarning, match='Weights only load failed'):
load(tmpdir.join('tf.pt'))
32 changes: 28 additions & 4 deletions torch_frame/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import pickle
import re
import warnings
from typing import Any

import torch
Expand All @@ -13,7 +16,7 @@
)
from torch_frame.data.multi_tensor import _MultiTensor
from torch_frame.data.stats import StatType
from torch_frame.typing import WITH_PT24, TensorData
from torch_frame.typing import TensorData


def serialize_feat_dict(
Expand Down Expand Up @@ -96,9 +99,30 @@ def load(
tuple: A tuple of loaded :class:`TensorFrame` object and
optional :obj:`col_stats`.
"""
tf_dict, col_stats = torch.load(path, weights_only=WITH_PT24)
if torch_frame.typing.WITH_PT24:
try:
tf_dict, col_stats = torch.load(path, weights_only=True)
except pickle.UnpicklingError as e:
error_msg = str(e)
if "add_safe_globals" in error_msg:
warn_msg = ("Weights only load failed. Please file an issue "
"to make `torch.load(weights_only=True)` "
"compatible in your case.")
match = re.search(r'add_safe_globals\(.*?\)', error_msg)
if match is not None:
warnings.warn(f"{warn_msg} Please use "
f"`torch.serialization.{match.group()}` to "
f"allowlist this global.")
else:
warnings.warn(warn_msg)

tf_dict, col_stats = torch.load(path, weights_only=False)
else:
raise e
else:
tf_dict, col_stats = torch.load(path, weights_only=False)

tf_dict['feat_dict'] = deserialize_feat_dict(
tf_dict.pop('feat_serialized_dict'))
tensor_frame = TensorFrame(**tf_dict)
tensor_frame.to(device)
tensor_frame = TensorFrame(**tf_dict).to(device)
return tensor_frame, col_stats

0 comments on commit a3b73c4

Please sign in to comment.