From 5b3733139e89d7522cda5c6657577915f43e570e Mon Sep 17 00:00:00 2001 From: markus583 Date: Sat, 23 Dec 2023 11:04:06 +0000 Subject: [PATCH] fix tokenization?, eval during training --- configs/xlmr_stratify_0.1_3layers.json | 4 +- wtpsplit/configs.py | 2 + wtpsplit/evaluation/intrinsic.py | 7 +- wtpsplit/extract.py | 92 +++++++++++++++------- wtpsplit/models.py | 7 +- wtpsplit/train/evaluate.py | 45 ++++++++++- wtpsplit/train/train.py | 101 ++++++++++++++++++++----- wtpsplit/utils.py | 27 ++++++- 8 files changed, 224 insertions(+), 61 deletions(-) diff --git a/configs/xlmr_stratify_0.1_3layers.json b/configs/xlmr_stratify_0.1_3layers.json index cbece18f..fc40f4c0 100644 --- a/configs/xlmr_stratify_0.1_3layers.json +++ b/configs/xlmr_stratify_0.1_3layers.json @@ -12,14 +12,14 @@ "per_device_eval_batch_size": 32, "gradient_accumulation_steps": 1, "eval_accumulation_steps": 8, - "dataloader_num_workers": 32, + "dataloader_num_workers": 4, "preprocessing_num_workers": 32, "learning_rate": 1e-4, "save_strategy": "steps", "fp16": false, "max_steps": 2000000, "save_steps": 100000, - "eval_steps": 50000000000, + "eval_steps": 5000, "logging_steps": 50, "report_to": "wandb", "is_decoder": false, diff --git a/wtpsplit/configs.py b/wtpsplit/configs.py index 656c5a38..6b44d512 100644 --- a/wtpsplit/configs.py +++ b/wtpsplit/configs.py @@ -45,12 +45,14 @@ class SubwordXLMConfig(XLMRobertaConfig): XLMRobertaConfig: Base class. """ model_type = "xlm-token" + mixture_name = "xlm-token" def __init__( self, **kwargs, ): super().__init__(**kwargs) + self.mixture_name = "xlm-token" AutoConfig.register("bert-char", BertCharConfig) diff --git a/wtpsplit/evaluation/intrinsic.py b/wtpsplit/evaluation/intrinsic.py index f4a6aab1..3f9ce8c5 100644 --- a/wtpsplit/evaluation/intrinsic.py +++ b/wtpsplit/evaluation/intrinsic.py @@ -34,7 +34,7 @@ class Args: # } eval_data_path: str = "data/eval.pth" valid_text_path: str = None#"data/sentence/valid.parquet" - device: str = "xla:1" + device: str = "cpu" block_size: int = 512 stride: int = 64 batch_size: int = 32 @@ -44,7 +44,8 @@ class Args: def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_sentences=10_000): logits_path = Constants.CACHE_DIR / (model.config.mixture_name + "_logits.h5") - with h5py.File(logits_path, "a") as f, torch.no_grad(): + # TODO: revert to "a" + with h5py.File(logits_path, "w") as f, torch.no_grad(): for lang_code in Constants.LANGINFO.index: if args.include_langs is not None and lang_code not in args.include_langs: continue @@ -152,6 +153,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_ if "train_logits" in f[lang_code][dataset_name]: feature_indices = None + # TODO: tokenize here clf = train_mixture( [lang_code], f[lang_code][dataset_name]["train_logits"][:], @@ -159,6 +161,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_ features=feature_indices, ) + # TODO: tokenize here, too score_t, score_punct, _ = evaluate_mixture( lang_code, f[lang_code][dataset_name]["test_logits"][:], diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 15339b2a..5ee39ede 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -3,6 +3,8 @@ import numpy as np from tqdm.auto import tqdm +from transformers import AutoTokenizer +from tokenizers import AddedToken from wtpsplit.utils import Constants, hash_encode @@ -37,7 +39,7 @@ def __getattr__(self, name): assert hasattr(self, "model") return getattr(self.model, name) - def __call__(self, hashed_ids, attention_mask, language_ids=None): + def __call__(self, input_ids, hashed_ids, attention_mask, language_ids=None): try: import torch except ImportError: @@ -46,7 +48,8 @@ def __call__(self, hashed_ids, attention_mask, language_ids=None): with torch.no_grad(): logits = ( self.model( - hashed_ids=torch.from_numpy(hashed_ids).to(self.model.device), + input_ids=torch.from_numpy(input_ids).to(self.model.device) if input_ids is not None else None, + hashed_ids=torch.from_numpy(hashed_ids).to(self.model.device) if hashed_ids is not None else None, attention_mask=torch.from_numpy(attention_mask).to(self.model.device), language_ids=torch.from_numpy(language_ids).to(self.model.device) if language_ids is not None @@ -76,6 +79,20 @@ def extract( ad 1.: text is sliced into partially overlapping chunks by moving forward by a `stride` parameter (think conv1d). """ + if "xlm" in model.config.model_type: + use_subwords = True + tokenizer = AutoTokenizer.from_pretrained(model.config.base_model) + tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) + tokens = tokenizer(batch_of_texts, return_offsets_mapping=True) + # remove CLS and SEP tokens, they are added later anyhow + batch_of_texts = [text[1:-1] for text in tokens["input_ids"]] + offset_mapping = [offset[1:-1] for offset in tokens["offset_mapping"]] + cls_token_id = tokenizer.cls_token_id + sep_token_id = tokenizer.sep_token_id + pad_token_id = tokenizer.pad_token_id + else: + pad_token_id = 0 + use_subwords = False text_lengths = [len(text) for text in batch_of_texts] # reduce block size if possible @@ -84,44 +101,56 @@ def extract( # make sure block_size is a multiple of downsampling rate downsampling_rate = getattr(model.config, "downsampling_rate", 1) block_size = math.ceil(block_size / downsampling_rate) * downsampling_rate + actual_block_size = block_size - 2 if use_subwords else block_size # account for CLS and SEP tokens # total number of forward passes - num_chunks = sum(math.ceil(max(length - block_size, 0) / stride) + 1 for length in text_lengths) + num_chunks = sum(math.ceil(max(length - actual_block_size, 0) / stride) + 1 for length in text_lengths) # preallocate a buffer for all input hashes & attention masks - input_hashes = np.zeros((num_chunks, block_size, model.config.num_hash_functions), dtype=np.int64) + if not use_subwords: + input_hashes = np.zeros((num_chunks, block_size, model.config.num_hash_functions), dtype=np.int64) + else: + input_ids = np.zeros((num_chunks, block_size), dtype=np.int64) attention_mask = np.zeros((num_chunks, block_size), dtype=np.float32) # locs keep track of the location of every chunk with a 3-tuple (text_idx, char_start, char_end) that indexes # back into the batch_of_texts locs = np.zeros((num_chunks, 3), dtype=np.int32) - # this is equivalent to (but faster than) np.array([ord(c) for c in "".join(batch_of_texts)]) - codec = "utf-32-le" if sys.byteorder == "little" else "utf-32-be" - ordinals = np.frombuffer(bytearray("".join(batch_of_texts), encoding=codec), dtype=np.int32) - - # hash encode all ids - flat_hashed_ids = hash_encode(ordinals, - num_hashes=model.config.num_hash_functions, - num_buckets=model.config.num_hash_buckets) + if not use_subwords: + # this is equivalent to (but faster than) np.array([ord(c) for c in "".join(batch_of_texts)]) + codec = "utf-32-le" if sys.byteorder == "little" else "utf-32-be" + ordinals = np.frombuffer(bytearray("".join(batch_of_texts), encoding=codec), dtype=np.int32) + # hash encode all ids + flat_hashed_ids = hash_encode(ordinals, + num_hashes=model.config.num_hash_functions, + num_buckets=model.config.num_hash_buckets) + # note that ordinals and flat_hashed_ids have the same length offset = 0 current_chunk = 0 - + + + # create chunks for i in range(len(batch_of_texts)): for j in range(0, text_lengths[i], stride): # for every chunk, assign input hashes, attention mask and loc - start, end = j, j + block_size + start, end = j, j + actual_block_size done = False if end >= text_lengths[i]: end = text_lengths[i] - start = max(end - block_size, 0) + start = max(end - actual_block_size, 0) done = True - input_hashes[current_chunk, : end - start] = flat_hashed_ids[offset + start : offset + end] - attention_mask[current_chunk, : end - start] = 1 + if not use_subwords: + input_hashes[current_chunk, : end - start] = flat_hashed_ids[offset + start : offset + end] + attention_mask[current_chunk, : end - start] = 1 + else: + chunk = [cls_token_id] + batch_of_texts[i][start:end] + [sep_token_id] + input_ids[current_chunk, :len(chunk)] = chunk + attention_mask[current_chunk, :len(chunk)] = 1 + locs[current_chunk, :] = [i, start, end] - current_chunk += 1 if done: @@ -130,7 +159,7 @@ def extract( offset += text_lengths[i] assert current_chunk == num_chunks - n_batches = math.ceil(len(input_hashes) / batch_size) + n_batches = math.ceil(len(attention_mask) / batch_size) # containers for the final logits all_logits = [ @@ -163,21 +192,30 @@ def extract( # forward passes through all chunks for batch_idx in tqdm(range(n_batches), disable=not verbose): - start, end = batch_idx * batch_size, min(len(input_hashes), (batch_idx + 1) * batch_size) + start, end = batch_idx * batch_size, min(len(attention_mask), (batch_idx + 1) * batch_size) - batch_input_hashes = input_hashes[start:end] + if not use_subwords: + batch_input_hashes = input_hashes[start:end] + else: + batch_input_ids = input_ids[start:end] batch_attention_mask = attention_mask[start:end] - if len(batch_input_hashes) < batch_size and pad_last_batch: - n_missing = batch_size - len(batch_input_hashes) + if len(batch_attention_mask) < batch_size and pad_last_batch: + n_missing = batch_size - len(batch_attention_mask) - batch_input_hashes = np.pad(batch_input_hashes, ((0, n_missing), (0, 0), (0, 0))) + if not use_subwords: + batch_input_hashes = np.pad(batch_input_hashes, ((0, n_missing), (0, 0), (0, 0))) + else: + # Pad with the specific pad_token_id for the tokenizer + batch_input_ids = np.pad(batch_input_ids, ((0, n_missing), (0, 0)), constant_values=pad_token_id) batch_attention_mask = np.pad(batch_attention_mask, ((0, n_missing), (0, 0))) + - kwargs = {"language_ids": language_ids[: len(batch_input_hashes)]} if uses_lang_adapters else {} + kwargs = {"language_ids": language_ids[: len(batch_attention_mask)]} if uses_lang_adapters else {} logits = model( - hashed_ids=batch_input_hashes, + input_ids=batch_input_ids if use_subwords else None, + hashed_ids=batch_input_hashes if not use_subwords else None, attention_mask=batch_attention_mask, **kwargs, )["logits"] @@ -190,4 +228,4 @@ def extract( # so far, logits are summed, so we average them here all_logits = [(logits / counts[:, None]).astype(np.float16) for logits, counts in zip(all_logits, all_counts)] - return all_logits + return all_logits, offset_mapping if use_subwords else None, tokenizer if use_subwords else None diff --git a/wtpsplit/models.py b/wtpsplit/models.py index cfbaa57f..ee731b21 100644 --- a/wtpsplit/models.py +++ b/wtpsplit/models.py @@ -989,8 +989,9 @@ def forward( labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + hashed_ids: Optional[torch.Tensor] = None, language_ids=None, + return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: return super().forward( input_ids, @@ -1030,7 +1031,7 @@ def forward( text = "This is a test\n sentence \n\n" tokenizer = AutoTokenizer.from_pretrained(model_str) - tokens = tokenizer(text, return_tensors="pt") + tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False) from tokenizers import AddedToken tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) print(tokenizer.tokenize(text)) @@ -1038,4 +1039,4 @@ def forward( print(tokens) # forward pass print(backbone(**tokens)) - + \ No newline at end of file diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index 4fd6bbe4..0f2e07cd 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -47,6 +47,33 @@ def get_metrics(labels, preds): return metrics, info +def get_token_spans(tokenizer: object, offsets_mapping: list, tokens: list): + token_spans = [] + for idx, token in enumerate(tokens): + # Skip special tokens like [CLS], [SEP] + if idx >= len(offsets_mapping): + continue + if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: + continue + + char_start, char_end = offsets_mapping[idx] + token_spans.append((char_start, char_end)) + + return token_spans + +def token_to_char_probs(text: str, tokens: list, token_probs: np.ndarray, tokenizer, offsets_mapping): + char_probs = np.zeros(len(text)) + token_spans = get_token_spans(tokenizer, offsets_mapping, tokens) + + for (start, end), prob in zip(token_spans, token_probs): + # assign the token's prob to the last char of the token + # Ensure the end index does not exceed the length of the text + if end >= len(text): + print(f"Adjusting end index from {end} to {len(text)} for token '{text[start:end]}'") + end = len(text) - 1 + char_probs[end] = prob + + return char_probs def evaluate_sentence( lang_code, @@ -67,7 +94,7 @@ def evaluate_sentence( separator = Constants.SEPARATORS[lang_code] text = separator.join(sentences) - logits = extract( + logits, offsets_mapping, tokenizer = extract( [text], PyTorchWrapper(model.backbone), lang_code=lang_code, @@ -75,13 +102,22 @@ def evaluate_sentence( block_size=block_size, batch_size=batch_size, verbose=True, - )[0] + ) + logits, offsets_mapping = logits[0], offsets_mapping[0] true_end_indices = np.cumsum(np.array([len(s) for s in sentences])) + np.arange(len(sentences)) * len(separator) newline_labels = np.zeros(len(text)) newline_labels[true_end_indices - 1] = 1 - - metrics, info = get_metrics(newline_labels, logits[:, positive_index]) + + print("newline_labels", newline_labels.shape) + + if "xlm" in model.config.model_type: + tokens = tokenizer.tokenize(text) + char_probs = token_to_char_probs(text, tokens, logits[:, positive_index], tokenizer, offsets_mapping) + else: + char_probs = logits[:, positive_index] + print("char probs", char_probs.shape) + metrics, info = get_metrics(newline_labels, char_probs) info["newline_labels"] = newline_labels @@ -94,3 +130,4 @@ def evaluate_sentence( info["newline_probs_pysbd"] = newline_probs_pysbd return metrics["pr_auc"], info + diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index b70c3d0f..7240be5b 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -30,8 +30,6 @@ from wtpsplit.train.trainer import Trainer from wtpsplit.utils import Constants, LabelArgs, corrupt, get_label_dict, get_subword_label_dict -# TODO: use seed from training args (also add default value) - # TODO: set logger (see ScaLearn?) # TODO: double-check checkpointing and saving (also to txt) @@ -209,18 +207,63 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): min_length=args.block_size, tokenizer=tokenizer if args.use_subwords else None, ) + + if input_ids[0] != tokenizer.cls_token_id: + print(input_ids) + print(len(input_ids)) + print(tokenizer.cls_token_id) + raise ValueError("CLS token not first token") + if input_ids[-1] != tokenizer.sep_token_id: + print(input_ids) + print(len(input_ids)) + print(tokenizer.sep_token_id) + raise ValueError("SEP token not last token") if len(input_ids) > args.block_size: - start = np.random.randint(0, len(input_ids) - args.block_size) - input_ids = input_ids[start : start + args.block_size] - labels = labels[start : start + args.block_size] + if tokenizer: + # always include CLS + start = np.random.randint(0, len(input_ids) - args.block_size) + if start != 0: + # this removes the CLS token + # -1 removes the SEP token, for sure + input_ids = [tokenizer.cls_token_id] + input_ids[start : start + args.block_size - 2] + labels = [0] + labels[start : start + args.block_size - 2] + else: + input_ids = input_ids[start : start + args.block_size - 1] + labels = labels[start : start + args.block_size - 1] + # always include SEP + if input_ids[-1] != tokenizer.sep_token_id: + input_ids = input_ids + [tokenizer.sep_token_id] + labels = labels + [0] + else: + start = np.random.randint(0, len(input_ids) - args.block_size) + input_ids = input_ids[start : start + args.block_size] + labels = labels[start : start + args.block_size] input_ids = torch.tensor(input_ids[: args.block_size], dtype=torch.long) labels = torch.tensor(labels[: args.block_size], dtype=torch.long) + if input_ids[-1] != tokenizer.sep_token_id: + print(input_ids) + print(tokenizer.sep_token_id) + print(labels) + raise ValueError("SEP token not last token") + if input_ids[0] != tokenizer.cls_token_id: + print(input_ids) + print(tokenizer.cls_token_id) + print(labels) + raise ValueError("CLS token not first token") + if (input_ids == tokenizer.cls_token_id).sum() != 1: + print(input_ids) + print(tokenizer.cls_token_id) + print(labels) + raise ValueError("CLS token not unique") position_ids = torch.arange(len(input_ids), dtype=torch.long) label_weights = torch.ones(args.block_size, dtype=torch.float32) - attention_mask = (input_ids != 0).to(torch.float32) + if tokenizer: + attention_mask = (input_ids != tokenizer.pad_token_id).to(torch.float32) + else: + attention_mask = (input_ids != 0).to(torch.float32) all_input_ids.append(input_ids) all_label_weights.append(label_weights) @@ -233,7 +276,7 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): out = { "input_ids": torch.stack(all_input_ids, 0), "attention_mask": torch.stack(all_attention_masks, 0), - "position_ids": torch.stack(all_position_ids, 0), + "position_ids": torch.stack(all_position_ids, 0) if not args.use_subwords else None, # safer "language_ids": torch.tensor(all_language_ids, dtype=torch.long), "label_weights": torch.stack(all_label_weights, 0), "labels": torch.stack(all_labels, 0), @@ -251,6 +294,8 @@ def main(): else: (args, training_args, label_args) = parser.parse_args_into_dataclasses() wandb_name = None + + set_seed(training_args.seed) num_labels = Constants.AUX_OFFSET + ((1 + len(Constants.PUNCTUATION_CHARS)) if args.do_auxiliary_training else 0) if args.use_subwords: @@ -261,6 +306,7 @@ def main(): num_labels=num_labels, ) backbone = SubwordXLMForTokenClassification(config) + else: config = SubwordXLMConfig.from_pretrained( args.model_name_or_path, @@ -269,6 +315,7 @@ def main(): ) backbone = SubwordXLMForTokenClassification(config) + backbone.config.base_model = args.model_name_or_path tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) @@ -309,7 +356,9 @@ def main(): ) with training_args.main_process_first(): - print(summary(model, depth=10)) + print(summary(model, depth=4)) + # also save base model + # backbone.push_to_hub("markus583/xlm-token-untrained", private=True) def prepare_dataset( num_workers=1, @@ -388,13 +437,14 @@ def drop_some_non_punctuation_samples(examples): batch_size=1_000_000, num_proc=num_workers, ) + def tokenize_texts(examples): - # TODO: before, we used use_special_tokens=False --> check effect! - tokenized = tokenizer(examples[args.text_column]) - # also add tokenized tokens in str format - # TODO: only use input_ids, no double conversion - return {"input_ids": tokenized["input_ids"]} + # do not return CLS and SEP token here + # there should only be 1 of these per block later, not multiple + # we still can't use return_special_tokens=False since we need the \n token later for the labels + tokenized = tokenizer(examples[args.text_column], verbose=False) + return {"input_ids": [example[1:-1] for example in tokenized["input_ids"]]} # similar to group_texts in huggingface's run_clm.py / run_mlm.py: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py def group_texts(examples): @@ -516,15 +566,22 @@ def maybe_pad(text): num_proc=num_workers, remove_columns=[args.text_column], ) + + if split == "train": + with training_args.main_process_first(): + for root, dirs, files in os.walk(os.environ.get("HF_DATASETS_CACHE")): + for file in files: + if file.startswith("m_c4-test-train"): + print(f"Removing {os.path.join(root, file)}") + os.remove(os.path.join(root, file)) if not args.one_sample_per_line: with training_args.main_process_first(): dataset = dataset.map( group_texts, batched=True, - num_proc=num_workers, + num_proc=1, # a bit hacky but oh well, only drop if sentence - # TODO: clean this remove_columns=["ends_with_punctuation"] if args.text_column == "text" else [], @@ -532,22 +589,23 @@ def maybe_pad(text): return dataset - train_dataset = prepare_dataset( + valid_dataset = prepare_dataset( num_workers=args.preprocessing_num_workers, include_languages=args.include_languages, - shuffle=args.shuffle, - split="train", + shuffle=False, + split="valid", ) - valid_dataset = prepare_dataset( + train_dataset = prepare_dataset( num_workers=args.preprocessing_num_workers, include_languages=args.include_languages, - shuffle=False, + shuffle=args.shuffle, split="valid", ) # print some samples from the dataset - for index in random.sample(range(len(train_dataset)), 10): + for index in random.sample(range(len(train_dataset)), 5): print(f"Sample {index} of the training set: {train_dataset[index]}.") + print(tokenizer.decode(train_dataset[index]["input_ids"])) print() # dataset we use is in cached now @@ -604,6 +662,7 @@ def compute_metrics(trainer): for file in glob(os.path.join(os.path.dirname(__file__), "*.py")): wandb.save(os.path.abspath(file), policy="now") + # TODO: check tokenized mapping; UNKs? label_dict = get_subword_label_dict(label_args, tokenizer) if args.use_subwords else get_label_dict(label_args) # needed in the trainer diff --git a/wtpsplit/utils.py b/wtpsplit/utils.py index cd2982c9..3650b090 100644 --- a/wtpsplit/utils.py +++ b/wtpsplit/utils.py @@ -79,17 +79,23 @@ def get_label_dict(label_args): def get_subword_label_dict(label_args, tokenizer): label_dict = {} + n_unks = 0 # Map auxiliary characters to token IDs with labels for i, c in enumerate(label_args.auxiliary_chars): token_id = tokenizer.convert_tokens_to_ids(c) label_dict[token_id] = 1 + Constants.AUX_OFFSET + i - print(f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}") + print(f"auxiliary character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded: {tokenizer.decode([token_id])}") + if token_id == tokenizer.unk_token_id: + n_unks += 1 + + print(f"found {n_unks} UNK tokens in auxiliary characters") # Map newline characters to token IDs with labels for c in label_args.newline_chars: token_id = tokenizer.convert_tokens_to_ids(c) label_dict[token_id] = 1 + Constants.NEWLINE_INDEX - print(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}") + print(f"newline character {c} has token ID {token_id} and label {label_dict[token_id]}, decoded:") + print(r"{}".format(tokenizer.decode([token_id]))) return label_dict @@ -152,8 +158,17 @@ def corrupt( try: i = next(index for index, label in enumerate(labels) if label != 0) except StopIteration: + if tokenizer is not None: + input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id] + # Extend block_ids for the added CLS and SEP tokens + block_ids = [block_ids[0]] + block_ids + [block_ids[-1]] + # labels for CLS and SEP tokens are 0 (none) + labels = [0] + labels + [0] return input_ids, block_ids, labels + if tokenizer: + # account for CLS and SEP token, added later + min_length = min_length - 2 if min_length is not None else None while True: if min_length is not None and len(input_ids) <= min_length: break @@ -193,6 +208,14 @@ def corrupt( except StopIteration: break + # Add CLS and SEP tokens after the corruption process + if tokenizer is not None: + input_ids = [tokenizer.cls_token_id] + input_ids + [tokenizer.sep_token_id] + # Extend block_ids for the added CLS and SEP tokens + block_ids = [block_ids[0]] + block_ids + [block_ids[-1]] + # labels for CLS and SEP tokens are 0 (none) + labels = [0] + labels + [0] + return input_ids, block_ids, labels