Skip to content

Commit

Permalink
update intrinsic.py, add violin plot
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jan 12, 2024
1 parent 1fd7730 commit e72c358
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 4 deletions.
52 changes: 48 additions & 4 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Args:
stride: int = 64
batch_size: int = 32
include_langs: List[str] = None
threshold: float = 0.01


def process_logits(text, model, lang_code, args):
Expand Down Expand Up @@ -70,10 +71,12 @@ def process_logits(text, model, lang_code, args):


def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_sentences=10_000):
logits_path = Constants.CACHE_DIR / (model.config.mixture_name + "_logits_u0001.h5")
logits_path = Constants.CACHE_DIR / (
f"{args.model_path.split('/')[0]}_b{args.block_size}+s{args.stride}_logits_u{args.threshold}.h5"
)

# TODO: revert to "a"
with h5py.File(logits_path, "w") as f, torch.no_grad():
with h5py.File(logits_path, "a") as f, torch.no_grad():
for lang_code in Constants.LANGINFO.index:
if args.include_langs is not None and lang_code not in args.include_langs:
continue
Expand Down Expand Up @@ -127,6 +130,24 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
return h5py.File(logits_path, "r")


def compute_statistics(values):
if not values: # Check for empty values list
return {"mean": None, "median": None, "std": None, "min": None, "min_lang": None, "max": None, "max_lang": None}

scores, langs = zip(*values) # Unpack scores and languages
min_index = np.argmin(scores)
max_index = np.argmax(scores)
return {
"mean": np.mean(scores),
"median": np.median(scores),
"std": np.std(scores),
"min": scores[min_index],
"min_lang": langs[min_index],
"max": scores[max_index],
"max_lang": langs[max_index]
}


if __name__ == "__main__":
(args,) = HfArgumentParser([Args]).parse_args_into_dataclasses()

Expand All @@ -145,6 +166,8 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
# now, compute the intrinsic scores.
results = {}
clfs = {}
# Initialize lists to store scores for each metric across all languages
u_scores, t_scores, punct_scores = [], [], []

for lang_code, dsets in tqdm(eval_data.items()):
if args.include_langs is not None and lang_code not in args.include_langs:
Expand All @@ -169,6 +192,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
print(clf)
print(np.argsort(clf[0].coef_[0])[:10], "...", np.argsort(clf[0].coef_[0])[-10:])
print(np.where(np.argsort(clf[0].coef_[0]) == 0)[0])

score_t, score_punct, _ = evaluate_mixture(
lang_code,
f[lang_code][dataset_name]["test_logits"][:],
Expand All @@ -194,20 +218,40 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
# just for printing
score_t = score_t or 0.0
score_punct = score_punct or 0.0

u_scores.append((score_u, lang_code))
t_scores.append((score_t, lang_code))
punct_scores.append((score_punct, lang_code))
print(f"{lang_code} {dataset_name} {score_u:.3f} {score_t:.3f} {score_punct:.3f}")

# Compute statistics for each metric across all languages
results_avg = {
"u": compute_statistics(u_scores),
"t": compute_statistics(t_scores),
"punct": compute_statistics(punct_scores),
"include_langs": args.include_langs,
}

sio.dump(
clfs,
open(
Constants.CACHE_DIR / (model.config.mixture_name + ".skops"),
Constants.CACHE_DIR / (f"{args.model_path.split('/')[0]}_b{args.block_size}+s{args.stride}.skops"),
"wb",
),
)
json.dump(
results,
open(
Constants.CACHE_DIR / (model.config.mixture_name + "_intrinsic_results_u0005.json"),
Constants.CACHE_DIR
/ (f"{args.model_path.split('/')[0]}_b{args.block_size}+s{args.stride}_intrinsic_results_u{args.threshold}.json"),
"w",
),
indent=4,
)

# Write results_avg to JSON
json.dump(
results_avg,
open(Constants.CACHE_DIR / (f"{args.model_path.split('/')[0]}_b{args.block_size}+s{args.stride}_u{args.threshold}_AVG.json"), "w"),
indent=4,
)
122 changes: 122 additions & 0 deletions wtpsplit/summary_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pandas as pd
import plotly.graph_objects as go
import json

FILES = [
".cache/xlmr-normal-v2_b512+s_64_intrinsic_results_u0.01.json",
".cache/xlmr-normal-v2_b512+s64_intrinsic_results_u0.01.json",
".cache/xlm-tokenv2_intrinsic_results_u001.json",
]
NAME = "test"


def darken_color(color, factor):
"""Darken a given RGB color by a specified factor."""
r, g, b, a = color
r = max(int(r * factor), 0)
g = max(int(g * factor), 0)
b = max(int(b * factor), 0)
return (r, g, b, a)


def plot_violin_from_json(files, name):
# Prepare data
data = {"score": [], "metric": [], "file": [], "x": []}
spacing = 1.0 # Space between groups of metrics
violin_width = 0.3 # Width of each violin

# Base colors for each metric
base_colors = {"u": (0, 123, 255, 0.6), "t": (40, 167, 69, 0.6), "punct": (255, 193, 7, 0.6)}
color_darkening_factor = 0.8 # Factor to darken color for each subsequent file

# Compute x positions and prepare colors for each file within each metric group
x_positions = {}
colors = {}
for i, metric in enumerate(["u", "t", "punct"]):
x_positions[metric] = {}
colors[metric] = {}
base_color = base_colors[metric]
for j, file in enumerate(files):
x_positions[metric][file] = i * spacing + spacing / (len(files) + 1) * (j + 1)
colors[metric][file] = "rgba" + str(darken_color(base_color, color_darkening_factor**j))

for file in files:
with open(file, "r") as f:
content = json.load(f)
for lang, scores in content.items():
for dataset, values in scores.items():
for metric in ["u", "t", "punct"]:
data["score"].append(values[metric])
data["metric"].append(metric)
data["file"].append(
file.split("/")[-1].split(".")[0]
) # Use file base name without extension for legend
data["x"].append(x_positions[metric][file]) # Use computed x position

# Convert to DataFrame
df = pd.DataFrame(data)

# Create violin plots
fig = go.Figure()
for metric in ["u", "t", "punct"]:
metric_df = df[df["metric"] == metric]
for file in files:
file_name = file.split("/")[-1].split(".")[0]
file_df = metric_df[metric_df["file"] == file_name]
if not file_df.empty:
fig.add_trace(
go.Violin(
y=file_df["score"],
x=file_df["x"],
name=file_name if metric == "u" else "", # Only show legend for 'u' to avoid duplicates
legendgroup=file_name,
line_color=colors[metric][file],
box_visible=True,
meanline_visible=True,
width=violin_width,
)
)

# Update layout
# Calculate the center positions for each metric group
center_positions = [sum(values.values()) / len(values) for key, values in x_positions.items()]
fig.update_layout(
title="Violin Plots of Scores by Metric",
xaxis=dict(title="Metric", tickvals=center_positions, ticktext=["u", "t", "punct"]),
yaxis_title="Scores",
violingap=0,
violingroupgap=0,
violinmode="overlay",
paper_bgcolor="white",
plot_bgcolor="white",
font=dict(color="black", size=14),
title_x=0.5,
legend_title_text="File",
xaxis_showgrid=False,
yaxis_showgrid=True,
xaxis_zeroline=False,
yaxis_zeroline=False,
margin=dict(l=40, r=40, t=40, b=40),
)

# Update axes lines
fig.update_xaxes(showline=True, linewidth=1, linecolor="gray")
fig.update_yaxes(showline=True, linewidth=1, linecolor="gray")

# Simplify and move legend to the top
fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1))

fig.show()

# Save the figure as HTML
html_filename = f"{name}.html"
fig.write_html(html_filename)
print(f"Plot saved as {html_filename}")

# Save the figure as PNG
png_filename = f"{name}.png"
fig.write_image(png_filename)
print(f"Plot saved as {png_filename}")


plot_violin_from_json(FILES, NAME)

0 comments on commit e72c358

Please sign in to comment.