Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dropdown selectors for iteration flow #121

Merged
merged 6 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 135 additions & 34 deletions admin_apps/journeys/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import requests
import sqlglot
import streamlit as st
from snowflake.connector import SnowflakeConnection
from snowflake.connector import ProgrammingError, SnowflakeConnection
from streamlit.delta_generator import DeltaGenerator
from streamlit_monaco import st_monaco

from admin_apps.journeys.builder import get_available_databases, get_available_schemas
from admin_apps.shared_utils import (
GeneratorAppScreen,
SnowflakeStage,
Expand Down Expand Up @@ -38,6 +39,8 @@
SNOWFLAKE_USER,
)
from semantic_model_generator.snowflake_utils.snowflake_connector import (
fetch_stages_in_schema,
fetch_yaml_names_in_stage,
set_database,
set_schema,
)
Expand Down Expand Up @@ -368,19 +371,25 @@ def upload_handler(file_name: str) -> None:
else:
# If coming from the builder flow, we need to ask the user for the exact stage path to upload to.
st.markdown("Please enter the destination of your YAML file.")
with st.form("upload_form"):
stage_database = st.text_input("Stage database", value="")
stage_schema = st.text_input("Stage schema", value="")
stage_name = st.text_input("Stage name", value="")
new_name = st.text_input("File name (omit .yaml suffix)", value="")
stage_selector_container()
new_name = st.text_input("File name (omit .yaml suffix)", value="")

if st.button("Submit Upload"):
if (
not st.session_state["selected_iteration_database"]
or not st.session_state["selected_iteration_schema"]
or not st.session_state["selected_iteration_stage"]
or not new_name
):
st.error("Please fill in all fields.")
return

if st.form_submit_button("Submit Upload"):
st.session_state["snowflake_stage"] = SnowflakeStage(
stage_database=stage_database,
stage_schema=stage_schema,
stage_name=stage_name,
)
upload_handler(new_name)
st.session_state["snowflake_stage"] = SnowflakeStage(
stage_database=st.session_state["selected_iteration_database"],
stage_schema=st.session_state["selected_iteration_schema"],
stage_name=st.session_state["selected_iteration_stage"],
)
upload_handler(new_name)


def update_container(
Expand Down Expand Up @@ -480,33 +489,125 @@ def yaml_editor(yaml_str: str) -> None:
update_container(status_container, "editing", prefix=status_container_title)


@st.cache_resource(show_spinner=False)
def get_available_stages(schema: str) -> List[str]:
"""
Fetches the available stages from the Snowflake account.
Returns:
List[str]: A list of available stages.
"""
return fetch_stages_in_schema(get_snowflake_connection(), schema)


@st.cache_resource(show_spinner=False)
def get_yamls_from_stage(stage: str) -> List[str]:
"""
Fetches the YAML files from the specified stage.
Args:
stage (str): The name of the stage to fetch the YAML files from.
Returns:
List[str]: A list of YAML files in the specified stage.
"""
return fetch_yaml_names_in_stage(get_snowflake_connection(), stage)


def stage_selector_container() -> None:
"""
Common component that encapsulates db/schema/stage selection for the admin app.
When a db/schema/stage is selected, it is saved to the session state for reading elsewhere.
Returns: None
"""
available_schemas = []
available_stages = []

# First, retrieve all databases that the user has access to.
stage_database = st.selectbox(
"Stage database",
options=get_available_databases(),
index=None,
key="selected_iteration_database",
)
if stage_database:
# When a valid database is selected, fetch the available schemas in that database.
try:
set_database(get_snowflake_connection(), stage_database)
available_schemas = get_available_schemas(stage_database)
except (ValueError, ProgrammingError):
st.error("Insufficient permissions to read from the selected database.")
st.stop()

stage_schema = st.selectbox(
"Stage schema",
options=available_schemas,
index=None,
key="selected_iteration_schema",
)
if stage_schema:
# When a valid schema is selected, fetch the available stages in that schema.
try:
set_schema(get_snowflake_connection(), stage_schema)
available_stages = get_available_stages(stage_schema)
except (ValueError, ProgrammingError):
st.error("Insufficient permissions to read from the selected schema.")
st.stop()

st.selectbox(
"Stage name",
options=available_stages,
index=None,
key="selected_iteration_stage",
)


@st.experimental_dialog("Welcome to the Iteration app! 💬", width="large")
def set_up_requirements() -> None:
"""
Collects existing YAML location from the user so that we can download it.
"""
# Otherwise, we should collect the prebuilt YAML location from the user so that we can download it.
with st.form("download_yaml_requirements"):
st.markdown(
"Fill in the Snowflake stage details to download your existing YAML file."
)
# TODO: Make these dropdown selectors by fetching all dbs/schemas similar to table approach?
stage_database = st.text_input("Stage database", value="")
stage_schema = st.text_input("Stage schema", value="")
stage_name = st.text_input("Stage name", value="")
file_name = st.text_input("File name", value="<your_file>.yaml")
if st.form_submit_button("Submit"):
st.session_state["snowflake_stage"] = SnowflakeStage(
stage_database=stage_database,
stage_schema=stage_schema,
stage_name=stage_name,
st.markdown(
"Fill in the Snowflake stage details to download your existing YAML file."
)

stage_selector_container()

# Based on the currently selected stage, show a dropdown of YAML files for the user to pick from.
available_files = []
if (
"selected_iteration_stage" in st.session_state
and st.session_state["selected_iteration_stage"]
):
# When a valid stage is selected, fetch the available YAML files in that stage.
try:
available_files = get_yamls_from_stage(
st.session_state["selected_iteration_stage"]
)
st.session_state["account_name"] = SNOWFLAKE_ACCOUNT_LOCATOR
st.session_state["host_name"] = SNOWFLAKE_HOST
st.session_state["user_name"] = SNOWFLAKE_USER
st.session_state["file_name"] = file_name
st.session_state["page"] = GeneratorAppScreen.ITERATION
st.rerun()
except (ValueError, ProgrammingError):
st.error("Insufficient permissions to read from the selected stage.")
st.stop()

file_name = st.selectbox("File name", options=available_files, index=None)

if st.button(
"Submit",
disabled=not st.session_state["selected_iteration_database"]
or not st.session_state["selected_iteration_schema"]
or not st.session_state["selected_iteration_stage"]
or not file_name,
):
st.session_state["snowflake_stage"] = SnowflakeStage(
stage_database=st.session_state["selected_iteration_database"],
stage_schema=st.session_state["selected_iteration_schema"],
stage_name=st.session_state["selected_iteration_stage"],
)
st.session_state["account_name"] = SNOWFLAKE_ACCOUNT_LOCATOR
st.session_state["host_name"] = SNOWFLAKE_HOST
st.session_state["user_name"] = SNOWFLAKE_USER
st.session_state["file_name"] = file_name
st.session_state["page"] = GeneratorAppScreen.ITERATION
st.rerun()


SAVE_HELP = """Save changes to the active semantic model in this app. This is
Expand Down
37 changes: 37 additions & 0 deletions semantic_model_generator/snowflake_utils/snowflake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,43 @@ def fetch_tables_views_in_schema(
return results


def fetch_stages_in_schema(conn: SnowflakeConnection, schema_name: str) -> list[str]:
"""
Fetches all stages that the current user has access to in the current schema
Args:
conn: SnowflakeConnection to run the query
schema_name: The name of the schema to connect to.

Returns: a list of fully qualified stage names
"""

query = f"show stages in schema {schema_name};"
cursor = conn.cursor()
cursor.execute(query)
stages = cursor.fetchall()

return [f"{result[2]}.{result[3]}.{result[1]}" for result in stages]


def fetch_yaml_names_in_stage(conn: SnowflakeConnection, stage_name: str) -> list[str]:
"""
Fetches all yaml files that the current user has access to in the current stage
Args:
conn: SnowflakeConnection to run the query
stage_name: The fully qualified name of the stage to connect to.

Returns: a list of yaml file names
"""

query = f"list @{stage_name} pattern='.*\\.yaml';"
cursor = conn.cursor()
cursor.execute(query)
yaml_files = cursor.fetchall()

# The file name is prefixed with "@{stage_name}/", so we need to remove that prefix.
return [result[0].split("/")[-1] for result in yaml_files]


def get_valid_schemas_tables_columns_df(
conn: SnowflakeConnection,
table_schema: Optional[str] = None,
Expand Down
Loading