Skip to content

Commit

Permalink
autopopulate
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-cnivera committed Oct 18, 2024
1 parent 2f54899 commit 1257d34
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
1 change: 1 addition & 0 deletions semantic_model_generator/data_processing/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Table:
id_: int
name: str
columns: List[Column]
primary_key: Optional[list[str]] = None
comment: Optional[str] = (
None # comment field's to save the table comment user specified on the table
)
Expand Down
18 changes: 17 additions & 1 deletion semantic_model_generator/generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_placeholder_joins() -> List[semantic_model_pb2.Relationship]:


def _raw_table_to_semantic_context_table(
database: str, schema: str, raw_table: data_types.Table
database: str, schema: str, raw_table: data_types.Table, allow_joins: bool = False
) -> semantic_model_pb2.Table:
"""
Converts a raw table representation to a semantic model table in protobuf format.
Expand All @@ -68,6 +68,7 @@ def _raw_table_to_semantic_context_table(
database (str): The name of the database containing the table.
schema (str): The name of the schema containing the table.
raw_table (data_types.Table): The raw table object to be transformed.
allow_joins (bool): Whether joins are enabled in the semantic model.
Returns:
semantic_model_pb2.Table: A protobuf representation of the semantic table.
Expand Down Expand Up @@ -146,6 +147,18 @@ def _raw_table_to_semantic_context_table(
f"No valid columns found for table {raw_table.name}. Please verify that this table contains column's datatypes not in {OBJECT_DATATYPES}."
)

primary_key = None
if allow_joins:
# Populate the primary key field if we were able to retrieve one during raw table construction.
# If not, leave a placeholder for the user to fill out.
primary_key = semantic_model_pb2.PrimaryKey(
columns=(
raw_table.primary_key
if raw_table.primary_key
else [_PLACEHOLDER_COMMENT]
)
)

return semantic_model_pb2.Table(
name=raw_table.name,
base_table=semantic_model_pb2.FullyQualifiedTable(
Expand All @@ -157,6 +170,7 @@ def _raw_table_to_semantic_context_table(
dimensions=dimensions,
time_dimensions=time_dimensions,
measures=measures,
primary_key=primary_key,
)


Expand Down Expand Up @@ -222,11 +236,13 @@ def raw_schema_to_semantic_context(
ndv_per_column=n_sample_values, # number of sample values to pull per column.
columns_df=valid_columns_df_this_table,
max_workers=1,
allow_joins=allow_joins,
)
table_object = _raw_table_to_semantic_context_table(
database=fqn_table.database,
schema=fqn_table.schema_name,
raw_table=raw_table,
allow_joins=allow_joins,
)
table_objects.append(table_object)
# TODO(jhilgart): Call cortex model to generate a semantically friendly name here.
Expand Down
18 changes: 18 additions & 0 deletions semantic_model_generator/snowflake_utils/snowflake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ def _get_column_comment(
return ""


def _get_table_primary_keys(
conn: SnowflakeConnection, schema_name: str, table_name: str
) -> list[str] | None:
query = f"show primary keys in table {schema_name}.{table_name};"
cursor = conn.cursor()
cursor.execute(query)
primary_keys = cursor.fetchall()
if primary_keys:
return [pk[3] for pk in primary_keys]
return None


def get_table_representation(
conn: SnowflakeConnection,
schema_name: str,
Expand All @@ -134,6 +146,7 @@ def get_table_representation(
ndv_per_column: int,
columns_df: pd.DataFrame,
max_workers: int,
allow_joins: bool = False,
) -> Table:
table_comment = _get_table_comment(conn, schema_name, table_name, columns_df)

Expand All @@ -159,11 +172,16 @@ def _get_col(col_index: int, column_row: pd.Series) -> Column:
index_and_column.append((col_index, column))
columns = [c for _, c in sorted(index_and_column, key=lambda x: x[0])]

primary_keys = (
_get_table_primary_keys(conn, schema_name, table_name) if allow_joins else None
)

return Table(
id_=table_index,
name=table_name,
comment=table_comment,
columns=columns,
primary_keys=primary_keys,
)


Expand Down

0 comments on commit 1257d34

Please sign in to comment.