Skip to content

Commit

Permalink
Add test cases of array contains (#27889)
Browse files Browse the repository at this point in the history
Signed-off-by: nico <[email protected]>
  • Loading branch information
NicoYuan1986 authored Oct 24, 2023
1 parent b9d5ef3 commit ec99eb1
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 60 deletions.
41 changes: 9 additions & 32 deletions tests/python_client/common/common_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -1315,54 +1315,31 @@ def index_to_dict(index):


def assert_json_contains(expr, list_data):
opposite = False
if expr.startswith("not"):
opposite = True
expr = expr.split("not ", 1)[1]
result_ids = []
expr_prefix = expr.split('(', 1)[0]
exp_ids = eval(expr.split(', ', 1)[1].split(')', 1)[0])
if expr_prefix in ["json_contains", "JSON_CONTAINS"]:
if expr_prefix in ["json_contains", "JSON_CONTAINS", "array_contains", "ARRAY_CONTAINS"]:
for i in range(len(list_data)):
if exp_ids in list_data[i]:
result_ids.append(i)
elif expr_prefix in ["json_contains_all", "JSON_CONTAINS_ALL"]:
elif expr_prefix in ["json_contains_all", "JSON_CONTAINS_ALL", "array_contains_all", "ARRAY_CONTAINS_ALL"]:
for i in range(len(list_data)):
set_list_data = set(tuple(element) if isinstance(element, list) else element for element in list_data[i])
if set(exp_ids).issubset(set_list_data):
result_ids.append(i)
elif expr_prefix in ["json_contains_any", "JSON_CONTAINS_ANY"]:
elif expr_prefix in ["json_contains_any", "JSON_CONTAINS_ANY", "array_contains_any", "ARRAY_CONTAINS_ANY"]:
for i in range(len(list_data)):
set_list_data = set(tuple(element) if isinstance(element, list) else element for element in list_data[i])
if set(exp_ids) & set_list_data:
result_ids.append(i)
else:
log.warning("unknown expr: %s" % expr)
return result_ids


def assert_array_contains(expr, list_data):
nb = len(list_data)
result_ids = []
exp_ids = eval(expr.split(', ', 1)[1].split(')', 1)[0])
reverse = True if "not array" or "not ARRAY" in expr else False
expr_prefix = expr.split('(', 1)[0]
if "array_contains_any" or "ARRAY_CONTAINS_ANY" in expr_prefix:
for i in range(nb):
set_list_data = set(tuple(element) if isinstance(element, list) else element for element in list_data[i])
if set(exp_ids) & set_list_data:
result_ids.append(i)
elif "array_contains_all" or "ARRAY_CONTAINS_ALL" in expr_prefix:
for i in range(nb):
set_list_data = set(tuple(element) if isinstance(element, list) else element for element in list_data[i])
if set(exp_ids).issubset(set_list_data):
result_ids.append(i)
elif "array_contains" or "ARRAY_CONTAINS" in expr_prefix:
for i in range(nb):
if exp_ids in list_data[i]:
result_ids.append(i)
else:
log.warning("unknown expr: %s" % expr)

if reverse:
result_ids = [x for x in result_ids if x not in range(nb)]

if opposite:
result_ids = [i for i in range(len(list_data)) if i not in result_ids]
return result_ids


Expand Down
79 changes: 51 additions & 28 deletions tests/python_client/testcases/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,16 +573,16 @@ def test_query_expr_non_constant_array_term(self):
collection_w.query(term_expr, check_task=CheckTasks.err_res, check_items=error)

@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("expr_prefix", ["json_contains", "JSON_CONTAINS"])
@pytest.mark.parametrize("expr_prefix", ["json_contains", "JSON_CONTAINS",
"array_contains", "ARRAY_CONTAINS"])
def test_query_expr_json_contains(self, enable_dynamic_field, expr_prefix):
"""
target: test query with expression using json_contains
method: query with expression using json_contains
expected: succeed
"""
# 1. initialize with data
collection_w = self.init_collection_general(
prefix, enable_dynamic_field=enable_dynamic_field)[0]
collection_w = self.init_collection_general(prefix, enable_dynamic_field=enable_dynamic_field)[0]

# 2. insert data
array = cf.gen_default_rows_data()
Expand All @@ -608,8 +608,7 @@ def test_query_expr_list_json_contains(self, expr_prefix):
expected: succeed
"""
# 1. initialize with data
collection_w = self.init_collection_general(
prefix, enable_dynamic_field=True)[0]
collection_w = self.init_collection_general(prefix, enable_dynamic_field=True)[0]

# 2. insert data
limit = ct.default_nb // 4
Expand Down Expand Up @@ -656,7 +655,8 @@ def test_query_expr_json_contains_combined_with_normal(self, enable_dynamic_fiel
assert len(res) == limit // 2

@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("expr_prefix", ["json_contains_all", "JSON_CONTAINS_ALL"])
@pytest.mark.parametrize("expr_prefix", ["json_contains_all", "JSON_CONTAINS_ALL",
"array_contains_all", "ARRAY_CONTAINS_ALL"])
def test_query_expr_all_datatype_json_contains_all(self, enable_dynamic_field, expr_prefix):
"""
target: test query with expression using json_contains
Expand Down Expand Up @@ -865,7 +865,8 @@ def test_query_expr_all_datatype_json_contains_any(self, enable_dynamic_field, e
assert len(res) == ct.default_nb // 2

@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("expr_prefix", ["json_contains_any", "JSON_CONTAINS_ANY"])
@pytest.mark.parametrize("expr_prefix", ["json_contains_any", "JSON_CONTAINS_ANY",
"array_contains_any", "ARRAY_CONTAINS_ANY"])
def test_query_expr_list_all_datatype_json_contains_any(self, expr_prefix):
"""
target: test query with expression using json_contains_any
Expand Down Expand Up @@ -1019,49 +1020,71 @@ def test_query_expr_json_contains_pagination(self, enable_dynamic_field, expr_pr
assert len(res) == limit - offset

@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.skip("Too many are not supported")
@pytest.mark.parametrize("expression", cf.gen_normal_expressions_field("array_length(float_array)")[1:])
def test_query_expr_array_length(self, expression, enable_dynamic_field):
@pytest.mark.parametrize("array_length", ["ARRAY_LENGTH", "array_length"])
@pytest.mark.parametrize("op", ["==", "!="])
def test_query_expr_array_length(self, array_length, op, enable_dynamic_field):
"""
target: test query with expression using json_contains_any
method: query with expression using json_contains_any
target: test query with expression using array_length
method: query with expression using array_length
array_length only support == , !=
expected: succeed
"""
# 1. create a collection
nb = ct.default_nb
max_capacity = 1000
schema = cf.gen_array_collection_schema(max_capacity=max_capacity)
schema = cf.gen_array_collection_schema()
collection_w = self.init_collection_wrap(schema=schema, enable_dynamic_field=enable_dynamic_field)

# 2. insert data
data = []
data = cf.gen_array_dataframe_data()
length = []
for i in range(nb):
array_length = random.randint(0, max_capacity)
length.append(array_length)
arr = {ct.default_int64_field_name: i,
ct.default_float_vec_field_name: cf.gen_vectors(1, ct.default_dim)[0],
ct.default_int32_array_field_name: [],
ct.default_float_array_field_name: [np.float32(i) for i in range(array_length)],
ct.default_string_array_field_name: []}
data.append(arr)
for i in range(ct.default_nb):
ran_int = random.randint(50, 53)
length.append(ran_int)

data[ct.default_float_array_field_name] = \
[[np.float32(j) for j in range(length[i])] for i in range(ct.default_nb)]
collection_w.insert(data)

# 3. load and query
collection_w.create_index(ct.default_float_vec_field_name, ct.default_flat_index)
collection_w.load()
expression = f"{array_length}({ct.default_float_array_field_name}) {op} 51"
res = collection_w.query(expression)[0]

# 4. check
expression = expression.replace("&&", "and").replace("||", "or")
expression = expression.replace("array_length(float_array)", "array_length")
expression = expression.replace(f"{array_length}(float_array)", "array_length")
filter_ids = []
for i in range(nb):
for i in range(ct.default_nb):
array_length = length[i]
if not expression or eval(expression):
filter_ids.append(i)
assert len(res) == len(filter_ids)

@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("op", [">", "<=", "+ 1 =="])
def test_query_expr_invalid_array_length(self, op):
"""
target: test query with expression using array_length
method: query with expression using array_length
array_length only support == , !=
expected: raise error
"""
# 1. create a collection
schema = cf.gen_array_collection_schema()
collection_w = self.init_collection_wrap(schema=schema)

# 2. insert data
data = cf.gen_array_dataframe_data()
collection_w.insert(data)

# 3. load and query
collection_w.create_index(ct.default_float_vec_field_name, ct.default_flat_index)
collection_w.load()
expression = f"array_length({ct.default_float_array_field_name}) {op} 51"
collection_w.query(expression, check_task=CheckTasks.err_res,
check_items={ct.err_code: 65535,
ct.err_msg: "cannot parse expression: %s, error %s "
"is not supported" % (expression, op)})

@pytest.mark.tags(CaseLabel.L1)
def test_query_expr_empty_without_limit(self):
"""
Expand Down
138 changes: 138 additions & 0 deletions tests/python_client/testcases/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9135,6 +9135,144 @@ def test_search_expression_json_contains_combined_with_normal(self, enable_dynam
check_items={"nq": default_nq,
"limit": limit // 2})

@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("expr_prefix", ["array_contains", "ARRAY_CONTAINS"])
def test_search_expr_array_contains(self, expr_prefix):
"""
target: test query with expression using json_contains
method: query with expression using json_contains
expected: succeed
"""
# 1. create a collection
schema = cf.gen_array_collection_schema()
collection_w = self.init_collection_wrap(schema=schema)

# 2. insert data
string_field_value = [[str(j) for j in range(i, i+3)] for i in range(ct.default_nb)]
data = cf.gen_array_dataframe_data()
data[ct.default_string_array_field_name] = string_field_value
collection_w.insert(data)
collection_w.create_index(ct.default_float_vec_field_name, {})

# 3. search
collection_w.load()
expression = f"{expr_prefix}({ct.default_string_array_field_name}, '1000')"
res = collection_w.search(vectors[:default_nq], default_search_field, {},
limit=ct.default_nb, expr=expression)[0]
exp_ids = cf.assert_json_contains(expression, string_field_value)
assert set(res[0].ids) == set(exp_ids)

@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("expr_prefix", ["array_contains", "ARRAY_CONTAINS"])
def test_search_expr_not_array_contains(self, expr_prefix):
"""
target: test query with expression using json_contains
method: query with expression using json_contains
expected: succeed
"""
# 1. create a collection
schema = cf.gen_array_collection_schema()
collection_w = self.init_collection_wrap(schema=schema)

# 2. insert data
string_field_value = [[str(j) for j in range(i, i + 3)] for i in range(ct.default_nb)]
data = cf.gen_array_dataframe_data()
data[ct.default_string_array_field_name] = string_field_value
collection_w.insert(data)
collection_w.create_index(ct.default_float_vec_field_name, {})

# 3. search
collection_w.load()
expression = f"not {expr_prefix}({ct.default_string_array_field_name}, '1000')"
res = collection_w.search(vectors[:default_nq], default_search_field, {},
limit=ct.default_nb, expr=expression)[0]
exp_ids = cf.assert_json_contains(expression, string_field_value)
assert set(res[0].ids) == set(exp_ids)

@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("expr_prefix", ["array_contains_all", "ARRAY_CONTAINS_ALL"])
def test_search_expr_array_contains_all(self, expr_prefix):
"""
target: test query with expression using json_contains
method: query with expression using json_contains
expected: succeed
"""
# 1. create a collection
schema = cf.gen_array_collection_schema()
collection_w = self.init_collection_wrap(schema=schema)

# 2. insert data
string_field_value = [[str(j) for j in range(i, i + 3)] for i in range(ct.default_nb)]
data = cf.gen_array_dataframe_data()
data[ct.default_string_array_field_name] = string_field_value
collection_w.insert(data)
collection_w.create_index(ct.default_float_vec_field_name, {})

# 3. search
collection_w.load()
expression = f"{expr_prefix}({ct.default_string_array_field_name}, ['1000'])"
res = collection_w.search(vectors[:default_nq], default_search_field, {},
limit=ct.default_nb, expr=expression)[0]
exp_ids = cf.assert_json_contains(expression, string_field_value)
assert set(res[0].ids) == set(exp_ids)

@pytest.mark.tags(CaseLabel.L1)
@pytest.mark.parametrize("expr_prefix", ["array_contains_any", "ARRAY_CONTAINS_ANY",
"not array_contains_any", "not ARRAY_CONTAINS_ANY"])
def test_search_expr_array_contains_any(self, expr_prefix):
"""
target: test query with expression using json_contains
method: query with expression using json_contains
expected: succeed
"""
# 1. create a collection
schema = cf.gen_array_collection_schema()
collection_w = self.init_collection_wrap(schema=schema)

# 2. insert data
string_field_value = [[str(j) for j in range(i, i + 3)] for i in range(ct.default_nb)]
data = cf.gen_array_dataframe_data()
data[ct.default_string_array_field_name] = string_field_value
collection_w.insert(data)
collection_w.create_index(ct.default_float_vec_field_name, {})

# 3. search
collection_w.load()
expression = f"{expr_prefix}({ct.default_string_array_field_name}, ['1000'])"
res = collection_w.search(vectors[:default_nq], default_search_field, {},
limit=ct.default_nb, expr=expression)[0]
exp_ids = cf.assert_json_contains(expression, string_field_value)
assert set(res[0].ids) == set(exp_ids)

@pytest.mark.tags(CaseLabel.L2)
@pytest.mark.parametrize("expr_prefix", ["array_contains_all", "ARRAY_CONTAINS_ALL",
"array_contains_any", "ARRAY_CONTAINS_ANY"])
def test_search_expr_array_contains_invalid(self, expr_prefix):
"""
target: test query with expression using json_contains
method: query with expression using json_contains(a, b) b not list
expected: report error
"""
# 1. create a collection
schema = cf.gen_array_collection_schema()
collection_w = self.init_collection_wrap(schema=schema)

# 2. insert data
data = cf.gen_array_dataframe_data()
collection_w.insert(data)
collection_w.create_index(ct.default_float_vec_field_name, {})

# 3. search
collection_w.load()
expression = f"{expr_prefix}({ct.default_string_array_field_name}, '1000')"
collection_w.search(vectors[:default_nq], default_search_field, {},
limit=ct.default_nb, expr=expression,
check_task=CheckTasks.err_res,
check_items={ct.err_code: 65535,
ct.err_msg: "failed to create query plan: cannot parse "
"expression: %s, error: contains_any operation "
"element must be an array" % expression})


class TestSearchIterator(TestcaseBase):
""" Test case of search iterator """
Expand Down

0 comments on commit ec99eb1

Please sign in to comment.