From 2f06d8484d1754bdf7853e56adcf875b8e630d2f Mon Sep 17 00:00:00 2001 From: markus583 Date: Sun, 19 May 2024 06:06:19 +0000 Subject: [PATCH] baselines --- wtpsplit/evaluation/__init__.py | 3 + wtpsplit/evaluation/download_spacy.py | 47 +++--- wtpsplit/evaluation/intrinsic_baselines.py | 36 ++-- .../intrinsic_baselines_multilingual.py | 157 ++++++++++++++++++ 4 files changed, 203 insertions(+), 40 deletions(-) create mode 100644 wtpsplit/evaluation/intrinsic_baselines_multilingual.py diff --git a/wtpsplit/evaluation/__init__.py b/wtpsplit/evaluation/__init__.py index 37349fc7..5a0c7a7f 100644 --- a/wtpsplit/evaluation/__init__.py +++ b/wtpsplit/evaluation/__init__.py @@ -365,6 +365,7 @@ def pysbd_sentencize(lang_code, text): "es": "es_core_news_sm", "sv": "sv_core_news_sm", "uk": "uk_core_news_sm", + "xx": "xx_sent_ud_sm", } @@ -374,6 +375,7 @@ def spacy_sent_sentencize(lang_code, text): try: nlp = spacy.blank(lang_code) nlp.add_pipe("sentencizer") + nlp.max_length = 1_000_000_000 if lang_code == "ja": # spacy uses SudachiPy for japanese, which has a length limit: @@ -397,6 +399,7 @@ def spacy_dp_sentencize(lang_code, text): try: nlp = spacy.load(SPACY_LANG_TO_DP_MODEL[lang_code], disable=["ner"]) + nlp.max_length = 1_000_000_000 if lang_code == "ja": # spacy uses SudachiPy for japanese, which has a length limit: diff --git a/wtpsplit/evaluation/download_spacy.py b/wtpsplit/evaluation/download_spacy.py index 0c1dfc31..fc72ade6 100644 --- a/wtpsplit/evaluation/download_spacy.py +++ b/wtpsplit/evaluation/download_spacy.py @@ -1,29 +1,30 @@ import subprocess SPACY_LANG_TO_DP_MODEL = { - "ca": "ca_core_news_sm", - "zh": "zh_core_web_sm", - "hr": "hr_core_news_sm", - "da": "da_core_news_sm", - "nl": "nl_core_news_sm", - "en": "en_core_web_sm", - "fi": "fi_core_news_sm", - "fr": "fr_core_news_sm", - "de": "de_core_news_sm", - "el": "el_core_news_sm", - "it": "it_core_news_sm", - "ja": "ja_core_news_sm", - "ko": "ko_core_news_sm", - "lt": "lt_core_news_sm", - "mk": "mk_core_news_sm", - "nb": "nb_core_news_sm", - "pl": "pl_core_news_sm", - "pt": "pt_core_news_sm", - "ro": "ro_core_news_sm", - "ru": "ru_core_news_sm", - "es": "es_core_news_sm", - "sv": "sv_core_news_sm", - "uk": "uk_core_news_sm", + # "ca": "ca_core_news_sm", + # "zh": "zh_core_web_sm", + # "hr": "hr_core_news_sm", + # "da": "da_core_news_sm", + # "nl": "nl_core_news_sm", + # "en": "en_core_web_sm", + # "fi": "fi_core_news_sm", + # "fr": "fr_core_news_sm", + # "de": "de_core_news_sm", + # "el": "el_core_news_sm", + # "it": "it_core_news_sm", + # "ja": "ja_core_news_sm", + # "ko": "ko_core_news_sm", + # "lt": "lt_core_news_sm", + # "mk": "mk_core_news_sm", + # "nb": "nb_core_news_sm", + # "pl": "pl_core_news_sm", + # "pt": "pt_core_news_sm", + # "ro": "ro_core_news_sm", + # "ru": "ru_core_news_sm", + # "es": "es_core_news_sm", + # "sv": "sv_core_news_sm", + # "uk": "uk_core_news_sm", + "multi": "xx_sent_ud_sm" } def download_models(): diff --git a/wtpsplit/evaluation/intrinsic_baselines.py b/wtpsplit/evaluation/intrinsic_baselines.py index 3384ecf9..ce0132c2 100644 --- a/wtpsplit/evaluation/intrinsic_baselines.py +++ b/wtpsplit/evaluation/intrinsic_baselines.py @@ -61,12 +61,12 @@ class Args: for dataset_name, dataset in lang_data["sentence"].items(): if "nllb" in dataset_name: continue - if "corrupted" in dataset_name and dataset_name != "ted2020-corrupted-asr": - print("SKIP: ", lang, dataset_name) - continue - if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name): - print("SKIP: ", lang, dataset_name) - continue + # if "corrupted" in dataset_name and dataset_name != "ted2020-corrupted-asr": + # print("SKIP: ", lang, dataset_name) + # continue + # if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name): + # print("SKIP: ", lang, dataset_name) + # continue if not dataset["data"]: continue results[lang][dataset_name] = {} @@ -79,10 +79,10 @@ class Args: for f, name in [ (punkt_sentencize, "punkt"), - (spacy_dp_sentencize, "spacy_dp"), - (spacy_sent_sentencize, "spacy_sent"), - (pysbd_sentencize, "pysbd"), - (ersatz_sentencize, "ersatz"), + # (spacy_dp_sentencize, "spacy_dp"), + # (spacy_sent_sentencize, "spacy_sent"), + # (pysbd_sentencize, "pysbd"), + # (ersatz_sentencize, "ersatz"), ]: print(f"Running {name} on {dataset_name} in {lang_code}...") indices[lang][dataset_name][name] = {} @@ -146,14 +146,16 @@ class Args: f1 = metrics[0] metrics = metrics[1] metrics["f1"] = f1 - indices[lang][dataset_name][name]["true_indices"] = metrics.pop("true_indices") - indices[lang][dataset_name][name]["predicted_indices"] = metrics.pop("predicted_indices") - indices[lang][dataset_name][name]["length"] = metrics.pop("length") + print(f1) + indices[lang][dataset_name][name]["true_indices"] = [metrics.pop("true_indices")] + indices[lang][dataset_name][name]["predicted_indices"] =[ metrics.pop("predicted_indices")] + indices[lang][dataset_name][name]["length"] = [metrics.pop("length")] results[lang][dataset_name][name] = metrics - except LanguageError: - print("Language not supported for", name) + except LanguageError as e: + # print("Language not supported for", name) + # print(e) results[lang][dataset_name][name] = None - json.dump(results, open(Constants.CACHE_DIR / "intrinsic_baselines.json", "w"), indent=4, default=int) - json.dump(indices, open(Constants.CACHE_DIR / "intrinsic_baselines_IDX.json", "w"), indent=4, default=int) + json.dump(results, open(Constants.CACHE_DIR / "intrinsic_baselines_punkt.json", "w"), indent=4, default=int) + json.dump(indices, open(Constants.CACHE_DIR / "intrinsic_baselines_punkt_IDX.json", "w"), indent=4, default=int) print(Constants.CACHE_DIR / "intrinsic_baselines.json") diff --git a/wtpsplit/evaluation/intrinsic_baselines_multilingual.py b/wtpsplit/evaluation/intrinsic_baselines_multilingual.py new file mode 100644 index 00000000..20a30a90 --- /dev/null +++ b/wtpsplit/evaluation/intrinsic_baselines_multilingual.py @@ -0,0 +1,157 @@ +import json +from dataclasses import dataclass +from typing import List + +import torch +from tqdm import tqdm +from transformers import HfArgumentParser + +from wtpsplit.evaluation import ( + LanguageError, + ersatz_sentencize, + evaluate_sentences, + preprocess_sentence, + punkt_sentencize, + pysbd_sentencize, + spacy_dp_sentencize, + spacy_sent_sentencize, +) +from wtpsplit.utils import Constants + +def split_language_data(eval_data): + new_eval_data = {} + + for lang_code, lang_data in eval_data.items(): + if '-' in lang_code: + lang1, lang2 = lang_code.split('-') + new_lang1 = f"{lang_code}_{lang1.upper()}" + new_lang2 = f"{lang_code}_{lang2.upper()}" + + # Adding the same content for both new language keys + new_eval_data[new_lang1] = lang_data + new_eval_data[new_lang2] = lang_data + else: + new_eval_data[lang_code] = lang_data + + return new_eval_data + + +@dataclass +class Args: + eval_data_path: str = "data/all_data_11_05-all.pth" + include_langs: List[str] = None + exclude_every_k: int = 10 + + +if __name__ == "__main__": + (args,) = HfArgumentParser([Args]).parse_args_into_dataclasses() + + eval_data = torch.load(args.eval_data_path) + eval_data = split_language_data(eval_data) + results = {} + indices = {} + + for lang, lang_data in tqdm(eval_data.items()): + if args.include_langs is not None and lang not in args.include_langs: + continue + + results[lang] = {} + indices[lang] = {} + + for dataset_name, dataset in lang_data["sentence"].items(): + if "nllb" in dataset_name: + continue + # if "corrupted" in dataset_name and dataset_name != "ted2020-corrupted-asr": + # print("SKIP: ", lang, dataset_name) + # continue + # if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name): + # print("SKIP: ", lang, dataset_name) + # continue + if not dataset["data"]: + continue + results[lang][dataset_name] = {} + indices[lang][dataset_name] = {} + if "-" in lang: + # code-switched data: eval 2x + lang_code = lang.split("_")[1].lower() + else: + lang_code = lang + + for f, name in [ + (spacy_dp_sentencize, "spacy_dp"), + (spacy_sent_sentencize, "spacy_sent"), + ]: + print(f"Running {name} on {dataset_name} in {lang_code}...") + indices[lang][dataset_name][name] = {} + if "lyrics" in dataset_name or "short" in dataset_name: + exclude_every_k = 0 + else: + exclude_every_k = args.exclude_every_k + try: + if isinstance(dataset["data"][0], list): + all_sentences = [[preprocess_sentence(s) for s in doc] for doc in dataset["data"]] + metrics = [] + for i, sentences in enumerate(all_sentences): + text = Constants.SEPARATORS[lang_code].join(sentences) + doc_metrics = {} + doc_metrics = evaluate_sentences( + lang_code, sentences, f("xx", text), return_indices=True, exclude_every_k=exclude_every_k + ) + f1 = doc_metrics[0] + doc_metrics = doc_metrics[1] + doc_metrics["f1"] = f1 + doc_metrics["length"] = [doc_metrics["length"]] + metrics.append(doc_metrics) + avg_results = {} + concat_indices = {} + for doc in metrics: + for key, value in doc.items(): + if isinstance(value, (float, int)): + # numeric + if key not in avg_results: + avg_results[key] = [] + + avg_results[key].append(value) + elif isinstance(value, list): + # concat + if key not in concat_indices: + concat_indices[key] = [] + if key == "length": + concat_indices[key].append(value[0]) + else: + concat_indices[key].append(value) + + # avg + for key in list(avg_results): + if avg_results[key]: + avg_results[key] = sum(avg_results[key]) / len(avg_results[key]) + + # Store the results and indices + results[lang][dataset_name][name] = avg_results + indices[lang][dataset_name][name] = concat_indices + else: + sentences = [preprocess_sentence(s) for s in dataset["data"]] + text = Constants.SEPARATORS[lang_code].join(sentences) + + metrics = evaluate_sentences( + lang_code, + sentences, + f("xx", text), + return_indices=True, + exclude_every_k=exclude_every_k, + ) + f1 = metrics[0] + metrics = metrics[1] + metrics["f1"] = f1 + print(f1) + indices[lang][dataset_name][name]["true_indices"] = [metrics.pop("true_indices")] + indices[lang][dataset_name][name]["predicted_indices"] =[ metrics.pop("predicted_indices")] + indices[lang][dataset_name][name]["length"] = [metrics.pop("length")] + results[lang][dataset_name][name] = metrics + except LanguageError as l: + print("Language not supported for", name, l) + results[lang][dataset_name][name] = None + + json.dump(results, open(Constants.CACHE_DIR / "intrinsic_baselines_multi.json", "w"), indent=4, default=int) + json.dump(indices, open(Constants.CACHE_DIR / "intrinsic_baselines_multi_IDX.json", "w"), indent=4, default=int) + print(Constants.CACHE_DIR / "intrinsic_baselines.json")