From 2214f4c99ea8781c3c97e24b24e464cfe7762fcf Mon Sep 17 00:00:00 2001 From: markus583 Date: Sun, 14 Jan 2024 13:09:14 +0000 Subject: [PATCH] add lowercase eval --- wtpsplit/evaluation/intrinsic_lowercase.py | 258 +++++++++++++++++++++ wtpsplit/extract.py | 1 - wtpsplit/models.py | 4 +- wtpsplit/summary_plot.py | 24 +- wtpsplit/train/evaluate.py | 3 + wtpsplit/train/train.py | 22 ++ 6 files changed, 302 insertions(+), 10 deletions(-) create mode 100644 wtpsplit/evaluation/intrinsic_lowercase.py diff --git a/wtpsplit/evaluation/intrinsic_lowercase.py b/wtpsplit/evaluation/intrinsic_lowercase.py new file mode 100644 index 00000000..e90ecd0a --- /dev/null +++ b/wtpsplit/evaluation/intrinsic_lowercase.py @@ -0,0 +1,258 @@ +import copy +import json +from dataclasses import dataclass +from typing import List + +import h5py +import skops.io as sio +import torch +from datasets import load_dataset +from tqdm.auto import tqdm +from transformers import AutoModelForTokenClassification, HfArgumentParser +import numpy as np + +import wtpsplit.models # noqa: F401 +from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs +from wtpsplit.extract import PyTorchWrapper, extract +from wtpsplit.utils import Constants + + +@dataclass +class Args: + model_path: str + # eval data in the format: + # { + # "": { + # "sentence": { + # "": { + # "meta": { + # "train_data": ["train sentence 1", "train sentence 2"] + # }, + # "data": ["test sentence 1", "test sentence 2"] + # } + # } + # } + # } + eval_data_path: str = "data/eval.pth" + valid_text_path: str = None # "data/sentence/valid.parquet" + device: str = "cpu" + block_size: int = 512 + stride: int = 64 + batch_size: int = 32 + include_langs: List[str] = None + threshold: float = 0.01 + + +def process_logits(text, model, lang_code, args): + # Extract necessary data + text = text.lower() + logits, offsets_mapping, tokenizer = extract( + [text], + model, + lang_code=lang_code, + stride=args.stride, + block_size=args.block_size, + batch_size=args.batch_size, + pad_last_batch=True, + verbose=True, + ) + logits = logits[0] + if offsets_mapping is not None: + offsets_mapping = offsets_mapping[0] + + if "xlm" in model.config.model_type: + tokens = tokenizer.tokenize(text, verbose=False) + + # Use the vectorized function to convert token probabilities to character probabilities for the entire array + char_probs = token_to_char_probs(text, tokens, logits, tokenizer, offsets_mapping) + + logits = char_probs + + return logits + + +def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_sentences=10_000): + logits_path = Constants.CACHE_DIR / ( + f"{args.model_path.split('/')[0]}_L_b{args.block_size}+s{args.stride}_logits_u{args.threshold}.h5" + ) + + # TODO: revert to "a" + with h5py.File(logits_path, "w") as f, torch.no_grad(): + for lang_code in Constants.LANGINFO.index: + if args.include_langs is not None and lang_code not in args.include_langs: + continue + + print(f"Processing {lang_code}...") + if lang_code not in f: + lang_group = f.create_group(lang_code) + else: + lang_group = f[lang_code] + + # valid data + if valid_data is not None and "valid" not in lang_group: + sentences = [sample["text"].strip() for sample in valid_data if sample["lang"] == lang_code] + assert len(sentences) > 0 + + separator = Constants.SEPARATORS[lang_code] + valid_text = separator.join(sentences) + + valid_logits = process_logits(valid_text, model, lang_code, args) + + lang_group.create_dataset("valid", data=valid_logits) + + # eval data + for dataset_name, dataset in eval_data[lang_code]["sentence"].items(): + if dataset_name not in lang_group: + dset_group = lang_group.create_group(dataset_name) + else: + dset_group = lang_group[dataset_name] + + if "test_logits" not in dset_group: + test_sentences = dataset["data"] + test_text = Constants.SEPARATORS[lang_code].join(test_sentences) + + test_logits = process_logits(test_text, model, lang_code, args) + test_labels = get_labels(lang_code, test_sentences, after_space=False) + + dset_group.create_dataset("test_logits", data=test_logits) + dset_group.create_dataset("test_labels", data=test_labels) + + train_sentences = dataset["meta"].get("train_data") + if train_sentences is not None and "train_logits" not in dset_group: + train_sentences = train_sentences[:max_n_train_sentences] + train_text = Constants.SEPARATORS[lang_code].join(train_sentences) + + train_logits = process_logits(train_text, model, lang_code, args) + train_labels = get_labels(lang_code, train_sentences, after_space=False) + + dset_group.create_dataset("train_logits", data=train_logits) + dset_group.create_dataset("train_labels", data=train_labels) + + return h5py.File(logits_path, "r") + + +def compute_statistics(values): + if not values: # Check for empty values list + return {"mean": None, "median": None, "std": None, "min": None, "min_lang": None, "max": None, "max_lang": None} + + scores, langs = zip(*values) # Unpack scores and languages + min_index = np.argmin(scores) + max_index = np.argmax(scores) + return { + "mean": np.mean(scores), + "median": np.median(scores), + "std": np.std(scores), + "min": scores[min_index], + "min_lang": langs[min_index], + "max": scores[max_index], + "max_lang": langs[max_index] + } + + +if __name__ == "__main__": + (args,) = HfArgumentParser([Args]).parse_args_into_dataclasses() + + eval_data = torch.load(args.eval_data_path) + if args.valid_text_path is not None: + valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train") + else: + valid_data = None + + print("Loading model...") + model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device)) + + # first, logits for everything. + f = load_or_compute_logits(args, model, eval_data, valid_data) + + # now, compute the intrinsic scores. + results = {} + clfs = {} + # Initialize lists to store scores for each metric across all languages + u_scores, t_scores, punct_scores = [], [], [] + + for lang_code, dsets in tqdm(eval_data.items()): + if args.include_langs is not None and lang_code not in args.include_langs: + continue + + print(f"Predicting {lang_code}...") + results[lang_code] = {} + clfs[lang_code] = {} + + for dataset_name, dataset in dsets["sentence"].items(): + sentences = dataset["data"] + + if "train_logits" in f[lang_code][dataset_name]: + feature_indices = None + clf = train_mixture( + [lang_code], + f[lang_code][dataset_name]["train_logits"][:], + f[lang_code][dataset_name]["train_labels"][:], + features=feature_indices, + ) + if clf[0] is not None: + print(clf) + print(np.argsort(clf[0].coef_[0])[:10], "...", np.argsort(clf[0].coef_[0])[-10:]) + print(np.where(np.argsort(clf[0].coef_[0]) == 0)[0]) + + score_t, score_punct, _ = evaluate_mixture( + lang_code, + f[lang_code][dataset_name]["test_logits"][:], + sentences, + *clf, + ) + + clfs[lang_code][dataset_name] = clf + + clf = list(copy.deepcopy(clf)) + clf[-1] = args.threshold + else: + score_t = score_punct = None + + score_u, _, _ = evaluate_mixture(lang_code, f[lang_code][dataset_name]["test_logits"][:], sentences, *clf) + + results[lang_code][dataset_name] = { + "u": score_u, + "t": score_t, + "punct": score_punct, + } + + # just for printing + score_t = score_t or 0.0 + score_punct = score_punct or 0.0 + + u_scores.append((score_u, lang_code)) + t_scores.append((score_t, lang_code)) + punct_scores.append((score_punct, lang_code)) + print(f"{lang_code} {dataset_name} {score_u:.3f} {score_t:.3f} {score_punct:.3f}") + + # Compute statistics for each metric across all languages + results_avg = { + "u": compute_statistics(u_scores), + "t": compute_statistics(t_scores), + "punct": compute_statistics(punct_scores), + "include_langs": args.include_langs, + } + + sio.dump( + clfs, + open( + Constants.CACHE_DIR / (f"{args.model_path.split('/')[0]}_L_b{args.block_size}+s{args.stride}.skops"), + "wb", + ), + ) + json.dump( + results, + open( + Constants.CACHE_DIR + / (f"{args.model_path.split('/')[0]}_L_b{args.block_size}+s{args.stride}_intrinsic_results_u{args.threshold}.json"), + "w", + ), + indent=4, + ) + + # Write results_avg to JSON + json.dump( + results_avg, + open(Constants.CACHE_DIR / (f"{args.model_path.split('/')[0]}_L_b{args.block_size}+s{args.stride}_u{args.threshold}_AVG.json"), "w"), + indent=4, + ) diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 0d378ec9..242fda17 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -89,7 +89,6 @@ def extract( tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) tokens = tokenizer(batch_of_texts, return_offsets_mapping=True, verbose=False) # remove CLS and SEP tokens, they are added later anyhow - old_batch_of_texts = batch_of_texts batch_of_texts = [text[1:-1] for text in tokens["input_ids"]] offset_mapping = [offset[1:-1] for offset in tokens["offset_mapping"]] cls_token_id = tokenizer.cls_token_id diff --git a/wtpsplit/models.py b/wtpsplit/models.py index cfbd9eca..ab1c6850 100644 --- a/wtpsplit/models.py +++ b/wtpsplit/models.py @@ -1291,7 +1291,7 @@ def get_extended_attention_mask( text = "This is a test\n sentence \n\n" tokenizer = AutoTokenizer.from_pretrained(model_str) - tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=8, padding=True) + tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=512, padding=True) from tokenizers import AddedToken tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) @@ -1300,5 +1300,5 @@ def get_extended_attention_mask( print(tokens) # forward pass - lookahead = 1 + lookahead = 512 print(backbone(**tokens, lookahead=lookahead)) diff --git a/wtpsplit/summary_plot.py b/wtpsplit/summary_plot.py index fb2c624f..3a1d7813 100644 --- a/wtpsplit/summary_plot.py +++ b/wtpsplit/summary_plot.py @@ -3,13 +3,23 @@ import json FILES = [ - ".cache/xlmr-normal-v2_b128+s64_intrinsic_results_u0.01.json", - ".cache/xlmr-normal-v2_b256+s64_intrinsic_results_u0.01.json", - ".cache/xlmr-normal-v2_b512+s64_intrinsic_results_u0.01.json", - ".cache/xlmr-normal-p-v2_auxp0.3_n0.9_b256_s64_intrinsic_results_u0.01.json", - ".cache/xlm-normal-p-v2_auxp0.3_n0.9_b512_s64_intrinsic_results_u0.001.json", - ".cache/xlmr-normal-p-v2_n0.9_b512+s64_intrinsic_results_u0.01.json", - ".cache/xlmr-normal-p-v2-auxp0.1_b512+s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-p-v2_auxp0.3_n0.9_b64_s32_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-p-v2_auxp0.3_n0.9_b128_s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-p-v2_auxp0.3_n0.9_b256_s64_intrinsic_results_u0.01.json", + # ".cache/xlm-normal-p-v2_auxp0.3_n0.9_b512_s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-p-v2_n0.9_b256+s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-p-v2_n0.9_b512+s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-p-v2_n0.9_b128+s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-p-v2-auxp0.1_b256+s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-p-v2-auxp0.1_b512+s64_intrinsic_results_u0.01.json", + ".cache/xlmr-normal-p-v2_auxp0.2_b512+s64_intrinsic_results_u0.01.json", + ".cache/xlmr-normal-p-v2_auxp0.2_b512+s128_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-p-v2_auxp0.4_b256+s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-p-v2_auxp0.4_b512+s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-v2_b128+s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-v2_b256+s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-v2_b512+s64_intrinsic_results_u0.01.json", + # ".cache/xlmr-normal-9l-v2_b512+s64_intrinsic_results_u0.01.json", # "evaluation/evaluation_results/wtp-canine-s-3l-no-adapters_intrinsic_results.json", # "evaluation/evaluation_results/wtp-canine-s-3l_intrinsic_results.json", ] diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index 2cf22938..b0e50c64 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -65,6 +65,7 @@ def evaluate_sentence( batch_size, use_pysbd=False, positive_index=None, + # do_lowercase=False, ): if positive_index is None: positive_index = Constants.NEWLINE_INDEX @@ -74,6 +75,8 @@ def evaluate_sentence( separator = Constants.SEPARATORS[lang_code] text = separator.join(sentences) + # if do_lowercase: + # text = text.lower() logits, offsets_mapping, tokenizer = extract( [text], diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 4349420b..a9139f17 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -551,6 +551,8 @@ def maybe_pad(text): def compute_metrics(trainer): metrics = {} avg_metrics = defaultdict(lambda: []) + # metrics_lower = {} + # avg_metrics_lower = defaultdict(lambda: []) model = trainer._wrap_model(trainer.model, training=False) @@ -574,10 +576,30 @@ def compute_metrics(trainer): avg_metrics[f"average_nonwhitespace_{dataset_name}_pr_auc"].append(score) else: avg_metrics[f"average_whitespace_{dataset_name}_pr_auc"].append(score) + # for dataset_name, dataset in lang_data["sentence"].items(): + # score, _ = evaluate_sentence( + # lang_code, + # dataset["data"], + # model, + # stride=args.eval_stride, + # block_size=args.block_size, + # batch_size=training_args.per_device_eval_batch_size, + # do_lowercase=True, + # ) + # metrics_lower[f"lower_{lang_code}_{dataset_name}_pr_auc"] = score + # avg_metrics_lower[f"lower_average_{dataset_name}_pr_auc"].append(score) + # if lang_code in ["zh", "ja", "my", "km"]: + # avg_metrics_lower[f"lower_average_nonwhitespace_{dataset_name}_pr_auc"].append(score) + # else: + # avg_metrics_lower[f"lower_average_whitespace_{dataset_name}_pr_auc"].append(score) + for name, values in avg_metrics.items(): if len(values) > 1: metrics[name] = np.mean(values) + # for name, values in avg_metrics_lower.items(): + # if len(values) > 1: + # metrics_lower[name] = np.mean(values) return metrics