Skip to content

Commit

Permalink
test: use ml-dtypes lib to produce bf16 datatype (#33354)
Browse files Browse the repository at this point in the history
Signed-off-by: zhuwenxing <[email protected]>
  • Loading branch information
zhuwenxing authored May 24, 2024
1 parent 970bf18 commit ed883b3
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 13 deletions.
6 changes: 3 additions & 3 deletions tests/python_client/common/bulk_insert_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time

import numpy as np
import jax.numpy as jnp
from ml_dtypes import bfloat16
import pandas as pd
import random
from faker import Faker
Expand Down Expand Up @@ -128,9 +128,9 @@ def gen_bf16_vectors(num, dim, for_json=False):
raw_vector = [random.random() for _ in range(dim)]
raw_vectors.append(raw_vector)
if for_json:
bf16_vector = np.array(jnp.array(raw_vector, dtype=jnp.bfloat16)).tolist()
bf16_vector = np.array(raw_vector, dtype=bfloat16).tolist()
else:
bf16_vector = np.array(jnp.array(raw_vector, dtype=jnp.bfloat16)).view(np.uint8).tolist()
bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist()
bf16_vectors.append(bf16_vector)

return raw_vectors, bf16_vectors
Expand Down
11 changes: 4 additions & 7 deletions tests/python_client/common/common_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import singledispatch
import numpy as np
import pandas as pd
import jax.numpy as jnp
from ml_dtypes import bfloat16
from sklearn import preprocessing
from npy_append_array import NpyAppendArray
from faker import Faker
Expand All @@ -20,7 +20,6 @@
from utils.util_log import test_log as log
from customize.milvus_operator import MilvusOperator
import pickle
import tensorflow as tf
fake = Faker()
"""" Methods of processing data """

Expand Down Expand Up @@ -1070,14 +1069,12 @@ def gen_data_by_collection_field(field, nb=None, start=None):
dim = field.params['dim']
if nb is None:
raw_vector = [random.random() for _ in range(dim)]
bf16_vector = jnp.array(raw_vector, dtype=jnp.bfloat16)
bf16_vector = np.array(bf16_vector).view(np.uint8).tolist()
bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist()
return bytes(bf16_vector)
bf16_vectors = []
for i in range(nb):
raw_vector = [random.random() for _ in range(dim)]
bf16_vector = jnp.array(raw_vector, dtype=jnp.bfloat16)
bf16_vector = np.array(bf16_vector).view(np.uint8).tolist()
bf16_vector = np.array(raw_vector, dtype=bfloat16).view(np.uint8).tolist()
bf16_vectors.append(bytes(bf16_vector))
return bf16_vectors
if data_type == DataType.FLOAT16_VECTOR:
Expand Down Expand Up @@ -2077,7 +2074,7 @@ def gen_bf16_vectors(num, dim):
for _ in range(num):
raw_vector = [random.random() for _ in range(dim)]
raw_vectors.append(raw_vector)
bf16_vector = tf.cast(raw_vector, dtype=tf.bfloat16).numpy()
bf16_vector = np.array(raw_vector, dtype=bfloat16)
bf16_vectors.append(bf16_vector)

return raw_vectors, bf16_vectors
Expand Down
4 changes: 1 addition & 3 deletions tests/python_client/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,5 @@ pyarrow==14.0.1
fastparquet==2023.7.0

# for bf16 datatype
jax==0.4.13
jaxlib==0.4.13
tensorflow==2.13.1
ml-dtypes==0.2.0

0 comments on commit ed883b3

Please sign in to comment.