From d55d547d44e80918c87837e145e9eecf8ad571c0 Mon Sep 17 00:00:00 2001 From: markus583 Date: Wed, 27 Dec 2023 09:26:41 +0000 Subject: [PATCH] fix collate length?! --- wtpsplit/train/evaluate.py | 2 +- wtpsplit/train/train.py | 24 ++++++++++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index d9b22b0e..d6475ba3 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -110,7 +110,7 @@ def evaluate_sentence( newline_labels[true_end_indices - 1] = 1 if "xlm" in model.config.model_type: - tokens = tokenizer.tokenize(text) + tokens = tokenizer.tokenize(text, verbose=False) char_probs = token_to_char_probs(text, tokens, logits[:, positive_index], tokenizer, offsets_mapping) else: char_probs = logits[:, positive_index] diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index c4177774..6c25ea9d 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -10,6 +10,7 @@ import numpy as np import torch +from tqdm.auto import tqdm import wandb from datasets import load_dataset from datasets.download import DownloadConfig @@ -34,7 +35,7 @@ # TODO: double-check checkpointing and saving (also to txt) -# os.environ["PJRT_DEVICE"] = "None" +os.environ["PJRT_DEVICE"] = "None" class Model(nn.Module): @@ -241,7 +242,8 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): # always include SEP if input_ids[-1] != tokenizer.sep_token_id: # also insert PAD token as long as len < block_size - while len(input_ids) < args.block_size - 1: + while len(input_ids) <= args.block_size - 1: + print("first", len(input_ids)) input_ids = input_ids + [tokenizer.pad_token_id] labels = labels + [0] input_ids = input_ids + [tokenizer.sep_token_id] @@ -250,6 +252,16 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): start = np.random.randint(0, len(input_ids) - args.block_size) input_ids = input_ids[start : start + args.block_size] labels = labels[start : start + args.block_size] + elif len(input_ids) != args.block_size and args.use_subwords: + del input_ids[-1] + del labels[-1] + while len(input_ids) <= args.block_size - 1: + # insert pad token at second-to-last position + print("second", len(input_ids)) + input_ids = input_ids + [tokenizer.pad_token_id] + labels = labels + [0] + input_ids = input_ids + [tokenizer.sep_token_id] + labels = labels + [0] input_ids = torch.tensor(input_ids[: args.block_size], dtype=torch.long) labels = torch.tensor(labels[: args.block_size], dtype=torch.long) @@ -580,6 +592,10 @@ def maybe_pad(text): num_proc=num_workers, remove_columns=[args.text_column], ) + else: + # this is no longer used and would cause an error otherwise + with training_args.main_process_first(): + dataset = dataset.remove_columns([args.text_column]) if split == "train": with training_args.main_process_first(): @@ -596,7 +612,7 @@ def maybe_pad(text): batched=True, num_proc=num_workers, # a bit hacky but oh well, only drop if sentence - remove_columns=["ends_with_punctuation"] # FIXME: needed for char-based args.text_column dropping + remove_columns=["ends_with_punctuation"] if args.text_column == "text" else [], ) @@ -649,7 +665,7 @@ def compute_metrics(trainer): model = trainer._wrap_model(trainer.model, training=False) - for lang_code, lang_data in eval_data.items(): # TODO: tqdm integration + for lang_code, lang_data in tqdm(eval_data.items(), desc="Evaluate!"): if args.include_languages is not None and lang_code not in args.include_languages: continue