From 1257d345bad03d53bffc251df17265d0f51ea748 Mon Sep 17 00:00:00 2001 From: Chris Nivera Date: Thu, 17 Oct 2024 10:21:57 -0700 Subject: [PATCH] autopopulate --- .../data_processing/data_types.py | 1 + semantic_model_generator/generate_model.py | 18 +++++++++++++++++- .../snowflake_utils/snowflake_connector.py | 18 ++++++++++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/semantic_model_generator/data_processing/data_types.py b/semantic_model_generator/data_processing/data_types.py index e4529332..b7e4ec0b 100644 --- a/semantic_model_generator/data_processing/data_types.py +++ b/semantic_model_generator/data_processing/data_types.py @@ -37,6 +37,7 @@ class Table: id_: int name: str columns: List[Column] + primary_key: Optional[list[str]] = None comment: Optional[str] = ( None # comment field's to save the table comment user specified on the table ) diff --git a/semantic_model_generator/generate_model.py b/semantic_model_generator/generate_model.py index 766bc0c6..dc3afba2 100644 --- a/semantic_model_generator/generate_model.py +++ b/semantic_model_generator/generate_model.py @@ -59,7 +59,7 @@ def _get_placeholder_joins() -> List[semantic_model_pb2.Relationship]: def _raw_table_to_semantic_context_table( - database: str, schema: str, raw_table: data_types.Table + database: str, schema: str, raw_table: data_types.Table, allow_joins: bool = False ) -> semantic_model_pb2.Table: """ Converts a raw table representation to a semantic model table in protobuf format. @@ -68,6 +68,7 @@ def _raw_table_to_semantic_context_table( database (str): The name of the database containing the table. schema (str): The name of the schema containing the table. raw_table (data_types.Table): The raw table object to be transformed. + allow_joins (bool): Whether joins are enabled in the semantic model. Returns: semantic_model_pb2.Table: A protobuf representation of the semantic table. @@ -146,6 +147,18 @@ def _raw_table_to_semantic_context_table( f"No valid columns found for table {raw_table.name}. Please verify that this table contains column's datatypes not in {OBJECT_DATATYPES}." ) + primary_key = None + if allow_joins: + # Populate the primary key field if we were able to retrieve one during raw table construction. + # If not, leave a placeholder for the user to fill out. + primary_key = semantic_model_pb2.PrimaryKey( + columns=( + raw_table.primary_key + if raw_table.primary_key + else [_PLACEHOLDER_COMMENT] + ) + ) + return semantic_model_pb2.Table( name=raw_table.name, base_table=semantic_model_pb2.FullyQualifiedTable( @@ -157,6 +170,7 @@ def _raw_table_to_semantic_context_table( dimensions=dimensions, time_dimensions=time_dimensions, measures=measures, + primary_key=primary_key, ) @@ -222,11 +236,13 @@ def raw_schema_to_semantic_context( ndv_per_column=n_sample_values, # number of sample values to pull per column. columns_df=valid_columns_df_this_table, max_workers=1, + allow_joins=allow_joins, ) table_object = _raw_table_to_semantic_context_table( database=fqn_table.database, schema=fqn_table.schema_name, raw_table=raw_table, + allow_joins=allow_joins, ) table_objects.append(table_object) # TODO(jhilgart): Call cortex model to generate a semantically friendly name here. diff --git a/semantic_model_generator/snowflake_utils/snowflake_connector.py b/semantic_model_generator/snowflake_utils/snowflake_connector.py index 815d6668..dbe24440 100644 --- a/semantic_model_generator/snowflake_utils/snowflake_connector.py +++ b/semantic_model_generator/snowflake_utils/snowflake_connector.py @@ -126,6 +126,18 @@ def _get_column_comment( return "" +def _get_table_primary_keys( + conn: SnowflakeConnection, schema_name: str, table_name: str +) -> list[str] | None: + query = f"show primary keys in table {schema_name}.{table_name};" + cursor = conn.cursor() + cursor.execute(query) + primary_keys = cursor.fetchall() + if primary_keys: + return [pk[3] for pk in primary_keys] + return None + + def get_table_representation( conn: SnowflakeConnection, schema_name: str, @@ -134,6 +146,7 @@ def get_table_representation( ndv_per_column: int, columns_df: pd.DataFrame, max_workers: int, + allow_joins: bool = False, ) -> Table: table_comment = _get_table_comment(conn, schema_name, table_name, columns_df) @@ -159,11 +172,16 @@ def _get_col(col_index: int, column_row: pd.Series) -> Column: index_and_column.append((col_index, column)) columns = [c for _, c in sorted(index_and_column, key=lambda x: x[0])] + primary_keys = ( + _get_table_primary_keys(conn, schema_name, table_name) if allow_joins else None + ) + return Table( id_=table_index, name=table_name, comment=table_comment, columns=columns, + primary_keys=primary_keys, )