Skip to content

Commit

Permalink
Add pin_memory support
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Aug 17, 2024
1 parent 59994ec commit e1b98de
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added support for inferring `stype.categorical` from boolean columns in `utils.infer_series_stype` ([#421](https://github.com/pyg-team/pytorch-frame/pull/421))
- Added `pin_memory()` to `TensorFrame`, `MultiEmbeddingTensor`, and `MultiNestedTensor` ([#437](https://github.com/pyg-team/pytorch-frame/pull/437))

### Changed

Expand Down
14 changes: 14 additions & 0 deletions test/data/test_multi_embedding_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,3 +476,17 @@ def test_cat(device):
# case: list of non-MultiEmbeddingTensor should raise error
with pytest.raises(AssertionError):
MultiEmbeddingTensor.cat([object()], dim=0)


def test_pin_memory():
met, _ = get_fake_multi_embedding_tensor(
num_rows=2,
num_cols=3,
)
assert not met.is_pinned()
assert not met.values.is_pinned()
assert not met.offset.is_pinned()
met = met.pin_memory()
assert met.is_pinned()
assert met.values.is_pinned()
assert met.offset.is_pinned()
20 changes: 18 additions & 2 deletions test/data/test_multi_nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_fillna_col():


@withCUDA
def test_multi_nested_tensor_basics(device):
def test_basics(device):
num_rows = 8
num_cols = 10
max_value = 100
Expand Down Expand Up @@ -317,7 +317,7 @@ def test_multi_nested_tensor_basics(device):
cloned_multi_nested_tensor)


def test_multi_nested_tensor_different_num_rows():
def test_different_num_rows():
tensor_mat = [
[torch.tensor([1, 2, 3]),
torch.tensor([4, 5])],
Expand All @@ -331,3 +331,19 @@ def test_multi_nested_tensor_different_num_rows():
match="The length of each row must be the same",
):
MultiNestedTensor.from_tensor_mat(tensor_mat)


def test_pin_memory():
num_rows = 10
num_cols = 3
tensor = MultiNestedTensor.from_tensor_mat(
[[torch.randn(random.randint(0, 10)) for _ in range(num_cols)]
for _ in range(num_rows)])

assert not tensor.is_pinned()
assert not tensor.values.is_pinned()
assert not tensor.offset.is_pinned()
tensor = tensor.pin_memory()
assert tensor.is_pinned()
assert tensor.values.is_pinned()
assert tensor.offset.is_pinned()
15 changes: 15 additions & 0 deletions test/data/test_tensor_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,18 @@ def test_custom_tf_get_col_feat():
assert torch.equal(feat, feat_dict['numerical'][:, 0:1])
feat = tf.get_col_feat('num_2')
assert torch.equal(feat, feat_dict['numerical'][:, 1:2])


def test_pin_memory(get_fake_tensor_frame):
def assert_is_pinned(tf: TensorFrame, expected: bool) -> bool:
for value in tf.feat_dict.values():
if isinstance(value, dict):
for v in value.values():
assert v.is_pinned() is expected
else:
assert value.is_pinned() is expected

tf = get_fake_tensor_frame(10)
assert_is_pinned(tf, expected=False)
tf = tf.pin_memory()
assert_is_pinned(tf, expected=True)
6 changes: 6 additions & 0 deletions torch_frame/data/multi_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def cpu(self, *args, **kwargs):
def cuda(self, *args, **kwargs):
return self._apply(lambda x: x.cuda(*args, **kwargs))

def pin_memory(self, *args, **kwargs):
return self._apply(lambda x: x.pin_memory(*args, **kwargs))

def is_pinned(self) -> bool:
return self.values.is_pinned()

# Helper Functions ########################################################

def _apply(self, fn: Callable[[Tensor], Tensor]) -> _MultiTensor:
Expand Down
11 changes: 11 additions & 0 deletions torch_frame/data/tensor_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,17 @@ def fn(x):

return self._apply(fn)

def pin_memory(self, *args, **kwargs):
def fn(x):
if isinstance(x, dict):
for key in x:
x[key] = x[key].pin_memory(*args, **kwargs)
else:
x = x.pin_memory(*args, **kwargs)
return x

return self._apply(fn)

# Helper Functions ########################################################

def _apply(self, fn: Callable[[TensorData], TensorData]) -> TensorFrame:
Expand Down

0 comments on commit e1b98de

Please sign in to comment.