From a9f7fd98e9810bd2337be0c806e6651bf4bd23b5 Mon Sep 17 00:00:00 2001 From: markus583 Date: Fri, 3 May 2024 14:42:16 +0000 Subject: [PATCH] fix --- wtpsplit/evaluation/intrinsic_pairwise.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index cfe9278d..0f098e01 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -61,6 +61,7 @@ class Args: keep_logits: bool = True skip_corrupted: bool = True skip_punct: bool = True + return_indices: bool = False # k_mer-specific args k: int = 2 @@ -461,16 +462,17 @@ def main(args): # evaluate each pair for i, k_mer in enumerate(sent_k_mers): start, end = f[lang_code][dataset_name]["test_logit_lengths"][i] - single_score_t, single_score_punct, info = evaluate_mixture( + single_score_t, single_score_punct, info, t_indices, punct_indices = evaluate_mixture( lang_code, f[lang_code][dataset_name]["test_logits"][:][start:end], list(k_mer), + args.return_indices, *clf, ) score_t.append(single_score_t) score_punct.append(single_score_punct) - acc_t.append(info["info_newline"]["correct_pairwise"]) - acc_punct.append(info["info_transformed"]["correct_pairwise"]) + acc_t.append(info["info_newline"]["correct_pairwise"] if info["info_newline"] else None) + acc_punct.append(info["info_transformed"]["correct_pairwise"] if info["info_transformed"] else None) clfs[lang_code][dataset_name] = clf @@ -501,18 +503,19 @@ def main(args): thresholds.append(threshold_adjusted) else: thresholds.append(args.threshold) - single_score_u, _, info = evaluate_mixture( + single_score_u, _, info, u_indices, _ = evaluate_mixture( lang_code, f[lang_code][dataset_name]["test_logits"][:][start:end], list(k_mer), + args.return_indices, *clf, ) score_u.append(single_score_u) acc_u.append(info["info_newline"]["correct_pairwise"]) score_u = np.mean(score_u) - score_t = np.mean(score_t) if score_t else None - score_punct = np.mean(score_punct) if score_punct else None + score_t = np.mean(score_t) if score_t and not args.skip_adaptation else None + score_punct = np.mean(score_punct) if score_punct and not (args.skip_punct or args.skip_adaptation) else None acc_u = np.mean(acc_u) acc_t = np.mean(acc_t) if score_t else None acc_punct = np.mean(acc_punct) if score_punct else None