diff --git a/semantic_model_generator/data_processing/cte_utils.py b/semantic_model_generator/data_processing/cte_utils.py index 66d4b840..8de7418e 100644 --- a/semantic_model_generator/data_processing/cte_utils.py +++ b/semantic_model_generator/data_processing/cte_utils.py @@ -93,7 +93,8 @@ def remove_ltable_cte(sql_w_ltable_cte: str, table_names: list[str]) -> str: raise ValueError("Analyst queries must contain the logical CTE.") table_names_lower = [table_name.lower() for table_name in table_names] - # Iterate through all CTEs. If a CTE starts with the logical table prefix and matches a table name, remove it. + # Iterate through all CTEs, and filter out logical table CTEs. + # This is done by checking if the CTE alias starts with the logical table prefix and if the alias is in a table in the semantic model. non_logical_cte = [ cte for cte in with_.expressions diff --git a/semantic_model_generator/data_processing/cte_utils_test.py b/semantic_model_generator/data_processing/cte_utils_test.py index 9d90aa88..e3bc80ba 100644 --- a/semantic_model_generator/data_processing/cte_utils_test.py +++ b/semantic_model_generator/data_processing/cte_utils_test.py @@ -26,7 +26,7 @@ def test_does_not_remove_non_logical_cte(self) -> None: query = ( "WITH __other_table AS (SELECT * FROM table1) SELECT * FROM __other_table" ) - table_names = ["logical_table"] + table_names = ["LOGICAL_TABLE"] expected_query = ( "WITH __other_table AS ( SELECT * FROM table1 ) SELECT * FROM __other_table" ) @@ -38,13 +38,11 @@ def test_does_not_remove_non_logical_cte(self) -> None: def test_mixed_ctes(self) -> None: """ - Testing that in a query with a mixture of CTEs for logical tables and other tables, only the logical table CTEs are removed. - Returns: - + Given a query containing a mixture of CTEs, only the logical table CTEs should be removed. """ - query = "WITH __logical_table AS (SELECT * FROM table1), __other_table AS (SELECT * FROM table2) SELECT * FROM __logical_table" - table_names = ["logical_table"] - expected_query = "WITH __other_table AS ( SELECT * FROM table2 ) SELECT * FROM __logical_table" + query = "WITH __logical_table AS (SELECT * FROM table1), __other_table AS (SELECT * FROM table2), __custom_table AS (SELECT * FROM table3) SELECT * FROM __logical_table" + table_names = ["LOGICAL_TABLE"] + expected_query = "WITH __other_table AS ( SELECT * FROM table2 ), __custom_table AS ( SELECT * FROM table3 ) SELECT * FROM __logical_table" actual_output = remove_ltable_cte(query, table_names=table_names) actual_output = re.sub(r"\s+", " ", actual_output) @@ -56,7 +54,7 @@ def test_throws_value_error_without_cte(self) -> None: Testing that an error is thrown if there is no CTE in the query. """ query = "SELECT * FROM table1" - table_names = ["logical_table"] + table_names = ["LOGICAL_TABLE"] with pytest.raises(ValueError): remove_ltable_cte(query, table_names=table_names) @@ -66,7 +64,7 @@ def test_throws_value_error_if_first_cte_not_logical_table(self) -> None: Testing that an error is thrown if the first CTE is not a logical table. """ query = "WITH random_alias AS (SELECT * FROM table1), __logical_table AS (SELECT * FROM table2) SELECT * FROM __logical_table" - table_names = ["logical_table"] + table_names = ["LOGICAL_TABLE"] with pytest.raises(ValueError): remove_ltable_cte(query, table_names=table_names)