Skip to content

Commit

Permalink
fix xlmr 3l eval
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed May 25, 2024
1 parent 1a11667 commit 2b77d40
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 28 deletions.
70 changes: 70 additions & 0 deletions utils/clean_tweets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import re
import torch


def remove_emojis_and_special_chars(text):
emoji_pattern = re.compile(
"["
"\U0001f600-\U0001f64f" # emoticons
"\U0001f300-\U0001f5ff" # symbols & pictographs
"\U0001f680-\U0001f6ff" # transport & map symbols
"\U0001f700-\U0001f77f" # alchemical symbols
"\U0001f780-\U0001f7ff" # Geometric Shapes Extended
"\U0001f800-\U0001f8ff" # Supplemental Arrows-C
"\U0001f900-\U0001f9ff" # Supplemental Symbols and Pictographs
"\U0001fa00-\U0001fa6f" # Chess Symbols
"\U0001fa70-\U0001faff" # Symbols and Pictographs Extended-A
"\U00002702-\U000027b0" # Dingbats
"\U000024c2-\U0001f251"
"]+",
flags=re.UNICODE,
)
text = emoji_pattern.sub(r"", text)
text = re.sub(r"[:;=Xx][\-oO\']*[\)\(\[\]DdPp3><\|\\\/]", "", text)
return text


def transform_data(data):
def pair_sentences(sequences):
paired_sequences = []
for sequence in sequences:
processed_sequence = []
for sentence in sequence:
words = sentence.strip().split()
filtered_words = [
remove_emojis_and_special_chars(word)
for word in words
if not (word.startswith("http") or word.startswith("#") or word.startswith("@"))
]
cleaned_sentence = " ".join(filtered_words) # fine for our langs.
if cleaned_sentence and len(cleaned_sentence.split()) > 0:
processed_sequence.append(cleaned_sentence.strip())
if processed_sequence and len(processed_sequence) < 6:
paired_sequences.append(processed_sequence)
return paired_sequences

transformed_data = {}
for lang_code, lang_data in data.items():
if lang_code == "en-de":
continue
transformed_data[lang_code] = {}
for content_type, datasets in lang_data.items():
if content_type != "sentence":
continue
transformed_data[lang_code] = {}
transformed_data[lang_code][content_type] = {}
for dataset_name, content in datasets.items():
if "short" not in dataset_name:
continue
transformed_data[lang_code][content_type][dataset_name] = {
"meta": {"train_data": pair_sentences(content["meta"]["train_data"])},
"data": pair_sentences(content["data"]),
}

return transformed_data


data = torch.load("data/all_data_11_05-all.pth")

transformed_data = transform_data(data)
torch.save(transformed_data, "data/all_data_11_05-short_proc_SM.pth")
28 changes: 16 additions & 12 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import json
from dataclasses import dataclass
from typing import List
from typing import List, Union
import os
import time
import logging
Expand All @@ -16,6 +16,7 @@
import adapters

import wtpsplit.models # noqa: F401
from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs
from wtpsplit.evaluation.intrinsic_baselines import split_language_data
from wtpsplit.extract import PyTorchWrapper, extract
Expand Down Expand Up @@ -62,6 +63,7 @@ class Args:
return_indices: bool = True
exclude_every_k: int = 10
save_suffix: str = ""
num_hidden_layers: Union[int, None] = None


def process_logits(text, model, lang_code, args):
Expand Down Expand Up @@ -178,21 +180,13 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st
continue
if "nllb" in dataset_name:
continue
if "-" in lang_code and "canine" in args.model_path and not "no-adapters" in args.model_path:
if "-" in lang_code and "canine" in args.model_path and "no-adapters" not in args.model_path:
# code-switched data: eval 2x
lang_code = lang_code.split("_")[1].lower()
try:
if args.adapter_path:
if args.clf_from_scratch:
model.model.classifier = torch.nn.Linear(model.model.classifier.in_features, 1)
# elif model.model.classifier.out_features == 2:
elif args.model_path == "xlm-roberta-base" or args.model_path == "xlm-roberta-large":
# we train XLM-R using our wrapper, needs to be adapted for adapters to be loaded
model.model.classifier = torch.nn.Linear(
model.model.classifier.in_features,
1, # FIXME: hardcoded?
)
model.model.__class__.__name__ = "SubwordXLMForTokenClassification"
# if (
# any(code in lang_code for code in ["ceb", "jv", "mn", "yo"])
# and "ted2020" not in dataset_name
Expand Down Expand Up @@ -338,7 +332,7 @@ def main(args):
save_str = f"{save_model_path.replace('/','_')}_b{args.block_size}_s{args.stride}"

eval_data = torch.load(args.eval_data_path)
if "canine" in args.model_path and not "no-adapters" in args.model_path:
if "canine" in args.model_path and "no-adapters" not in args.model_path:
eval_data = split_language_data(eval_data)
if args.valid_text_path is not None:
valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train")
Expand All @@ -347,7 +341,17 @@ def main(args):

logger.warning("Loading model...")
model_path = args.model_path
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device))
if args.model_path == "xlm-roberta-base" or args.model_path == "xlm-roberta-large":
config = SubwordXLMConfig.from_pretrained(
args.model_path,
num_hidden_layers=args.num_hidden_layers,
num_labels=1,
)
model = PyTorchWrapper(
SubwordXLMForTokenClassification.from_pretrained(model_path, config=config).to(args.device)
)
else:
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device))
if args.adapter_path:
model_type = model.model.config.model_type
# adapters need xlm-roberta as model type.
Expand Down
22 changes: 6 additions & 16 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,22 +219,12 @@ def extract(

kwargs = {"language_ids": language_ids[: len(batch_attention_mask)]} if uses_lang_adapters else {}

if use_subwords and model.config.model_type == "xlm-roberta":
# TODO: generalize
import torch
with torch.no_grad():
logits = model.model(
input_ids=torch.from_numpy(batch_input_ids).to(model.model.device),
attention_mask=torch.from_numpy(batch_attention_mask).to(model.model.device),
**kwargs,
)["logits"].cpu().numpy()
else:
logits = model(
input_ids=batch_input_ids if use_subwords else None,
hashed_ids=None if use_subwords else batch_input_hashes,
attention_mask=batch_attention_mask,
**kwargs,
)["logits"]
logits = model(
input_ids=batch_input_ids if use_subwords else None,
hashed_ids=None if use_subwords else batch_input_hashes,
attention_mask=batch_attention_mask,
**kwargs,
)["logits"]

if use_subwords:
logits = logits[:, 1:-1, :] # remove CLS and SEP tokens
Expand Down

0 comments on commit 2b77d40

Please sign in to comment.