Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing dataloader prefetching to work with iterable datasets. #34925

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 73 additions & 74 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,16 +928,53 @@ def _get_collator_with_removed_columns(
)
return remove_columns_collator

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
def _get_dataloader(
self,
dataset: Dataset,
get_sampler_func: Callable[[Dataset], Optional[torch.utils.data.Sampler]],
batch_size: int,
description: str,
seed_workers: bool,
) -> DataLoader:
data_collator = self.data_collator

if is_datasets_available() and isinstance(dataset, datasets.Dataset):
dataset = self._remove_unused_columns(dataset, description=description)
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description=description)

dataloader_params = {
"dataset": dataset,
"batch_size": batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"prefetch_factor": self.args.dataloader_prefetch_factor,
}

if seed_workers:
dataloader_params["worker_init_fn"] = seed_worker

if not isinstance(dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = get_sampler_func(dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last

# accelerator.free_memory() will destroy the references, so
# we need to return the unprepared version here in case someone
# needs them. We prepare in the get_{train/eval/test}_dataloader functions.
return DataLoader(**dataloader_params)

def _get_train_sampler(self, train_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
if train_dataset is None or not has_length(train_dataset):
return None

# Build the sampler.
if self.args.group_by_length:
if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
lengths = (
self.train_dataset[self.args.length_column_name]
if self.args.length_column_name in self.train_dataset.column_names
train_dataset[self.args.length_column_name]
if self.args.length_column_name in train_dataset.column_names
else None
)
else:
Expand All @@ -947,13 +984,13 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
)
return LengthGroupedSampler(
self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
dataset=train_dataset,
lengths=lengths,
model_input_name=model_input_name,
)

else:
return RandomSampler(self.train_dataset)
return RandomSampler(train_dataset)

def get_train_dataloader(self) -> DataLoader:
"""
Expand All @@ -967,28 +1004,15 @@ def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")

train_dataset = self.train_dataset
data_collator = self.data_collator
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(train_dataset, description="training")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")

dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
return self.accelerator.prepare(
self._get_dataloader(
dataset=self.train_dataset,
get_sampler_func=self._get_train_sampler,
batch_size=self._train_batch_size,
description="training",
seed_workers=not isinstance(self.train_dataset, torch.utils.data.IterableDataset),
)
)

def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
if eval_dataset is None or not has_length(eval_dataset):
Expand Down Expand Up @@ -1063,34 +1087,22 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None
if eval_dataset is not None
else self.eval_dataset
)
data_collator = self.data_collator

if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")

dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
eval_dataloader = self._get_dataloader(
dataset=eval_dataset,
get_sampler_func=self._get_eval_sampler,
batch_size=self.args.eval_batch_size,
description="evaluation",
seed_workers=False,
)

# accelerator.free_memory() will destroy the references, so
# we need to store the non-prepared version
eval_dataloader = DataLoader(eval_dataset, **dataloader_params)
if self.args.dataloader_persistent_workers:
if hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders[dataloader_key] = eval_dataloader
else:
self._eval_dataloaders = {dataloader_key: eval_dataloader}
if not hasattr(self, "_eval_dataloaders"):
self._eval_dataloaders = {}

self._eval_dataloaders[dataloader_key] = eval_dataloader

return self.accelerator.prepare(eval_dataloader)

Expand All @@ -1105,28 +1117,15 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. It must implement `__len__`.
"""
data_collator = self.data_collator

if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
test_dataset = self._remove_unused_columns(test_dataset, description="test")
else:
data_collator = self._get_collator_with_removed_columns(data_collator, description="test")

dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}

if not isinstance(test_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

# We use the same batch_size as for eval.
return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
return self.accelerator.prepare(
self._get_dataloader(
dataset=test_dataset,
get_sampler_func=self._get_eval_sampler,
batch_size=self.args.eval_batch_size,
description="test",
seed_workers=False,
)
)

def create_optimizer_and_scheduler(self, num_training_steps: int):
"""
Expand Down
94 changes: 94 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
import importlib
import json
import math
import multiprocessing
import os
import queue
import random
import re
import subprocess
import sys
import tempfile
import time
import unittest
from functools import partial
from itertools import product
Expand Down Expand Up @@ -342,6 +345,20 @@ def __init__(self, a=0, b=0, double_output=False, random_torch=True, **kwargs):
self.hidden_size = 1


class QueueIterableDataset(IterableDataset):
def __init__(self, q):
self.q = q

def __iter__(self):
while True:
try:
item = self.q.get_nowait()
print(item)
yield {"label_ids": item}
except queue.Empty:
break


if is_torch_available():

class SampleIterableDataset(IterableDataset):
Expand Down Expand Up @@ -1571,6 +1588,83 @@ def test_get_eval_dataloader_with_persistent_workers(self):
self.assertEqual(first_dataloader, first_dataloader_repeated)
self.assertEqual(second_dataloader, second_dataloader_repeated)

def test_dataloader_prefetch_iterable_occurs(self):
"""Test that prefetching works correctly with iterable datasets."""

# Create a multiprocessing queue and fill it with 5 elements,
# 4 of these should be fetched ahead of time after getting the first batch
q = multiprocessing.Queue()
for i in range(5):
q.put(torch.tensor([i]))

# Create dataset and model
dataset = QueueIterableDataset(q)
model = RegressionModel()

# Configure trainer with 2 workers and prefetch_factor=1
training_args = TrainingArguments(
output_dir="./test-prefetch",
per_device_train_batch_size=1,
dataloader_num_workers=2,
dataloader_prefetch_factor=1,
max_steps=4,
report_to="none",
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)

# Get the dataloader and fetch one batch
dataloader = trainer.get_train_dataloader()
next(iter(dataloader))
time.sleep(1) # Wait for workers to prefetch the next batch

# The queue should have 1 remaining element after the first batch is fetched
self.assertTrue(q.get_nowait().item() == 4)
with self.assertRaises(queue.Empty):
q.get_nowait()

def test_dataloader_prefetch_iterable_will_exhaust_dataset(self):
"""
The same as test_dataloader_prefetch_iterable_occurs but now
we confirm that setting a higher prefetch factor will result in the
queue dataset being empty after the first batch is fetched.
"""
q = multiprocessing.Queue()
for i in range(6): # 6 elements rather than 5
q.put(torch.tensor([i]))

# Create dataset and model
dataset = QueueIterableDataset(q)
model = RegressionModel()

training_args = TrainingArguments(
output_dir="./test-prefetch",
per_device_train_batch_size=1,
dataloader_num_workers=2,
dataloader_prefetch_factor=3, # 3 rather than 1 above
max_steps=4,
report_to="none",
)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
)

# Get the dataloader and fetch one batch
dataloader = trainer.get_train_dataloader()
next(iter(dataloader))
time.sleep(1) # Wait for workers to prefetch the next batch

# The queue should now be empty since the workers prefetched all remaining elements
with self.assertRaises(queue.Empty):
q.get_nowait()

@require_liger_kernel
def test_use_liger_kernel_patching(self):
# Ensure any monkey patching is cleaned up for subsequent tests
Expand Down