diff --git a/tests/python_client/testcases/test_search.py b/tests/python_client/testcases/test_search.py index 5f2470accf7b1..13ac49264da05 100644 --- a/tests/python_client/testcases/test_search.py +++ b/tests/python_client/testcases/test_search.py @@ -24,6 +24,9 @@ import threading import pytest import pandas as pd +from faker import Faker +fake_en = Faker("en_US") +fake_zh = Faker("zh_CN") pd.set_option("expand_frame_repr", False) @@ -13113,3 +13116,144 @@ def test_search_normal_none_data_partition_key(self, is_flush, enable_dynamic_fi "output_fields": [default_int64_field_name, default_float_field_name]}) +class TestSearchWithTextMatchFilter(TestcaseBase): + """ + ****************************************************************** + The following cases are used to test query text match + ****************************************************************** + """ + + @pytest.mark.tags(CaseLabel.L0) + @pytest.mark.parametrize("enable_partition_key", [True, False]) + @pytest.mark.parametrize("enable_inverted_index", [True, False]) + @pytest.mark.parametrize("tokenizer", ["jieba", "default"]) + def test_search_with_text_match_filter_normal( + self, tokenizer, enable_inverted_index, enable_partition_key + ): + """ + target: test text match normal + method: 1. enable text match and insert data with varchar + 2. get the most common words and query with text match + 3. verify the result + expected: text match successfully and result is correct + """ + analyzer_params = { + "tokenizer": tokenizer, + } + dim = 128 + fields = [ + FieldSchema(name="id", dtype=DataType.INT64, is_primary=True), + FieldSchema( + name="word", + dtype=DataType.VARCHAR, + max_length=65535, + enable_match=True, + is_partition_key=enable_partition_key, + analyzer_params=analyzer_params, + ), + FieldSchema( + name="sentence", + dtype=DataType.VARCHAR, + max_length=65535, + enable_match=True, + analyzer_params=analyzer_params, + ), + FieldSchema( + name="paragraph", + dtype=DataType.VARCHAR, + max_length=65535, + enable_match=True, + analyzer_params=analyzer_params, + ), + FieldSchema( + name="text", + dtype=DataType.VARCHAR, + max_length=65535, + enable_match=True, + analyzer_params=analyzer_params, + ), + FieldSchema(name="emb", dtype=DataType.FLOAT_VECTOR, dim=dim), + ] + schema = CollectionSchema(fields=fields, description="test collection") + data_size = 5000 + collection_w = self.init_collection_wrap( + name=cf.gen_unique_str(prefix), schema=schema + ) + fake = fake_en + if tokenizer == "jieba": + language = "zh" + fake = fake_zh + else: + language = "en" + + data = [ + { + "id": i, + "word": fake.word().lower(), + "sentence": fake.sentence().lower(), + "paragraph": fake.paragraph().lower(), + "text": fake.text().lower(), + "emb": [random.random() for _ in range(dim)], + } + for i in range(data_size) + ] + df = pd.DataFrame(data) + log.info(f"dataframe\n{df}") + batch_size = 5000 + for i in range(0, len(df), batch_size): + collection_w.insert( + data[i : i + batch_size] + if i + batch_size < len(df) + else data[i : len(df)] + ) + collection_w.flush() + collection_w.create_index( + "emb", + {"index_type": "IVF_SQ8", "metric_type": "L2", "params": {"nlist": 64}}, + ) + if enable_inverted_index: + collection_w.create_index("word", {"index_type": "INVERTED"}) + collection_w.load() + # analyze the croup + text_fields = ["word", "sentence", "paragraph", "text"] + wf_map = {} + for field in text_fields: + wf_map[field] = cf.analyze_documents(df[field].tolist(), language=language) + # query single field for one token + for field in text_fields: + token = wf_map[field].most_common()[0][0] + expr = f"TextMatch({field}, '{token}')" + log.info(f"expr: {expr}") + res_list, _ = collection_w.search( + data=[[random.random() for _ in range(dim)]], + anns_field="emb", + param={}, + limit=100, + expr=expr, output_fields=["id", field]) + for res in res_list: + assert len(res) > 0 + log.info(f"res len {len(res)} res {res}") + for r in res: + r = r.to_dict() + assert token in r["entity"][field] + + # query single field for multi-word + for field in text_fields: + # match top 10 most common words + top_10_tokens = [] + for word, count in wf_map[field].most_common(10): + top_10_tokens.append(word) + string_of_top_10_words = " ".join(top_10_tokens) + expr = f"TextMatch({field}, '{string_of_top_10_words}')" + log.info(f"expr {expr}") + res_list, _ = collection_w.search( + data=[[random.random() for _ in range(dim)]], + anns_field="emb", + param={}, + limit=100, + expr=expr, output_fields=["id", field]) + for res in res_list: + log.info(f"res len {len(res)} res {res}") + for r in res: + r = r.to_dict() + assert any([token in r["entity"][field] for token in top_10_tokens])