Skip to content

Commit

Permalink
test: add search with text match
Browse files Browse the repository at this point in the history
Signed-off-by: zhuwenxing <[email protected]>
  • Loading branch information
zhuwenxing committed Sep 25, 2024
1 parent 58baeee commit cd3e8b4
Showing 1 changed file with 144 additions and 0 deletions.
144 changes: 144 additions & 0 deletions tests/python_client/testcases/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import threading
import pytest
import pandas as pd
from faker import Faker
fake_en = Faker("en_US")
fake_zh = Faker("zh_CN")
pd.set_option("expand_frame_repr", False)


Expand Down Expand Up @@ -13113,3 +13116,144 @@ def test_search_normal_none_data_partition_key(self, is_flush, enable_dynamic_fi
"output_fields": [default_int64_field_name,
default_float_field_name]})

class TestSearchWithTextMatchFilter(TestcaseBase):
"""
******************************************************************
The following cases are used to test query text match
******************************************************************
"""

@pytest.mark.tags(CaseLabel.L0)
@pytest.mark.parametrize("enable_partition_key", [True, False])
@pytest.mark.parametrize("enable_inverted_index", [True, False])
@pytest.mark.parametrize("tokenizer", ["jieba", "default"])
def test_search_with_text_match_filter_normal(
self, tokenizer, enable_inverted_index, enable_partition_key
):
"""
target: test text match normal
method: 1. enable text match and insert data with varchar
2. get the most common words and query with text match
3. verify the result
expected: text match successfully and result is correct
"""
analyzer_params = {
"tokenizer": tokenizer,
}
dim = 128
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True),
FieldSchema(
name="word",
dtype=DataType.VARCHAR,
max_length=65535,
enable_match=True,
is_partition_key=enable_partition_key,
analyzer_params=analyzer_params,
),
FieldSchema(
name="sentence",
dtype=DataType.VARCHAR,
max_length=65535,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(
name="paragraph",
dtype=DataType.VARCHAR,
max_length=65535,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(
name="text",
dtype=DataType.VARCHAR,
max_length=65535,
enable_match=True,
analyzer_params=analyzer_params,
),
FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim),
]
schema = CollectionSchema(fields=fields, description="test collection")
data_size = 5000
collection_w = self.init_collection_wrap(
name=cf.gen_unique_str(prefix), schema=schema
)
fake = fake_en
if tokenizer == "jieba":
language = "zh"
fake = fake_zh
else:
language = "en"

data = [
{
"id": i,
"word": fake.word().lower(),
"sentence": fake.sentence().lower(),
"paragraph": fake.paragraph().lower(),
"text": fake.text().lower(),
"emb": [random.random() for _ in range(dim)],
}
for i in range(data_size)
]
df = pd.DataFrame(data)
log.info(f"dataframe\n{df}")
batch_size = 5000
for i in range(0, len(df), batch_size):
collection_w.insert(
data[i : i + batch_size]
if i + batch_size < len(df)
else data[i : len(df)]
)
collection_w.flush()
collection_w.create_index(
"emb",
{"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}},
)
if enable_inverted_index:
collection_w.create_index("word", {"index_type": "INVERTED"})
collection_w.load()
# analyze the croup
text_fields = ["word", "sentence", "paragraph", "text"]
wf_map = {}
for field in text_fields:
wf_map[field] = cf.analyze_documents(df[field].tolist(), language=language)
# query single field for one token
for field in text_fields:
token = wf_map[field].most_common()[0][0]
expr = f"TextMatch({field}, '{token}')"
log.info(f"expr: {expr}")
res_list, _ = collection_w.search(
data=[[random.random() for _ in range(dim)]],
anns_field="emb",
param={},
limit=100,
expr=expr, output_fields=["id", field])
for res in res_list:
assert len(res) > 0
log.info(f"res len {len(res)} res {res}")
for r in res:
r = r.to_dict()
assert token in r["entity"][field]

# query single field for multi-word
for field in text_fields:
# match top 10 most common words
top_10_tokens = []
for word, count in wf_map[field].most_common(10):
top_10_tokens.append(word)
string_of_top_10_words = " ".join(top_10_tokens)
expr = f"TextMatch({field}, '{string_of_top_10_words}')"
log.info(f"expr {expr}")
res_list, _ = collection_w.search(
data=[[random.random() for _ in range(dim)]],
anns_field="emb",
param={},
limit=100,
expr=expr, output_fields=["id", field])
for res in res_list:
log.info(f"res len {len(res)} res {res}")
for r in res:
r = r.to_dict()
assert any([token in r["entity"][field] for token in top_10_tokens])

0 comments on commit cd3e8b4

Please sign in to comment.