From 2b77d401542214fb65b3842997c3fffe2250c94a Mon Sep 17 00:00:00 2001 From: markus583 Date: Sat, 25 May 2024 06:00:12 +0000 Subject: [PATCH] fix xlmr 3l eval --- utils/clean_tweets.py | 70 ++++++++++++++++++++++++++++++++ wtpsplit/evaluation/intrinsic.py | 28 +++++++------ wtpsplit/extract.py | 22 +++------- 3 files changed, 92 insertions(+), 28 deletions(-) create mode 100644 utils/clean_tweets.py diff --git a/utils/clean_tweets.py b/utils/clean_tweets.py new file mode 100644 index 00000000..696a07bc --- /dev/null +++ b/utils/clean_tweets.py @@ -0,0 +1,70 @@ +import re +import torch + + +def remove_emojis_and_special_chars(text): + emoji_pattern = re.compile( + "[" + "\U0001f600-\U0001f64f" # emoticons + "\U0001f300-\U0001f5ff" # symbols & pictographs + "\U0001f680-\U0001f6ff" # transport & map symbols + "\U0001f700-\U0001f77f" # alchemical symbols + "\U0001f780-\U0001f7ff" # Geometric Shapes Extended + "\U0001f800-\U0001f8ff" # Supplemental Arrows-C + "\U0001f900-\U0001f9ff" # Supplemental Symbols and Pictographs + "\U0001fa00-\U0001fa6f" # Chess Symbols + "\U0001fa70-\U0001faff" # Symbols and Pictographs Extended-A + "\U00002702-\U000027b0" # Dingbats + "\U000024c2-\U0001f251" + "]+", + flags=re.UNICODE, + ) + text = emoji_pattern.sub(r"", text) + text = re.sub(r"[:;=Xx][\-oO\']*[\)\(\[\]DdPp3><\|\\\/]", "", text) + return text + + +def transform_data(data): + def pair_sentences(sequences): + paired_sequences = [] + for sequence in sequences: + processed_sequence = [] + for sentence in sequence: + words = sentence.strip().split() + filtered_words = [ + remove_emojis_and_special_chars(word) + for word in words + if not (word.startswith("http") or word.startswith("#") or word.startswith("@")) + ] + cleaned_sentence = " ".join(filtered_words) # fine for our langs. + if cleaned_sentence and len(cleaned_sentence.split()) > 0: + processed_sequence.append(cleaned_sentence.strip()) + if processed_sequence and len(processed_sequence) < 6: + paired_sequences.append(processed_sequence) + return paired_sequences + + transformed_data = {} + for lang_code, lang_data in data.items(): + if lang_code == "en-de": + continue + transformed_data[lang_code] = {} + for content_type, datasets in lang_data.items(): + if content_type != "sentence": + continue + transformed_data[lang_code] = {} + transformed_data[lang_code][content_type] = {} + for dataset_name, content in datasets.items(): + if "short" not in dataset_name: + continue + transformed_data[lang_code][content_type][dataset_name] = { + "meta": {"train_data": pair_sentences(content["meta"]["train_data"])}, + "data": pair_sentences(content["data"]), + } + + return transformed_data + + +data = torch.load("data/all_data_11_05-all.pth") + +transformed_data = transform_data(data) +torch.save(transformed_data, "data/all_data_11_05-short_proc_SM.pth") diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index 3045cc85..10e132bc 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -1,7 +1,7 @@ import copy import json from dataclasses import dataclass -from typing import List +from typing import List, Union import os import time import logging @@ -16,6 +16,7 @@ import adapters import wtpsplit.models # noqa: F401 +from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs from wtpsplit.evaluation.intrinsic_baselines import split_language_data from wtpsplit.extract import PyTorchWrapper, extract @@ -62,6 +63,7 @@ class Args: return_indices: bool = True exclude_every_k: int = 10 save_suffix: str = "" + num_hidden_layers: Union[int, None] = None def process_logits(text, model, lang_code, args): @@ -178,21 +180,13 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st continue if "nllb" in dataset_name: continue - if "-" in lang_code and "canine" in args.model_path and not "no-adapters" in args.model_path: + if "-" in lang_code and "canine" in args.model_path and "no-adapters" not in args.model_path: # code-switched data: eval 2x lang_code = lang_code.split("_")[1].lower() try: if args.adapter_path: if args.clf_from_scratch: model.model.classifier = torch.nn.Linear(model.model.classifier.in_features, 1) - # elif model.model.classifier.out_features == 2: - elif args.model_path == "xlm-roberta-base" or args.model_path == "xlm-roberta-large": - # we train XLM-R using our wrapper, needs to be adapted for adapters to be loaded - model.model.classifier = torch.nn.Linear( - model.model.classifier.in_features, - 1, # FIXME: hardcoded? - ) - model.model.__class__.__name__ = "SubwordXLMForTokenClassification" # if ( # any(code in lang_code for code in ["ceb", "jv", "mn", "yo"]) # and "ted2020" not in dataset_name @@ -338,7 +332,7 @@ def main(args): save_str = f"{save_model_path.replace('/','_')}_b{args.block_size}_s{args.stride}" eval_data = torch.load(args.eval_data_path) - if "canine" in args.model_path and not "no-adapters" in args.model_path: + if "canine" in args.model_path and "no-adapters" not in args.model_path: eval_data = split_language_data(eval_data) if args.valid_text_path is not None: valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train") @@ -347,7 +341,17 @@ def main(args): logger.warning("Loading model...") model_path = args.model_path - model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device)) + if args.model_path == "xlm-roberta-base" or args.model_path == "xlm-roberta-large": + config = SubwordXLMConfig.from_pretrained( + args.model_path, + num_hidden_layers=args.num_hidden_layers, + num_labels=1, + ) + model = PyTorchWrapper( + SubwordXLMForTokenClassification.from_pretrained(model_path, config=config).to(args.device) + ) + else: + model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(model_path).to(args.device)) if args.adapter_path: model_type = model.model.config.model_type # adapters need xlm-roberta as model type. diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index e1bbd853..f6dd39a9 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -219,22 +219,12 @@ def extract( kwargs = {"language_ids": language_ids[: len(batch_attention_mask)]} if uses_lang_adapters else {} - if use_subwords and model.config.model_type == "xlm-roberta": - # TODO: generalize - import torch - with torch.no_grad(): - logits = model.model( - input_ids=torch.from_numpy(batch_input_ids).to(model.model.device), - attention_mask=torch.from_numpy(batch_attention_mask).to(model.model.device), - **kwargs, - )["logits"].cpu().numpy() - else: - logits = model( - input_ids=batch_input_ids if use_subwords else None, - hashed_ids=None if use_subwords else batch_input_hashes, - attention_mask=batch_attention_mask, - **kwargs, - )["logits"] + logits = model( + input_ids=batch_input_ids if use_subwords else None, + hashed_ids=None if use_subwords else batch_input_hashes, + attention_mask=batch_attention_mask, + **kwargs, + )["logits"] if use_subwords: logits = logits[:, 1:-1, :] # remove CLS and SEP tokens