From aca7067e3dd2d4c11ddbf4e7f5180e8e02a69f33 Mon Sep 17 00:00:00 2001 From: Ovler Date: Mon, 2 Dec 2024 03:49:01 -0500 Subject: [PATCH] feat: update transformation button and enhance plotting with transformed data handling --- src/ia_collection_analyzer/streamlit.py | 27 ++++++++++++++++--------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/ia_collection_analyzer/streamlit.py b/src/ia_collection_analyzer/streamlit.py index 66499f13..38dbb328 100644 --- a/src/ia_collection_analyzer/streamlit.py +++ b/src/ia_collection_analyzer/streamlit.py @@ -359,7 +359,7 @@ def transform_data(): elif transform_type == "Numeric Bins": num_bins = st.number_input("Number of bins:", min_value=2, value=5) - if st.button("Preview Transformation"): + if st.button("Preview and Apply"): if transform_type == "Date Quarter": new_col = pd.to_datetime(filtered_pd[source_col]).dt.quarter elif transform_type == "Date Week": @@ -435,7 +435,6 @@ def safe_map(x): st.write("Preview showing examples of each mapping:") st.write(preview_df.T) - if st.button("Apply Transformation"): st.session_state.transformed_data = { "source_col": source_col, "transform_type": transform_type, @@ -447,7 +446,6 @@ def safe_map(x): ) st.session_state.original_values[source_col] = preview_df["Original"] - st.rerun() @st.fragment @@ -471,11 +469,20 @@ def plot_data(): st.write("Plotting the data...") st.write(f"X-axis: {x_axis}, Y-axis: {y_axis}") - # if y_axis is hashable , plot - if isinstance(filtered_pd[y_axis].iloc[0], (int, float, np.int64, np.float64)): - # Create comprehensive aggregation table + # Create a working copy of the dataframe + plot_df = filtered_pd.copy() + + # Replace data with transformed versions if available + for axis, col_name in [("x", x_axis), ("y", y_axis)]: + if col_name in st.session_state.transformed_columns: + if st.session_state.transformed_data["source_col"] == col_name: + plot_df[col_name] = st.session_state.transformed_data["new_col"] + st.write(f"Using transformed data for {axis}-axis") + + # Continue with plotting logic using plot_df instead of filtered_pd + if isinstance(plot_df[y_axis].iloc[0], (int, float, np.int64, np.float64)): all_metrics = ( - filtered_pd.groupby(x_axis)[y_axis] + plot_df.groupby(x_axis)[y_axis] .agg( [ ("Count", "count"), @@ -505,14 +512,14 @@ def plot_data(): st.write("Analyzing distribution across categories...") # Create mask for list and non-list values - is_list_mask = filtered_pd[y_axis].apply(lambda x: isinstance(x, list)) + is_list_mask = plot_df[y_axis].apply(lambda x: isinstance(x, list)) # Handle list values - list_data = filtered_pd[is_list_mask][[x_axis, y_axis]].copy() + list_data = plot_df[is_list_mask][[x_axis, y_axis]].copy() exploded_list = list_data.explode(y_axis) # Handle non-list values - non_list_data = filtered_pd[~is_list_mask][[x_axis, y_axis]] + non_list_data = plot_df[~is_list_mask][[x_axis, y_axis]] # Combine results efficiently expanded_df = pd.concat([exploded_list, non_list_data])