From a2259bd5522106fb848d07964062aabd0cbba38a Mon Sep 17 00:00:00 2001 From: Luca Weihs Date: Mon, 25 Nov 2024 09:51:51 -0800 Subject: [PATCH] Allowing dataloader prefetching to work with iterable datasets. --- src/transformers/trainer.py | 147 +++++++++++++++++----------------- tests/trainer/test_trainer.py | 94 ++++++++++++++++++++++ 2 files changed, 167 insertions(+), 74 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3fd067edfc5b06..3b42efe8d9252d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -928,16 +928,53 @@ def _get_collator_with_removed_columns( ) return remove_columns_collator - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.train_dataset is None or not has_length(self.train_dataset): + def _get_dataloader( + self, + dataset: Dataset, + get_sampler_func: Callable[[Dataset], Optional[torch.utils.data.Sampler]], + batch_size: int, + description: str, + seed_workers: bool, + ) -> DataLoader: + data_collator = self.data_collator + + if is_datasets_available() and isinstance(dataset, datasets.Dataset): + dataset = self._remove_unused_columns(dataset, description=description) + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description=description) + + dataloader_params = { + "dataset": dataset, + "batch_size": batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + "prefetch_factor": self.args.dataloader_prefetch_factor, + } + + if seed_workers: + dataloader_params["worker_init_fn"] = seed_worker + + if not isinstance(dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = get_sampler_func(dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + # accelerator.free_memory() will destroy the references, so + # we need to return the unprepared version here in case someone + # needs them. We prepare in the get_{train/eval/test}_dataloader functions. + return DataLoader(**dataloader_params) + + def _get_train_sampler(self, train_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: + if train_dataset is None or not has_length(train_dataset): return None # Build the sampler. if self.args.group_by_length: - if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): lengths = ( - self.train_dataset[self.args.length_column_name] - if self.args.length_column_name in self.train_dataset.column_names + train_dataset[self.args.length_column_name] + if self.args.length_column_name in train_dataset.column_names else None ) else: @@ -947,13 +984,13 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: ) return LengthGroupedSampler( self.args.train_batch_size * self.args.gradient_accumulation_steps, - dataset=self.train_dataset, + dataset=train_dataset, lengths=lengths, model_input_name=model_input_name, ) else: - return RandomSampler(self.train_dataset) + return RandomSampler(train_dataset) def get_train_dataloader(self) -> DataLoader: """ @@ -967,28 +1004,15 @@ def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") - train_dataset = self.train_dataset - data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns(train_dataset, description="training") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - - dataloader_params = { - "batch_size": self._train_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - - if not isinstance(train_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_train_sampler() - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + return self.accelerator.prepare( + self._get_dataloader( + dataset=self.train_dataset, + get_sampler_func=self._get_train_sampler, + batch_size=self._train_batch_size, + description="training", + seed_workers=not isinstance(self.train_dataset, torch.utils.data.IterableDataset), + ) + ) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: if eval_dataset is None or not has_length(eval_dataset): @@ -1063,34 +1087,22 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None if eval_dataset is not None else self.eval_dataset ) - data_collator = self.data_collator - - if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): - eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") - - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - if not isinstance(eval_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + eval_dataloader = self._get_dataloader( + dataset=eval_dataset, + get_sampler_func=self._get_eval_sampler, + batch_size=self.args.eval_batch_size, + description="evaluation", + seed_workers=False, + ) # accelerator.free_memory() will destroy the references, so # we need to store the non-prepared version - eval_dataloader = DataLoader(eval_dataset, **dataloader_params) if self.args.dataloader_persistent_workers: - if hasattr(self, "_eval_dataloaders"): - self._eval_dataloaders[dataloader_key] = eval_dataloader - else: - self._eval_dataloaders = {dataloader_key: eval_dataloader} + if not hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders = {} + + self._eval_dataloaders[dataloader_key] = eval_dataloader return self.accelerator.prepare(eval_dataloader) @@ -1105,28 +1117,15 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement `__len__`. """ - data_collator = self.data_collator - - if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): - test_dataset = self._remove_unused_columns(test_dataset, description="test") - else: - data_collator = self._get_collator_with_removed_columns(data_collator, description="test") - - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - "persistent_workers": self.args.dataloader_persistent_workers, - } - - if not isinstance(test_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_eval_sampler(test_dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - # We use the same batch_size as for eval. - return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) + return self.accelerator.prepare( + self._get_dataloader( + dataset=test_dataset, + get_sampler_func=self._get_eval_sampler, + batch_size=self.args.eval_batch_size, + description="test", + seed_workers=False, + ) + ) def create_optimizer_and_scheduler(self, num_training_steps: int): """ diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 5658372fa71308..735f876aa8c50d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -18,12 +18,15 @@ import importlib import json import math +import multiprocessing import os +import queue import random import re import subprocess import sys import tempfile +import time import unittest from functools import partial from itertools import product @@ -342,6 +345,20 @@ def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs): self.hidden_size = 1 +class QueueIterableDataset(IterableDataset): + def __init__(self, q): + self.q = q + + def __iter__(self): + while True: + try: + item = self.q.get_nowait() + print(item) + yield {"label_ids": item} + except queue.Empty: + break + + if is_torch_available(): class SampleIterableDataset(IterableDataset): @@ -1571,6 +1588,83 @@ def test_get_eval_dataloader_with_persistent_workers(self): self.assertEqual(first_dataloader, first_dataloader_repeated) self.assertEqual(second_dataloader, second_dataloader_repeated) + def test_dataloader_prefetch_iterable_occurs(self): + """Test that prefetching works correctly with iterable datasets.""" + + # Create a multiprocessing queue and fill it with 5 elements, + # 4 of these should be fetched ahead of time after getting the first batch + q = multiprocessing.Queue() + for i in range(5): + q.put(torch.tensor([i])) + + # Create dataset and model + dataset = QueueIterableDataset(q) + model = RegressionModel() + + # Configure trainer with 2 workers and prefetch_factor=1 + training_args = TrainingArguments( + output_dir="./test-prefetch", + per_device_train_batch_size=1, + dataloader_num_workers=2, + dataloader_prefetch_factor=1, + max_steps=4, + report_to="none", + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset, + ) + + # Get the dataloader and fetch one batch + dataloader = trainer.get_train_dataloader() + next(iter(dataloader)) + time.sleep(1) # Wait for workers to prefetch the next batch + + # The queue should have 1 remaining element after the first batch is fetched + self.assertTrue(q.get_nowait().item() == 4) + with self.assertRaises(queue.Empty): + q.get_nowait() + + def test_dataloader_prefetch_iterable_will_exhaust_dataset(self): + """ + The same as test_dataloader_prefetch_iterable_occurs but now + we confirm that setting a higher prefetch factor will result in the + queue dataset being empty after the first batch is fetched. + """ + q = multiprocessing.Queue() + for i in range(6): # 6 elements rather than 5 + q.put(torch.tensor([i])) + + # Create dataset and model + dataset = QueueIterableDataset(q) + model = RegressionModel() + + training_args = TrainingArguments( + output_dir="./test-prefetch", + per_device_train_batch_size=1, + dataloader_num_workers=2, + dataloader_prefetch_factor=3, # 3 rather than 1 above + max_steps=4, + report_to="none", + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=dataset, + ) + + # Get the dataloader and fetch one batch + dataloader = trainer.get_train_dataloader() + next(iter(dataloader)) + time.sleep(1) # Wait for workers to prefetch the next batch + + # The queue should now be empty since the workers prefetched all remaining elements + with self.assertRaises(queue.Empty): + q.get_nowait() + @require_liger_kernel def test_use_liger_kernel_patching(self): # Ensure any monkey patching is cleaned up for subsequent tests