Skip to content

Commit

Permalink
fixed model rerun on model changes + added eval_hash for easy select
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-tzayats committed Dec 13, 2024
1 parent ee90cca commit 7389906
Showing 1 changed file with 60 additions and 15 deletions.
75 changes: 60 additions & 15 deletions journeys/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import pandas as pd
import snowflake.snowpark._internal.utils as snowpark_utils
import sqlglot
import streamlit as st
import yaml
from loguru import logger
from snowflake.connector.pandas_tools import write_pandas

Expand All @@ -16,6 +18,7 @@
schema_selector_container,
set_sit_query_tag,
table_selector_container,
update_last_validated_model,
validate_table_exist,
validate_table_schema,
)
Expand All @@ -26,6 +29,7 @@
fetch_table,
get_table_hash,
)
from semantic_model_generator.validate_model import validate

EVALUATION_TABLE_SCHEMA = {
"ID": "VARCHAR",
Expand All @@ -46,6 +50,7 @@
"MODEL_HASH": "VARCHAR",
"SEMANTIC_MODEL_STRING": "VARCHAR",
"EVAL_TABLE": "VARCHAR",
"EVAL_HASH": "VARCHAR",
}

LLM_JUDGE_PROMPT_TEMPLATE = """\
Expand Down Expand Up @@ -85,13 +90,27 @@ def visualize_eval_results(frame: pd.DataFrame) -> None:

col1, col2 = st.columns(2)

try:
analyst_sql = sqlglot.parse_one(row["ANALYST_SQL"], dialect="snowflake")
analyst_sql = analyst_sql.sql(dialect="snowflake", pretty=True)
except Exception as e:
logger.warning(f"Error parsing analyst SQL: {e} for {row_id}")
analyst_sql = row["ANALYST_SQL"]

try:
gold_sql = sqlglot.parse_one(row["GOLD_SQL"], dialect="snowflake")
gold_sql = gold_sql.sql(dialect="snowflake", pretty=True)
except Exception as e:
logger.warning(f"Error parsing gold SQL: {e} for {row_id}")
gold_sql = row["GOLD_SQL"]

with col1:
st.write("Analyst SQL")
st.code(row["ANALYST_SQL"], language="sql")
st.code(analyst_sql, language="sql")

with col2:
st.write("Golden SQL")
st.code(row["GOLD_SQL"], language="sql")
st.code(gold_sql, language="sql")

col1, col2 = st.columns(2)
with col1:
Expand Down Expand Up @@ -269,6 +288,7 @@ def result_comparisons() -> None:
def write_eval_results(frame: pd.DataFrame) -> None:
frame_to_write = frame.copy()
frame_to_write["TIMESTAMP"] = st.session_state["eval_timestamp"]
frame_to_write["EVAL_HASH"] = st.session_state["eval_hash"]
frame_to_write["EVAL_TABLE"] = st.session_state["eval_table"]
frame_to_write["EVAL_TABLE_HASH"] = st.session_state["eval_table_hash"]
frame_to_write["MODEL_HASH"] = st.session_state["semantic_model_hash"]
Expand Down Expand Up @@ -550,14 +570,12 @@ def evaluation_data_dialog() -> None:
st.session_state["results_eval_table"] = st.session_state[
"selected_results_eval_table"
]
# clear the results table if it exists
if "total_eval_frame" in st.session_state:
del st.session_state["total_eval_frame"]
clear_evaluation_data()

st.rerun()


def clear_evaluation_data() -> None:
def clear_evaluation_selection() -> None:
session_states = (
"selected_eval_database",
"selected_eval_schema",
Expand All @@ -574,9 +592,21 @@ def clear_evaluation_data() -> None:
del st.session_state[feature]


def clear_evaluation_data() -> None:
session_states = (
"total_eval_frame",
"eval_accuracy",
"analyst_results_frame",
"query_results_frame",
)
for feature in session_states:
if feature in st.session_state:
del st.session_state[feature]


def evaluation_mode_show() -> None:

if st.button("Select Evaluation Tables", on_click=clear_evaluation_data):
if st.button("Select Evaluation Tables", on_click=clear_evaluation_selection):
evaluation_data_dialog()

st.write(
Expand Down Expand Up @@ -609,13 +639,14 @@ def evaluation_mode_show() -> None:
[
["Evaluation Table Hash", st.session_state["eval_table_hash"]],
["Semantic Model Hash", st.session_state["semantic_model_hash"]],
["Evaluation Run Hash", st.session_state["eval_hash"]],
["Timestamp", st.session_state["eval_timestamp"]],
["Accuracy", f"{st.session_state['eval_accuracy']:.2f}%"],
],
columns=["Summary Statistic", "Value"],
)
if model_changed_test:
st.write("Model has changed since last evaluation run.")
st.warning("Model has changed since last evaluation run.")
st.markdown("#### Previous Evaluation Run Summary")
else:
st.markdown("#### Current Evaluation Run Summary")
Expand All @@ -630,19 +661,33 @@ def run_evaluation() -> None:
action="evaluation_run",
)
current_hash = generate_hash(st.session_state["working_yml"])
model_changed_test = ("semantic_model_hash" in st.session_state) and (
model_changed_test = ("semantic_model_hash" not in st.session_state) or (
current_hash != st.session_state["semantic_model_hash"]
)
if (
"validated" in st.session_state and not st.session_state["validated"]
) or model_changed_test:
st.error("Please validate your semantic model before evaluating.")
placeholder = st.empty()

if not model_changed_test and "total_eval_frame" in st.session_state:
placeholder.write("Model has not changed since last evaluation run.")
return
if "total_eval_frame" in st.session_state:
del st.session_state["total_eval_frame"]

if not st.session_state.validated or model_changed_test:
placeholder.write("Validating model...")
try:
# try loading the yaml
_ = yaml.safe_load(st.session_state["working_yml"])
# try validating the yaml using analyst
validate(st.session_state["working_yml"], get_snowflake_connection())
st.session_state.validated = True
update_last_validated_model()
except Exception as e:
placeholder.error(f"Could not validate model ❌ with error: {e}")
return
placeholder.write("Model validated ✅")
clear_evaluation_data()
st.session_state["semantic_model_hash"] = current_hash
st.write("Running evaluation...")
st.session_state["eval_timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S")
st.session_state["eval_hash"] = generate_hash(st.session_state["eval_timestamp"])
send_analyst_requests()
run_sql_queries()
result_comparisons()
Expand Down

0 comments on commit 7389906

Please sign in to comment.