diff --git a/admin_apps/journeys/iteration.py b/admin_apps/journeys/iteration.py index 00632d55..474f46c4 100644 --- a/admin_apps/journeys/iteration.py +++ b/admin_apps/journeys/iteration.py @@ -6,10 +6,11 @@ import requests import sqlglot import streamlit as st -from snowflake.connector import SnowflakeConnection +from snowflake.connector import ProgrammingError, SnowflakeConnection from streamlit.delta_generator import DeltaGenerator from streamlit_monaco import st_monaco +from admin_apps.journeys.builder import get_available_databases, get_available_schemas from admin_apps.shared_utils import ( GeneratorAppScreen, SnowflakeStage, @@ -38,6 +39,8 @@ SNOWFLAKE_USER, ) from semantic_model_generator.snowflake_utils.snowflake_connector import ( + fetch_stages_in_schema, + fetch_yaml_names_in_stage, set_database, set_schema, ) @@ -368,19 +371,25 @@ def upload_handler(file_name: str) -> None: else: # If coming from the builder flow, we need to ask the user for the exact stage path to upload to. st.markdown("Please enter the destination of your YAML file.") - with st.form("upload_form"): - stage_database = st.text_input("Stage database", value="") - stage_schema = st.text_input("Stage schema", value="") - stage_name = st.text_input("Stage name", value="") - new_name = st.text_input("File name (omit .yaml suffix)", value="") + stage_selector_container() + new_name = st.text_input("File name (omit .yaml suffix)", value="") + + if st.button("Submit Upload"): + if ( + not st.session_state["selected_iteration_database"] + or not st.session_state["selected_iteration_schema"] + or not st.session_state["selected_iteration_stage"] + or not new_name + ): + st.error("Please fill in all fields.") + return - if st.form_submit_button("Submit Upload"): - st.session_state["snowflake_stage"] = SnowflakeStage( - stage_database=stage_database, - stage_schema=stage_schema, - stage_name=stage_name, - ) - upload_handler(new_name) + st.session_state["snowflake_stage"] = SnowflakeStage( + stage_database=st.session_state["selected_iteration_database"], + stage_schema=st.session_state["selected_iteration_schema"], + stage_name=st.session_state["selected_iteration_stage"], + ) + upload_handler(new_name) def update_container( @@ -480,33 +489,125 @@ def yaml_editor(yaml_str: str) -> None: update_container(status_container, "editing", prefix=status_container_title) +@st.cache_resource(show_spinner=False) +def get_available_stages(schema: str) -> List[str]: + """ + Fetches the available stages from the Snowflake account. + + Returns: + List[str]: A list of available stages. + """ + return fetch_stages_in_schema(get_snowflake_connection(), schema) + + +@st.cache_resource(show_spinner=False) +def get_yamls_from_stage(stage: str) -> List[str]: + """ + Fetches the YAML files from the specified stage. + + Args: + stage (str): The name of the stage to fetch the YAML files from. + + Returns: + List[str]: A list of YAML files in the specified stage. + """ + return fetch_yaml_names_in_stage(get_snowflake_connection(), stage) + + +def stage_selector_container() -> None: + """ + Common component that encapsulates db/schema/stage selection for the admin app. + When a db/schema/stage is selected, it is saved to the session state for reading elsewhere. + Returns: None + """ + available_schemas = [] + available_stages = [] + + # First, retrieve all databases that the user has access to. + stage_database = st.selectbox( + "Stage database", + options=get_available_databases(), + index=None, + key="selected_iteration_database", + ) + if stage_database: + # When a valid database is selected, fetch the available schemas in that database. + try: + set_database(get_snowflake_connection(), stage_database) + available_schemas = get_available_schemas(stage_database) + except (ValueError, ProgrammingError): + st.error("Insufficient permissions to read from the selected database.") + st.stop() + + stage_schema = st.selectbox( + "Stage schema", + options=available_schemas, + index=None, + key="selected_iteration_schema", + ) + if stage_schema: + # When a valid schema is selected, fetch the available stages in that schema. + try: + set_schema(get_snowflake_connection(), stage_schema) + available_stages = get_available_stages(stage_schema) + except (ValueError, ProgrammingError): + st.error("Insufficient permissions to read from the selected schema.") + st.stop() + + st.selectbox( + "Stage name", + options=available_stages, + index=None, + key="selected_iteration_stage", + ) + + @st.experimental_dialog("Welcome to the Iteration app! 💬", width="large") def set_up_requirements() -> None: """ Collects existing YAML location from the user so that we can download it. """ - # Otherwise, we should collect the prebuilt YAML location from the user so that we can download it. - with st.form("download_yaml_requirements"): - st.markdown( - "Fill in the Snowflake stage details to download your existing YAML file." - ) - # TODO: Make these dropdown selectors by fetching all dbs/schemas similar to table approach? - stage_database = st.text_input("Stage database", value="") - stage_schema = st.text_input("Stage schema", value="") - stage_name = st.text_input("Stage name", value="") - file_name = st.text_input("File name", value=".yaml") - if st.form_submit_button("Submit"): - st.session_state["snowflake_stage"] = SnowflakeStage( - stage_database=stage_database, - stage_schema=stage_schema, - stage_name=stage_name, + st.markdown( + "Fill in the Snowflake stage details to download your existing YAML file." + ) + + stage_selector_container() + + # Based on the currently selected stage, show a dropdown of YAML files for the user to pick from. + available_files = [] + if ( + "selected_iteration_stage" in st.session_state + and st.session_state["selected_iteration_stage"] + ): + # When a valid stage is selected, fetch the available YAML files in that stage. + try: + available_files = get_yamls_from_stage( + st.session_state["selected_iteration_stage"] ) - st.session_state["account_name"] = SNOWFLAKE_ACCOUNT_LOCATOR - st.session_state["host_name"] = SNOWFLAKE_HOST - st.session_state["user_name"] = SNOWFLAKE_USER - st.session_state["file_name"] = file_name - st.session_state["page"] = GeneratorAppScreen.ITERATION - st.rerun() + except (ValueError, ProgrammingError): + st.error("Insufficient permissions to read from the selected stage.") + st.stop() + + file_name = st.selectbox("File name", options=available_files, index=None) + + if st.button( + "Submit", + disabled=not st.session_state["selected_iteration_database"] + or not st.session_state["selected_iteration_schema"] + or not st.session_state["selected_iteration_stage"] + or not file_name, + ): + st.session_state["snowflake_stage"] = SnowflakeStage( + stage_database=st.session_state["selected_iteration_database"], + stage_schema=st.session_state["selected_iteration_schema"], + stage_name=st.session_state["selected_iteration_stage"], + ) + st.session_state["account_name"] = SNOWFLAKE_ACCOUNT_LOCATOR + st.session_state["host_name"] = SNOWFLAKE_HOST + st.session_state["user_name"] = SNOWFLAKE_USER + st.session_state["file_name"] = file_name + st.session_state["page"] = GeneratorAppScreen.ITERATION + st.rerun() SAVE_HELP = """Save changes to the active semantic model in this app. This is diff --git a/semantic_model_generator/snowflake_utils/snowflake_connector.py b/semantic_model_generator/snowflake_utils/snowflake_connector.py index 683b0eae..c0d0e84b 100644 --- a/semantic_model_generator/snowflake_utils/snowflake_connector.py +++ b/semantic_model_generator/snowflake_utils/snowflake_connector.py @@ -305,6 +305,43 @@ def fetch_tables_views_in_schema( return results +def fetch_stages_in_schema(conn: SnowflakeConnection, schema_name: str) -> list[str]: + """ + Fetches all stages that the current user has access to in the current schema + Args: + conn: SnowflakeConnection to run the query + schema_name: The name of the schema to connect to. + + Returns: a list of fully qualified stage names + """ + + query = f"show stages in schema {schema_name};" + cursor = conn.cursor() + cursor.execute(query) + stages = cursor.fetchall() + + return [f"{result[2]}.{result[3]}.{result[1]}" for result in stages] + + +def fetch_yaml_names_in_stage(conn: SnowflakeConnection, stage_name: str) -> list[str]: + """ + Fetches all yaml files that the current user has access to in the current stage + Args: + conn: SnowflakeConnection to run the query + stage_name: The fully qualified name of the stage to connect to. + + Returns: a list of yaml file names + """ + + query = f"list @{stage_name} pattern='.*\\.yaml';" + cursor = conn.cursor() + cursor.execute(query) + yaml_files = cursor.fetchall() + + # The file name is prefixed with "@{stage_name}/", so we need to remove that prefix. + return [result[0].split("/")[-1] for result in yaml_files] + + def get_valid_schemas_tables_columns_df( conn: SnowflakeConnection, table_schema: Optional[str] = None,