Skip to content

Commit

Permalink
refactor deepspeed dataloader prepare logic (#2238)
Browse files Browse the repository at this point in the history
* refactor deepspeed dataloader prepare logic

Co-Authored-By: Stas Bekman <[email protected]>

* address comments and fix issues

Co-Authored-By: Stas Bekman <[email protected]>

* further refactor

* add test

* rename test

---------

Co-authored-by: Stas Bekman <[email protected]>
  • Loading branch information
pacman100 and stas00 authored Dec 19, 2023
1 parent b052839 commit a03c361
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 26 deletions.
48 changes: 24 additions & 24 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,38 +1414,38 @@ def _prepare_deepspeed(self, *args):
deepspeed_plugin = self.state.deepspeed_plugin

is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args)
if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto" or is_dataloader_present:
result = [
self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
for obj in args
]
result = [
self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
for obj in args
]

batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
if self.split_batches:
batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes]
if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto":
if is_dataloader_present:
batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
if any(bs is None for bs in batch_sizes):
raise ValueError(
"At least one of the dataloaders passed to `accelerate.prepare()` has `None` as batch size. "
"Please set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file "
"or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`."
)
if self.split_batches:
batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes]

if any(bs is None for bs in batch_sizes):
raise ValueError(
"At least one of the dataloaders passed to `accelerate.prepare()` has `None` as batch size. "
"Please set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file "
"or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`."
)
if len(batch_sizes) == 0:
batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes)
if len(batch_sizes) > 1:
logger.info(
"Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
f"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})."
)
else:
raise ValueError(
"When using DeepSpeed `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders "
"When using DeepSpeed, `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders "
"with `batch_size` attribute returning an integer value "
"or alternatively set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file "
"or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`."
)

batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes)
if len(batch_sizes) > 1:
logger.info(
"Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
f"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})."
)
else:
batch_size_per_device = deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"]
result = [obj for obj in args]

# handle `gradient_accumulation_steps` when the value is `auto`
deepspeed_plugin.fill_match(
Expand Down
46 changes: 44 additions & 2 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch
from parameterized import parameterized
from torch.utils.data import DataLoader
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoModel, AutoModelForCausalLM, get_scheduler
from transformers.testing_utils import mockenv_context
from transformers.trainer_utils import set_seed
Expand Down Expand Up @@ -337,7 +337,8 @@ def test_prepare_deepspeed(self, optim_type, scheduler_type):
with self.assertRaises(ValueError) as cm:
model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
self.assertTrue(
"When using DeepSpeed `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders "
"When using DeepSpeed, `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders "
"with `batch_size` attribute returning an integer value "
"or alternatively set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file "
"or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`."
in str(cm.exception)
Expand Down Expand Up @@ -506,6 +507,47 @@ def _lr_scheduler_callable(optimizer):
model, dummy_optimizer, train_dataloader, eval_dataloader, dummy_lr_scheduler
)

def test_dataloader_with_batch_sampler(self):
deepspeed_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=1,
gradient_clipping=1.0,
zero_stage=2,
offload_optimizer_device="cpu",
offload_param_device="cpu",
zero3_save_16bit_model=False,
zero3_init_flag=False,
)
with mockenv_context(**self.dist_env):
accelerator = Accelerator(mixed_precision="fp16", deepspeed_plugin=deepspeed_plugin)

train_set = RegressionDataset(length=80)
eval_set = RegressionDataset(length=20)
train_dataloader = DataLoader(
train_set, batch_sampler=BatchSampler(RandomSampler(train_set), batch_size=10, drop_last=False)
)
eval_dataloader = DataLoader(
eval_set, batch_sampler=BatchSampler(SequentialSampler(eval_set), batch_size=10, drop_last=False)
)
model = AutoModel.from_pretrained(GPT2_TINY)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=1000,
)

with self.assertRaises(ValueError) as cm:
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
self.assertTrue(
"At least one of the dataloaders passed to `accelerate.prepare()` has `None` as batch size. "
"Please set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file "
"or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`."
in str(cm.exception)
)

def test_save_checkpoints(self):
deepspeed_plugin = DeepSpeedPlugin(
hf_ds_config=self.ds_config_file[ZERO3],
Expand Down

0 comments on commit a03c361

Please sign in to comment.