Skip to content

Commit

Permalink
add lowercase eval
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jan 14, 2024
1 parent 9ca7955 commit 2214f4c
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 10 deletions.
258 changes: 258 additions & 0 deletions wtpsplit/evaluation/intrinsic_lowercase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import copy
import json
from dataclasses import dataclass
from typing import List

import h5py
import skops.io as sio
import torch
from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoModelForTokenClassification, HfArgumentParser
import numpy as np

import wtpsplit.models # noqa: F401
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants


@dataclass
class Args:
model_path: str
# eval data in the format:
# {
# "<lang_code>": {
# "sentence": {
# "<dataset_name>": {
# "meta": {
# "train_data": ["train sentence 1", "train sentence 2"]
# },
# "data": ["test sentence 1", "test sentence 2"]
# }
# }
# }
# }
eval_data_path: str = "data/eval.pth"
valid_text_path: str = None # "data/sentence/valid.parquet"
device: str = "cpu"
block_size: int = 512
stride: int = 64
batch_size: int = 32
include_langs: List[str] = None
threshold: float = 0.01


def process_logits(text, model, lang_code, args):
# Extract necessary data
text = text.lower()
logits, offsets_mapping, tokenizer = extract(
[text],
model,
lang_code=lang_code,
stride=args.stride,
block_size=args.block_size,
batch_size=args.batch_size,
pad_last_batch=True,
verbose=True,
)
logits = logits[0]
if offsets_mapping is not None:
offsets_mapping = offsets_mapping[0]

if "xlm" in model.config.model_type:
tokens = tokenizer.tokenize(text, verbose=False)

# Use the vectorized function to convert token probabilities to character probabilities for the entire array
char_probs = token_to_char_probs(text, tokens, logits, tokenizer, offsets_mapping)

logits = char_probs

return logits


def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_sentences=10_000):
logits_path = Constants.CACHE_DIR / (
f"{args.model_path.split('/')[0]}_L_b{args.block_size}+s{args.stride}_logits_u{args.threshold}.h5"
)

# TODO: revert to "a"
with h5py.File(logits_path, "w") as f, torch.no_grad():
for lang_code in Constants.LANGINFO.index:
if args.include_langs is not None and lang_code not in args.include_langs:
continue

print(f"Processing {lang_code}...")
if lang_code not in f:
lang_group = f.create_group(lang_code)
else:
lang_group = f[lang_code]

# valid data
if valid_data is not None and "valid" not in lang_group:
sentences = [sample["text"].strip() for sample in valid_data if sample["lang"] == lang_code]
assert len(sentences) > 0

separator = Constants.SEPARATORS[lang_code]
valid_text = separator.join(sentences)

valid_logits = process_logits(valid_text, model, lang_code, args)

lang_group.create_dataset("valid", data=valid_logits)

# eval data
for dataset_name, dataset in eval_data[lang_code]["sentence"].items():
if dataset_name not in lang_group:
dset_group = lang_group.create_group(dataset_name)
else:
dset_group = lang_group[dataset_name]

if "test_logits" not in dset_group:
test_sentences = dataset["data"]
test_text = Constants.SEPARATORS[lang_code].join(test_sentences)

test_logits = process_logits(test_text, model, lang_code, args)
test_labels = get_labels(lang_code, test_sentences, after_space=False)

dset_group.create_dataset("test_logits", data=test_logits)
dset_group.create_dataset("test_labels", data=test_labels)

train_sentences = dataset["meta"].get("train_data")
if train_sentences is not None and "train_logits" not in dset_group:
train_sentences = train_sentences[:max_n_train_sentences]
train_text = Constants.SEPARATORS[lang_code].join(train_sentences)

train_logits = process_logits(train_text, model, lang_code, args)
train_labels = get_labels(lang_code, train_sentences, after_space=False)

dset_group.create_dataset("train_logits", data=train_logits)
dset_group.create_dataset("train_labels", data=train_labels)

return h5py.File(logits_path, "r")


def compute_statistics(values):
if not values: # Check for empty values list
return {"mean": None, "median": None, "std": None, "min": None, "min_lang": None, "max": None, "max_lang": None}

scores, langs = zip(*values) # Unpack scores and languages
min_index = np.argmin(scores)
max_index = np.argmax(scores)
return {
"mean": np.mean(scores),
"median": np.median(scores),
"std": np.std(scores),
"min": scores[min_index],
"min_lang": langs[min_index],
"max": scores[max_index],
"max_lang": langs[max_index]
}


if __name__ == "__main__":
(args,) = HfArgumentParser([Args]).parse_args_into_dataclasses()

eval_data = torch.load(args.eval_data_path)
if args.valid_text_path is not None:
valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train")
else:
valid_data = None

print("Loading model...")
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device))

# first, logits for everything.
f = load_or_compute_logits(args, model, eval_data, valid_data)

# now, compute the intrinsic scores.
results = {}
clfs = {}
# Initialize lists to store scores for each metric across all languages
u_scores, t_scores, punct_scores = [], [], []

for lang_code, dsets in tqdm(eval_data.items()):
if args.include_langs is not None and lang_code not in args.include_langs:
continue

print(f"Predicting {lang_code}...")
results[lang_code] = {}
clfs[lang_code] = {}

for dataset_name, dataset in dsets["sentence"].items():
sentences = dataset["data"]

if "train_logits" in f[lang_code][dataset_name]:
feature_indices = None
clf = train_mixture(
[lang_code],
f[lang_code][dataset_name]["train_logits"][:],
f[lang_code][dataset_name]["train_labels"][:],
features=feature_indices,
)
if clf[0] is not None:
print(clf)
print(np.argsort(clf[0].coef_[0])[:10], "...", np.argsort(clf[0].coef_[0])[-10:])
print(np.where(np.argsort(clf[0].coef_[0]) == 0)[0])

score_t, score_punct, _ = evaluate_mixture(
lang_code,
f[lang_code][dataset_name]["test_logits"][:],
sentences,
*clf,
)

clfs[lang_code][dataset_name] = clf

clf = list(copy.deepcopy(clf))
clf[-1] = args.threshold
else:
score_t = score_punct = None

score_u, _, _ = evaluate_mixture(lang_code, f[lang_code][dataset_name]["test_logits"][:], sentences, *clf)

results[lang_code][dataset_name] = {
"u": score_u,
"t": score_t,
"punct": score_punct,
}

# just for printing
score_t = score_t or 0.0
score_punct = score_punct or 0.0

u_scores.append((score_u, lang_code))
t_scores.append((score_t, lang_code))
punct_scores.append((score_punct, lang_code))
print(f"{lang_code} {dataset_name} {score_u:.3f} {score_t:.3f} {score_punct:.3f}")

# Compute statistics for each metric across all languages
results_avg = {
"u": compute_statistics(u_scores),
"t": compute_statistics(t_scores),
"punct": compute_statistics(punct_scores),
"include_langs": args.include_langs,
}

sio.dump(
clfs,
open(
Constants.CACHE_DIR / (f"{args.model_path.split('/')[0]}_L_b{args.block_size}+s{args.stride}.skops"),
"wb",
),
)
json.dump(
results,
open(
Constants.CACHE_DIR
/ (f"{args.model_path.split('/')[0]}_L_b{args.block_size}+s{args.stride}_intrinsic_results_u{args.threshold}.json"),
"w",
),
indent=4,
)

# Write results_avg to JSON
json.dump(
results_avg,
open(Constants.CACHE_DIR / (f"{args.model_path.split('/')[0]}_L_b{args.block_size}+s{args.stride}_u{args.threshold}_AVG.json"), "w"),
indent=4,
)
1 change: 0 additions & 1 deletion wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def extract(
tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
tokens = tokenizer(batch_of_texts, return_offsets_mapping=True, verbose=False)
# remove CLS and SEP tokens, they are added later anyhow
old_batch_of_texts = batch_of_texts
batch_of_texts = [text[1:-1] for text in tokens["input_ids"]]
offset_mapping = [offset[1:-1] for offset in tokens["offset_mapping"]]
cls_token_id = tokenizer.cls_token_id
Expand Down
4 changes: 2 additions & 2 deletions wtpsplit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ def get_extended_attention_mask(
text = "This is a test\n sentence \n\n"
tokenizer = AutoTokenizer.from_pretrained(model_str)

tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=8, padding=True)
tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False, pad_to_multiple_of=512, padding=True)
from tokenizers import AddedToken

tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
Expand All @@ -1300,5 +1300,5 @@ def get_extended_attention_mask(
print(tokens)

# forward pass
lookahead = 1
lookahead = 512
print(backbone(**tokens, lookahead=lookahead))
24 changes: 17 additions & 7 deletions wtpsplit/summary_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@
import json

FILES = [
".cache/xlmr-normal-v2_b128+s64_intrinsic_results_u0.01.json",
".cache/xlmr-normal-v2_b256+s64_intrinsic_results_u0.01.json",
".cache/xlmr-normal-v2_b512+s64_intrinsic_results_u0.01.json",
".cache/xlmr-normal-p-v2_auxp0.3_n0.9_b256_s64_intrinsic_results_u0.01.json",
".cache/xlm-normal-p-v2_auxp0.3_n0.9_b512_s64_intrinsic_results_u0.001.json",
".cache/xlmr-normal-p-v2_n0.9_b512+s64_intrinsic_results_u0.01.json",
".cache/xlmr-normal-p-v2-auxp0.1_b512+s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-p-v2_auxp0.3_n0.9_b64_s32_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-p-v2_auxp0.3_n0.9_b128_s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-p-v2_auxp0.3_n0.9_b256_s64_intrinsic_results_u0.01.json",
# ".cache/xlm-normal-p-v2_auxp0.3_n0.9_b512_s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-p-v2_n0.9_b256+s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-p-v2_n0.9_b512+s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-p-v2_n0.9_b128+s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-p-v2-auxp0.1_b256+s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-p-v2-auxp0.1_b512+s64_intrinsic_results_u0.01.json",
".cache/xlmr-normal-p-v2_auxp0.2_b512+s64_intrinsic_results_u0.01.json",
".cache/xlmr-normal-p-v2_auxp0.2_b512+s128_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-p-v2_auxp0.4_b256+s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-p-v2_auxp0.4_b512+s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-v2_b128+s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-v2_b256+s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-v2_b512+s64_intrinsic_results_u0.01.json",
# ".cache/xlmr-normal-9l-v2_b512+s64_intrinsic_results_u0.01.json",
# "evaluation/evaluation_results/wtp-canine-s-3l-no-adapters_intrinsic_results.json",
# "evaluation/evaluation_results/wtp-canine-s-3l_intrinsic_results.json",
]
Expand Down
3 changes: 3 additions & 0 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def evaluate_sentence(
batch_size,
use_pysbd=False,
positive_index=None,
# do_lowercase=False,
):
if positive_index is None:
positive_index = Constants.NEWLINE_INDEX
Expand All @@ -74,6 +75,8 @@ def evaluate_sentence(

separator = Constants.SEPARATORS[lang_code]
text = separator.join(sentences)
# if do_lowercase:
# text = text.lower()

logits, offsets_mapping, tokenizer = extract(
[text],
Expand Down
22 changes: 22 additions & 0 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,8 @@ def maybe_pad(text):
def compute_metrics(trainer):
metrics = {}
avg_metrics = defaultdict(lambda: [])
# metrics_lower = {}
# avg_metrics_lower = defaultdict(lambda: [])

model = trainer._wrap_model(trainer.model, training=False)

Expand All @@ -574,10 +576,30 @@ def compute_metrics(trainer):
avg_metrics[f"average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
else:
avg_metrics[f"average_whitespace_{dataset_name}_pr_auc"].append(score)
# for dataset_name, dataset in lang_data["sentence"].items():
# score, _ = evaluate_sentence(
# lang_code,
# dataset["data"],
# model,
# stride=args.eval_stride,
# block_size=args.block_size,
# batch_size=training_args.per_device_eval_batch_size,
# do_lowercase=True,
# )
# metrics_lower[f"lower_{lang_code}_{dataset_name}_pr_auc"] = score
# avg_metrics_lower[f"lower_average_{dataset_name}_pr_auc"].append(score)
# if lang_code in ["zh", "ja", "my", "km"]:
# avg_metrics_lower[f"lower_average_nonwhitespace_{dataset_name}_pr_auc"].append(score)
# else:
# avg_metrics_lower[f"lower_average_whitespace_{dataset_name}_pr_auc"].append(score)


for name, values in avg_metrics.items():
if len(values) > 1:
metrics[name] = np.mean(values)
# for name, values in avg_metrics_lower.items():
# if len(values) > 1:
# metrics_lower[name] = np.mean(values)

return metrics

Expand Down

0 comments on commit 2214f4c

Please sign in to comment.