Skip to content

Commit

Permalink
Support SELECT Func(*) FROM Table. (#946)
Browse files Browse the repository at this point in the history
This aims to improve the user experience of using a trained model for
inference in EvaDB, when the training involves a large number of
columns.

👋 Thanks for submitting a Pull Request to EvaDB!

🙌 We want to make contributing to EvaDB as easy and transparent as
possible. Here are a few tips to get you started:

- 🔍 Search existing EvaDB
[PRs](https://github.com/georgia-tech-db/eva/pulls) to see if a similar
PR already exists.
- 🔗 Link this PR to a EvaDB
[issue](https://github.com/georgia-tech-db/eva/issues) to help us
understand what bug fix or feature is being implemented.
- 📈 Provide before and after profiling results to help us quantify the
improvement your PR provides (if applicable).

👉 Please see our ✅ [Contributing
Guide](https://evadb.readthedocs.io/en/stable/source/contribute/index.html)
for more details.
  • Loading branch information
xzdandy authored Aug 26, 2023
1 parent c6a9125 commit 9e268b3
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 10 deletions.
6 changes: 6 additions & 0 deletions docs/source/reference/udfs/model-train.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ You can also simply give all other columns in `HomeRentals` as inputs and let th
CREATE PredictHouseRent(sqft, location) FROM HomeRentals;
You can also simply give all columns in `HomeRentals` as inputs for inference. The customized UDF with the underlying model can figure out the proper inference columns via the training columns.

.. code-block:: sql
CREATE PredictHouseRent(*) FROM HomeRentals;
Check out our `Integration Tests <https://github.com/georgia-tech-db/evadb/blob/master/test/integration_tests/test_model_train.py>`_ for working example.


8 changes: 8 additions & 0 deletions evadb/binder/statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ def _bind_func_expr(self, node: FunctionExpression):
if node.name.upper() == str(UDFType.EXTRACT_OBJECT):
handle_bind_extract_object_function(node, self)
return

# Handle Func(*)
if (
len(node.children) == 1
and isinstance(node.children[0], TupleValueExpression)
and node.children[0].name == "*"
):
node.children = extend_star(self._binder_context)
# bind all the children
for child in node.children:
self.bind(child)
Expand Down
2 changes: 1 addition & 1 deletion evadb/parser/evadb.lark
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ if_not_exists: IF NOT EXISTS
function_call: udf_function ->udf_function_call
| aggregate_windowed_function ->aggregate_function_call

udf_function: simple_id "(" function_args ")" dotted_id?
udf_function: simple_id "(" (STAR | function_args) ")" dotted_id?

aggregate_windowed_function: aggregate_function_name "(" function_arg ")"
| COUNT "(" (STAR | function_arg) ")"
Expand Down
3 changes: 3 additions & 0 deletions evadb/parser/lark_visitor/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def udf_function(self, tree):
udf_args = None

for child in tree.children:
if isinstance(child, Token):
if child.value == "*":
udf_args = [TupleValueExpression(name="*")]
if isinstance(child, Tree):
if child.data == "simple_id":
udf_name = self.visit(child)
Expand Down
27 changes: 27 additions & 0 deletions test/binder/test_statement_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,33 @@ def test_bind_explain_statement(self):
binder._bind_explain_statement(stmt)
mock_binder.assert_called_with(stmt.explainable_stmt)

def test_bind_func_expr_with_star(self):
func_expr = MagicMock(
name="func_expr", alias=Alias("func_expr"), output_col_aliases=[]
)
func_expr.name.lower.return_value = "func_expr"
func_expr.children = [TupleValueExpression(name="*")]

binderContext = MagicMock()
tvp1 = ("T", "col1")
tvp2 = ("T", "col2")
binderContext._catalog.return_value.get_udf_catalog_entry_by_name.return_value = (
None
)
binderContext._get_all_alias_and_col_name.return_value = [tvp1, tvp2]

with patch.object(StatementBinder, "bind") as mock_binder:
binder = StatementBinder(binderContext)
with self.assertRaises(BinderError):
binder._bind_func_expr(func_expr)
call1, call2 = mock_binder.call_args_list
self.assertEqual(
call1.args[0], TupleValueExpression(name=tvp1[1], table_alias=tvp1[0])
)
self.assertEqual(
call2.args[0], TupleValueExpression(name=tvp2[1], table_alias=tvp2[0])
)

@patch("evadb.binder.statement_binder.load_udf_class_from_file")
def test_bind_func_expr(self, mock_load_udf_class_from_file):
# setup
Expand Down
10 changes: 1 addition & 9 deletions test/integration_tests/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,7 @@ def test_ludwig_automl(self):
execute_query_fetch_all(self.evadb, create_predict_udf)

predict_query = """
SELECT PredictHouseRent(
number_of_rooms,
number_of_bathrooms,
sqft,
location,
days_on_market,
initial_price,
neighborhood
) FROM HomeRentals LIMIT 10;
SELECT PredictHouseRent(*) FROM HomeRentals LIMIT 10;
"""
result = execute_query_fetch_all(self.evadb, predict_query)
self.assertEqual(len(result.columns), 1)
Expand Down
30 changes: 30 additions & 0 deletions test/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,36 @@ def test_select_statement_sample_class(self):
# sample_freq
self.assertEqual(select_stmt.from_table.sample_freq, ConstantValueExpression(5))

def test_select_udf_star(self):
parser = Parser()

query = "SELECT DemoUDF(*) FROM DemoDB.DemoTable"
evadb_stmt_list = parser.parse(query)

# check stmt itself
self.assertIsInstance(evadb_stmt_list, list)
self.assertEqual(len(evadb_stmt_list), 1)
self.assertEqual(evadb_stmt_list[0].stmt_type, StatementType.SELECT)
select_stmt = evadb_stmt_list[0]

# target List
self.assertIsNotNone(select_stmt.target_list)
self.assertEqual(len(select_stmt.target_list), 1)
self.assertEqual(
select_stmt.target_list[0].etype, ExpressionType.FUNCTION_EXPRESSION
)
self.assertEqual(len(select_stmt.target_list[0].children), 1)
self.assertEqual(
select_stmt.target_list[0].children[0].etype, ExpressionType.TUPLE_VALUE
)
self.assertEqual(select_stmt.target_list[0].children[0].name, "*")

# from_table
self.assertIsNotNone(select_stmt.from_table)
self.assertIsInstance(select_stmt.from_table, TableRef)
self.assertEqual(select_stmt.from_table.table.table_name, "DemoTable")
self.assertEqual(select_stmt.from_table.table.database_name, "DemoDB")

def test_table_ref(self):
"""Testing table info in TableRef
Class: TableInfo
Expand Down

0 comments on commit 9e268b3

Please sign in to comment.