diff --git a/evadb/executor/create_index_executor.py b/evadb/executor/create_index_executor.py index 31421358a7..4c5e46a7b9 100644 --- a/evadb/executor/create_index_executor.py +++ b/evadb/executor/create_index_executor.py @@ -102,10 +102,11 @@ def _create_evadb_index(self): ][0] metadata_col_names = [include_col.name for include_col in self.include_list] - metadata_column_catalog_entries = [col for col in feat_tb_catalog_entry.columns - if col.name in metadata_col_names] - - + metadata_column_catalog_entries = [ + col + for col in feat_tb_catalog_entry.columns + if col.name in metadata_col_names + ] if function_expression is not None: feat_col_name = function_expression.output_objs[0].name @@ -147,12 +148,20 @@ def _create_evadb_index(self): input_batch.drop_column_alias() feat = input_batch.column_as_numpy_array(feat_col_name) - metadata = {metadata_col_name: input_batch.column_as_numpy_array(metadata_col_name) for metadata_col_name in metadata_col_names} + metadata = { + metadata_col_name: input_batch.column_as_numpy_array( + metadata_col_name + ) + for metadata_col_name in metadata_col_names + } row_num = input_batch.column_as_numpy_array(ROW_NUM_COLUMN) for i in range(len(input_batch)): row_feat = feat[i].reshape(1, -1) - row_metadata = {metadata_col_name: metadata[metadata_col_name][i] for metadata_col_name in metadata_col_names} + row_metadata = { + metadata_col_name: metadata[metadata_col_name][i] + for metadata_col_name in metadata_col_names + } # Create new index if not exists. if index is None: diff --git a/evadb/optimizer/operators.py b/evadb/optimizer/operators.py index 1880f11877..b51e45db79 100644 --- a/evadb/optimizer/operators.py +++ b/evadb/optimizer/operators.py @@ -24,7 +24,6 @@ from evadb.catalog.models.table_catalog import TableCatalogEntry from evadb.catalog.models.utils import IndexCatalogEntry from evadb.expression.abstract_expression import AbstractExpression -from evadb.expression.comparison_expression import ComparisonExpression from evadb.expression.constant_value_expression import ConstantValueExpression from evadb.expression.function_expression import FunctionExpression from evadb.parser.alias import Alias @@ -1126,7 +1125,7 @@ def table_ref(self): @property def col_list(self): return self._col_list - + @property def include_list(self): return self._include_list @@ -1239,7 +1238,7 @@ def __init__( index: IndexCatalogEntry, limit_count: ConstantValueExpression, search_query_expr: FunctionExpression, - filter_expr: ComparisonExpression = None, + filter_expr: AbstractExpression = None, children: List = None, ): super().__init__(OperatorType.LOGICAL_VECTOR_INDEX_SCAN, children) @@ -1259,7 +1258,7 @@ def limit_count(self): @property def search_query_expr(self): return self._search_query_expr - + @property def filter_expr(self): return self._filter_expr diff --git a/evadb/optimizer/rules/rules.py b/evadb/optimizer/rules/rules.py index e00af9bcaa..3c80196bb7 100644 --- a/evadb/optimizer/rules/rules.py +++ b/evadb/optimizer/rules/rules.py @@ -489,9 +489,9 @@ def apply(self, before: LogicalJoin, context: OptimizerContext): class CombineWhereSimilarityOrderByAndLimitToFilteredVectorIndexScan(Rule): """ - This rule currently rewrites Where + Order By + Limit to a filtered vector + This rule currently rewrites Where + Order By + Limit to a filtered vector index scan. Because vector index only works for similarity search, the rule will - only be applied when the Order By is on Similarity expression. + only be applied when the Order By is on Similarity expression. Limit(10) | @@ -510,15 +510,18 @@ def __init__(self): orderby_pattern.append_child(where_pattern) pattern.append_child(orderby_pattern) super().__init__( - RuleType.COMBINE_WHERE_SIMILARITY_ORDERBY_AND_LIMIT_TO_FILTERED_VECTOR_INDEX_SCAN, pattern + RuleType.COMBINE_WHERE_SIMILARITY_ORDERBY_AND_LIMIT_TO_FILTERED_VECTOR_INDEX_SCAN, + pattern, ) # Entries populate after rule eligibility validation. self._index_catalog_entry = None self._query_func_expr = None - + def promise(self): - return Promise.COMBINE_WHERE_SIMILARITY_ORDERBY_AND_LIMIT_TO_FILTERED_VECTOR_INDEX_SCAN + return ( + Promise.COMBINE_WHERE_SIMILARITY_ORDERBY_AND_LIMIT_TO_FILTERED_VECTOR_INDEX_SCAN + ) def check(self, before: LogicalLimit, context: OptimizerContext): return True @@ -597,10 +600,7 @@ def apply(self, before: LogicalLimit, context: OptimizerContext): # Construct the Vector index scan plan (with filter condition). vector_index_scan_node = LogicalVectorIndexScan( - index_catalog_entry, - limit_node.limit_count, - query_func_expr, - filter_expr + index_catalog_entry, limit_node.limit_count, query_func_expr, filter_expr ) for child in orderby_node.children: vector_index_scan_node.append_child(child) @@ -1421,7 +1421,7 @@ def apply(self, before: LogicalVectorIndexScan, context: OptimizerContext): before.index, before.limit_count, before.search_query_expr, - before.filter_expr + before.filter_expr, ) for child in before.children: after.append_child(child) diff --git a/evadb/parser/create_index_statement.py b/evadb/parser/create_index_statement.py index 55d4ea4d74..5040c60ebe 100644 --- a/evadb/parser/create_index_statement.py +++ b/evadb/parser/create_index_statement.py @@ -40,7 +40,7 @@ def __init__( self._if_not_exists = if_not_exists self._table_ref = table_ref self._col_list = col_list - self._include_list = include_list; + self._include_list = include_list self._vector_store_type = vector_store_type self._project_expr_list = project_expr_list @@ -70,7 +70,9 @@ def traverse_create_function_expression_str(expr): print_str += f" ({traverse_create_function_expression_str(function_expr)})" if len(self.include_list) > 0: - print_str += f" INCLUDE {','.join([col_def.name for col_def in self.include_list])}" + print_str += ( + f" INCLUDE {','.join([col_def.name for col_def in self.include_list])}" + ) print_str += f" USING {self._vector_store_type};" return print_str @@ -89,7 +91,7 @@ def table_ref(self): @property def col_list(self): return self._col_list - + @property def include_list(self): return self._include_list diff --git a/evadb/parser/lark_visitor/_create_statements.py b/evadb/parser/lark_visitor/_create_statements.py index 23cbcee5ce..0e6d279e75 100644 --- a/evadb/parser/lark_visitor/_create_statements.py +++ b/evadb/parser/lark_visitor/_create_statements.py @@ -274,7 +274,7 @@ def create_index(self, tree): index_elem = [index_elem] else: project_expr_list += index_elem - + if include_elem: project_expr_list += include_elem # Add tv_expr for projected columns. diff --git a/evadb/plan_nodes/create_index_plan.py b/evadb/plan_nodes/create_index_plan.py index 0368259839..5e807f60e7 100644 --- a/evadb/plan_nodes/create_index_plan.py +++ b/evadb/plan_nodes/create_index_plan.py @@ -60,7 +60,7 @@ def table_ref(self): @property def col_list(self): return self._col_list - + @property def include_list(self): return self._include_list diff --git a/evadb/plan_nodes/vector_index_scan_plan.py b/evadb/plan_nodes/vector_index_scan_plan.py index b391c15545..6e2bf4ff3b 100644 --- a/evadb/plan_nodes/vector_index_scan_plan.py +++ b/evadb/plan_nodes/vector_index_scan_plan.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from evadb.catalog.models.utils import IndexCatalogEntry -from evadb.expression.comparison_expression import ComparisonExpression +from evadb.expression.abstract_expression import AbstractExpression from evadb.expression.constant_value_expression import ConstantValueExpression from evadb.expression.function_expression import FunctionExpression from evadb.plan_nodes.abstract_plan import AbstractPlan @@ -37,7 +37,7 @@ def __init__( index: IndexCatalogEntry, limit_count: ConstantValueExpression, search_query_expr: FunctionExpression, - filter_expr: ComparisonExpression = None + filter_expr: AbstractExpression = None, ): super().__init__(PlanOprType.VECTOR_INDEX_SCAN) self._index = index @@ -56,7 +56,7 @@ def limit_count(self): @property def search_query_expr(self): return self._search_query_expr - + @property def filter_expr(self): return self._filter_expr diff --git a/evadb/third_party/vector_stores/milvus.py b/evadb/third_party/vector_stores/milvus.py index 4d752d97d8..817858039f 100644 --- a/evadb/third_party/vector_stores/milvus.py +++ b/evadb/third_party/vector_stores/milvus.py @@ -84,13 +84,13 @@ def etype_to_milvus_symbol(etype: ExpressionType): elif etype == ExpressionType.LOGICAL_NOT: return "not" elif etype == ExpressionType.ARITHMETIC_ADD: - return '+' + return "+" elif etype == ExpressionType.ARITHMETIC_DIVIDE: - return '/' + return "/" elif etype == ExpressionType.ARITHMETIC_MULTIPLY: - return '*' + return "*" elif etype == ExpressionType.ARITHMETIC_SUBTRACT: - return '-' + return "-" def expression_to_milvus_expr(expr: AbstractExpression): @@ -107,7 +107,7 @@ def expression_to_milvus_expr(expr: AbstractExpression): return expr.value elif isinstance(expr, TupleValueExpression): return expr.name - + class MilvusVectorStore(VectorStore): def __init__(self, index_name: str, **kwargs) -> None: diff --git a/evadb/third_party/vector_stores/types.py b/evadb/third_party/vector_stores/types.py index a87ae5143f..c9da9bc4a1 100644 --- a/evadb/third_party/vector_stores/types.py +++ b/evadb/third_party/vector_stores/types.py @@ -14,8 +14,7 @@ # limitations under the License. from dataclasses import dataclass from typing import List, Dict - -from evadb.expression.comparison_expression import ComparisonExpression +from evadb.expression.abstract_expression import AbstractExpression @dataclass @@ -29,7 +28,7 @@ class FeaturePayload: class VectorIndexQuery: embedding: List[float] top_k: int - filter_expr_str: ComparisonExpression + filter_expr_str: AbstractExpression @dataclass