diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bbd9232..91513a03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added support for inferring `stype.categorical` from boolean columns in `utils.infer_series_stype` ([#421](https://github.com/pyg-team/pytorch-frame/pull/421)) +- Added `pin_memory()` to `TensorFrame`, `MultiEmbeddingTensor`, and `MultiNestedTensor` ([#437](https://github.com/pyg-team/pytorch-frame/pull/437)) ### Changed diff --git a/test/data/test_multi_embedding_tensor.py b/test/data/test_multi_embedding_tensor.py index 2752bc0c..36774f62 100644 --- a/test/data/test_multi_embedding_tensor.py +++ b/test/data/test_multi_embedding_tensor.py @@ -476,3 +476,17 @@ def test_cat(device): # case: list of non-MultiEmbeddingTensor should raise error with pytest.raises(AssertionError): MultiEmbeddingTensor.cat([object()], dim=0) + + +def test_pin_memory(): + met, _ = get_fake_multi_embedding_tensor( + num_rows=2, + num_cols=3, + ) + assert not met.is_pinned() + assert not met.values.is_pinned() + assert not met.offset.is_pinned() + met = met.pin_memory() + assert met.is_pinned() + assert met.values.is_pinned() + assert met.offset.is_pinned() diff --git a/test/data/test_multi_nested_tensor.py b/test/data/test_multi_nested_tensor.py index 62594ac7..cb394a91 100644 --- a/test/data/test_multi_nested_tensor.py +++ b/test/data/test_multi_nested_tensor.py @@ -87,7 +87,7 @@ def test_fillna_col(): @withCUDA -def test_multi_nested_tensor_basics(device): +def test_basics(device): num_rows = 8 num_cols = 10 max_value = 100 @@ -317,7 +317,7 @@ def test_multi_nested_tensor_basics(device): cloned_multi_nested_tensor) -def test_multi_nested_tensor_different_num_rows(): +def test_different_num_rows(): tensor_mat = [ [torch.tensor([1, 2, 3]), torch.tensor([4, 5])], @@ -331,3 +331,19 @@ def test_multi_nested_tensor_different_num_rows(): match="The length of each row must be the same", ): MultiNestedTensor.from_tensor_mat(tensor_mat) + + +def test_pin_memory(): + num_rows = 10 + num_cols = 3 + tensor = MultiNestedTensor.from_tensor_mat( + [[torch.randn(random.randint(0, 10)) for _ in range(num_cols)] + for _ in range(num_rows)]) + + assert not tensor.is_pinned() + assert not tensor.values.is_pinned() + assert not tensor.offset.is_pinned() + tensor = tensor.pin_memory() + assert tensor.is_pinned() + assert tensor.values.is_pinned() + assert tensor.offset.is_pinned() diff --git a/test/data/test_tensor_frame.py b/test/data/test_tensor_frame.py index a713fde7..7c719c05 100644 --- a/test/data/test_tensor_frame.py +++ b/test/data/test_tensor_frame.py @@ -230,3 +230,18 @@ def test_custom_tf_get_col_feat(): assert torch.equal(feat, feat_dict['numerical'][:, 0:1]) feat = tf.get_col_feat('num_2') assert torch.equal(feat, feat_dict['numerical'][:, 1:2]) + + +def test_pin_memory(get_fake_tensor_frame): + def assert_is_pinned(tf: TensorFrame, expected: bool) -> bool: + for value in tf.feat_dict.values(): + if isinstance(value, dict): + for v in value.values(): + assert v.is_pinned() is expected + else: + assert value.is_pinned() is expected + + tf = get_fake_tensor_frame(10) + assert_is_pinned(tf, expected=False) + tf = tf.pin_memory() + assert_is_pinned(tf, expected=True) diff --git a/torch_frame/data/multi_tensor.py b/torch_frame/data/multi_tensor.py index ba2dd309..0d5c6ef8 100644 --- a/torch_frame/data/multi_tensor.py +++ b/torch_frame/data/multi_tensor.py @@ -93,6 +93,12 @@ def cpu(self, *args, **kwargs): def cuda(self, *args, **kwargs): return self._apply(lambda x: x.cuda(*args, **kwargs)) + def pin_memory(self, *args, **kwargs): + return self._apply(lambda x: x.pin_memory(*args, **kwargs)) + + def is_pinned(self) -> bool: + return self.values.is_pinned() + # Helper Functions ######################################################## def _apply(self, fn: Callable[[Tensor], Tensor]) -> _MultiTensor: diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index 62d999f1..11358def 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -340,6 +340,17 @@ def fn(x): return self._apply(fn) + def pin_memory(self, *args, **kwargs): + def fn(x): + if isinstance(x, dict): + for key in x: + x[key] = x[key].pin_memory(*args, **kwargs) + else: + x = x.pin_memory(*args, **kwargs) + return x + + return self._apply(fn) + # Helper Functions ######################################################## def _apply(self, fn: Callable[[TensorData], TensorData]) -> TensorFrame: