Skip to content

Commit

Permalink
test: add coo format sparse vector in restful test (milvus-io#33677)
Browse files Browse the repository at this point in the history
* add coo format sparse vector
* search data and insert data in the same sparse format or a different
format

Signed-off-by: zhuwenxing <[email protected]>
  • Loading branch information
zhuwenxing authored Jun 6, 2024
1 parent 27cc9f2 commit 9c2e325
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
28 changes: 23 additions & 5 deletions tests/restful_client_v2/testcases/test_vector_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,9 @@ def test_search_vector_with_float_vector_datatype(self, nb, dim, insert_round, a
@pytest.mark.parametrize("nb", [3000])
@pytest.mark.parametrize("dim", [128])
@pytest.mark.parametrize("groupingField", ['user_id', None])
@pytest.mark.parametrize("sparse_format", ['dok', 'coo'])
def test_search_vector_with_sparse_float_vector_datatype(self, nb, dim, insert_round, auto_id,
is_partition_key, enable_dynamic_schema, groupingField):
is_partition_key, enable_dynamic_schema, groupingField, sparse_format):
"""
Insert a vector with a simple payload
"""
Expand Down Expand Up @@ -879,15 +880,15 @@ def test_search_vector_with_sparse_float_vector_datatype(self, nb, dim, insert_r
"user_id": idx%100,
"word_count": j,
"book_describe": f"book_{idx}",
"sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim),
"sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format=sparse_format),
}
else:
tmp = {
"book_id": idx,
"user_id": idx%100,
"word_count": j,
"book_describe": f"book_{idx}",
"sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim),
"sparse_float_vector": gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format=sparse_format),
}
if enable_dynamic_schema:
tmp.update({f"dynamic_field_{i}": i})
Expand All @@ -902,7 +903,7 @@ def test_search_vector_with_sparse_float_vector_datatype(self, nb, dim, insert_r
# search data
payload = {
"collectionName": name,
"data": [gen_vector(datatype="SparseFloatVector", dim=dim)],
"data": [gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format="dok")],
"filter": "word_count > 100",
"outputFields": ["*"],
"searchParams": {
Expand All @@ -918,7 +919,24 @@ def test_search_vector_with_sparse_float_vector_datatype(self, nb, dim, insert_r
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 0


# search data
payload = {
"collectionName": name,
"data": [gen_vector(datatype="SparseFloatVector", dim=dim, sparse_format="coo")],
"filter": "word_count > 100",
"outputFields": ["*"],
"searchParams": {
"metricType": "IP",
"params": {
"drop_ratio_search": "0.2",
}
},
"limit": 500,
}
if groupingField:
payload["groupingField"] = groupingField
rsp = self.vector_client.vector_search(payload)
assert rsp['code'] == 0

@pytest.mark.parametrize("insert_round", [2])
@pytest.mark.parametrize("auto_id", [True])
Expand Down
14 changes: 12 additions & 2 deletions tests/restful_client_v2/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,22 @@ def gen_bf16_vectors(num, dim):
return raw_vectors, bf16_vectors


def gen_vector(datatype="float_vector", dim=128, binary_data=False):
def gen_vector(datatype="float_vector", dim=128, binary_data=False, sparse_format='dok'):
value = None
if datatype == "FloatVector":
return preprocessing.normalize([np.array([random.random() for i in range(dim)])])[0].tolist()
if datatype == "SparseFloatVector":
return {d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))}
if sparse_format == 'dok':
return {d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))}
elif sparse_format == 'coo':
data = {d: rng.random() for d in random.sample(range(dim), random.randint(20, 30))}
coo_data = {
"indices": list(data.keys()),
"values": list(data.values())
}
return coo_data
else:
raise Exception(f"unsupported sparse format: {sparse_format}")
if datatype == "BinaryVector":
value = gen_binary_vectors(1, dim)[1][0]
if datatype == "Float16Vector":
Expand Down

0 comments on commit 9c2e325

Please sign in to comment.