Skip to content

Commit

Permalink
test: Add full-text search test cases (#36998)
Browse files Browse the repository at this point in the history
/kind improvement

---------

Signed-off-by: zhuwenxing <[email protected]>
  • Loading branch information
zhuwenxing authored Oct 23, 2024
1 parent 80d48f1 commit 3b024f9
Show file tree
Hide file tree
Showing 5 changed files with 3,358 additions and 29 deletions.
7 changes: 6 additions & 1 deletion tests/python_client/base/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from common import common_type as ct
from common.common_params import IndexPrams

from pymilvus import ResourceGroupInfo, DataType
from pymilvus import ResourceGroupInfo, DataType, utility
import pymilvus


class Base:
Expand All @@ -44,6 +45,7 @@ def teardown_class(self):

def setup_method(self, method):
log.info(("*" * 35) + " setup " + ("*" * 35))
log.info(f"pymilvus version: {pymilvus.__version__}")
log.info("[setup_method] Start setup test case %s." % method.__name__)
self._setup_objects()

Expand Down Expand Up @@ -144,6 +146,7 @@ def _connect(self, enable_milvus_client_api=False):
uri = cf.param_info.param_uri
else:
uri = "http://" + cf.param_info.param_host + ":" + str(cf.param_info.param_port)
self.connection_wrap.connect(alias=DefaultConfig.DEFAULT_USING,uri=uri,token=cf.param_info.param_token)
res, is_succ = self.connection_wrap.MilvusClient(uri=uri,
token=cf.param_info.param_token)
else:
Expand All @@ -159,6 +162,8 @@ def _connect(self, enable_milvus_client_api=False):
host=cf.param_info.param_host,
port=cf.param_info.param_port)

server_version = utility.get_server_version()
log.info(f"server version: {server_version}")
return res

def init_collection_wrap(self, name=None, schema=None, check_task=None, check_items=None,
Expand Down
109 changes: 86 additions & 23 deletions tests/python_client/common/common_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from faker import Faker
from pathlib import Path
from minio import Minio
from pymilvus import DataType, CollectionSchema
from base.schema_wrapper import ApiCollectionSchemaWrapper, ApiFieldSchemaWrapper
from common import common_type as ct
from common.common_params import ExprCheckParams
Expand All @@ -24,6 +23,12 @@
from collections import Counter
import bm25s
import jieba
import re

from pymilvus import CollectionSchema, DataType

from bm25s.tokenization import Tokenizer

fake = Faker()


Expand Down Expand Up @@ -76,23 +81,83 @@ def prepare_param_info(self, host, port, handler, replica_num, user, password, s
param_info = ParamInfo()


def analyze_documents(texts, language="en"):
stopwords = "en"
if language in ["en", "english"]:
stopwords = "en"
def get_bm25_ground_truth(corpus, queries, top_k=100, language="en"):
"""
Get the ground truth for BM25 search.
:param corpus: The corpus of documents
:param queries: The query string or list of query strings
:return: The ground truth for BM25 search
"""

def remove_punctuation(text):
text = text.strip()
text = text.replace("\n", " ")
return re.sub(r'[^\w\s]', ' ', text)

# Tokenize the corpus
def jieba_split(text):
text_without_punctuation = remove_punctuation(text)
return jieba.lcut(text_without_punctuation)

stopwords = "english" if language in ["en", "english"] else [" "]
stemmer = None
if language in ["zh", "cn", "chinese"]:
stopword = " "
new_texts = []
for doc in texts:
seg_list = jieba.cut(doc, cut_all=True)
new_texts.append(" ".join(seg_list))
texts = new_texts
stopwords = [stopword]
splitter = jieba_split
tokenizer = Tokenizer(
stemmer=stemmer, splitter=splitter, stopwords=stopwords
)
else:
tokenizer = Tokenizer(
stemmer=stemmer, stopwords=stopwords
)
corpus_tokens = tokenizer.tokenize(corpus, return_as="tuple")
retriever = bm25s.BM25()
retriever.index(corpus_tokens)
query_tokens = tokenizer.tokenize(queries,return_as="tuple")
results, scores = retriever.retrieve(query_tokens, corpus=corpus, k=top_k)
return results, scores




def custom_tokenizer(language="en"):
def remove_punctuation(text):
text = text.strip()
text = text.replace("\n", " ")
return re.sub(r'[^\w\s]', ' ', text)

# Tokenize the corpus
def jieba_split(text):
text_without_punctuation = remove_punctuation(text)
return jieba.lcut(text_without_punctuation)

def blank_space_split(text):
text_without_punctuation = remove_punctuation(text)
return text_without_punctuation.split()

stopwords = [" "]
stemmer = None
if language in ["zh", "cn", "chinese"]:
splitter = jieba_split
tokenizer = Tokenizer(
stemmer=stemmer, splitter=splitter, stopwords=stopwords
)
else:
splitter = blank_space_split
tokenizer = Tokenizer(
stemmer=stemmer, splitter= splitter, stopwords=stopwords
)
return tokenizer


def analyze_documents(texts, language="en"):

tokenizer = custom_tokenizer(language)
# Start timing
t0 = time.time()

# Tokenize the corpus
tokenized = bm25s.tokenize(texts, lower=True, stopwords=stopwords)
tokenized = tokenizer.tokenize(texts, return_as="tuple")
# log.info(f"Tokenized: {tokenized}")
# Create a frequency counter
freq = Counter()
Expand All @@ -112,25 +177,23 @@ def analyze_documents(texts, language="en"):

return word_freq

def check_token_overlap(text_a, text_b, language="en"):
word_freq_a = analyze_documents([text_a], language)
word_freq_b = analyze_documents([text_b], language)
overlap = set(word_freq_a.keys()).intersection(set(word_freq_b.keys()))
return overlap, word_freq_a, word_freq_b


def split_dataframes(df, fields, language="en"):
df_copy = df.copy()
if language in ["zh", "cn", "chinese"]:
for col in fields:
new_texts = []
for doc in df[col]:
seg_list = jieba.cut(doc, cut_all=True)
new_texts.append(list(seg_list))
df_copy[col] = new_texts
return df_copy
tokenizer = custom_tokenizer(language)
for col in fields:
texts = df[col].to_list()
tokenized = bm25s.tokenize(texts, lower=True, stopwords="en")
tokenized = tokenizer.tokenize(texts, return_as="tuple")
new_texts = []
id_vocab_map = {id: word for word, id in tokenized.vocab.items()}
for doc_ids in tokenized.ids:
new_texts.append([id_vocab_map[token_id] for token_id in doc_ids])

df_copy[col] = new_texts
return df_copy

Expand Down
6 changes: 4 additions & 2 deletions tests/python_client/common/common_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
float16_type = "FLOAT16_VECTOR"
bfloat16_type = "BFLOAT16_VECTOR"
sparse_vector = "SPARSE_FLOAT_VECTOR"
text_sparse_vector = "TEXT_SPARSE_VECTOR"
append_vector_type = [float16_type, bfloat16_type, sparse_vector]
all_dense_vector_types = [float_type, float16_type, bfloat16_type]
all_vector_data_types = [float_type, float16_type, bfloat16_type, sparse_vector]
Expand Down Expand Up @@ -254,7 +255,8 @@
default_bin_flat_index = {"index_type": "BIN_FLAT", "params": {}, "metric_type": "JACCARD"}
default_sparse_inverted_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "IP",
"params": {"drop_ratio_build": 0.2}}

default_text_sparse_inverted_index = {"index_type": "SPARSE_INVERTED_INDEX", "metric_type": "BM25",
"params": {"drop_ratio_build": 0.2, "bm25_k1": 1.5, "bm25_b": 0.75,}}
default_search_params = {"params": default_all_search_params_params[2].copy()}
default_search_ip_params = {"metric_type": "IP", "params": default_all_search_params_params[2].copy()}
default_search_binary_params = {"metric_type": "JACCARD", "params": {"nprobe": 32}}
Expand All @@ -263,7 +265,7 @@
default_diskann_index = {"index_type": "DISKANN", "metric_type": default_L0_metric, "params": {}}
default_diskann_search_params = {"params": {"search_list": 30}}
default_sparse_search_params = {"metric_type": "IP", "params": {"drop_ratio_search": "0.2"}}

default_text_sparse_search_params = {"metric_type": "BM25", "params": {}}

class CheckTasks:
""" The name of the method used to check the result """
Expand Down
7 changes: 4 additions & 3 deletions tests/python_client/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ pytest-parallel
pytest-random-order

# pymilvus
pymilvus==2.5.0rc95
pymilvus[bulk_writer]==2.5.0rc95
pymilvus==2.5.0rc101
pymilvus[bulk_writer]==2.5.0rc101

# for customize config test
python-benedict==0.24.3
Expand Down Expand Up @@ -62,9 +62,10 @@ fastparquet==2023.7.0
# for bf16 datatype
ml-dtypes==0.2.0

# for text match
# for full text search
bm25s==0.2.0
jieba==0.42.1


# for perf test
locust==2.25.0
Loading

0 comments on commit 3b024f9

Please sign in to comment.