Skip to content

Commit

Permalink
Support the bfloat16 for pytorch modules. (#1872)
Browse files Browse the repository at this point in the history
Fixes #1871

Signed-off-by: Ye Cao <[email protected]>
  • Loading branch information
dashanji authored Apr 19, 2024
1 parent 5d3c5f6 commit 22424b6
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/vineyard/contrib/ml/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,17 @@ def put_torch_tensors(client, tensors) -> List[Union[ObjectID, ObjectMeta]]:
blobs = client.create_remote_blob(blob_writers)

metadatas = []
found_bfloat16 = False
for tensor, size, blob in zip(tensors, sizes, blobs):
if tensor.dtype == torch.bfloat16:
if not found_bfloat16:
warnings.warn(
"Important, bfloat16 is not supported by vineyard, "
"converting to float16 instead, which may cause precision loss."
)
found_bfloat16 = True
tensor = tensor.to(torch.float16)

value = tensor.numpy()
meta = ObjectMeta()
meta['typename'] = 'vineyard::Tensor<%s>' % normalize_cpptype(value.dtype)
Expand Down

0 comments on commit 22424b6

Please sign in to comment.