Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 3, 2024
1 parent 14f4283 commit a9f7fd9
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions wtpsplit/evaluation/intrinsic_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Args:
keep_logits: bool = True
skip_corrupted: bool = True
skip_punct: bool = True
return_indices: bool = False

# k_mer-specific args
k: int = 2
Expand Down Expand Up @@ -461,16 +462,17 @@ def main(args):
# evaluate each pair
for i, k_mer in enumerate(sent_k_mers):
start, end = f[lang_code][dataset_name]["test_logit_lengths"][i]
single_score_t, single_score_punct, info = evaluate_mixture(
single_score_t, single_score_punct, info, t_indices, punct_indices = evaluate_mixture(
lang_code,
f[lang_code][dataset_name]["test_logits"][:][start:end],
list(k_mer),
args.return_indices,
*clf,
)
score_t.append(single_score_t)
score_punct.append(single_score_punct)
acc_t.append(info["info_newline"]["correct_pairwise"])
acc_punct.append(info["info_transformed"]["correct_pairwise"])
acc_t.append(info["info_newline"]["correct_pairwise"] if info["info_newline"] else None)
acc_punct.append(info["info_transformed"]["correct_pairwise"] if info["info_transformed"] else None)

clfs[lang_code][dataset_name] = clf

Expand Down Expand Up @@ -501,18 +503,19 @@ def main(args):
thresholds.append(threshold_adjusted)
else:
thresholds.append(args.threshold)
single_score_u, _, info = evaluate_mixture(
single_score_u, _, info, u_indices, _ = evaluate_mixture(
lang_code,
f[lang_code][dataset_name]["test_logits"][:][start:end],
list(k_mer),
args.return_indices,
*clf,
)
score_u.append(single_score_u)
acc_u.append(info["info_newline"]["correct_pairwise"])

score_u = np.mean(score_u)
score_t = np.mean(score_t) if score_t else None
score_punct = np.mean(score_punct) if score_punct else None
score_t = np.mean(score_t) if score_t and not args.skip_adaptation else None
score_punct = np.mean(score_punct) if score_punct and not (args.skip_punct or args.skip_adaptation) else None
acc_u = np.mean(acc_u)
acc_t = np.mean(acc_t) if score_t else None
acc_punct = np.mean(acc_punct) if score_punct else None
Expand Down

0 comments on commit a9f7fd9

Please sign in to comment.