Skip to content

Commit

Permalink
directly use input_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 21, 2023
1 parent 775d4a6 commit cdb8743
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion configs/xlmr_stratify_0.1_3layers.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"gradient_accumulation_steps": 1,
"eval_accumulation_steps": 8,
"dataloader_num_workers": 32,
"preprocessing_num_workers": 6,
"preprocessing_num_workers": 32,
"learning_rate": 1e-4,
"save_strategy": "steps",
"fp16": false,
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ ersatz
iso-639
scikit-learn==1.2.2
numpy==1.23.5
pydantic
pydantic
torchinfo
23 changes: 11 additions & 12 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ def collate_fn(batch, args, label_args, label_dict, tokenizer):
for sample in batch:
# subword-level
if args.use_subwords:
input_ids = tokenizer.convert_tokens_to_ids(sample[args.text_column])
input_ids = sample["input_ids"]
# char-level
else:
input_ids = [ord(c) for c in sample[args.text_column]]
input_ids = [ord(c) for c in sample["input_ids"]]
lang = sample["lang"]

while len(input_ids) < args.block_size + args.overflow_size:
Expand Down Expand Up @@ -309,7 +309,7 @@ def main():
)

with training_args.main_process_first():
print(summary(model, depth=3))
print(summary(model, depth=10))

def prepare_dataset(
num_workers=1,
Expand Down Expand Up @@ -390,11 +390,11 @@ def drop_some_non_punctuation_samples(examples):
)

def tokenize_texts(examples):
tokenized = tokenizer(examples[args.text_column], add_special_tokens=False)
# TODO: before, we used use_special_tokens=False --> check effect!
tokenized = tokenizer(examples[args.text_column])
# also add tokenized tokens in str format
tokenized["tokenized_text"] = [tokenizer.convert_ids_to_tokens(ids) for ids in tokenized["input_ids"]]

return tokenized
# TODO: only use input_ids, no double conversion
return {"input_ids": tokenized["input_ids"]}

# similar to group_texts in huggingface's run_clm.py / run_mlm.py: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py
def group_texts(examples):
Expand Down Expand Up @@ -423,7 +423,7 @@ def maybe_pad(text):
# only retain current_lang examples (all columns)
lang_subwords = [
subwords
for subwords, lang in zip(examples["tokenized_text"], examples["lang"])
for subwords, lang in zip(examples["input_ids"], examples["lang"])
if lang == current_lang
]

Expand Down Expand Up @@ -497,7 +497,7 @@ def maybe_pad(text):
all_langs.extend(block_langs)

return {
args.text_column: all_input_blocks,
"input_ids": all_input_blocks,
"block_lengths": all_input_block_lengths,
"lang": all_langs,
}
Expand All @@ -514,19 +514,18 @@ def maybe_pad(text):
tokenize_texts,
batched=True,
num_proc=num_workers,
# remove_columns=[args.text_column],
remove_columns=[args.text_column],
)

if not args.one_sample_per_line:
with training_args.main_process_first():
# drop columns input_ids, attentionm
dataset = dataset.map(
group_texts,
batched=True,
num_proc=num_workers,
# a bit hacky but oh well, only drop if sentence
# TODO: clean this
remove_columns=["ends_with_punctuation", "input_ids", "attention_mask", "tokenized_text"]
remove_columns=["ends_with_punctuation"]
if args.text_column == "text"
else [],
)
Expand Down

0 comments on commit cdb8743

Please sign in to comment.