diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index 8f0760a0..d10c2908 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -40,6 +40,7 @@ class Args: stride: int = 64 batch_size: int = 32 include_langs: List[str] = None + threshold: float = 0.01 def process_logits(text, model, lang_code, args): @@ -70,10 +71,12 @@ def process_logits(text, model, lang_code, args): def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_sentences=10_000): - logits_path = Constants.CACHE_DIR / (model.config.mixture_name + "_logits_u0001.h5") + logits_path = Constants.CACHE_DIR / ( + f"{args.model_path.split('/')[0]}_b{args.block_size}+s{args.stride}_logits_u{args.threshold}.h5" + ) # TODO: revert to "a" - with h5py.File(logits_path, "w") as f, torch.no_grad(): + with h5py.File(logits_path, "a") as f, torch.no_grad(): for lang_code in Constants.LANGINFO.index: if args.include_langs is not None and lang_code not in args.include_langs: continue @@ -127,6 +130,24 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_ return h5py.File(logits_path, "r") +def compute_statistics(values): + if not values: # Check for empty values list + return {"mean": None, "median": None, "std": None, "min": None, "min_lang": None, "max": None, "max_lang": None} + + scores, langs = zip(*values) # Unpack scores and languages + min_index = np.argmin(scores) + max_index = np.argmax(scores) + return { + "mean": np.mean(scores), + "median": np.median(scores), + "std": np.std(scores), + "min": scores[min_index], + "min_lang": langs[min_index], + "max": scores[max_index], + "max_lang": langs[max_index] + } + + if __name__ == "__main__": (args,) = HfArgumentParser([Args]).parse_args_into_dataclasses() @@ -145,6 +166,8 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_ # now, compute the intrinsic scores. results = {} clfs = {} + # Initialize lists to store scores for each metric across all languages + u_scores, t_scores, punct_scores = [], [], [] for lang_code, dsets in tqdm(eval_data.items()): if args.include_langs is not None and lang_code not in args.include_langs: @@ -169,6 +192,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_ print(clf) print(np.argsort(clf[0].coef_[0])[:10], "...", np.argsort(clf[0].coef_[0])[-10:]) print(np.where(np.argsort(clf[0].coef_[0]) == 0)[0]) + score_t, score_punct, _ = evaluate_mixture( lang_code, f[lang_code][dataset_name]["test_logits"][:], @@ -194,20 +218,40 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_ # just for printing score_t = score_t or 0.0 score_punct = score_punct or 0.0 + + u_scores.append((score_u, lang_code)) + t_scores.append((score_t, lang_code)) + punct_scores.append((score_punct, lang_code)) print(f"{lang_code} {dataset_name} {score_u:.3f} {score_t:.3f} {score_punct:.3f}") + # Compute statistics for each metric across all languages + results_avg = { + "u": compute_statistics(u_scores), + "t": compute_statistics(t_scores), + "punct": compute_statistics(punct_scores), + "include_langs": args.include_langs, + } + sio.dump( clfs, open( - Constants.CACHE_DIR / (model.config.mixture_name + ".skops"), + Constants.CACHE_DIR / (f"{args.model_path.split('/')[0]}_b{args.block_size}+s{args.stride}.skops"), "wb", ), ) json.dump( results, open( - Constants.CACHE_DIR / (model.config.mixture_name + "_intrinsic_results_u0005.json"), + Constants.CACHE_DIR + / (f"{args.model_path.split('/')[0]}_b{args.block_size}+s{args.stride}_intrinsic_results_u{args.threshold}.json"), "w", ), indent=4, ) + + # Write results_avg to JSON + json.dump( + results_avg, + open(Constants.CACHE_DIR / (f"{args.model_path.split('/')[0]}_b{args.block_size}+s{args.stride}_u{args.threshold}_AVG.json"), "w"), + indent=4, + ) diff --git a/wtpsplit/summary_plot.py b/wtpsplit/summary_plot.py new file mode 100644 index 00000000..f8f3438a --- /dev/null +++ b/wtpsplit/summary_plot.py @@ -0,0 +1,122 @@ +import pandas as pd +import plotly.graph_objects as go +import json + +FILES = [ + ".cache/xlmr-normal-v2_b512+s_64_intrinsic_results_u0.01.json", + ".cache/xlmr-normal-v2_b512+s64_intrinsic_results_u0.01.json", + ".cache/xlm-tokenv2_intrinsic_results_u001.json", +] +NAME = "test" + + +def darken_color(color, factor): + """Darken a given RGB color by a specified factor.""" + r, g, b, a = color + r = max(int(r * factor), 0) + g = max(int(g * factor), 0) + b = max(int(b * factor), 0) + return (r, g, b, a) + + +def plot_violin_from_json(files, name): + # Prepare data + data = {"score": [], "metric": [], "file": [], "x": []} + spacing = 1.0 # Space between groups of metrics + violin_width = 0.3 # Width of each violin + + # Base colors for each metric + base_colors = {"u": (0, 123, 255, 0.6), "t": (40, 167, 69, 0.6), "punct": (255, 193, 7, 0.6)} + color_darkening_factor = 0.8 # Factor to darken color for each subsequent file + + # Compute x positions and prepare colors for each file within each metric group + x_positions = {} + colors = {} + for i, metric in enumerate(["u", "t", "punct"]): + x_positions[metric] = {} + colors[metric] = {} + base_color = base_colors[metric] + for j, file in enumerate(files): + x_positions[metric][file] = i * spacing + spacing / (len(files) + 1) * (j + 1) + colors[metric][file] = "rgba" + str(darken_color(base_color, color_darkening_factor**j)) + + for file in files: + with open(file, "r") as f: + content = json.load(f) + for lang, scores in content.items(): + for dataset, values in scores.items(): + for metric in ["u", "t", "punct"]: + data["score"].append(values[metric]) + data["metric"].append(metric) + data["file"].append( + file.split("/")[-1].split(".")[0] + ) # Use file base name without extension for legend + data["x"].append(x_positions[metric][file]) # Use computed x position + + # Convert to DataFrame + df = pd.DataFrame(data) + + # Create violin plots + fig = go.Figure() + for metric in ["u", "t", "punct"]: + metric_df = df[df["metric"] == metric] + for file in files: + file_name = file.split("/")[-1].split(".")[0] + file_df = metric_df[metric_df["file"] == file_name] + if not file_df.empty: + fig.add_trace( + go.Violin( + y=file_df["score"], + x=file_df["x"], + name=file_name if metric == "u" else "", # Only show legend for 'u' to avoid duplicates + legendgroup=file_name, + line_color=colors[metric][file], + box_visible=True, + meanline_visible=True, + width=violin_width, + ) + ) + + # Update layout + # Calculate the center positions for each metric group + center_positions = [sum(values.values()) / len(values) for key, values in x_positions.items()] + fig.update_layout( + title="Violin Plots of Scores by Metric", + xaxis=dict(title="Metric", tickvals=center_positions, ticktext=["u", "t", "punct"]), + yaxis_title="Scores", + violingap=0, + violingroupgap=0, + violinmode="overlay", + paper_bgcolor="white", + plot_bgcolor="white", + font=dict(color="black", size=14), + title_x=0.5, + legend_title_text="File", + xaxis_showgrid=False, + yaxis_showgrid=True, + xaxis_zeroline=False, + yaxis_zeroline=False, + margin=dict(l=40, r=40, t=40, b=40), + ) + + # Update axes lines + fig.update_xaxes(showline=True, linewidth=1, linecolor="gray") + fig.update_yaxes(showline=True, linewidth=1, linecolor="gray") + + # Simplify and move legend to the top + fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)) + + fig.show() + + # Save the figure as HTML + html_filename = f"{name}.html" + fig.write_html(html_filename) + print(f"Plot saved as {html_filename}") + + # Save the figure as PNG + png_filename = f"{name}.png" + fig.write_image(png_filename) + print(f"Plot saved as {png_filename}") + + +plot_violin_from_json(FILES, NAME)