Skip to content

Commit

Permalink
baselines
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 19, 2024
1 parent fd6716e commit 2f06d84
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 40 deletions.
3 changes: 3 additions & 0 deletions wtpsplit/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def pysbd_sentencize(lang_code, text):
"es": "es_core_news_sm",
"sv": "sv_core_news_sm",
"uk": "uk_core_news_sm",
"xx": "xx_sent_ud_sm",
}


Expand All @@ -374,6 +375,7 @@ def spacy_sent_sentencize(lang_code, text):
try:
nlp = spacy.blank(lang_code)
nlp.add_pipe("sentencizer")
nlp.max_length = 1_000_000_000

if lang_code == "ja":
# spacy uses SudachiPy for japanese, which has a length limit:
Expand All @@ -397,6 +399,7 @@ def spacy_dp_sentencize(lang_code, text):

try:
nlp = spacy.load(SPACY_LANG_TO_DP_MODEL[lang_code], disable=["ner"])
nlp.max_length = 1_000_000_000

if lang_code == "ja":
# spacy uses SudachiPy for japanese, which has a length limit:
Expand Down
47 changes: 24 additions & 23 deletions wtpsplit/evaluation/download_spacy.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import subprocess

SPACY_LANG_TO_DP_MODEL = {
"ca": "ca_core_news_sm",
"zh": "zh_core_web_sm",
"hr": "hr_core_news_sm",
"da": "da_core_news_sm",
"nl": "nl_core_news_sm",
"en": "en_core_web_sm",
"fi": "fi_core_news_sm",
"fr": "fr_core_news_sm",
"de": "de_core_news_sm",
"el": "el_core_news_sm",
"it": "it_core_news_sm",
"ja": "ja_core_news_sm",
"ko": "ko_core_news_sm",
"lt": "lt_core_news_sm",
"mk": "mk_core_news_sm",
"nb": "nb_core_news_sm",
"pl": "pl_core_news_sm",
"pt": "pt_core_news_sm",
"ro": "ro_core_news_sm",
"ru": "ru_core_news_sm",
"es": "es_core_news_sm",
"sv": "sv_core_news_sm",
"uk": "uk_core_news_sm",
# "ca": "ca_core_news_sm",
# "zh": "zh_core_web_sm",
# "hr": "hr_core_news_sm",
# "da": "da_core_news_sm",
# "nl": "nl_core_news_sm",
# "en": "en_core_web_sm",
# "fi": "fi_core_news_sm",
# "fr": "fr_core_news_sm",
# "de": "de_core_news_sm",
# "el": "el_core_news_sm",
# "it": "it_core_news_sm",
# "ja": "ja_core_news_sm",
# "ko": "ko_core_news_sm",
# "lt": "lt_core_news_sm",
# "mk": "mk_core_news_sm",
# "nb": "nb_core_news_sm",
# "pl": "pl_core_news_sm",
# "pt": "pt_core_news_sm",
# "ro": "ro_core_news_sm",
# "ru": "ru_core_news_sm",
# "es": "es_core_news_sm",
# "sv": "sv_core_news_sm",
# "uk": "uk_core_news_sm",
"multi": "xx_sent_ud_sm"
}

def download_models():
Expand Down
36 changes: 19 additions & 17 deletions wtpsplit/evaluation/intrinsic_baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ class Args:
for dataset_name, dataset in lang_data["sentence"].items():
if "nllb" in dataset_name:
continue
if "corrupted" in dataset_name and dataset_name != "ted2020-corrupted-asr":
print("SKIP: ", lang, dataset_name)
continue
if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
print("SKIP: ", lang, dataset_name)
continue
# if "corrupted" in dataset_name and dataset_name != "ted2020-corrupted-asr":
# print("SKIP: ", lang, dataset_name)
# continue
# if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
# print("SKIP: ", lang, dataset_name)
# continue
if not dataset["data"]:
continue
results[lang][dataset_name] = {}
Expand All @@ -79,10 +79,10 @@ class Args:

for f, name in [
(punkt_sentencize, "punkt"),
(spacy_dp_sentencize, "spacy_dp"),
(spacy_sent_sentencize, "spacy_sent"),
(pysbd_sentencize, "pysbd"),
(ersatz_sentencize, "ersatz"),
# (spacy_dp_sentencize, "spacy_dp"),
# (spacy_sent_sentencize, "spacy_sent"),
# (pysbd_sentencize, "pysbd"),
# (ersatz_sentencize, "ersatz"),
]:
print(f"Running {name} on {dataset_name} in {lang_code}...")
indices[lang][dataset_name][name] = {}
Expand Down Expand Up @@ -146,14 +146,16 @@ class Args:
f1 = metrics[0]
metrics = metrics[1]
metrics["f1"] = f1
indices[lang][dataset_name][name]["true_indices"] = metrics.pop("true_indices")
indices[lang][dataset_name][name]["predicted_indices"] = metrics.pop("predicted_indices")
indices[lang][dataset_name][name]["length"] = metrics.pop("length")
print(f1)
indices[lang][dataset_name][name]["true_indices"] = [metrics.pop("true_indices")]
indices[lang][dataset_name][name]["predicted_indices"] =[ metrics.pop("predicted_indices")]
indices[lang][dataset_name][name]["length"] = [metrics.pop("length")]
results[lang][dataset_name][name] = metrics
except LanguageError:
print("Language not supported for", name)
except LanguageError as e:
# print("Language not supported for", name)
# print(e)
results[lang][dataset_name][name] = None

json.dump(results, open(Constants.CACHE_DIR / "intrinsic_baselines.json", "w"), indent=4, default=int)
json.dump(indices, open(Constants.CACHE_DIR / "intrinsic_baselines_IDX.json", "w"), indent=4, default=int)
json.dump(results, open(Constants.CACHE_DIR / "intrinsic_baselines_punkt.json", "w"), indent=4, default=int)
json.dump(indices, open(Constants.CACHE_DIR / "intrinsic_baselines_punkt_IDX.json", "w"), indent=4, default=int)
print(Constants.CACHE_DIR / "intrinsic_baselines.json")
157 changes: 157 additions & 0 deletions wtpsplit/evaluation/intrinsic_baselines_multilingual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import json
from dataclasses import dataclass
from typing import List

import torch
from tqdm import tqdm
from transformers import HfArgumentParser

from wtpsplit.evaluation import (
LanguageError,
ersatz_sentencize,
evaluate_sentences,
preprocess_sentence,
punkt_sentencize,
pysbd_sentencize,
spacy_dp_sentencize,
spacy_sent_sentencize,
)
from wtpsplit.utils import Constants

def split_language_data(eval_data):
new_eval_data = {}

for lang_code, lang_data in eval_data.items():
if '-' in lang_code:
lang1, lang2 = lang_code.split('-')
new_lang1 = f"{lang_code}_{lang1.upper()}"
new_lang2 = f"{lang_code}_{lang2.upper()}"

# Adding the same content for both new language keys
new_eval_data[new_lang1] = lang_data
new_eval_data[new_lang2] = lang_data
else:
new_eval_data[lang_code] = lang_data

return new_eval_data


@dataclass
class Args:
eval_data_path: str = "data/all_data_11_05-all.pth"
include_langs: List[str] = None
exclude_every_k: int = 10


if __name__ == "__main__":
(args,) = HfArgumentParser([Args]).parse_args_into_dataclasses()

eval_data = torch.load(args.eval_data_path)
eval_data = split_language_data(eval_data)
results = {}
indices = {}

for lang, lang_data in tqdm(eval_data.items()):
if args.include_langs is not None and lang not in args.include_langs:
continue

results[lang] = {}
indices[lang] = {}

for dataset_name, dataset in lang_data["sentence"].items():
if "nllb" in dataset_name:
continue
# if "corrupted" in dataset_name and dataset_name != "ted2020-corrupted-asr":
# print("SKIP: ", lang, dataset_name)
# continue
# if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
# print("SKIP: ", lang, dataset_name)
# continue
if not dataset["data"]:
continue
results[lang][dataset_name] = {}
indices[lang][dataset_name] = {}
if "-" in lang:
# code-switched data: eval 2x
lang_code = lang.split("_")[1].lower()
else:
lang_code = lang

for f, name in [
(spacy_dp_sentencize, "spacy_dp"),
(spacy_sent_sentencize, "spacy_sent"),
]:
print(f"Running {name} on {dataset_name} in {lang_code}...")
indices[lang][dataset_name][name] = {}
if "lyrics" in dataset_name or "short" in dataset_name:
exclude_every_k = 0
else:
exclude_every_k = args.exclude_every_k
try:
if isinstance(dataset["data"][0], list):
all_sentences = [[preprocess_sentence(s) for s in doc] for doc in dataset["data"]]
metrics = []
for i, sentences in enumerate(all_sentences):
text = Constants.SEPARATORS[lang_code].join(sentences)
doc_metrics = {}
doc_metrics = evaluate_sentences(
lang_code, sentences, f("xx", text), return_indices=True, exclude_every_k=exclude_every_k
)
f1 = doc_metrics[0]
doc_metrics = doc_metrics[1]
doc_metrics["f1"] = f1
doc_metrics["length"] = [doc_metrics["length"]]
metrics.append(doc_metrics)
avg_results = {}
concat_indices = {}
for doc in metrics:
for key, value in doc.items():
if isinstance(value, (float, int)):
# numeric
if key not in avg_results:
avg_results[key] = []

avg_results[key].append(value)
elif isinstance(value, list):
# concat
if key not in concat_indices:
concat_indices[key] = []
if key == "length":
concat_indices[key].append(value[0])
else:
concat_indices[key].append(value)

# avg
for key in list(avg_results):
if avg_results[key]:
avg_results[key] = sum(avg_results[key]) / len(avg_results[key])

# Store the results and indices
results[lang][dataset_name][name] = avg_results
indices[lang][dataset_name][name] = concat_indices
else:
sentences = [preprocess_sentence(s) for s in dataset["data"]]
text = Constants.SEPARATORS[lang_code].join(sentences)

metrics = evaluate_sentences(
lang_code,
sentences,
f("xx", text),
return_indices=True,
exclude_every_k=exclude_every_k,
)
f1 = metrics[0]
metrics = metrics[1]
metrics["f1"] = f1
print(f1)
indices[lang][dataset_name][name]["true_indices"] = [metrics.pop("true_indices")]
indices[lang][dataset_name][name]["predicted_indices"] =[ metrics.pop("predicted_indices")]
indices[lang][dataset_name][name]["length"] = [metrics.pop("length")]
results[lang][dataset_name][name] = metrics
except LanguageError as l:
print("Language not supported for", name, l)
results[lang][dataset_name][name] = None

json.dump(results, open(Constants.CACHE_DIR / "intrinsic_baselines_multi.json", "w"), indent=4, default=int)
json.dump(indices, open(Constants.CACHE_DIR / "intrinsic_baselines_multi_IDX.json", "w"), indent=4, default=int)
print(Constants.CACHE_DIR / "intrinsic_baselines.json")

0 comments on commit 2f06d84

Please sign in to comment.