From 73899063c2702191475e7e385c8fcd3f070b213e Mon Sep 17 00:00:00 2001 From: Tom Zayats Date: Fri, 13 Dec 2024 09:42:40 -0800 Subject: [PATCH] fixed model rerun on model changes + added eval_hash for easy select --- journeys/evaluation.py | 75 +++++++++++++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 15 deletions(-) diff --git a/journeys/evaluation.py b/journeys/evaluation.py index 00eb7c7c..a51456a1 100644 --- a/journeys/evaluation.py +++ b/journeys/evaluation.py @@ -6,7 +6,9 @@ import pandas as pd import snowflake.snowpark._internal.utils as snowpark_utils +import sqlglot import streamlit as st +import yaml from loguru import logger from snowflake.connector.pandas_tools import write_pandas @@ -16,6 +18,7 @@ schema_selector_container, set_sit_query_tag, table_selector_container, + update_last_validated_model, validate_table_exist, validate_table_schema, ) @@ -26,6 +29,7 @@ fetch_table, get_table_hash, ) +from semantic_model_generator.validate_model import validate EVALUATION_TABLE_SCHEMA = { "ID": "VARCHAR", @@ -46,6 +50,7 @@ "MODEL_HASH": "VARCHAR", "SEMANTIC_MODEL_STRING": "VARCHAR", "EVAL_TABLE": "VARCHAR", + "EVAL_HASH": "VARCHAR", } LLM_JUDGE_PROMPT_TEMPLATE = """\ @@ -85,13 +90,27 @@ def visualize_eval_results(frame: pd.DataFrame) -> None: col1, col2 = st.columns(2) + try: + analyst_sql = sqlglot.parse_one(row["ANALYST_SQL"], dialect="snowflake") + analyst_sql = analyst_sql.sql(dialect="snowflake", pretty=True) + except Exception as e: + logger.warning(f"Error parsing analyst SQL: {e} for {row_id}") + analyst_sql = row["ANALYST_SQL"] + + try: + gold_sql = sqlglot.parse_one(row["GOLD_SQL"], dialect="snowflake") + gold_sql = gold_sql.sql(dialect="snowflake", pretty=True) + except Exception as e: + logger.warning(f"Error parsing gold SQL: {e} for {row_id}") + gold_sql = row["GOLD_SQL"] + with col1: st.write("Analyst SQL") - st.code(row["ANALYST_SQL"], language="sql") + st.code(analyst_sql, language="sql") with col2: st.write("Golden SQL") - st.code(row["GOLD_SQL"], language="sql") + st.code(gold_sql, language="sql") col1, col2 = st.columns(2) with col1: @@ -269,6 +288,7 @@ def result_comparisons() -> None: def write_eval_results(frame: pd.DataFrame) -> None: frame_to_write = frame.copy() frame_to_write["TIMESTAMP"] = st.session_state["eval_timestamp"] + frame_to_write["EVAL_HASH"] = st.session_state["eval_hash"] frame_to_write["EVAL_TABLE"] = st.session_state["eval_table"] frame_to_write["EVAL_TABLE_HASH"] = st.session_state["eval_table_hash"] frame_to_write["MODEL_HASH"] = st.session_state["semantic_model_hash"] @@ -550,14 +570,12 @@ def evaluation_data_dialog() -> None: st.session_state["results_eval_table"] = st.session_state[ "selected_results_eval_table" ] - # clear the results table if it exists - if "total_eval_frame" in st.session_state: - del st.session_state["total_eval_frame"] + clear_evaluation_data() st.rerun() -def clear_evaluation_data() -> None: +def clear_evaluation_selection() -> None: session_states = ( "selected_eval_database", "selected_eval_schema", @@ -574,9 +592,21 @@ def clear_evaluation_data() -> None: del st.session_state[feature] +def clear_evaluation_data() -> None: + session_states = ( + "total_eval_frame", + "eval_accuracy", + "analyst_results_frame", + "query_results_frame", + ) + for feature in session_states: + if feature in st.session_state: + del st.session_state[feature] + + def evaluation_mode_show() -> None: - if st.button("Select Evaluation Tables", on_click=clear_evaluation_data): + if st.button("Select Evaluation Tables", on_click=clear_evaluation_selection): evaluation_data_dialog() st.write( @@ -609,13 +639,14 @@ def evaluation_mode_show() -> None: [ ["Evaluation Table Hash", st.session_state["eval_table_hash"]], ["Semantic Model Hash", st.session_state["semantic_model_hash"]], + ["Evaluation Run Hash", st.session_state["eval_hash"]], ["Timestamp", st.session_state["eval_timestamp"]], ["Accuracy", f"{st.session_state['eval_accuracy']:.2f}%"], ], columns=["Summary Statistic", "Value"], ) if model_changed_test: - st.write("Model has changed since last evaluation run.") + st.warning("Model has changed since last evaluation run.") st.markdown("#### Previous Evaluation Run Summary") else: st.markdown("#### Current Evaluation Run Summary") @@ -630,19 +661,33 @@ def run_evaluation() -> None: action="evaluation_run", ) current_hash = generate_hash(st.session_state["working_yml"]) - model_changed_test = ("semantic_model_hash" in st.session_state) and ( + model_changed_test = ("semantic_model_hash" not in st.session_state) or ( current_hash != st.session_state["semantic_model_hash"] ) - if ( - "validated" in st.session_state and not st.session_state["validated"] - ) or model_changed_test: - st.error("Please validate your semantic model before evaluating.") + placeholder = st.empty() + + if not model_changed_test and "total_eval_frame" in st.session_state: + placeholder.write("Model has not changed since last evaluation run.") return - if "total_eval_frame" in st.session_state: - del st.session_state["total_eval_frame"] + + if not st.session_state.validated or model_changed_test: + placeholder.write("Validating model...") + try: + # try loading the yaml + _ = yaml.safe_load(st.session_state["working_yml"]) + # try validating the yaml using analyst + validate(st.session_state["working_yml"], get_snowflake_connection()) + st.session_state.validated = True + update_last_validated_model() + except Exception as e: + placeholder.error(f"Could not validate model ❌ with error: {e}") + return + placeholder.write("Model validated ✅") + clear_evaluation_data() st.session_state["semantic_model_hash"] = current_hash st.write("Running evaluation...") st.session_state["eval_timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S") + st.session_state["eval_hash"] = generate_hash(st.session_state["eval_timestamp"]) send_analyst_requests() run_sql_queries() result_comparisons()