Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardZhangRZ committed Nov 18, 2023
1 parent ba10889 commit d404ac2
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 36 deletions.
21 changes: 15 additions & 6 deletions evadb/executor/create_index_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ def _create_evadb_index(self):
][0]

metadata_col_names = [include_col.name for include_col in self.include_list]
metadata_column_catalog_entries = [col for col in feat_tb_catalog_entry.columns
if col.name in metadata_col_names]


metadata_column_catalog_entries = [
col
for col in feat_tb_catalog_entry.columns
if col.name in metadata_col_names
]

if function_expression is not None:
feat_col_name = function_expression.output_objs[0].name
Expand Down Expand Up @@ -147,12 +148,20 @@ def _create_evadb_index(self):
input_batch.drop_column_alias()
feat = input_batch.column_as_numpy_array(feat_col_name)

metadata = {metadata_col_name: input_batch.column_as_numpy_array(metadata_col_name) for metadata_col_name in metadata_col_names}
metadata = {
metadata_col_name: input_batch.column_as_numpy_array(
metadata_col_name
)
for metadata_col_name in metadata_col_names
}
row_num = input_batch.column_as_numpy_array(ROW_NUM_COLUMN)

for i in range(len(input_batch)):
row_feat = feat[i].reshape(1, -1)
row_metadata = {metadata_col_name: metadata[metadata_col_name][i] for metadata_col_name in metadata_col_names}
row_metadata = {
metadata_col_name: metadata[metadata_col_name][i]
for metadata_col_name in metadata_col_names
}

# Create new index if not exists.
if index is None:
Expand Down
7 changes: 3 additions & 4 deletions evadb/optimizer/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from evadb.catalog.models.table_catalog import TableCatalogEntry
from evadb.catalog.models.utils import IndexCatalogEntry
from evadb.expression.abstract_expression import AbstractExpression
from evadb.expression.comparison_expression import ComparisonExpression
from evadb.expression.constant_value_expression import ConstantValueExpression
from evadb.expression.function_expression import FunctionExpression
from evadb.parser.alias import Alias
Expand Down Expand Up @@ -1126,7 +1125,7 @@ def table_ref(self):
@property
def col_list(self):
return self._col_list

@property
def include_list(self):
return self._include_list
Expand Down Expand Up @@ -1239,7 +1238,7 @@ def __init__(
index: IndexCatalogEntry,
limit_count: ConstantValueExpression,
search_query_expr: FunctionExpression,
filter_expr: ComparisonExpression = None,
filter_expr: AbstractExpression = None,
children: List = None,
):
super().__init__(OperatorType.LOGICAL_VECTOR_INDEX_SCAN, children)
Expand All @@ -1259,7 +1258,7 @@ def limit_count(self):
@property
def search_query_expr(self):
return self._search_query_expr

@property
def filter_expr(self):
return self._filter_expr
Expand Down
20 changes: 10 additions & 10 deletions evadb/optimizer/rules/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,9 @@ def apply(self, before: LogicalJoin, context: OptimizerContext):

class CombineWhereSimilarityOrderByAndLimitToFilteredVectorIndexScan(Rule):
"""
This rule currently rewrites Where + Order By + Limit to a filtered vector
This rule currently rewrites Where + Order By + Limit to a filtered vector
index scan. Because vector index only works for similarity search, the rule will
only be applied when the Order By is on Similarity expression.
only be applied when the Order By is on Similarity expression.
Limit(10)
|
Expand All @@ -510,15 +510,18 @@ def __init__(self):
orderby_pattern.append_child(where_pattern)
pattern.append_child(orderby_pattern)
super().__init__(
RuleType.COMBINE_WHERE_SIMILARITY_ORDERBY_AND_LIMIT_TO_FILTERED_VECTOR_INDEX_SCAN, pattern
RuleType.COMBINE_WHERE_SIMILARITY_ORDERBY_AND_LIMIT_TO_FILTERED_VECTOR_INDEX_SCAN,
pattern,
)

# Entries populate after rule eligibility validation.
self._index_catalog_entry = None
self._query_func_expr = None

def promise(self):
return Promise.COMBINE_WHERE_SIMILARITY_ORDERBY_AND_LIMIT_TO_FILTERED_VECTOR_INDEX_SCAN
return (
Promise.COMBINE_WHERE_SIMILARITY_ORDERBY_AND_LIMIT_TO_FILTERED_VECTOR_INDEX_SCAN
)

def check(self, before: LogicalLimit, context: OptimizerContext):
return True
Expand Down Expand Up @@ -597,10 +600,7 @@ def apply(self, before: LogicalLimit, context: OptimizerContext):

# Construct the Vector index scan plan (with filter condition).
vector_index_scan_node = LogicalVectorIndexScan(
index_catalog_entry,
limit_node.limit_count,
query_func_expr,
filter_expr
index_catalog_entry, limit_node.limit_count, query_func_expr, filter_expr
)
for child in orderby_node.children:
vector_index_scan_node.append_child(child)
Expand Down Expand Up @@ -1421,7 +1421,7 @@ def apply(self, before: LogicalVectorIndexScan, context: OptimizerContext):
before.index,
before.limit_count,
before.search_query_expr,
before.filter_expr
before.filter_expr,
)
for child in before.children:
after.append_child(child)
Expand Down
8 changes: 5 additions & 3 deletions evadb/parser/create_index_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self._if_not_exists = if_not_exists
self._table_ref = table_ref
self._col_list = col_list
self._include_list = include_list;
self._include_list = include_list
self._vector_store_type = vector_store_type
self._project_expr_list = project_expr_list

Expand Down Expand Up @@ -70,7 +70,9 @@ def traverse_create_function_expression_str(expr):

print_str += f" ({traverse_create_function_expression_str(function_expr)})"
if len(self.include_list) > 0:
print_str += f" INCLUDE {','.join([col_def.name for col_def in self.include_list])}"
print_str += (
f" INCLUDE {','.join([col_def.name for col_def in self.include_list])}"
)
print_str += f" USING {self._vector_store_type};"
return print_str

Expand All @@ -89,7 +91,7 @@ def table_ref(self):
@property
def col_list(self):
return self._col_list

@property
def include_list(self):
return self._include_list
Expand Down
2 changes: 1 addition & 1 deletion evadb/parser/lark_visitor/_create_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def create_index(self, tree):
index_elem = [index_elem]
else:
project_expr_list += index_elem

if include_elem:
project_expr_list += include_elem
# Add tv_expr for projected columns.
Expand Down
2 changes: 1 addition & 1 deletion evadb/plan_nodes/create_index_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def table_ref(self):
@property
def col_list(self):
return self._col_list

@property
def include_list(self):
return self._include_list
Expand Down
6 changes: 3 additions & 3 deletions evadb/plan_nodes/vector_index_scan_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from evadb.catalog.models.utils import IndexCatalogEntry
from evadb.expression.comparison_expression import ComparisonExpression
from evadb.expression.abstract_expression import AbstractExpression
from evadb.expression.constant_value_expression import ConstantValueExpression
from evadb.expression.function_expression import FunctionExpression
from evadb.plan_nodes.abstract_plan import AbstractPlan
Expand All @@ -37,7 +37,7 @@ def __init__(
index: IndexCatalogEntry,
limit_count: ConstantValueExpression,
search_query_expr: FunctionExpression,
filter_expr: ComparisonExpression = None
filter_expr: AbstractExpression = None,
):
super().__init__(PlanOprType.VECTOR_INDEX_SCAN)
self._index = index
Expand All @@ -56,7 +56,7 @@ def limit_count(self):
@property
def search_query_expr(self):
return self._search_query_expr

@property
def filter_expr(self):
return self._filter_expr
Expand Down
10 changes: 5 additions & 5 deletions evadb/third_party/vector_stores/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,13 @@ def etype_to_milvus_symbol(etype: ExpressionType):
elif etype == ExpressionType.LOGICAL_NOT:
return "not"
elif etype == ExpressionType.ARITHMETIC_ADD:
return '+'
return "+"
elif etype == ExpressionType.ARITHMETIC_DIVIDE:
return '/'
return "/"
elif etype == ExpressionType.ARITHMETIC_MULTIPLY:
return '*'
return "*"
elif etype == ExpressionType.ARITHMETIC_SUBTRACT:
return '-'
return "-"


def expression_to_milvus_expr(expr: AbstractExpression):
Expand All @@ -107,7 +107,7 @@ def expression_to_milvus_expr(expr: AbstractExpression):
return expr.value
elif isinstance(expr, TupleValueExpression):
return expr.name


class MilvusVectorStore(VectorStore):
def __init__(self, index_name: str, **kwargs) -> None:
Expand Down
5 changes: 2 additions & 3 deletions evadb/third_party/vector_stores/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import List, Dict

from evadb.expression.comparison_expression import ComparisonExpression
from evadb.expression.abstract_expression import AbstractExpression


@dataclass
Expand All @@ -29,7 +28,7 @@ class FeaturePayload:
class VectorIndexQuery:
embedding: List[float]
top_k: int
filter_expr_str: ComparisonExpression
filter_expr_str: AbstractExpression


@dataclass
Expand Down

0 comments on commit d404ac2

Please sign in to comment.