diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index bf3a7fba..ed326157 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -242,7 +242,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st total_test_time = 0 # Initialize total test processing time start_time = time.time() - with h5py.File(logits_path, "w") as f, torch.no_grad(): + with h5py.File(logits_path, "a") as f, torch.no_grad(): for lang_code in Constants.LANGINFO.index: if args.include_langs is not None and lang_code not in args.include_langs: continue @@ -525,6 +525,9 @@ def main(args): ) score_u.append(single_score_u) acc_u.append(info["info_newline"]["correct_pairwise"]) + u_indices.append(cur_u_indices["pred_indices"] if cur_u_indices["pred_indices"] else []) + true_indices.append(cur_u_indices["true_indices"] if cur_u_indices["true_indices"] else []) + length.append(cur_u_indices["length"]) score_u = np.mean(score_u) score_t = np.mean(score_t) if score_t and not args.skip_adaptation else None @@ -535,9 +538,7 @@ def main(args): acc_t = np.mean(acc_t) if score_t else None acc_punct = np.mean(acc_punct) if score_punct else None threshold = np.mean(thresholds) - u_indices.append(cur_u_indices["pred_indices"] if cur_u_indices["pred_indices"] else []) - true_indices.append(cur_u_indices["true_indices"] if cur_u_indices["true_indices"] else []) - length.append(cur_u_indices["length"]) + results[lang_code][dataset_name] = { "u": score_u,