From 98443b7fa94e0aeafbdf9c0335854742b3def97e Mon Sep 17 00:00:00 2001 From: markus583 Date: Sun, 19 May 2024 16:08:31 +0000 Subject: [PATCH] finally fix indices --- wtpsplit/evaluation/intrinsic_pairwise.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index 457f5033..bf3a7fba 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -19,6 +19,7 @@ import wtpsplit.models # noqa: F401 from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs +from wtpsplit.evaluation.intrinsic_baselines import split_language_data from wtpsplit.extract import PyTorchWrapper from wtpsplit.extract_batched import extract_batched from wtpsplit.utils import Constants @@ -241,11 +242,10 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st total_test_time = 0 # Initialize total test processing time start_time = time.time() - with h5py.File(logits_path, "a") as f, torch.no_grad(): + with h5py.File(logits_path, "w") as f, torch.no_grad(): for lang_code in Constants.LANGINFO.index: if args.include_langs is not None and lang_code not in args.include_langs: continue - print(f"Processing {lang_code}...") if lang_code not in f: lang_group = f.create_group(lang_code) @@ -254,8 +254,11 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st # eval data for dataset_name, dataset in eval_data[lang_code]["sentence"].items(): - if args.skip_corrupted and "corrupted" in dataset_name and"ted2020" not in dataset_name: + if args.skip_corrupted and "corrupted" in dataset_name and "ted2020" not in dataset_name: continue + if "-" in lang_code and "canine" in args.model_path and "no-adapters" not in args.model_path: + # code-switched data: eval 2x + lang_code = lang_code.split("_")[1].lower() try: if args.adapter_path: model.model.load_adapter( @@ -377,6 +380,8 @@ def main(args): print(save_str) eval_data = torch.load(args.eval_data_path) + if "canine" in args.model_path and not "no-adapters" in args.model_path: + eval_data = split_language_data(eval_data) if args.valid_text_path is not None: valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train") else: @@ -530,7 +535,9 @@ def main(args): acc_t = np.mean(acc_t) if score_t else None acc_punct = np.mean(acc_punct) if score_punct else None threshold = np.mean(thresholds) - + u_indices.append(cur_u_indices["pred_indices"] if cur_u_indices["pred_indices"] else []) + true_indices.append(cur_u_indices["true_indices"] if cur_u_indices["true_indices"] else []) + length.append(cur_u_indices["length"]) results[lang_code][dataset_name] = { "u": score_u, @@ -596,7 +603,7 @@ def main(args): ), indent=4, ) - + if args.return_indices: json.dump( indices,