Skip to content

Commit

Permalink
Merge pull request #219 from Snowflake-Labs/tzayats/add-loading-eval-…
Browse files Browse the repository at this point in the history
…from-cache

[Customer Eval] Allow loading eval data from cache
  • Loading branch information
sfc-gh-tzayats authored Dec 10, 2024
2 parents 68c935f + c3b3709 commit c5547fc
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 39 deletions.
106 changes: 72 additions & 34 deletions journeys/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from app_utils.shared_utils import (
get_snowflake_connection,
schema_selector_container,
set_sit_query_tag,
table_selector_container,
validate_table_exist,
validate_table_schema,
set_sit_query_tag,
)
from semantic_model_generator.data_processing.proto_utils import proto_to_yaml
from semantic_model_generator.snowflake_utils.snowflake_connector import (
Expand Down Expand Up @@ -261,28 +261,32 @@ def result_comparisons() -> None:
status_text.text(
f"Analyst and Gold Results Compared ✅ (Time taken: {elapsed_time:.2f} seconds)"
)
# compute accuracy
st.session_state["eval_accuracy"] = (frame["CORRECT"].sum() / len(frame)) * 100
st.session_state["total_eval_frame"] = frame

visualize_eval_results(frame)

frame["TIMESTAMP"] = st.session_state["eval_timestamp"]
frame["EVAL_TABLE"] = st.session_state["eval_table"]
frame["EVAL_TABLE_HASH"] = st.session_state["eval_table_hash"]
frame["MODEL_HASH"] = st.session_state["semantic_model_hash"]
def write_eval_results(frame: pd.DataFrame) -> None:
frame_to_write = frame.copy()
frame_to_write["TIMESTAMP"] = st.session_state["eval_timestamp"]
frame_to_write["EVAL_TABLE"] = st.session_state["eval_table"]
frame_to_write["EVAL_TABLE_HASH"] = st.session_state["eval_table_hash"]
frame_to_write["MODEL_HASH"] = st.session_state["semantic_model_hash"]

# Save results to frame as string
frame["ANALYST_RESULT"] = frame["ANALYST_RESULT"].apply(
frame_to_write["ANALYST_RESULT"] = frame["ANALYST_RESULT"].apply(
lambda x: x.to_string(index=False) if isinstance(x, pd.DataFrame) else x
)
frame["GOLD_RESULT"] = frame["GOLD_RESULT"].apply(
frame_to_write["GOLD_RESULT"] = frame["GOLD_RESULT"].apply(
lambda x: x.to_string(index=False) if isinstance(x, pd.DataFrame) else x
)
frame["SEMANTIC_MODEL_STRING"] = st.session_state["working_yml"]
frame_to_write["SEMANTIC_MODEL_STRING"] = st.session_state["working_yml"]

frame = frame.reset_index()[list(RESULTS_TABLE_SCHEMA)]
frame_to_write = frame_to_write.reset_index()[list(RESULTS_TABLE_SCHEMA)]
write_pandas(
conn=get_snowflake_connection(),
df=frame,
table_name=st.session_state["selected_results_eval_table"],
df=frame_to_write,
table_name=st.session_state["results_eval_table"],
overwrite=False,
quote_identifiers=False,
auto_create_table=False,
Expand Down Expand Up @@ -413,6 +417,16 @@ def _get_content(
@st.experimental_dialog("Evaluation Tables", width="large")
def evaluation_data_dialog() -> None:
st.markdown("Please select an evaluation table.")
st.markdown("The evaluation table should have the following schema:")
eval_table_schema_explained = pd.DataFrame(
[
["ID", "VARCHAR", "Unique identifier for each row"],
["QUERY", "VARCHAR", "The query to be evaluated"],
["GOLD_SQL", "VARCHAR", "The expected SQL for the query"],
],
columns=["Column", "Type", "Description"],
)
st.dataframe(eval_table_schema_explained, hide_index=True)
table_selector_container(
db_selector={"key": "selected_eval_database", "label": "Evaluation database"},
schema_selector={"key": "selected_eval_schema", "label": "Evaluation schema"},
Expand Down Expand Up @@ -536,6 +550,9 @@ def evaluation_data_dialog() -> None:
st.session_state["results_eval_table"] = st.session_state[
"selected_results_eval_table"
]
# clear the results table if it exists
if "total_eval_frame" in st.session_state:
del st.session_state["total_eval_frame"]

st.rerun()

Expand All @@ -559,9 +576,13 @@ def clear_evaluation_data() -> None:

def evaluation_mode_show() -> None:

if st.button("Set Evaluation Tables", on_click=clear_evaluation_data):
if st.button("Select Evaluation Tables", on_click=clear_evaluation_data):
evaluation_data_dialog()

st.write(
"Welcome!🧪 In the evaluation mode you can evaluate your semantic model using pairs of golden queries/questions and their expected SQL statements. These pairs should be captured in an **Evaluation Table**. Accuracy metrics will be shown and the results will be stored in an **Evaluation Results Table**."
)

# TODO: find a less awkward way of specifying this.
if any(key not in st.session_state for key in ("eval_table", "results_eval_table")):
st.error("Please select evaluation tables.")
Expand All @@ -577,39 +598,56 @@ def evaluation_mode_show() -> None:
)
st.markdown("#### Evaluation Data Summary")
st.dataframe(summary_stats, hide_index=True)

if st.button("Run Evaluation"):
set_sit_query_tag(
get_snowflake_connection(),
vendor="",
action="evaluation_run",
)
run_evaluation()

if "total_eval_frame" in st.session_state:
current_hash = generate_hash(st.session_state["working_yml"])
model_changed_test = ("semantic_model_hash" in st.session_state) and (
current_hash != st.session_state["semantic_model_hash"]
)
if (
"validated" in st.session_state and not st.session_state["validated"]
) or model_changed_test:
st.error("Please validate your semantic model before evaluating.")
return
st.session_state["semantic_model_hash"] = current_hash
st.write("Running evaluation...")
st.session_state["eval_timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S")
model_changed_test = current_hash != st.session_state["semantic_model_hash"]

evolution_run_summary = pd.DataFrame(
[
["Evaluation Table Hash", st.session_state["eval_table_hash"]],
["Semantic Model Hash", st.session_state["semantic_model_hash"]],
["Timestamp", st.session_state["eval_timestamp"]],
["Accuracy", f"{st.session_state['eval_accuracy']:.2f}%"],
],
columns=["Summary Statistic", "Value"],
)
st.markdown("#### Evaluation Run Summary")
if model_changed_test:
st.write("Model has changed since last evaluation run.")
st.markdown("#### Previous Evaluation Run Summary")
else:
st.markdown("#### Current Evaluation Run Summary")
st.dataframe(evolution_run_summary, hide_index=True)
visualize_eval_results(st.session_state["total_eval_frame"])

send_analyst_requests()
run_sql_queries()
result_comparisons()

def run_evaluation() -> None:
set_sit_query_tag(
get_snowflake_connection(),
vendor="",
action="evaluation_run",
)
current_hash = generate_hash(st.session_state["working_yml"])
model_changed_test = ("semantic_model_hash" in st.session_state) and (
current_hash != st.session_state["semantic_model_hash"]
)
if (
"validated" in st.session_state and not st.session_state["validated"]
) or model_changed_test:
st.error("Please validate your semantic model before evaluating.")
return
if "total_eval_frame" in st.session_state:
del st.session_state["total_eval_frame"]
st.session_state["semantic_model_hash"] = current_hash
st.write("Running evaluation...")
st.session_state["eval_timestamp"] = time.strftime("%Y-%m-%d %H:%M:%S")
send_analyst_requests()
run_sql_queries()
result_comparisons()
write_eval_results(st.session_state["total_eval_frame"])
st.write("Evaluation complete ✅")


@st.cache_resource(show_spinner=False)
Expand Down
8 changes: 4 additions & 4 deletions journeys/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,10 +677,10 @@ def show() -> None:
return_home_button()
with mode:
st.session_state["app_mode"] = st.selectbox(
label="App Mode",
label_visibility="collapsed",
options=["Chat", "Evaluation", "Preview YAML"],
)
label="App Mode",
label_visibility="collapsed",
options=["Chat", "Evaluation", "Preview YAML"],
)
if "yaml" not in st.session_state:
# Only proceed to download the YAML from stage if we don't have one from the builder flow.
yaml = download_yaml(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import concurrent.futures
from collections import defaultdict
from contextlib import contextmanager
from textwrap import dedent
from typing import Any, Dict, Generator, List, Optional, TypeVar, Union

import pandas as pd
Expand Down

0 comments on commit c5547fc

Please sign in to comment.