Skip to content

Commit

Permalink
feat: update transformation button and enhance plotting with transfor…
Browse files Browse the repository at this point in the history
…med data handling
  • Loading branch information
Ovler-Young committed Dec 2, 2024
1 parent 705787d commit aca7067
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions src/ia_collection_analyzer/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def transform_data():
elif transform_type == "Numeric Bins":
num_bins = st.number_input("Number of bins:", min_value=2, value=5)

if st.button("Preview Transformation"):
if st.button("Preview and Apply"):
if transform_type == "Date Quarter":
new_col = pd.to_datetime(filtered_pd[source_col]).dt.quarter
elif transform_type == "Date Week":
Expand Down Expand Up @@ -435,7 +435,6 @@ def safe_map(x):
st.write("Preview showing examples of each mapping:")
st.write(preview_df.T)

if st.button("Apply Transformation"):
st.session_state.transformed_data = {
"source_col": source_col,
"transform_type": transform_type,
Expand All @@ -447,7 +446,6 @@ def safe_map(x):
)
st.session_state.original_values[source_col] = preview_df["Original"]

st.rerun()


@st.fragment
Expand All @@ -471,11 +469,20 @@ def plot_data():
st.write("Plotting the data...")
st.write(f"X-axis: {x_axis}, Y-axis: {y_axis}")

# if y_axis is hashable , plot
if isinstance(filtered_pd[y_axis].iloc[0], (int, float, np.int64, np.float64)):
# Create comprehensive aggregation table
# Create a working copy of the dataframe
plot_df = filtered_pd.copy()

# Replace data with transformed versions if available
for axis, col_name in [("x", x_axis), ("y", y_axis)]:
if col_name in st.session_state.transformed_columns:
if st.session_state.transformed_data["source_col"] == col_name:
plot_df[col_name] = st.session_state.transformed_data["new_col"]
st.write(f"Using transformed data for {axis}-axis")

# Continue with plotting logic using plot_df instead of filtered_pd
if isinstance(plot_df[y_axis].iloc[0], (int, float, np.int64, np.float64)):
all_metrics = (
filtered_pd.groupby(x_axis)[y_axis]
plot_df.groupby(x_axis)[y_axis]
.agg(
[
("Count", "count"),
Expand Down Expand Up @@ -505,14 +512,14 @@ def plot_data():
st.write("Analyzing distribution across categories...")

# Create mask for list and non-list values
is_list_mask = filtered_pd[y_axis].apply(lambda x: isinstance(x, list))
is_list_mask = plot_df[y_axis].apply(lambda x: isinstance(x, list))

# Handle list values
list_data = filtered_pd[is_list_mask][[x_axis, y_axis]].copy()
list_data = plot_df[is_list_mask][[x_axis, y_axis]].copy()
exploded_list = list_data.explode(y_axis)

# Handle non-list values
non_list_data = filtered_pd[~is_list_mask][[x_axis, y_axis]]
non_list_data = plot_df[~is_list_mask][[x_axis, y_axis]]

# Combine results efficiently
expanded_df = pd.concat([exploded_list, non_list_data])
Expand Down

0 comments on commit aca7067

Please sign in to comment.