From 8b209bb7f9af5a82aec9a43aeb59ded5e4c65604 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 17 Aug 2024 11:16:18 +0000 Subject: [PATCH] Add pin_memory support --- CHANGELOG.md | 1 + test/data/test_multi_embedding_tensor.py | 12 ++++++++++++ test/data/test_multi_nested_tensor.py | 18 ++++++++++++++++-- test/data/test_tensor_frame.py | 15 +++++++++++++++ torch_frame/data/multi_tensor.py | 6 ++++++ torch_frame/data/tensor_frame.py | 11 +++++++++++ 6 files changed, 61 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bbd92324..88579e87a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added support for inferring `stype.categorical` from boolean columns in `utils.infer_series_stype` ([#421](https://github.com/pyg-team/pytorch-frame/pull/421)) +- Added support for `pin_memory()` to `TensorFrame`, `MultiEmbeddingTensor`, and `MultiNestedTensor` ([#437](https://github.com/pyg-team/pytorch-frame/pull/437)) ### Changed diff --git a/test/data/test_multi_embedding_tensor.py b/test/data/test_multi_embedding_tensor.py index 2752bc0c2..5b67407a6 100644 --- a/test/data/test_multi_embedding_tensor.py +++ b/test/data/test_multi_embedding_tensor.py @@ -476,3 +476,15 @@ def test_cat(device): # case: list of non-MultiEmbeddingTensor should raise error with pytest.raises(AssertionError): MultiEmbeddingTensor.cat([object()], dim=0) + + +def test_pin_memory(): + met, _ = get_fake_multi_embedding_tensor( + num_rows=2, + num_cols=3, + ) + assert not met.values.is_pinned() + assert not met.offset.is_pinned() + met = met.pin_memory() + assert met.values.is_pinned() + assert met.offset.is_pinned() diff --git a/test/data/test_multi_nested_tensor.py b/test/data/test_multi_nested_tensor.py index 62594ac7e..166336d3e 100644 --- a/test/data/test_multi_nested_tensor.py +++ b/test/data/test_multi_nested_tensor.py @@ -87,7 +87,7 @@ def test_fillna_col(): @withCUDA -def test_multi_nested_tensor_basics(device): +def test_basics(device): num_rows = 8 num_cols = 10 max_value = 100 @@ -317,7 +317,7 @@ def test_multi_nested_tensor_basics(device): cloned_multi_nested_tensor) -def test_multi_nested_tensor_different_num_rows(): +def test_different_num_rows(): tensor_mat = [ [torch.tensor([1, 2, 3]), torch.tensor([4, 5])], @@ -331,3 +331,17 @@ def test_multi_nested_tensor_different_num_rows(): match="The length of each row must be the same", ): MultiNestedTensor.from_tensor_mat(tensor_mat) + + +def test_pin_memory(): + num_rows = 10 + num_cols = 3 + tensor = MultiNestedTensor.from_tensor_mat( + [[torch.randn(random.randint(0, 10)) for _ in range(num_cols)] + for _ in range(num_rows)]) + + assert not tensor.values.is_pinned() + assert not tensor.offset.is_pinned() + tensor = tensor.pin_memory() + assert tensor.values.is_pinned() + assert tensor.offset.is_pinned() diff --git a/test/data/test_tensor_frame.py b/test/data/test_tensor_frame.py index a713fde7f..7c719c055 100644 --- a/test/data/test_tensor_frame.py +++ b/test/data/test_tensor_frame.py @@ -230,3 +230,18 @@ def test_custom_tf_get_col_feat(): assert torch.equal(feat, feat_dict['numerical'][:, 0:1]) feat = tf.get_col_feat('num_2') assert torch.equal(feat, feat_dict['numerical'][:, 1:2]) + + +def test_pin_memory(get_fake_tensor_frame): + def assert_is_pinned(tf: TensorFrame, expected: bool) -> bool: + for value in tf.feat_dict.values(): + if isinstance(value, dict): + for v in value.values(): + assert v.is_pinned() is expected + else: + assert value.is_pinned() is expected + + tf = get_fake_tensor_frame(10) + assert_is_pinned(tf, expected=False) + tf = tf.pin_memory() + assert_is_pinned(tf, expected=True) diff --git a/torch_frame/data/multi_tensor.py b/torch_frame/data/multi_tensor.py index ba2dd3090..0d5c6ef80 100644 --- a/torch_frame/data/multi_tensor.py +++ b/torch_frame/data/multi_tensor.py @@ -93,6 +93,12 @@ def cpu(self, *args, **kwargs): def cuda(self, *args, **kwargs): return self._apply(lambda x: x.cuda(*args, **kwargs)) + def pin_memory(self, *args, **kwargs): + return self._apply(lambda x: x.pin_memory(*args, **kwargs)) + + def is_pinned(self) -> bool: + return self.values.is_pinned() + # Helper Functions ######################################################## def _apply(self, fn: Callable[[Tensor], Tensor]) -> _MultiTensor: diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index 62d999f1b..11358def0 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -340,6 +340,17 @@ def fn(x): return self._apply(fn) + def pin_memory(self, *args, **kwargs): + def fn(x): + if isinstance(x, dict): + for key in x: + x[key] = x[key].pin_memory(*args, **kwargs) + else: + x = x.pin_memory(*args, **kwargs) + return x + + return self._apply(fn) + # Helper Functions ######################################################## def _apply(self, fn: Callable[[TensorData], TensorData]) -> TensorFrame: