Skip to content

Commit

Permalink
Support cuda tensor in vineyard.
Browse files Browse the repository at this point in the history
Signed-off-by: Ye Cao <[email protected]>
  • Loading branch information
dashanji committed Nov 18, 2024
1 parent 0f78867 commit fb2d8ff
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion python/vineyard/contrib/ml/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,12 @@ def torch_tensor_builder(client, value, **kw):

meta['typename'] = 'vineyard::Tensor<%s>' % str(value.dtype)
meta['value_type_'] = str(value.dtype)
meta.add_member('buffer_', build_torch_buffer(client, value))
if value.is_cuda:
meta['device_'] = str(value.device)
value_in_cpu = value.to('cpu')
meta.add_member('buffer_', build_torch_buffer(client, value_in_cpu))
else:
meta.add_member('buffer_', build_torch_buffer(client, value))

return client.create_metadata(meta)

Expand All @@ -157,6 +162,7 @@ def torch_tensor_resolver(obj):
value_type = normalize_tensor_dtype(value_type_name)
shape = from_json(meta['shape_'])
order = from_json(meta.get('order_', 'C'))
device = meta.get('device_', 'cpu')

if np.prod(shape) == 0:
return torch.zeros(shape, dtype=value_type)
Expand All @@ -167,6 +173,9 @@ def torch_tensor_resolver(obj):

c_tensor = torch.frombuffer(buffer, dtype=value_type).reshape(shape)

Check warning on line 174 in python/vineyard/contrib/ml/torch.py

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

Unknown word (frombuffer)
tensor = c_tensor if order == 'C' else c_tensor.contiguous()
if "cuda" in device:
cuda_device = torch.device(device)
tensor = c_tensor.to(cuda_device)

return tensor

Expand Down

0 comments on commit fb2d8ff

Please sign in to comment.