diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 938c5db4..bac5ecfc 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -641,11 +641,12 @@ def maybe_pad(text): split="valid", ) logger.info(f"Valid dataset has {len(valid_dataset)} examples.") + train_dataset = prepare_dataset( num_workers=args.preprocessing_num_workers, include_languages=args.include_languages, shuffle=args.shuffle, - split="valid", + split="train", ) logger.info(f"Train dataset has {len(train_dataset)} examples.")