Skip to content

Commit

Permalink
Drop row_id column when CREATE UDF xxx FROM (SELECT *)
Browse files Browse the repository at this point in the history
And add unittest case for binder
  • Loading branch information
xzdandy committed Aug 22, 2023
1 parent 5969e02 commit 4401071
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 16 deletions.
16 changes: 16 additions & 0 deletions evadb/binder/binder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
is_video_table,
)
from evadb.catalog.models.table_catalog import TableCatalogEntry
from evadb.catalog.sql_config import IDENTIFIER_COLUMN

if TYPE_CHECKING:
from evadb.binder.statement_binder_context import StatementBinderContext
Expand Down Expand Up @@ -257,3 +258,18 @@ def get_column_definition_from_select_target_list(
)
)
return binded_col_list


def drop_row_id_from_target_list(
target_list: List[AbstractExpression],
) -> List[AbstractExpression]:
"""
This function is intended to be used by CREATE UDF FROM (SELECT * FROM ...) and CREATE TABLE AS SELECT * FROM ... to exclude the row_id column.
"""
filtered_list = []
for expr in target_list:
if isinstance(expr, TupleValueExpression):
if expr.name == IDENTIFIER_COLUMN:
continue
filtered_list.append(expr)
return filtered_list
5 changes: 5 additions & 0 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
check_column_name_is_string,
check_groupby_pattern,
check_table_object_is_groupable,
drop_row_id_from_target_list,
extend_star,
get_column_definition_from_select_target_list,
handle_bind_extract_object_function,
Expand Down Expand Up @@ -75,6 +76,10 @@ def _bind_explain_statement(self, node: ExplainStatement):
def _bind_create_udf_statement(self, node: CreateUDFStatement):
if node.query is not None:
self.bind(node.query)
# Drop the automatically generated _row_id column
node.query.target_list = drop_row_id_from_target_list(
node.query.target_list
)
all_column_list = get_column_definition_from_select_target_list(
node.query.target_list
)
Expand Down
82 changes: 81 additions & 1 deletion test/binder/test_statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
from evadb.binder.statement_binder import StatementBinder
from evadb.binder.statement_binder_context import StatementBinderContext
from evadb.catalog.catalog_type import ColumnType, NdArrayType
from evadb.catalog.models.utils import ColumnCatalogEntry
from evadb.catalog.sql_config import IDENTIFIER_COLUMN
from evadb.expression.tuple_value_expression import TupleValueExpression
from evadb.parser.alias import Alias
from evadb.parser.create_statement import ColumnDefinition


class StatementBinderTests(unittest.TestCase):
Expand Down Expand Up @@ -176,7 +179,6 @@ def test_bind_func_expr(self, mock_load_udf_class_from_file):
udf_obj.impl_file_path, udf_obj.name
)
self.assertEqual(func_expr.output_objs, [obj1])
print(str(func_expr.alias))
self.assertEqual(
func_expr.alias,
Alias("func_expr", ["out1"]),
Expand Down Expand Up @@ -307,3 +309,81 @@ def test_bind_create_index(self):
binder._bind_create_index_statement(create_index_statement)
col.array_dimensions = [1, 10]
binder._bind_create_index_statement(create_index_statement)

def test_bind_create_udf_should_raise(self):
with patch.object(StatementBinder, "bind"):
create_udf_statement = MagicMock()
create_udf_statement.query.target_list = []
create_udf_statement.metadata = []
binder = StatementBinder(StatementBinderContext(MagicMock()))
with self.assertRaises(AssertionError):
binder._bind_create_udf_statement(create_udf_statement)

def test_bind_create_udf_should_drop_row_id(self):
with patch.object(StatementBinder, "bind"):
create_udf_statement = MagicMock()
row_id_col_obj = ColumnCatalogEntry(
name=IDENTIFIER_COLUMN,
type=MagicMock(),
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
input_col_obj = ColumnCatalogEntry(
name="input_column",
type=MagicMock(),
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
output_col_obj = ColumnCatalogEntry(
name="predict_column",
type=MagicMock(),
array_type=MagicMock(),
array_dimensions=MagicMock(),
)
create_udf_statement.query.target_list = [
TupleValueExpression(
name=IDENTIFIER_COLUMN, table_alias="a", col_object=row_id_col_obj
),
TupleValueExpression(
name="input_column", table_alias="a", col_object=input_col_obj
),
TupleValueExpression(
name="predict_column", table_alias="a", col_object=output_col_obj
),
]
create_udf_statement.metadata = [("predict", "predict_column")]
binder = StatementBinder(StatementBinderContext(MagicMock()))
binder._bind_create_udf_statement(create_udf_statement)

self.assertEqual(
create_udf_statement.query.target_list,
[
TupleValueExpression(
name="input_column", table_alias="a", col_object=input_col_obj
),
TupleValueExpression(
name="predict_column",
table_alias="a",
col_object=output_col_obj,
),
],
)

expected_inputs = [
ColumnDefinition(
"input_column",
input_col_obj.type,
input_col_obj.array_type,
input_col_obj.array_dimensions,
)
]
expected_outputs = [
ColumnDefinition(
"predict_column_predictions",
output_col_obj.type,
output_col_obj.array_type,
output_col_obj.array_dimensions,
)
]
self.assertEqual(create_udf_statement.inputs, expected_inputs)
self.assertEqual(create_udf_statement.outputs, expected_outputs)
17 changes: 2 additions & 15 deletions test/integration_tests/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,9 @@ def tearDownClass(cls):

@ludwig_skip_marker
def test_ludwig_automl(self):
select_query = """
SELECT
number_of_rooms,
number_of_bathrooms,
sqft,
location,
days_on_market,
initial_price,
neighborhood,
rental_price
FROM HomeRentals
"""

create_predict_udf = f"""
create_predict_udf = """
CREATE UDF IF NOT EXISTS PredictHouseRent FROM
({select_query})
( SELECT * FROM HomeRentals )
TYPE Ludwig
'predict' 'rental_price'
'time_limit' 120;
Expand Down

0 comments on commit 4401071

Please sign in to comment.