From 2eb684e94c020a594f875e70ca6ce5c2bba4cf54 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 24 Oct 2024 12:13:44 +0200 Subject: [PATCH 01/11] Add cross-entropy example in the gradient accumulation docs --- .../usage_guides/gradient_accumulation.md | 190 ++++++++++++++++++ 1 file changed, 190 insertions(+) diff --git a/docs/source/usage_guides/gradient_accumulation.md b/docs/source/usage_guides/gradient_accumulation.md index 5c765b6df7a..83987dd7ef6 100644 --- a/docs/source/usage_guides/gradient_accumulation.md +++ b/docs/source/usage_guides/gradient_accumulation.md @@ -238,3 +238,193 @@ initial model weight is 0.00000 w/ accumulation, the final model weight is 2.04000 w/o accumulation, the final model weight is 2.04000 ``` + +## Gradient accumulation on training samples of variable size + +As was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accumulation), which points out a common error that occurs when perfoming gradient accumulation on training samples of variable size: + +> [...] for gradient accumulation across token-level tasks like causal LM training, the correct loss should be computed by the **total loss across all batches in a gradient accumulation step** divided by the **total number of all non padding tokens in those batches**. This is not the same as the average of the per-batch loss values. + +In other words, some adjustements must be made on losses that operate on a token-level basis. + +### Skeleton code + +```python +from accelerate import Accelerator +import math +gradient_accumulation_steps = 2 +accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) +model, optimizer, training_dataloader, scheduler = accelerator.prepare( + model, optimizer, training_dataloader, scheduler +) + +training_iterator = iter(training_dataloader) +num_samples_in_epoch = len(training_dataloader) +remainder = num_samples_in_epoch % gradient_accumulation_steps +remainder = remainder if remainder != 0 else gradient_accumulation_steps +total_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps) + + +total_batched_samples = 0 +for update_step in range(total_updates): + # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss + # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples + batch_samples = [] + num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder + for _ in range(num_batches_in_step): + batch_samples += [next(training_iterator)] + + # get local num items in batch + num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch. + num_items_in_batch = accelerator.gather_for_metrics(num_items_in_batch).sum().item() + + for batch in batch_samples: + total_batched_samples += 1 + + # Since we performed prefetching, we need to manually set sync_gradients + if total_batched_samples % gradient_accumulation_steps != 0: + accelerator.gradient_state._set_sync_gradients(False) + else: + accelerator.gradient_state._set_sync_gradients(True) + + with accelerator.accumulate(model): + inputs, targets = batch + outputs = model(inputs) + loss = loss_function(outputs, targets) # the loss function shoud sum over samples rather than averaging + + # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices + # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps + loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch + + accelerator.backward(loss) + optimizer.step() + scheduler.step() + optimizer.zero_grad() +``` + +### Self-contained causal LM example + +```py +import torch +import copy +from accelerate import Accelerator +from accelerate.utils import set_seed +from accelerate.logging import get_logger +from torch.utils.data import Dataset, DataLoader +import math + +# seed +set_seed(0) +logger = get_logger(__name__) + +class MyDataset(Dataset): + def __init__(self, num_samples): + super().__init__() + self.len = num_samples + + def __getitem__(self, index): + input_ids = torch.arange(1, index+2, dtype=torch.float32) + labels = torch.remainder(input_ids, 2) + return {"input_ids": input_ids, "labels": labels} + + def __len__(self): + return self.len + +def collate_fn(features): + input_ids = torch.nn.utils.rnn.pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=-100) + labels = torch.nn.utils.rnn.pad_sequence([f["labels"] for f in features], batch_first=True, padding_value=-100) + return {"input_ids": input_ids[..., None], "labels": labels[..., None]} + +# define toy inputs and labels +gradient_accumulation_steps = 2 +batch_size = 4 + +# define accelerator +accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) + +# define dataset and dataloader +# for this toy example, we'll compute gradient descent over one single global batch +dataset = MyDataset(batch_size*gradient_accumulation_steps*accelerator.num_processes) +dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn) + +# define model, model_optimizer and loss function +model = torch.nn.Linear(1, 2, bias=False) +model_clone = copy.deepcopy(model) +criterion = torch.nn.CrossEntropyLoss(reduction="sum") # must sum over samples rather than averaging +model_optimizer = torch.optim.SGD(model.parameters(), lr=0.08) + + +logger.warning(f"initial model weight is {model.weight.detach().cpu().squeeze()}", main_process_only=True) +logger.warning(f"initial model clone weight is {model_clone.weight.detach().cpu().squeeze()}", main_process_only=True) + +# prepare artifacts - accelerator handles device placement and dataloader splitting +model, model_optimizer = accelerator.prepare(model, model_optimizer) +dataloader = accelerator.prepare_data_loader(dataloader, device_placement=True) +training_iterator = iter(dataloader) + +num_samples_in_epoch = len(dataloader) +remainder = num_samples_in_epoch % gradient_accumulation_steps +remainder = remainder if remainder != 0 else gradient_accumulation_steps +total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps) + +total_batched_samples = 0 +for update_step in range(total_gradient_updates): + # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss + # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples + batch_samples = [] + num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder + for _ in range(num_batches_in_step): + batch_samples += [next(training_iterator)] + + # get local num items in batch + local_num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + logger.warning(f"Step {update_step} - Device {accelerator.process_index} - num items in the local batch {local_num_items_in_batch}", main_process_only=False) + + # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch. + num_items_in_batch = accelerator.gather_for_metrics(local_num_items_in_batch).sum().item() + logger.warning(f"Total num items {num_items_in_batch}", main_process_only=True) + + for batch in batch_samples: + inputs, labels = batch["input_ids"], batch["labels"] + total_batched_samples += 1 + + with accelerator.accumulate(model): + # Since we performed prefetching, we need to manually set sync_gradients + if total_batched_samples % gradient_accumulation_steps != 0: + accelerator.gradient_state._set_sync_gradients(False) + else: + accelerator.gradient_state._set_sync_gradients(True) + + outputs = model(inputs) + loss = criterion(outputs.view(-1, 2), labels.view(-1).to(torch.int64)) + + # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices + # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps + loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch + accelerator.backward(loss) + model_optimizer.step() + model_optimizer.zero_grad() + + +logger.warning(f"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}", main_process_only=False) + +# We know do the same operation but on a single device and without gradient accumulation + +if accelerator.is_main_process: + # prepare one single entire batch + dataloader = DataLoader(dataset, batch_size=len(dataset), collate_fn=collate_fn) + full_batch_without_accum = next(iter(dataloader)) + total_inputs, total_labels = full_batch_without_accum["input_ids"], full_batch_without_accum["labels"] + model_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.08) + + # train the cloned model + loss = torch.nn.CrossEntropyLoss(reduction="mean")(model_clone(total_inputs).view(-1, 2), total_labels.view(-1).to(torch.int64)) + model_clone_optimizer.zero_grad() + loss.backward() + model_clone_optimizer.step() + + # We should have the same final weights. + logger.warning(f"w/o accumulation, the final model weight is {model_clone.weight.detach().cpu().squeeze()}") + +``` \ No newline at end of file From 4d4ed80643b215eb3f66660f5dfd02566ae99e5f Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 24 Oct 2024 12:19:23 +0200 Subject: [PATCH 02/11] add example of logs --- .../usage_guides/gradient_accumulation.md | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/source/usage_guides/gradient_accumulation.md b/docs/source/usage_guides/gradient_accumulation.md index 83987dd7ef6..87857f43bfc 100644 --- a/docs/source/usage_guides/gradient_accumulation.md +++ b/docs/source/usage_guides/gradient_accumulation.md @@ -427,4 +427,26 @@ if accelerator.is_main_process: # We should have the same final weights. logger.warning(f"w/o accumulation, the final model weight is {model_clone.weight.detach().cpu().squeeze()}") +``` + +Results on a single device: +``` +initial model weight is tensor([-0.0075, 0.5364]) +initial model clone weight is tensor([-0.0075, 0.5364]) +Step 0 - Device 0 - num items in the local batch 36 +Total num items 36 +Device 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337]) +w/o accumulation, the final model weight is tensor([0.0953, 0.4337]) +``` + +Results on a two devices set-up: +``` +initial model weight is tensor([-0.0075, 0.5364]) +initial model clone weight is tensor([-0.0075, 0.5364]) +Step 0 - Device 0 - num items in the local batch 52 +Step 0 - Device 1 - num items in the local batch 84 +Total num items 136 +Device 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172]) +Device 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172]) +w/o accumulation, the final model weight is tensor([0.2117, 0.3172]) ``` \ No newline at end of file From 3b8c887238957629d53e9f99cd79dd90ed9f2270 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 24 Oct 2024 12:25:03 +0200 Subject: [PATCH 03/11] correct skeleton code --- docs/source/usage_guides/gradient_accumulation.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/usage_guides/gradient_accumulation.md b/docs/source/usage_guides/gradient_accumulation.md index 87857f43bfc..fbf0ad33016 100644 --- a/docs/source/usage_guides/gradient_accumulation.md +++ b/docs/source/usage_guides/gradient_accumulation.md @@ -282,13 +282,13 @@ for update_step in range(total_updates): for batch in batch_samples: total_batched_samples += 1 - # Since we performed prefetching, we need to manually set sync_gradients - if total_batched_samples % gradient_accumulation_steps != 0: - accelerator.gradient_state._set_sync_gradients(False) - else: - accelerator.gradient_state._set_sync_gradients(True) - with accelerator.accumulate(model): + # Since we performed prefetching, we need to manually set sync_gradients + if total_batched_samples % gradient_accumulation_steps != 0: + accelerator.gradient_state._set_sync_gradients(False) + else: + accelerator.gradient_state._set_sync_gradients(True) + inputs, targets = batch outputs = model(inputs) loss = loss_function(outputs, targets) # the loss function shoud sum over samples rather than averaging From c01827c19b15af46066ca7bbf32349328d9042b1 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 24 Oct 2024 14:28:44 +0200 Subject: [PATCH 04/11] replace gather_for_metrics with gather --- docs/source/usage_guides/gradient_accumulation.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/usage_guides/gradient_accumulation.md b/docs/source/usage_guides/gradient_accumulation.md index fbf0ad33016..6faf84db2d3 100644 --- a/docs/source/usage_guides/gradient_accumulation.md +++ b/docs/source/usage_guides/gradient_accumulation.md @@ -277,7 +277,7 @@ for update_step in range(total_updates): # get local num items in batch num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch. - num_items_in_batch = accelerator.gather_for_metrics(num_items_in_batch).sum().item() + num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item() for batch in batch_samples: total_batched_samples += 1 @@ -382,7 +382,7 @@ for update_step in range(total_gradient_updates): logger.warning(f"Step {update_step} - Device {accelerator.process_index} - num items in the local batch {local_num_items_in_batch}", main_process_only=False) # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch. - num_items_in_batch = accelerator.gather_for_metrics(local_num_items_in_batch).sum().item() + num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item() logger.warning(f"Total num items {num_items_in_batch}", main_process_only=True) for batch in batch_samples: From 22cbf9c5c4906384fdd70a5562bc94f0ef2781c6 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 24 Oct 2024 14:30:36 +0200 Subject: [PATCH 05/11] batch_size -> per_device_batch_size --- docs/source/usage_guides/gradient_accumulation.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/usage_guides/gradient_accumulation.md b/docs/source/usage_guides/gradient_accumulation.md index 6faf84db2d3..2c180925b2d 100644 --- a/docs/source/usage_guides/gradient_accumulation.md +++ b/docs/source/usage_guides/gradient_accumulation.md @@ -187,11 +187,11 @@ set_seed(0) x = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.]) y = torch.tensor([2., 4., 6., 8., 10., 12., 14., 16.]) gradient_accumulation_steps = 4 -batch_size = len(x) // gradient_accumulation_steps +per_device_batch_size = len(x) // gradient_accumulation_steps # define dataset and dataloader dataset = TensorDataset(x, y) -dataloader = DataLoader(dataset, batch_size=batch_size) +dataloader = DataLoader(dataset, batch_size=per_device_batch_size) # define model, optimizer and loss function class SimpleLinearModel(torch.nn.Module): @@ -338,15 +338,15 @@ def collate_fn(features): # define toy inputs and labels gradient_accumulation_steps = 2 -batch_size = 4 +per_device_batch_size = 4 # define accelerator accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) # define dataset and dataloader # for this toy example, we'll compute gradient descent over one single global batch -dataset = MyDataset(batch_size*gradient_accumulation_steps*accelerator.num_processes) -dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn) +dataset = MyDataset(per_device_batch_size*gradient_accumulation_steps*accelerator.num_processes) +dataloader = DataLoader(dataset, batch_size=per_device_batch_size, collate_fn=collate_fn) # define model, model_optimizer and loss function model = torch.nn.Linear(1, 2, bias=False) From 395c572dd30ba7c8d18c6a7cff506a567897d69e Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Thu, 24 Oct 2024 14:32:58 +0200 Subject: [PATCH 06/11] remove main_process_only=True --- docs/source/usage_guides/gradient_accumulation.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/usage_guides/gradient_accumulation.md b/docs/source/usage_guides/gradient_accumulation.md index 2c180925b2d..82d60e79b48 100644 --- a/docs/source/usage_guides/gradient_accumulation.md +++ b/docs/source/usage_guides/gradient_accumulation.md @@ -355,8 +355,8 @@ criterion = torch.nn.CrossEntropyLoss(reduction="sum") # must sum over samples r model_optimizer = torch.optim.SGD(model.parameters(), lr=0.08) -logger.warning(f"initial model weight is {model.weight.detach().cpu().squeeze()}", main_process_only=True) -logger.warning(f"initial model clone weight is {model_clone.weight.detach().cpu().squeeze()}", main_process_only=True) +logger.warning(f"initial model weight is {model.weight.detach().cpu().squeeze()}") +logger.warning(f"initial model clone weight is {model_clone.weight.detach().cpu().squeeze()}") # prepare artifacts - accelerator handles device placement and dataloader splitting model, model_optimizer = accelerator.prepare(model, model_optimizer) @@ -383,7 +383,7 @@ for update_step in range(total_gradient_updates): # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch. num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item() - logger.warning(f"Total num items {num_items_in_batch}", main_process_only=True) + logger.warning(f"Total num items {num_items_in_batch}") for batch in batch_samples: inputs, labels = batch["input_ids"], batch["labels"] From 2e80bf038765f6f7312bbcb57c743588c382a5d9 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Tue, 29 Oct 2024 16:48:30 +0100 Subject: [PATCH 07/11] add autoregressive example in examples/ --- ..._accumulation_for_autoregressive_models.py | 321 ++++++++++++++++++ 1 file changed, 321 insertions(+) create mode 100644 examples/by_feature/gradient_accumulation_for_autoregressive_models.py diff --git a/examples/by_feature/gradient_accumulation_for_autoregressive_models.py b/examples/by_feature/gradient_accumulation_for_autoregressive_models.py new file mode 100644 index 00000000000..a4cb53c35fb --- /dev/null +++ b/examples/by_feature/gradient_accumulation_for_autoregressive_models.py @@ -0,0 +1,321 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import contextlib +import math +import os + +import torch +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from transformers import AutoModelForCausalLM, AutoTokenizer, get_constant_schedule, set_seed + +from accelerate import Accelerator, DistributedType + + +######################################################################## +# This is a fully working simple example to use Accelerate +# and perform gradient accumulation on samples of variable size +# +# This example trains a SmolLM base model on WikiText-2 v1 +# in any of the following settings (with the same script): +# - single CPU or single GPU +# - multi GPUS (using PyTorch distributed mode) +# - (multi) TPUs +# - fp16 (mixed-precision) or fp32 (normal precision) +# +# To run it in each of these various modes, follow the instructions +# in the readme for examples: +# https://github.com/huggingface/accelerate/tree/main/examples +# +######################################################################## + + +EVAL_BATCH_SIZE = 32 + + +def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, max_training_samples=500): + """ + Creates a set of `DataLoader`s for the `Salesforce/wikitext` dataset, + using "HuggingFaceTB/SmolLM-360M" as the tokenizer. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + """ + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-360M") + tokenizer.pad_token = tokenizer.eos_token + with accelerator.local_main_process_first(): + datasets = load_dataset("Salesforce/wikitext", "wikitext-2-v1") + datasets["train"] = datasets["train"].select(range(max_training_samples)) + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["text"], truncation=True, max_length=None, return_attention_mask=False) + return outputs + + # Filter out empty texts + with accelerator.main_process_first(): + datasets = datasets.filter( + lambda x: len(x) > 0, + input_columns="text", + ) + + # Apply the method we just defined to all the examples in all the splits of the dataset + # starting with the main process first: + with accelerator.main_process_first(): + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["text"], + ) + + # Filter out empty samples + with accelerator.main_process_first(): + tokenized_datasets = tokenized_datasets.filter( + lambda x: len(x) > 0, + input_columns="input_ids", + ) + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + max_length = 128 if accelerator.distributed_type == DistributedType.XLA else max([len(e["input_ids"]) for e in examples]) + # When using mixed precision we want round multiples of 8/16 + if accelerator.mixed_precision == "fp8": + pad_to_multiple_of = 16 + elif accelerator.mixed_precision != "no": + pad_to_multiple_of = 8 + else: + pad_to_multiple_of = None + + batch = tokenizer.pad( + examples, + padding="max_length", + max_length=max_length + 1, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ) + + batch["labels"] = batch["input_ids"][:, 1:] + batch["input_ids"] = batch["input_ids"][:, :-1] + + batch["labels"] = torch.where(batch["labels"] == tokenizer.pad_token_id, -100, batch["labels"]) + + return batch + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +# For testing only +if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1": + from accelerate.test_utils.training import mocked_dataloaders + + get_dataloaders = mocked_dataloaders # noqa: F811 + + +def training_function(config, args): + # For testing only + if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1": + config["num_epochs"] = 2 + + gradient_accumulation_steps = int(args.gradient_accumulation_steps) + # Initialize accelerator + if args.with_wandb_tracking: + accelerator = Accelerator( + cpu=args.cpu, mixed_precision=args.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps, + log_with="wandb" + ) + else: + accelerator = Accelerator( + cpu=args.cpu, mixed_precision=args.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps + ) + if accelerator.distributed_type == DistributedType.XLA and gradient_accumulation_steps > 1: + raise NotImplementedError( + "Gradient accumulation on TPUs is currently not supported. Pass `gradient_accumulation_steps=1`" + ) + # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + max_grad_norm = config["max_grad_norm"] + + # We need to initialize the trackers we use, and also store our configuration + if args.with_wandb_tracking: + run = os.path.split(__file__)[-1].split(".")[0] + run_name = f"{accelerator.num_processes}GPU-grad{gradient_accumulation_steps}-bs{batch_size}" + accelerator.init_trackers(run, config, init_kwargs={"wandb": {"name": run_name}},) + + set_seed(seed) + train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size) + # Instantiate the model (we build the model here so that the seed also control new weights initialization) + model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-360M") + + # We could avoid this line since the accelerator is set with `device_placement=True` (default value). + # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer + # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that). + model = model.to(accelerator.device) + + # Instantiate optimizer + optimizer = AdamW(params=model.parameters(), lr=lr) + + # Instantiate scheduler + lr_scheduler = get_constant_schedule( + optimizer=optimizer, + ) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the + # prepare method. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + num_samples_in_epoch = len(train_dataloader) + remainder = num_samples_in_epoch % gradient_accumulation_steps + remainder = remainder if remainder != 0 else gradient_accumulation_steps + total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps) + + total_batched_samples = 0 + # Now we train the model + for epoch in range(num_epochs): + model.train() + training_iterator = iter(train_dataloader) + for update_step in range(total_gradient_updates): + # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss + # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples + batch_samples = [] + num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder + for _ in range(num_batches_in_step): + batch_samples += [next(training_iterator)] + # get local num items in batch + local_num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) + + # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch. + num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item() + losses = [] + for i, batch in enumerate(batch_samples): + # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating + # cf: https://muellerzr.github.io/blog/gradient_accumulation.html + ctx = model.no_sync if (i < len(batch_samples) - 1 and accelerator.num_processes > 1) else contextlib.nullcontext + with ctx(): + total_batched_samples += 1 + + outputs = model(**batch, use_cache=False, num_items_in_batch=num_items_in_batch) + loss = outputs.loss + + # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices + # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps + # Because the loss is already divided by `num_items_in_batch` in the `transformers` code, we don't need to do it again + loss = (loss * gradient_accumulation_steps * accelerator.num_processes) + accelerator.backward(loss) + losses.append(loss.detach()) + + # Sync gradients and perform optimization steps once every gradient_accumulation_steps + grad_norm = accelerator.clip_grad_norm_(model.parameters(), max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + losses = accelerator.gather(sum(losses)).sum().item() / (accelerator.num_processes * gradient_accumulation_steps) + + grad_norm = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + accelerator.print(f"epoch {epoch} - update step {update_step}:: grad norm: {grad_norm} ::train loss: {losses}") + if args.with_wandb_tracking: + accelerator.log( + { + "train/grad_norm": grad_norm, + "train/epoch": epoch, + "train/loss": losses, + }, + step=update_step + total_gradient_updates * epoch, + ) + model.eval() + losses = [] + for step, batch in enumerate(eval_dataloader): + with torch.no_grad(): + outputs = model(**batch, use_cache=False) + eval_loss = outputs.loss + losses.append(accelerator.gather_for_metrics(loss.repeat(EVAL_BATCH_SIZE))) + + losses = torch.cat(losses) + try: + eval_loss = torch.mean(losses) + perplexity = math.exp(eval_loss) + except OverflowError: + perplexity = float("inf") + + # Use accelerator.print to print only on the main process. + accelerator.print(f"epoch {epoch}:: eval perplexity: {perplexity} eval_loss: {eval_loss}") + if args.with_wandb_tracking: + accelerator.log( + { + "eval/perplexity": perplexity, + "eval/loss": eval_loss, + "eval/epoch": epoch, + }, + step=update_step + total_gradient_updates * epoch, + ) + accelerator.end_training() + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script.") + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16", "fp8"], + help="Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU.", + ) + + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="The number of minibatches to be ran before gradients are accumulated.", + ) + parser.add_argument( + "--per_device_batch_size", + type=int, + default=16, + help="The number of minibatches to be ran before gradients are accumulated.", + ) + + parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.") + parser.add_argument( + "--with_wandb_tracking", + action="store_true", + help="Whether to load in wandb from the environment and use them for logging.", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": args.per_device_batch_size, "max_grad_norm": 1.0} + training_function(config, args) + + +if __name__ == "__main__": + main() From 5e3e8118c64950c949e39e4ef8df6ac10702cd99 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Tue, 29 Oct 2024 16:49:02 +0100 Subject: [PATCH 08/11] Update docs/source/usage_guides/gradient_accumulation.md Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- docs/source/usage_guides/gradient_accumulation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/usage_guides/gradient_accumulation.md b/docs/source/usage_guides/gradient_accumulation.md index 82d60e79b48..2e3e601542f 100644 --- a/docs/source/usage_guides/gradient_accumulation.md +++ b/docs/source/usage_guides/gradient_accumulation.md @@ -241,7 +241,7 @@ w/o accumulation, the final model weight is 2.04000 ## Gradient accumulation on training samples of variable size -As was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accumulation), which points out a common error that occurs when perfoming gradient accumulation on training samples of variable size: +As was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accumulation), which points out a common error that occurs when performing gradient accumulation on training samples of variable size: > [...] for gradient accumulation across token-level tasks like causal LM training, the correct loss should be computed by the **total loss across all batches in a gradient accumulation step** divided by the **total number of all non padding tokens in those batches**. This is not the same as the average of the per-batch loss values. From c56c780210cf2cf233db4c574fe52f997bda74c2 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Tue, 29 Oct 2024 16:55:26 +0100 Subject: [PATCH 09/11] ruff format --- ..._accumulation_for_autoregressive_models.py | 40 ++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/examples/by_feature/gradient_accumulation_for_autoregressive_models.py b/examples/by_feature/gradient_accumulation_for_autoregressive_models.py index a4cb53c35fb..5428fd974f7 100644 --- a/examples/by_feature/gradient_accumulation_for_autoregressive_models.py +++ b/examples/by_feature/gradient_accumulation_for_autoregressive_models.py @@ -93,7 +93,11 @@ def tokenize_function(examples): def collate_fn(examples): # On TPU it's best to pad everything to the same length or training will be very slow. - max_length = 128 if accelerator.distributed_type == DistributedType.XLA else max([len(e["input_ids"]) for e in examples]) + max_length = ( + 128 + if accelerator.distributed_type == DistributedType.XLA + else max([len(e["input_ids"]) for e in examples]) + ) # When using mixed precision we want round multiples of 8/16 if accelerator.mixed_precision == "fp8": pad_to_multiple_of = 16 @@ -144,9 +148,11 @@ def training_function(config, args): # Initialize accelerator if args.with_wandb_tracking: accelerator = Accelerator( - cpu=args.cpu, mixed_precision=args.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps, - log_with="wandb" - ) + cpu=args.cpu, + mixed_precision=args.mixed_precision, + gradient_accumulation_steps=gradient_accumulation_steps, + log_with="wandb", + ) else: accelerator = Accelerator( cpu=args.cpu, mixed_precision=args.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps @@ -166,7 +172,11 @@ def training_function(config, args): if args.with_wandb_tracking: run = os.path.split(__file__)[-1].split(".")[0] run_name = f"{accelerator.num_processes}GPU-grad{gradient_accumulation_steps}-bs{batch_size}" - accelerator.init_trackers(run, config, init_kwargs={"wandb": {"name": run_name}},) + accelerator.init_trackers( + run, + config, + init_kwargs={"wandb": {"name": run_name}}, + ) set_seed(seed) train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size) @@ -207,7 +217,9 @@ def training_function(config, args): # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples batch_samples = [] - num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder + num_batches_in_step = ( + gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder + ) for _ in range(num_batches_in_step): batch_samples += [next(training_iterator)] # get local num items in batch @@ -219,7 +231,11 @@ def training_function(config, args): for i, batch in enumerate(batch_samples): # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating # cf: https://muellerzr.github.io/blog/gradient_accumulation.html - ctx = model.no_sync if (i < len(batch_samples) - 1 and accelerator.num_processes > 1) else contextlib.nullcontext + ctx = ( + model.no_sync + if (i < len(batch_samples) - 1 and accelerator.num_processes > 1) + else contextlib.nullcontext + ) with ctx(): total_batched_samples += 1 @@ -229,7 +245,7 @@ def training_function(config, args): # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps # Because the loss is already divided by `num_items_in_batch` in the `transformers` code, we don't need to do it again - loss = (loss * gradient_accumulation_steps * accelerator.num_processes) + loss = loss * gradient_accumulation_steps * accelerator.num_processes accelerator.backward(loss) losses.append(loss.detach()) @@ -239,10 +255,14 @@ def training_function(config, args): lr_scheduler.step() optimizer.zero_grad() - losses = accelerator.gather(sum(losses)).sum().item() / (accelerator.num_processes * gradient_accumulation_steps) + losses = accelerator.gather(sum(losses)).sum().item() / ( + accelerator.num_processes * gradient_accumulation_steps + ) grad_norm = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm - accelerator.print(f"epoch {epoch} - update step {update_step}:: grad norm: {grad_norm} ::train loss: {losses}") + accelerator.print( + f"epoch {epoch} - update step {update_step}:: grad norm: {grad_norm} ::train loss: {losses}" + ) if args.with_wandb_tracking: accelerator.log( { From 80c720a9749520f0bbd9fe1151b9d92b87d52608 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Tue, 29 Oct 2024 17:01:02 +0100 Subject: [PATCH 10/11] add grad accum test --- .../gradient_accumulation_for_autoregressive_models.py | 2 +- tests/test_examples.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/by_feature/gradient_accumulation_for_autoregressive_models.py b/examples/by_feature/gradient_accumulation_for_autoregressive_models.py index 5428fd974f7..f4c7caddddd 100644 --- a/examples/by_feature/gradient_accumulation_for_autoregressive_models.py +++ b/examples/by_feature/gradient_accumulation_for_autoregressive_models.py @@ -322,7 +322,7 @@ def main(): parser.add_argument( "--per_device_batch_size", type=int, - default=16, + default=2, help="The number of minibatches to be ran before gradients are accumulated.", ) diff --git a/tests/test_examples.py b/tests/test_examples.py index a16dce52f92..18f3c68426c 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -54,6 +54,7 @@ "schedule_free.py", "tracking.py", "automatic_gradient_accumulation.py", + "gradient_accumulation_for_autoregressive_models.py", "fsdp_with_peak_mem_tracking.py", "deepspeed_with_config_support.py", "megatron_lm_gpt_pretraining.py", @@ -245,6 +246,10 @@ def test_gradient_accumulation(self): testargs = ["examples/by_feature/gradient_accumulation.py"] run_command(self.launch_args + testargs) + def test_gradient_accumulation_for_autoregressive_models(self): + testargs = ["examples/by_feature/gradient_accumulation_for_autoregressive_models.py"] + run_command(self.launch_args + testargs) + def test_local_sgd(self): testargs = ["examples/by_feature/local_sgd.py"] run_command(self.launch_args + testargs) From e5d2c50b08062cf45afeaf21998d5848fc82657a Mon Sep 17 00:00:00 2001 From: Yoach Lacombe Date: Tue, 29 Oct 2024 17:31:57 +0100 Subject: [PATCH 11/11] update docs --- .../usage_guides/gradient_accumulation.md | 66 ++++++++++++------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/docs/source/usage_guides/gradient_accumulation.md b/docs/source/usage_guides/gradient_accumulation.md index 2e3e601542f..13ce11b0a3c 100644 --- a/docs/source/usage_guides/gradient_accumulation.md +++ b/docs/source/usage_guides/gradient_accumulation.md @@ -252,6 +252,8 @@ In other words, some adjustements must be made on losses that operate on a token ```python from accelerate import Accelerator import math +import contextlib + gradient_accumulation_steps = 2 accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) model, optimizer, training_dataloader, scheduler = accelerator.prepare( @@ -279,16 +281,17 @@ for update_step in range(total_updates): # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch. num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item() - for batch in batch_samples: + for i, batch in enumerate(batch_samples): + # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating + # cf: https://muellerzr.github.io/blog/gradient_accumulation.html + if (i < len(batch_samples) - 1 and accelerator.num_processes > 1): + ctx = model.no_sync + else: + ctx = contextlib.nullcontext + total_batched_samples += 1 - with accelerator.accumulate(model): - # Since we performed prefetching, we need to manually set sync_gradients - if total_batched_samples % gradient_accumulation_steps != 0: - accelerator.gradient_state._set_sync_gradients(False) - else: - accelerator.gradient_state._set_sync_gradients(True) - + with ctx(): inputs, targets = batch outputs = model(inputs) loss = loss_function(outputs, targets) # the loss function shoud sum over samples rather than averaging @@ -298,9 +301,11 @@ for update_step in range(total_updates): loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch accelerator.backward(loss) - optimizer.step() - scheduler.step() - optimizer.zero_grad() + + # Sync gradients and perform optimization steps once every gradient_accumulation_steps + optimizer.step() + scheduler.step() + optimizer.zero_grad() ``` ### Self-contained causal LM example @@ -313,6 +318,7 @@ from accelerate.utils import set_seed from accelerate.logging import get_logger from torch.utils.data import Dataset, DataLoader import math +import contexlib # seed set_seed(0) @@ -385,16 +391,16 @@ for update_step in range(total_gradient_updates): num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item() logger.warning(f"Total num items {num_items_in_batch}") - for batch in batch_samples: + for i, batch in enumerate(batch_samples): inputs, labels = batch["input_ids"], batch["labels"] total_batched_samples += 1 - - with accelerator.accumulate(model): - # Since we performed prefetching, we need to manually set sync_gradients - if total_batched_samples % gradient_accumulation_steps != 0: - accelerator.gradient_state._set_sync_gradients(False) - else: - accelerator.gradient_state._set_sync_gradients(True) + # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating + # cf: https://muellerzr.github.io/blog/gradient_accumulation.html + if (i < len(batch_samples) - 1 and accelerator.num_processes > 1): + ctx = model.no_sync + else: + ctx = contextlib.nullcontext + with ctx(): outputs = model(inputs) loss = criterion(outputs.view(-1, 2), labels.view(-1).to(torch.int64)) @@ -403,8 +409,8 @@ for update_step in range(total_gradient_updates): # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch accelerator.backward(loss) - model_optimizer.step() - model_optimizer.zero_grad() + model_optimizer.step() + model_optimizer.zero_grad() logger.warning(f"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}", main_process_only=False) @@ -429,7 +435,7 @@ if accelerator.is_main_process: ``` -Results on a single device: +Results on a single device - gradient accumulation steps set to 1 and batch_size set to 8: ``` initial model weight is tensor([-0.0075, 0.5364]) initial model clone weight is tensor([-0.0075, 0.5364]) @@ -439,7 +445,7 @@ Device 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337]) w/o accumulation, the final model weight is tensor([0.0953, 0.4337]) ``` -Results on a two devices set-up: +Results on a two devices set-up - gradient accumulation steps set to 2 and batch_size set to 4. ``` initial model weight is tensor([-0.0075, 0.5364]) initial model clone weight is tensor([-0.0075, 0.5364]) @@ -449,4 +455,16 @@ Total num items 136 Device 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172]) Device 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172]) w/o accumulation, the final model weight is tensor([0.2117, 0.3172]) -``` \ No newline at end of file +``` + +### To go further: + +Please find a complete example script on a real world training run in the examples folder at the path [`accelerate/examples/by_feature/gradient_accumulation_for_autoregressive_models.py`](https://github.com/huggingface/accelerate/blob/main/examples/by_feature/gradient_accumulation_for_autoregressive_models.py). + +Running it on several training configurations with constant global batch size equal to 32 gives the following graph: + +
+ +
+ +Note that the training losses are exactly the same up to training step 20. The small deviation after this training step occurs at the very end of the first epoch, because, by [default](https://huggingface.co/docs/accelerate/en/package_reference/torch_wrappers#accelerate.data_loader.prepare_data_loader.even_batches), the dataloader duplicates the samples at the beginning of the dataset when the total batch size doesn't exactly divide the dataset.