diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 93e19d60..88c591e1 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -412,7 +412,7 @@ def maybe_pad(text): if not args.use_subwords: lang_texts = [ maybe_pad(text) - for text, lang in zip(examples[args.text_column], examples["lang"]) + for text, lang in zip(examples["input_ids"], examples["lang"]) if lang == current_lang ] else: @@ -520,7 +520,7 @@ def maybe_pad(text): else: # this is no longer used and would cause an error otherwise with training_args.main_process_first(): - dataset = dataset.remove_columns([args.text_column]) + dataset = dataset.rename_column(args.text_column, "input_ids") logger.warning(f"Tokenized {split} dataset.") if split == "train":