From 22424b64991641f6aaa2564fcfdd718ef819eefa Mon Sep 17 00:00:00 2001 From: Ye Cao Date: Fri, 19 Apr 2024 15:41:10 +0800 Subject: [PATCH] Support the bfloat16 for pytorch modules. (#1872) Fixes #1871 Signed-off-by: Ye Cao --- python/vineyard/contrib/ml/torch.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/vineyard/contrib/ml/torch.py b/python/vineyard/contrib/ml/torch.py index 74c27dd9..2161028d 100644 --- a/python/vineyard/contrib/ml/torch.py +++ b/python/vineyard/contrib/ml/torch.py @@ -215,7 +215,17 @@ def put_torch_tensors(client, tensors) -> List[Union[ObjectID, ObjectMeta]]: blobs = client.create_remote_blob(blob_writers) metadatas = [] + found_bfloat16 = False for tensor, size, blob in zip(tensors, sizes, blobs): + if tensor.dtype == torch.bfloat16: + if not found_bfloat16: + warnings.warn( + "Important, bfloat16 is not supported by vineyard, " + "converting to float16 instead, which may cause precision loss." + ) + found_bfloat16 = True + tensor = tensor.to(torch.float16) + value = tensor.numpy() meta = ObjectMeta() meta['typename'] = 'vineyard::Tensor<%s>' % normalize_cpptype(value.dtype)