From cdb8743937fccf99521313fa6f9a233fb6e1fc6b Mon Sep 17 00:00:00 2001 From: markus583 Date: Thu, 21 Dec 2023 16:55:11 +0000 Subject: [PATCH] directly use input_ids --- configs/xlmr_stratify_0.1_3layers.json | 2 +- requirements.txt | 3 ++- wtpsplit/train/train.py | 23 +++++++++++------------ 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/configs/xlmr_stratify_0.1_3layers.json b/configs/xlmr_stratify_0.1_3layers.json index 07c0bd98..cbece18f 100644 --- a/configs/xlmr_stratify_0.1_3layers.json +++ b/configs/xlmr_stratify_0.1_3layers.json @@ -13,7 +13,7 @@ "gradient_accumulation_steps": 1, "eval_accumulation_steps": 8, "dataloader_num_workers": 32, - "preprocessing_num_workers": 6, + "preprocessing_num_workers": 32, "learning_rate": 1e-4, "save_strategy": "steps", "fp16": false, diff --git a/requirements.txt b/requirements.txt index 1faca977..09ba6b3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ ersatz iso-639 scikit-learn==1.2.2 numpy==1.23.5 -pydantic \ No newline at end of file +pydantic +torchinfo \ No newline at end of file diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index b8285cb3..b70c3d0f 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -188,10 +188,10 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer): for sample in batch: # subword-level if args.use_subwords: - input_ids = tokenizer.convert_tokens_to_ids(sample[args.text_column]) + input_ids = sample["input_ids"] # char-level else: - input_ids = [ord(c) for c in sample[args.text_column]] + input_ids = [ord(c) for c in sample["input_ids"]] lang = sample["lang"] while len(input_ids) < args.block_size + args.overflow_size: @@ -309,7 +309,7 @@ def main(): ) with training_args.main_process_first(): - print(summary(model, depth=3)) + print(summary(model, depth=10)) def prepare_dataset( num_workers=1, @@ -390,11 +390,11 @@ def drop_some_non_punctuation_samples(examples): ) def tokenize_texts(examples): - tokenized = tokenizer(examples[args.text_column], add_special_tokens=False) + # TODO: before, we used use_special_tokens=False --> check effect! + tokenized = tokenizer(examples[args.text_column]) # also add tokenized tokens in str format - tokenized["tokenized_text"] = [tokenizer.convert_ids_to_tokens(ids) for ids in tokenized["input_ids"]] - - return tokenized + # TODO: only use input_ids, no double conversion + return {"input_ids": tokenized["input_ids"]} # similar to group_texts in huggingface's run_clm.py / run_mlm.py: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py def group_texts(examples): @@ -423,7 +423,7 @@ def maybe_pad(text): # only retain current_lang examples (all columns) lang_subwords = [ subwords - for subwords, lang in zip(examples["tokenized_text"], examples["lang"]) + for subwords, lang in zip(examples["input_ids"], examples["lang"]) if lang == current_lang ] @@ -497,7 +497,7 @@ def maybe_pad(text): all_langs.extend(block_langs) return { - args.text_column: all_input_blocks, + "input_ids": all_input_blocks, "block_lengths": all_input_block_lengths, "lang": all_langs, } @@ -514,19 +514,18 @@ def maybe_pad(text): tokenize_texts, batched=True, num_proc=num_workers, - # remove_columns=[args.text_column], + remove_columns=[args.text_column], ) if not args.one_sample_per_line: with training_args.main_process_first(): - # drop columns input_ids, attentionm dataset = dataset.map( group_texts, batched=True, num_proc=num_workers, # a bit hacky but oh well, only drop if sentence # TODO: clean this - remove_columns=["ends_with_punctuation", "input_ids", "attention_mask", "tokenized_text"] + remove_columns=["ends_with_punctuation"] if args.text_column == "text" else [], )