diff --git a/journeys/builder.py b/journeys/builder.py index 857fd33d..f980d233 100644 --- a/journeys/builder.py +++ b/journeys/builder.py @@ -117,8 +117,8 @@ def table_selector_dialog() -> None: st.markdown("
", unsafe_allow_html=True) experimental_features = st.checkbox( - "Enable experimental features (optional)", - help="Checking this box will enable generation of experimental features in the semantic model. If enabling this setting, please ensure that you have the proper parameters set on your Snowflake account. Some features (e.g. joins) are currently in Private Preview and available only to select accounts. Reach out to your account team for access.", + "Enable joins (optional)", + help="Checking this box will enable you to add/edit join paths in your semantic model. If enabling this setting, please ensure that you have the proper parameters set on your Snowflake account. Reach out to your account team for access.", ) st.session_state["experimental_features"] = experimental_features diff --git a/journeys/iteration.py b/journeys/iteration.py index abb698de..ad355320 100644 --- a/journeys/iteration.py +++ b/journeys/iteration.py @@ -629,8 +629,8 @@ def set_up_requirements() -> None: file_name = st.selectbox("File name", options=available_files, index=None) experimental_features = st.checkbox( - "Enable experimental features (optional)", - help="Checking this box will enable generation of experimental features in the semantic model. If enabling this setting, please ensure that you have the proper parameters set on your Snowflake account. Some features (e.g. joins) are currently in Private Preview and available only to select accounts. Reach out to your account team for access.", + "Enable joins (optional)", + help="Checking this box will enable you to add/edit join paths in your semantic model. If enabling this setting, please ensure that you have the proper parameters set on your Snowflake account. Reach out to your account team for access.", ) if st.button( @@ -703,7 +703,9 @@ def show() -> None: return_home_button() if "yaml" not in st.session_state: # Only proceed to download the YAML from stage if we don't have one from the builder flow. - yaml = download_yaml(st.session_state.file_name, st.session_state.snowflake_stage.stage_name) + yaml = download_yaml( + st.session_state.file_name, st.session_state.snowflake_stage.stage_name + ) st.session_state["yaml"] = yaml st.session_state["semantic_model"] = yaml_to_semantic_model(yaml) if "last_saved_yaml" not in st.session_state: diff --git a/journeys/joins.py b/journeys/joins.py index a6e0e700..2d09d519 100644 --- a/journeys/joins.py +++ b/journeys/joins.py @@ -3,7 +3,14 @@ import streamlit as st from streamlit_extras.row import row +from app_utils.shared_utils import get_snowflake_connection +from semantic_model_generator.data_processing.cte_utils import ( + fully_qualified_table_name, +) from semantic_model_generator.protos import semantic_model_pb2 +from semantic_model_generator.snowflake_utils.snowflake_connector import ( + get_table_primary_keys, +) SUPPORTED_JOIN_TYPES = [ join_type @@ -167,7 +174,6 @@ def relationship_builder( @st.experimental_dialog("Join Builder", width="large") def joins_dialog() -> None: - if "builder_joins" not in st.session_state: # Making a copy of the original relationships list so we can modify freely without affecting the original. st.session_state.builder_joins = st.session_state.semantic_model.relationships[ @@ -210,6 +216,41 @@ def joins_dialog() -> None: ) return + # Populate primary key information for each table in a join relationship. + left_table_object = next( + ( + table + for table in st.session_state.semantic_model.tables + if table.name == relationship.left_table + ) + ) + right_table_object = next( + ( + table + for table in st.session_state.semantic_model.tables + if table.name == relationship.right_table + ) + ) + + with st.spinner("Fetching primary keys..."): + if not left_table_object.primary_key.columns: + primary_keys = get_table_primary_keys( + get_snowflake_connection(), + table_fqn=fully_qualified_table_name( + left_table_object.base_table + ), + ) + left_table_object.primary_key.columns.extend(primary_keys or [""]) + + if not right_table_object.primary_key.columns: + primary_keys = get_table_primary_keys( + get_snowflake_connection(), + table_fqn=fully_qualified_table_name( + right_table_object.base_table + ), + ) + right_table_object.primary_key.columns.extend(primary_keys or [""]) + del st.session_state.semantic_model.relationships[:] st.session_state.semantic_model.relationships.extend( st.session_state.builder_joins diff --git a/partner/looker.py b/partner/looker.py index 903a62e1..fe874670 100644 --- a/partner/looker.py +++ b/partner/looker.py @@ -232,8 +232,8 @@ def set_looker_semantic() -> None: sample_values = input_sample_value_num() experimental_features = st.checkbox( - "Enable experimental features (optional)", - help="Checking this box will enable generation of experimental features in the semantic model. If enabling this setting, please ensure that you have the proper parameters set on your Snowflake account. Some features (e.g. joins) are currently in Private Preview and available only to select accounts. Reach out to your account team for access.", + "Enable joins (optional)", + help="Checking this box will enable you to add/edit join paths in your semantic model. If enabling this setting, please ensure that you have the proper parameters set on your Snowflake account. Reach out to your account team for access.", ) if st.button("Continue", type="primary"): diff --git a/semantic_model_generator/snowflake_utils/snowflake_connector.py b/semantic_model_generator/snowflake_utils/snowflake_connector.py index 815d6668..76b695ef 100644 --- a/semantic_model_generator/snowflake_utils/snowflake_connector.py +++ b/semantic_model_generator/snowflake_utils/snowflake_connector.py @@ -126,6 +126,19 @@ def _get_column_comment( return "" +def get_table_primary_keys( + conn: SnowflakeConnection, + table_fqn: str, +) -> list[str] | None: + query = f"show primary keys in table {table_fqn};" + cursor = conn.cursor() + cursor.execute(query) + primary_keys = cursor.fetchall() + if primary_keys: + return [pk[3] for pk in primary_keys] + return None + + def get_table_representation( conn: SnowflakeConnection, schema_name: str,