diff --git a/scripts/export_to_onnx_charbert.py b/scripts/export_to_onnx_charbert.py index 1a94dc75..fbaa00b2 100644 --- a/scripts/export_to_onnx_charbert.py +++ b/scripts/export_to_onnx_charbert.py @@ -9,6 +9,7 @@ import wtpsplit # noqa import wtpsplit.models # noqa + @dataclass class Args: model_name_or_path: str = "benjamin/wtp-bert-mini" diff --git a/setup.py b/setup.py index 49514e9b..817d5361 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ "pandas>=1", "cached_property", # for Py37 "mosestokenizer", - "adapters" + "adapters", ], url="https://github.com/segment-any-text/wtpsplit", package_data={"wtpsplit": ["data/*"]}, diff --git a/wtpsplit/__init__.py b/wtpsplit/__init__.py index 2e291698..308566cc 100644 --- a/wtpsplit/__init__.py +++ b/wtpsplit/__init__.py @@ -23,6 +23,7 @@ warnings.simplefilter("default", DeprecationWarning) # show by default warnings.simplefilter("ignore", category=FutureWarning) # for tranformers + class WtP: def __init__( self, @@ -435,9 +436,7 @@ def __init__( self.use_lora = False - self.tokenizer = AutoTokenizer.from_pretrained( - "facebookAI/xlm-roberta-base" - ) + self.tokenizer = AutoTokenizer.from_pretrained("facebookAI/xlm-roberta-base") if isinstance(model_name_or_model, (str, Path)): model_name = str(model_name_or_model) @@ -498,12 +497,14 @@ def __init__( ) # LoRA LOADING # TODO: LoRA + ONNX ? - if (style_or_domain and not language) or (language and not style_or_domain): - raise ValueError("Please specify both language and style_or_domain!") - if style_or_domain and language: + if not lora_path: + if (style_or_domain and not language) or (language and not style_or_domain): + raise ValueError("Please specify both language and style_or_domain!") + if (style_or_domain and language) or lora_path: import adapters # noqa from adapters.models import MODEL_MIXIN_MAPPING # noqa from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa + # monkey patch mixin to avoid forking whole adapters library MODEL_MIXIN_MAPPING["SubwordXLMRobertaModel"] = BertModelAdaptersMixin model_type = self.model.model.config.model_type @@ -642,7 +643,7 @@ def newline_probability_fn(logits): batch_size=batch_size, pad_last_batch=pad_last_batch, verbose=verbose, - tokenizer=self.tokenizer + tokenizer=self.tokenizer, ) # convert token probabilities to character probabilities for the entire array diff --git a/wtpsplit/evaluation/adapt.py b/wtpsplit/evaluation/adapt.py index 1e5da05e..57c6f688 100644 --- a/wtpsplit/evaluation/adapt.py +++ b/wtpsplit/evaluation/adapt.py @@ -182,7 +182,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, save_str: st if args.adapter_path: if args.clf_from_scratch: model.model.classifier = torch.nn.Linear(model.model.classifier.in_features, 1) - + dataset_load_name = dataset_name model.model.load_adapter( args.adapter_path + "/" + dataset_load_name + "/" + lang_code, @@ -354,7 +354,7 @@ def main(args): clfs = {} if args.return_indices: indices = {} - + u_scores, t_scores, punct_scores = [], [], [] for lang_code, dsets in tqdm(eval_data.items()): diff --git a/wtpsplit/evaluation/evaluate_sepp_nlg_subtask1.py b/wtpsplit/evaluation/evaluate_sepp_nlg_subtask1.py index baca8648..aa7fec51 100644 --- a/wtpsplit/evaluation/evaluate_sepp_nlg_subtask1.py +++ b/wtpsplit/evaluation/evaluate_sepp_nlg_subtask1.py @@ -12,7 +12,7 @@ def evaluate_subtask1(splits, langs, prediction_dir: str, supervisions, include_ Mirrors the original SEPP-NLG 2021 Shared Task evaluation function https://sites.google.com/view/sentence-segmentation """ - + results = {} avg_holder = {} for supervision in supervisions: diff --git a/wtpsplit/evaluation/intrinsic_baselines_multilingual.py b/wtpsplit/evaluation/intrinsic_baselines_multilingual.py index e1091ded..8459974f 100644 --- a/wtpsplit/evaluation/intrinsic_baselines_multilingual.py +++ b/wtpsplit/evaluation/intrinsic_baselines_multilingual.py @@ -61,7 +61,7 @@ class Args: continue results[lang][dataset_name] = {} indices[lang][dataset_name] = {} - + if "-" in lang: # code-switched data: eval 2x lang_code = lang.split("_")[1].lower() diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index 46a92a24..2e747175 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -70,7 +70,6 @@ class Args: min_k_mer_length: int = 0 - def process_logits_k_mers(pairs, model, lang_code, block_size, batch_size, verbose=True) -> List[np.ndarray]: logits_list = [] n_tokens_list = [] diff --git a/wtpsplit/evaluation/legal_baselines.py b/wtpsplit/evaluation/legal_baselines.py index d1d9382d..88141d88 100644 --- a/wtpsplit/evaluation/legal_baselines.py +++ b/wtpsplit/evaluation/legal_baselines.py @@ -131,7 +131,7 @@ def load_or_compute_logits(args, eval_data, save_str: str = None): start_time = time.time() test_logits = get_law_preds(test_text, model, current_name, args) - end_time = time.time() + end_time = time.time() total_test_time += end_time - start_time if isinstance(test_sentences[0], list): test_logit_lengths = [] diff --git a/wtpsplit/evaluation/llm_sentence.py b/wtpsplit/evaluation/llm_sentence.py index 64f9cbe0..71c68194 100644 --- a/wtpsplit/evaluation/llm_sentence.py +++ b/wtpsplit/evaluation/llm_sentence.py @@ -246,11 +246,7 @@ def load_or_compute_logits(args, eval_data, save_str: str = None): if isinstance(test_sentences[0], list) or args.type == "pairs": if args.type == "pairs": all_pairs = generate_k_mers( - test_sentences, - k=2, - do_lowercase=False, - do_remove_punct=False, - sample_pct=0.5 + test_sentences, k=2, do_lowercase=False, do_remove_punct=False, sample_pct=0.5 ) test_sentences = all_pairs # list of lists: chunk each sublist @@ -568,10 +564,8 @@ def main(args): eval_data = torch.load(eval_data_path) - save_str = ( - f"{args.model.split('/')[-1]}_k{args.k}_s{args.n_shots}" - ).replace("/", "_") - + save_str = (f"{args.model.split('/')[-1]}_k{args.k}_s{args.n_shots}").replace("/", "_") + if args.max_n_test_sentences < sys.maxsize and args.max_n_test_sentences != -1: save_str += f"_n{args.max_n_test_sentences}" if args.max_n_test_sentences == -1: diff --git a/wtpsplit/evaluation/stat_tests/permutation_test_data.py b/wtpsplit/evaluation/stat_tests/permutation_test_data.py index c05c239d..58b52d12 100644 --- a/wtpsplit/evaluation/stat_tests/permutation_test_data.py +++ b/wtpsplit/evaluation/stat_tests/permutation_test_data.py @@ -63,7 +63,6 @@ if isinstance(data_list[0], int): data_list = [data_list] - raw_data[lang][dataset][model + "-" + model_type] = data_list true_indices = data[lang][dataset][model_type]["true_indices"] diff --git a/wtpsplit/extract_batched.py b/wtpsplit/extract_batched.py index 53020b43..ead162aa 100644 --- a/wtpsplit/extract_batched.py +++ b/wtpsplit/extract_batched.py @@ -110,12 +110,17 @@ def extract_batched( if use_subwords and model.config.model_type == "xlm-roberta": # TODO: generalize import torch + with torch.no_grad(): - logits = model.model( - input_ids=torch.from_numpy(input_ids).to(model.model.device), - attention_mask=torch.from_numpy(attention_mask).to(model.model.device), - **kwargs, - )["logits"].cpu().numpy() + logits = ( + model.model( + input_ids=torch.from_numpy(input_ids).to(model.model.device), + attention_mask=torch.from_numpy(attention_mask).to(model.model.device), + **kwargs, + )["logits"] + .cpu() + .numpy() + ) else: logits = model( input_ids=input_ids if use_subwords else None, diff --git a/wtpsplit/models.py b/wtpsplit/models.py index 90580a99..b68d0e85 100644 --- a/wtpsplit/models.py +++ b/wtpsplit/models.py @@ -1231,11 +1231,11 @@ def get_extended_attention_mask( """ # if not (attention_mask.dim() == 2 and config.is_decoder): - # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` - # if device is not None: - # warnings.warn( - # "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning - # ) + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + # if device is not None: + # warnings.warn( + # "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + # ) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. if attention_mask.dim() == 3: diff --git a/wtpsplit/train/adaptertrainer.py b/wtpsplit/train/adaptertrainer.py index e9f584e1..b506f4fd 100644 --- a/wtpsplit/train/adaptertrainer.py +++ b/wtpsplit/train/adaptertrainer.py @@ -78,6 +78,7 @@ TRAINING_ARGS_NAME = "training_args.bin" + class AdapterTrainer(Trainer): def __init__( self, @@ -484,11 +485,13 @@ def evaluation_loop( if all_inputs is not None: all_inputs = nested_truncate(all_inputs, num_samples) else: - xm.rendezvous("eval_metrics") + if is_torch_tpu_available(): + xm.rendezvous("eval_metrics") all_losses, all_preds, all_labels, all_inputs, num_samples = None, None, None, None, 0 # Metrics! - xm.rendezvous("eval_metrics") + if is_torch_tpu_available(): + xm.rendezvous("eval_metrics") # MODIFIED: always compute metrics if self.compute_metrics is not None: metrics = self.compute_metrics(self) diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index 6ae6dacc..9299f369 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -5,7 +5,7 @@ import pysbd import sklearn.metrics -from wtpsplit.evaluation.intrinsic_pairwise import generate_pairs, generate_k_mers, process_logits_k_mers +from wtpsplit.evaluation.intrinsic_pairwise import generate_k_mers, process_logits_k_mers from wtpsplit.extract import PyTorchWrapper, extract from wtpsplit.utils import Constants, sigmoid, corrupt, token_to_char_probs @@ -153,13 +153,12 @@ def evaluate_sentence_pairwise( accuracy_list = [] # get pairs of sentences (non-overlapping) - sampled_pairs = generate_pairs( + sampled_pairs = generate_k_mers( sentences=sentences, - pair_sample_pct=pair_sample_pct, - max_n_pairs=max_pairs, - min_pair_length=0, - do_lowercase=do_lowercase, - do_remove_punct=do_remove_punct, + k=2, + sample_pct=pair_sample_pct, + max_n_samples=max_pairs, + min_k_mer_length=0, ) # get logits for each pair diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 6be21cfd..51c7ffeb 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -4,6 +4,7 @@ import random import shutil import sys + # import time from collections import Counter, defaultdict from dataclasses import dataclass @@ -14,14 +15,15 @@ import datasets import numpy as np import torch -import torch_xla.core.xla_model as xm import transformers from datasets import load_dataset + # from datasets.download import DownloadConfig from tokenizers import AddedToken from torchinfo import summary from tqdm.auto import tqdm from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed +from transformers.trainer import is_torch_tpu_available import wandb from wtpsplit.models import ( @@ -35,6 +37,7 @@ from wtpsplit.train.evaluate import evaluate_sentence from wtpsplit.train.trainer import Trainer from wtpsplit.train.utils import Model + # from wtpsplit.train.utils import cleanup_cache_files from wtpsplit.utils import Constants, LabelArgs, corrupt_training, get_label_dict, get_subword_label_dict @@ -197,12 +200,22 @@ def main(): else: (args, training_args, label_args) = parser.parse_args_into_dataclasses() wandb_name = None - if xm.xrt_world_size() == 4: - # ensure same batch size on TPUv3 and TPUv4 using same config.json - training_args.per_device_train_batch_size *= 2 + + if is_torch_tpu_available(): + import torch_xla.core.xla_model as xm + + world_size = xm.xrt_world_size() + if world_size == 4: + # ensure same batch size on TPUv3 and TPUv4 using same config.json + training_args.per_device_train_batch_size *= 2 + elif torch.cuda.is_available(): + world_size = torch.cuda.device_count() + else: + world_size = 1 + logger.warning(f"Per device train batch size: {training_args.per_device_train_batch_size}") logger.warning( - f"Total train batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps* xm.xrt_world_size()}" + f"Total train batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * world_size}" ) setup_logging(training_args) diff --git a/wtpsplit/train/train_SM.py b/wtpsplit/train/train_SM.py index 87cfefa7..e5bc38ed 100644 --- a/wtpsplit/train/train_SM.py +++ b/wtpsplit/train/train_SM.py @@ -22,10 +22,11 @@ @dataclass class Args: block_size: int = 256 - num_layers: int = 12 # number of layers - lim_lookahead: bool = False # our "Lookahead" ablation - without_pretraining: bool = False # our "No pre-training" ablation - no_sm_corruption: bool = False # our "Only clean text" ablation + num_layers: int = 12 # number of layers + lim_lookahead: bool = False # our "Lookahead" ablation + without_pretraining: bool = False # our "No pre-training" ablation + no_sm_corruption: bool = False # our "Only clean text" ablation + # Parsing command line arguments or JSON config files as needed parser = HfArgumentParser([Args, TrainingArguments]) @@ -55,7 +56,7 @@ class Args: and all_data[lang_code]["sentence"]["ud"]["meta"]["train_data"] is not None ): train_data = all_data[lang_code]["sentence"]["ud"]["meta"]["train_data"] - + if len(train_data) < 10000: # some languages have an insufficient number of sentences to fill a single batch # this is just a quick way to upsample these so we don't run into problems later @@ -155,7 +156,7 @@ class Args: assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast) if args.num_layers == 3 and args.without_pretraining: - # special case for one of our ablations, where we trim XLM-R (without any of our newline pretraining) to 3 layers + # special case for one of our ablations, where we trim XLM-R (without any of our newline pretraining) to 3 layers model = SubwordXLMForTokenClassification.from_pretrained( model_checkpoint, num_labels=1, diff --git a/wtpsplit/train/train_lora.py b/wtpsplit/train/train_lora.py index 06b399f0..b0e9154c 100644 --- a/wtpsplit/train/train_lora.py +++ b/wtpsplit/train/train_lora.py @@ -573,7 +573,7 @@ def compute_metrics(trainer): ) trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) logger.warning(f"Finished training for {lang} {dataset_name}.") - + # only save trained module if training_args.local_rank == 0: if not os.path.exists(os.path.join(training_args.output_dir, dataset_name, lang)): diff --git a/wtpsplit/utils/download_spacy.py b/wtpsplit/utils/download_spacy.py index fc72ade6..bd1d1956 100644 --- a/wtpsplit/utils/download_spacy.py +++ b/wtpsplit/utils/download_spacy.py @@ -27,9 +27,11 @@ "multi": "xx_sent_ud_sm" } + def download_models(): for lang, model in SPACY_LANG_TO_DP_MODEL.items(): subprocess.run(["python3", "-m", "spacy", "download", model]) + if __name__ == "__main__": download_models()