diff --git a/python/vineyard/contrib/ml/torch.py b/python/vineyard/contrib/ml/torch.py index 74c27dd9..2161028d 100644 --- a/python/vineyard/contrib/ml/torch.py +++ b/python/vineyard/contrib/ml/torch.py @@ -215,7 +215,17 @@ def put_torch_tensors(client, tensors) -> List[Union[ObjectID, ObjectMeta]]: blobs = client.create_remote_blob(blob_writers) metadatas = [] + found_bfloat16 = False for tensor, size, blob in zip(tensors, sizes, blobs): + if tensor.dtype == torch.bfloat16: + if not found_bfloat16: + warnings.warn( + "Important, bfloat16 is not supported by vineyard, " + "converting to float16 instead, which may cause precision loss." + ) + found_bfloat16 = True + tensor = tensor.to(torch.float16) + value = tensor.numpy() meta = ObjectMeta() meta['typename'] = 'vineyard::Tensor<%s>' % normalize_cpptype(value.dtype)