Skip to content

Commit

Permalink
Address comments and handle renamed columns in aggregate expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-nsehrawat committed Jul 3, 2024
1 parent 5ade8aa commit 7bc8896
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 58 deletions.
98 changes: 40 additions & 58 deletions semantic_model_generator/data_processing/cte_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TODO: support in quoted columns are not well tested. Need to add more tests.
# TODO: Add tests for quoted columns, which are not well tested today.

import copy
from typing import Dict, List, Optional
Expand Down Expand Up @@ -46,7 +46,7 @@ def is_aggregation_expr(col: semantic_model_pb2.Column) -> bool:
Raises:
ValueError: if expr is not parsable, or if aggregation expressions in non-measure columns.
"""
parsed = sqlglot.parse_one(col.expr, dialect="snowflake")
parsed = sqlglot.parse_one(col.expr, dialect=Snowflake)
agg_func = list(parsed.find_all(sqlglot.expressions.AggFunc))
window = list(parsed.find_all(sqlglot.expressions.Window))
# We've confirmed window functions cannot appear inside aggregate functions
Expand All @@ -62,7 +62,7 @@ def is_aggregation_expr(col: semantic_model_pb2.Column) -> bool:
def _is_physical_table_column(col: semantic_model_pb2.Column) -> bool:
"""Returns whether the column refers to a single raw table column."""
try:
parsed = sqlglot.parse_one(col.expr, dialect="snowflake")
parsed = sqlglot.parse_one(col.expr, dialect=Snowflake)
return isinstance(parsed, sqlglot.expressions.Column)
except Exception as ex:
logger.warning(
Expand All @@ -75,22 +75,6 @@ def _is_identifier_quoted(col_name: str) -> bool:
return '"' in col_name


def _standardize_sf_identifier(col_name: str) -> str:
# If the name is quoted, remove quotes if all cap letter and no spaces within quotes; else return the origin.
if _is_identifier_quoted(col_name):
col_name_stripped = col_name.strip('"')
if col_name_stripped.isupper() and " " not in col_name:
# Return the non-quoted and lower case version.
return col_name_stripped.lower()
else:
# Return the original quoted version.
return col_name

else:
# For non-quoted columns, return the lower case version.
return col_name.lower()


def remove_ltable_cte(sql_w_ltable_cte: str) -> str:
"""Given a sql with prefix'd logical table conversion CTE,
return:
Expand Down Expand Up @@ -172,10 +156,10 @@ def get_all_physical_column_references(
sum(foo) -> [foo]
"""
try:
parsed = sqlglot.parse_one(column.expr, dialect="snowflake")
parsed = sqlglot.parse_one(column.expr, dialect=Snowflake)
col_names = set()
for col in parsed.find_all(sqlglot.expressions.Column):
# TODO(renee): update to use _standardize_sf_identifier to handle quoted columns.
# TODO(renee): Handle quoted columns.
col_name = col.name.lower()
if col.this.quoted:
col_name = col.name
Expand All @@ -185,55 +169,53 @@ def get_all_physical_column_references(
raise ValueError(f"Failed to parse sql expression: {column.expr}. Error: {ex}")


def get_all_column_references_in_table(
def direct_mapping_logical_columns(
table: semantic_model_pb2.Table,
) -> Dict[str, Optional[semantic_model_pb2.Column]]:
"""Returns all physical table columns referenced in the table.
Maps the raw table column name to the semantic context Column that directly references it.
If the raw table column name has no exact semantic context Column, it'll be mapped to None.
) -> list[semantic_model_pb2.Column]:
"""
all_col_references: Dict[str, Optional[semantic_model_pb2.Column]] = {}
for col in table.columns:
col_references = get_all_physical_column_references(column=col)
if _is_physical_table_column(col):
column = col
else:
column = None
for reference in col_references:
if reference in all_col_references:
all_col_references[reference] = all_col_references[reference] or column
else:
all_col_references[reference] = column
return all_col_references
Returns a list of logical columns that map 1:1 to an underlying physical column
(i.e. logical table's expression is simply the physical column name) in this table.
"""
ret: list[semantic_model_pb2.Column] = []
for c in table.columns:
if _is_physical_table_column(c):
ret.append(c)
return ret


def _enrich_column_in_expr_with_aggregation(
table: semantic_model_pb2.Table,
) -> semantic_model_pb2.Table:
"""
Append column mentioned in expr with aggregation, if not listed explicitly in table.columns.
Expands the logical columns of 'table' to include columns mentioned in a logical columns
with an aggregate expression. E.g. for a logical column called CPC with expr sum(cost) / sum(clicks),
adds logical columns for "cost" and "clicks", in not present.
"""
tbl = copy.deepcopy(table)
col_name_to_obj = get_all_column_references_in_table(tbl)
direct_mapping_lcols = [
c.name.lower() for c in direct_mapping_logical_columns(table)
]
cols_to_append = set()
for col in tbl.columns:
for col in table.columns:
if not is_aggregation_expr(col):
continue
for col_name in get_all_physical_column_references(col):
if col_name_to_obj[col_name] is None:
cols_to_append.add(col_name)

for col_name_to_append in cols_to_append:
dest_col = semantic_model_pb2.Column(
name=col_name_to_append, expr=col_name_to_append
)
# Only append when the name is not used; otherwise will cause compilation issue.
if _standardize_sf_identifier(dest_col.name) not in [
_standardize_sf_identifier(col.name) for col in tbl.columns
]:
tbl.columns.append(dest_col)
return tbl
for pcol in get_all_physical_column_references(col):
# If the physical column doesn't have a direct mapping logical column
# with the same name, then we need to add a new logical column for it.
if pcol not in direct_mapping_lcols:
cols_to_append.add(pcol)

original_cols = {col.name.lower(): col.expr for col in table.columns}
ret = copy.deepcopy(table)
for c in cols_to_append:
if c in original_cols:
logger.warning(
f"Not adding a logical column for physical column {c} in table {table.name}, "
f"since this logical column already exists with expression {original_cols[c]}"
)
else:
new_col = semantic_model_pb2.Column(name=c, expr=c)
ret.columns.append(new_col)
return ret


def _generate_non_agg_cte(table: semantic_model_pb2.Table) -> Optional[str]:
Expand Down
58 changes: 58 additions & 0 deletions semantic_model_generator/tests/cte_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,32 @@ def get_test_table_col_format_w_agg_only() -> semantic_model_pb2.Table:
)


def get_test_table_col_format_agg_and_renaming() -> semantic_model_pb2.Table:
return semantic_model_pb2.Table(
name="t1",
base_table=semantic_model_pb2.FullyQualifiedTable(
database="db", schema="sc", table="t1"
),
columns=[
semantic_model_pb2.Column(
name="cost",
kind=semantic_model_pb2.ColumnKind.measure,
expr="cst",
),
semantic_model_pb2.Column(
name="clicks",
kind=semantic_model_pb2.ColumnKind.measure,
expr="clcks",
),
semantic_model_pb2.Column(
name="cpc",
kind=semantic_model_pb2.ColumnKind.measure,
expr="sum(cst) / sum(clcks)",
),
],
)


class SemanticModelTest(TestCase):
def test_convert_to_column_format(self) -> None:
"""
Expand Down Expand Up @@ -371,6 +397,38 @@ def test_enrich_column_in_expr_with_aggregation(self) -> None:
)
assert got == want

def test_enrich_column_in_expr_with_aggregation_and_renaming(self) -> None:
tbl = get_test_table_col_format_agg_and_renaming()
got = [c for c in _enrich_column_in_expr_with_aggregation(tbl).columns]
want = [
semantic_model_pb2.Column(
name="cost",
kind=semantic_model_pb2.ColumnKind.measure,
expr="cst",
),
semantic_model_pb2.Column(
name="clicks",
kind=semantic_model_pb2.ColumnKind.measure,
expr="clcks",
),
semantic_model_pb2.Column(
name="cpc",
kind=semantic_model_pb2.ColumnKind.measure,
expr="sum(cst) / sum(clcks)",
),
semantic_model_pb2.Column(
name="clcks",
expr="clcks",
),
semantic_model_pb2.Column(
name="cst",
expr="cst",
),
]
got.sort(key=lambda c: c.name.lower())
want.sort(key=lambda c: c.name.lower())
assert got == want

def test_expand_all_logical_tables_as_ctes(self) -> None:
vq = "SELECT * FROM __t2"
ctx = get_test_ctx_col_format()
Expand Down

0 comments on commit 7bc8896

Please sign in to comment.