Skip to content

Commit

Permalink
Make Logic Operators Case Insensitve (#1352)
Browse files Browse the repository at this point in the history
- Fix the following queries:

```
SELECT * FROM postgres_data.home_rentals where neighborhood='downtown' and number_of_rooms=2;
```

- Improve the error message: Instead of throwing arbitrary mask error,
now we raise `Unsupported Logical Operator: ...`.
  • Loading branch information
xzdandy authored Nov 14, 2023
1 parent 1c78b22 commit 995f7fe
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 3 deletions.
7 changes: 5 additions & 2 deletions evadb/parser/lark_visitor/_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from evadb.expression.comparison_expression import ComparisonExpression
from evadb.expression.constant_value_expression import ConstantValueExpression
from evadb.expression.logical_expression import LogicalExpression
from evadb.utils.generic_utils import string_comparison_case_insensitive


##################################################################
Expand Down Expand Up @@ -101,10 +102,12 @@ def comparison_operator(self, tree):
def logical_operator(self, tree):
op = str(tree.children[0])

if op == "OR":
if string_comparison_case_insensitive(op, "OR"):
return ExpressionType.LOGICAL_OR
elif op == "AND":
elif string_comparison_case_insensitive(op, "AND"):
return ExpressionType.LOGICAL_AND
else:
raise NotImplementedError("Unsupported logical operator: {}".format(op))

def expressions_with_defaults(self, tree):
expr_list = []
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def read(path, encoding="utf-8"):
"sentence-transformers",
"protobuf",
"bs4",
"openai>=0.27.4", # CHATGPT
"openai==0.28", # CHATGPT
"gpt4all", # PRIVATE GPT
"sentencepiece", # TRANSFORMERS
]
Expand Down
72 changes: 72 additions & 0 deletions test/unit_tests/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from evadb.expression.comparison_expression import ComparisonExpression
from evadb.expression.constant_value_expression import ConstantValueExpression
from evadb.expression.function_expression import FunctionExpression
from evadb.expression.logical_expression import LogicalExpression
from evadb.expression.tuple_value_expression import TupleValueExpression
from evadb.parser.alias import Alias
from evadb.parser.create_function_statement import CreateFunctionStatement
Expand Down Expand Up @@ -531,6 +532,77 @@ def test_select_statement_class(self):
self.assertEqual(select_stmt_new.from_table, select_stmt.from_table)
self.assertEqual(str(select_stmt_new), str(select_stmt))

def test_select_statement_where_class(self):
"""
Unit test for logical operators in the where clause.
"""

def _verify_select_statement(evadb_statement_list):
self.assertIsInstance(evadb_statement_list, list)
self.assertEqual(len(evadb_statement_list), 1)
self.assertEqual(evadb_statement_list[0].stmt_type, StatementType.SELECT)

select_stmt = evadb_statement_list[0]

# target list
self.assertIsNotNone(select_stmt.target_list)
self.assertEqual(len(select_stmt.target_list), 2)
self.assertEqual(
select_stmt.target_list[0].etype, ExpressionType.TUPLE_VALUE
)
self.assertEqual(select_stmt.target_list[0].name, "CLASS")
self.assertEqual(
select_stmt.target_list[1].etype, ExpressionType.TUPLE_VALUE
)
self.assertEqual(select_stmt.target_list[1].name, "REDNESS")

# from table
self.assertIsNotNone(select_stmt.from_table)
self.assertIsInstance(select_stmt.from_table, TableRef)
self.assertEqual(select_stmt.from_table.table.table_name, "TAIPAI")

# where clause
self.assertIsNotNone(select_stmt.where_clause)
self.assertIsInstance(select_stmt.where_clause, LogicalExpression)
self.assertEqual(select_stmt.where_clause.etype, ExpressionType.LOGICAL_AND)
self.assertEqual(len(select_stmt.where_clause.children), 2)
left = select_stmt.where_clause.children[0]
right = select_stmt.where_clause.children[1]
self.assertEqual(left.etype, ExpressionType.COMPARE_EQUAL)
self.assertEqual(right.etype, ExpressionType.COMPARE_LESSER)

self.assertEqual(len(left.children), 2)
self.assertEqual(left.children[0].etype, ExpressionType.TUPLE_VALUE)
self.assertEqual(left.children[0].name, "CLASS")
self.assertEqual(left.children[1].etype, ExpressionType.CONSTANT_VALUE)
self.assertEqual(left.children[1].value, "VAN")

self.assertEqual(len(right.children), 2)
self.assertEqual(right.children[0].etype, ExpressionType.TUPLE_VALUE)
self.assertEqual(right.children[0].name, "REDNESS")
self.assertEqual(right.children[1].etype, ExpressionType.CONSTANT_VALUE)
self.assertEqual(right.children[1].value, 400)

parser = Parser()
select_query = (
"SELECT CLASS, REDNESS FROM TAIPAI WHERE CLASS = 'VAN' AND REDNESS < 400;"
)
_verify_select_statement(parser.parse(select_query))

# Case insensitive test
select_query = (
"select CLASS, REDNESS from TAIPAI where CLASS = 'VAN' and REDNESS < 400;"
)
_verify_select_statement(parser.parse(select_query))

# Unsupported logical operator
select_query = (
"SELECT CLASS, REDNESS FROM TAIPAI WHERE CLASS = 'VAN' XOR REDNESS < 400;"
)
with self.assertRaises(NotImplementedError) as cm:
parser.parse(select_query)
self.assertEqual(str(cm.exception), "Unsupported logical operator: XOR")

def test_select_statement_groupby_class(self):
"""Testing sample frequency"""

Expand Down

0 comments on commit 995f7fe

Please sign in to comment.