From 33fd672924e82ee78b2777c64a0367378692fadf Mon Sep 17 00:00:00 2001 From: zzhhjjj Date: Mon, 22 Apr 2024 14:39:51 +0000 Subject: [PATCH 01/35] readme --- examples/mamba/README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/mamba/README.md b/examples/mamba/README.md index 5c31d07f..8eefa9c2 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -18,6 +18,18 @@ pip install -r requirements.txt > https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5 +## Bug related to nanotron +Encountered the following issue when ran train_mamba.sh: +``` +causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv +``` +Solved this by doing: +pip uninstall mamba-ssm +pip install causal_conv1d==1.1.1 +pip install mamba-ssm --no-cache-dir +https://github.com/state-spaces/mamba/issues/169 + + ## Credits Credits to the following repositories from which the code was adapted: - https://github.com/state-spaces/mamba From b1872e1ed8542d9aee8e659f2336289a89588062 Mon Sep 17 00:00:00 2001 From: Angel Gonzalez Date: Tue, 7 May 2024 11:35:59 +0200 Subject: [PATCH 02/35] Adding checkpoint after traning ends --- src/nanotron/config/config.py | 1 + src/nanotron/trainer.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index d9946f26..e26fac75 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -129,6 +129,7 @@ class CheckpointsArgs: checkpoints_path: Path checkpoint_interval: int save_initial_state: Optional[bool] = False + save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[Path] = None checkpoints_path_is_shared_file_system: Optional[bool] = False diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 0eda00dc..70d023fb 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -442,7 +442,10 @@ def train( self.save_checkpoint() dist.barrier() # let's wait for everyone before leaving - + + if self.config.checkpoints.save_final_state: + self.save_checkpoint() + self.post_training() def training_step( From e484d99db07bf0a69d35072fd11b500cb1722f45 Mon Sep 17 00:00:00 2001 From: Tiancheng Chen Date: Tue, 14 May 2024 18:57:55 +0200 Subject: [PATCH 03/35] wip --- src/nanotron/config/parallelism_config.py | 2 ++ src/nanotron/models/llama.py | 32 +++++++++++++++++++---- 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 5912425b..321ee045 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -23,6 +23,7 @@ class ParallelismArgs: pp_engine: Pipeline engine to use between "1f1b" and "afab" tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism tp_linear_async_communication: Whether to use async communication in TP linear layers + recompute_layer: Whether to recompute each Transformer layer to save memory. """ dp: int @@ -31,6 +32,7 @@ class ParallelismArgs: pp_engine: Optional[PipelineEngine] = None tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None + recompute_layer: bool = False expert_parallel_size: int = 1 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 32aab9cd..a439768b 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -18,6 +18,7 @@ import torch from torch import nn +from torch.utils.checkpoint import CheckpointFunction from nanotron import distributed as dist from nanotron import logging @@ -617,12 +618,14 @@ def __init__( self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) - - def forward( + + self.recompute_layer = parallel_config.recompute_layer + + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], - ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + ) -> List[Union[torch.Tensor, TensorPointer]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -635,12 +638,31 @@ def forward( hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] hidden_states = hidden_states + residual + return hidden_states, output["sequence_mask"] + + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + sequence_mask: torch.Tensor, + ) -> List[torch.Tensor]: + return CheckpointFunction.apply(self._core_forward, hidden_states, sequence_mask) + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + if self.recompute_layer: + hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) + else: + hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask) + return { "hidden_states": hidden_states, - "sequence_mask": output["sequence_mask"], + "sequence_mask": sequence_mask, } - class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() From 7e15516cf282cc8b1f10b34e5334615f4e124c60 Mon Sep 17 00:00:00 2001 From: Tiancheng Chen Date: Tue, 14 May 2024 23:26:40 +0200 Subject: [PATCH 04/35] layer recompute --- src/nanotron/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index a439768b..cb1b4d86 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -645,7 +645,7 @@ def _checkpointed_forward( hidden_states: torch.Tensor, sequence_mask: torch.Tensor, ) -> List[torch.Tensor]: - return CheckpointFunction.apply(self._core_forward, hidden_states, sequence_mask) + return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) def forward( self, @@ -653,7 +653,7 @@ def forward( sequence_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - if self.recompute_layer: + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) else: hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask) From 7dd5beb36cbfbcefa2e8a67ed6186908f0e77749 Mon Sep 17 00:00:00 2001 From: Tiancheng Chen Date: Thu, 16 May 2024 18:54:55 +0200 Subject: [PATCH 05/35] fix row parallel --- .../parallel/tensor_parallel/functional.py | 74 ++++++++++++++----- tests/test_tensor_parallel.py | 29 ++++++-- 2 files changed, 76 insertions(+), 27 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index fdef48ac..1c4db5de 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -387,8 +387,7 @@ def backward(ctx, grad_output): group = ctx.group use_bias = ctx.use_bias - handle_0: Optional[dist.Work] = None - handle_1: Optional[dist.Work] = None + handle: Optional[dist.Work] = None # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = grad_output.shape @@ -412,31 +411,69 @@ def backward(ctx, grad_output): # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 grad_output = grad_output.contiguous() - handle_0 = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - - grad_tensor = grad_output.matmul(weight) + handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - # wait for the first all_gather to finish before starting the second all_gather - if handle_0 is not None: - handle_0.wait() - - # TODO @thomasw21: gather along another dimension - sharded_batch_size, *rest_size = grad_tensor.shape + # total_grad_output: [b, s, h_out] + # weight: [h_out, h_in/n] + # total_grad_tensor: [b, s, h_in/n] + # grad_output: [b/n, s, h_out] + sharded_batch_size, *rest_size_grad_output = grad_output.shape + rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]] if group.size() == 1: - total_grad_tensor = grad_tensor + total_grad_tensor = grad_output.matmul(weight) else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_tensor = torch.empty( unsharded_batch_size, - *rest_size, - device=grad_tensor.device, - dtype=grad_tensor.dtype, + *rest_size_grad_tensor, + device=grad_output.device, + dtype=grad_output.dtype, requires_grad=False, ) + before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split( + total_grad_tensor, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + # compute local shard + torch.mm( + input=grad_output.view(-1, grad_output.shape[-1]), + mat2=weight, + out=same_device_shard_grad_tensor.view(-1, weight.shape[1]), + ) - handle_1 = dist.all_gather_into_tensor(total_grad_tensor, grad_tensor, group=group, async_op=True) + if handle is not None: + handle.wait() + + before_shard_grad_output, _, after_shard_grad_output = torch.split( + total_grad_output, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + + # before shard compute + if before_shard_grad_tensor.numel() > 0: + torch.mm( + input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]), + mat2=weight, + out=before_shard_grad_tensor.view(-1, weight.shape[1]), + ) + # after shard compute + if after_shard_grad_tensor.numel() > 0: + torch.mm( + input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]), + mat2=weight, + out=after_shard_grad_tensor.view(-1, weight.shape[1]), + ) # Convert the tensor shapes to 2D for execution compatibility tensor = tensor.contiguous() @@ -454,9 +491,6 @@ def backward(ctx, grad_output): grad_weight = total_grad_output.t().matmul(tensor) grad_bias = total_grad_output.sum(dim=0) if use_bias else None - if handle_1 is not None: - handle_1.wait() - return total_grad_tensor, grad_weight, grad_bias, None, None diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 127ba2fa..f5dcaeb0 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -208,14 +208,19 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL random_input = torch.randn(batch_size, in_features, device="cuda") # synchronize random_input across tp dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) - + random_input.requires_grad = True # Row linear receives as input sharded input - random_sharded_input = random_input[ - :, - dist.get_rank(parallel_context.tp_pg) - * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) - * in_features_per_rank, - ] + random_sharded_input = ( + random_input[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ] + .detach() + .clone() + ) + random_sharded_input.requires_grad = True # Test that we get the same output after forward pass # TODO @kunhao: We may want to have our custom error type @@ -261,6 +266,16 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL else: assert row_linear.bias is None + torch.testing.assert_close( + random_sharded_input.grad, + random_input.grad[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ], + ) + parallel_context.destroy() From bcf405d9af2028773d6d76cd4ff658540b87a3f1 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 27 Jun 2024 11:56:53 +0000 Subject: [PATCH 06/35] Implemented global memory buffer to reduce activation memory of differentiable distributed operations --- src/nanotron/config/config.py | 2 +- .../distributed_differentiable_primitives.py | 17 +++-------------- src/nanotron/parallel/utils.py | 19 +++++++++++++++++++ src/nanotron/utils.py | 17 +++++++++++++++++ 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index d5b9976f..d72ea97f 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] + dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 873d77df..57a67c42 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -19,6 +19,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup +from nanotron.parallel.utils import MemoryBuffer class DifferentiableIdentity(torch.autograd.Function): @@ -67,13 +68,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): group = torch_dist.distributed_c10d._get_default_group() unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = torch.empty( - unsharded_batch_size, - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + unsharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 @@ -108,13 +103,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 tensor = tensor.contiguous() - sharded_tensor = torch.empty( - unsharded_batch_size // group.size(), - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + sharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size//group.size(), *rest_size), dtype=tensor.dtype) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor diff --git a/src/nanotron/parallel/utils.py b/src/nanotron/parallel/utils.py index b9ac12ae..eb4e441d 100644 --- a/src/nanotron/parallel/utils.py +++ b/src/nanotron/parallel/utils.py @@ -1,13 +1,32 @@ import functools +import operator import os +import torch from torch import nn from nanotron import distributed as dist +from nanotron.utils import Singleton from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param +class MemoryBuffer(metaclass=Singleton): + """ + Global memory buffer to store intermediate activations that need not to be cached for the backward pass. + """ + + def __init__(self): + self.buffer = {} + + def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: + required_numel = functools.reduce(operator.mul, shape, 1) + if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel: + self.buffer[name, dtype] = torch.empty(required_numel, dtype=dtype, device=torch.cuda.current_device(), + requires_grad=False) + return self.buffer[name, dtype][:required_numel].view(shape) + + def assert_cuda_max_connections_set_to_1(func): flag_is_set_to_1 = None diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 14fe1ca8..8065962b 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -15,6 +15,23 @@ from nanotron import distributed as dist +class Singleton(type): + """ + Singleton metaclass. + Create objects using this class as the metaclass to enable singleton behaviour. + For instance: + ``` + class Logger(metaclass=Singleton): + ... + ``` + """ + _instances = {} + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + class ContextManagers: """ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` From ed1ca7d0b55b07696d1c622d713c793f3ca53e28 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 27 Jun 2024 14:29:14 +0000 Subject: [PATCH 07/35] GLU fusion --- src/nanotron/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ca8894b9..d310fe2a 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -163,8 +163,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - # TODO @nouamane: why can't we torch.jit.script GLUActivation? - self.split_silu_mul = GLUActivation(config.hidden_act) + self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) From 9b0de5be04afb9cac631399593aef8de6aa852a6 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 27 Jun 2024 14:42:42 +0000 Subject: [PATCH 08/35] precommit --- src/nanotron/models/llama.py | 2 +- .../distributed_differentiable_primitives.py | 4 +++- src/nanotron/parallel/utils.py | 7 ++++--- src/nanotron/utils.py | 8 +++++--- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index d310fe2a..3319b0ef 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 57a67c42..aa460cc6 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -103,7 +103,9 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 tensor = tensor.contiguous() - sharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size//group.size(), *rest_size), dtype=tensor.dtype) + sharded_tensor = MemoryBuffer().get( + "dist", (unsharded_batch_size // group.size(), *rest_size), dtype=tensor.dtype + ) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor diff --git a/src/nanotron/parallel/utils.py b/src/nanotron/parallel/utils.py index eb4e441d..f694b0e6 100644 --- a/src/nanotron/parallel/utils.py +++ b/src/nanotron/parallel/utils.py @@ -6,9 +6,9 @@ from torch import nn from nanotron import distributed as dist -from nanotron.utils import Singleton from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param +from nanotron.utils import Singleton class MemoryBuffer(metaclass=Singleton): @@ -22,8 +22,9 @@ def __init__(self): def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: required_numel = functools.reduce(operator.mul, shape, 1) if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel: - self.buffer[name, dtype] = torch.empty(required_numel, dtype=dtype, device=torch.cuda.current_device(), - requires_grad=False) + self.buffer[name, dtype] = torch.empty( + required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) return self.buffer[name, dtype][:required_numel].view(shape) diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 8065962b..b3831801 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -1,11 +1,10 @@ import functools import inspect -import math import os import random import socket from contextlib import ExitStack, contextmanager -from typing import Callable, ContextManager, List, Optional +from typing import ContextManager, List, Optional import torch from packaging import version @@ -25,7 +24,9 @@ class Logger(metaclass=Singleton): ... ``` """ + _instances = {} + def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) @@ -69,7 +70,7 @@ def main_rank_first(group: dist.ProcessGroup): @contextmanager def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None): """Context manager that executes the code in the context with all the local rank zero of the group going first. - Usefull to run only once per node first (e.g. to create local files, etc) + Useful to run only once per node first (e.g. to create local files, etc) """ is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0 if is_main: @@ -140,6 +141,7 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage: else: return tensor.storage().untyped() + def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype): # TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage. device = untyped_storage.device From ed5a11c291e1988e3a86d74a3fba99be9ed6f57f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?X=CE=BBRI-U5?= Date: Mon, 8 Jul 2024 17:05:47 +0700 Subject: [PATCH 09/35] Update README.md --- examples/doremi/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/doremi/README.md b/examples/doremi/README.md index 5a726bd1..dfc9ea40 100644 --- a/examples/doremi/README.md +++ b/examples/doremi/README.md @@ -87,3 +87,7 @@ For evaluation, we do uniform sampling on the test set to evaluate a 2.5B model - 2.5B llama trained using the optimized weights: https://huggingface.co/nanotron/doremi-llama-2.5b-optimized-weights and the dataset: https://huggingface.co/datasets/nanotron/the-pile-for-doremi + +#### Thoughts + +For DoReMi, it's useful if you don't initially have an idea of what would be a good distribution for your training data, or want a quick way to find a better baseline than the uniform distribution if you want to tune the data distribution by hand. In my previous experiments, DoReMi matched the pretraining performance of the distribution of mamba training but couldn't outperform it. I suspect it doesn't work well when there are nuances, meaning the difference between your known best distribution and a better distribution isn't significant. From d5cf7c42896645bad0b73c48641bf68085b62e0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?X=CE=BBRI-U5?= Date: Mon, 8 Jul 2024 17:07:18 +0700 Subject: [PATCH 10/35] Update README.md --- examples/mup/README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/mup/README.md b/examples/mup/README.md index c86850ca..ed94c1fb 100644 --- a/examples/mup/README.md +++ b/examples/mup/README.md @@ -32,3 +32,8 @@ We trained a 350m model with spectral µTransfer and standard parametrization us Please check the directory [[./examples/mup/configs]](/examples/mup/configs) for the configurations we used to reproduce the experiments. ![LLaMA](./assets/llama.png) + + +#### Thoughts + +For Spectral MuP, the experiments we used it on MLP only [link] and 300m LLaMA [link] (there are links to the experiment config in the mup readme). However, when we tested it on 1B/8B models iirc, the loss blew up for some reasons. So, we'd recommend they try μTransfer, not spectral μTransfer. From 803b6da3233a642a0ba7a62484310d1496db81dc Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 16 Jul 2024 11:39:32 +0200 Subject: [PATCH 11/35] Wrong backward fixed --- .../parallel/tensor_parallel/column_linear.py | 62 +++++++++++++++++++ .../distributed_differentiable_primitives.py | 27 +++++--- .../parallel/tensor_parallel/functional.py | 4 +- 3 files changed, 85 insertions(+), 8 deletions(-) create mode 100644 src/nanotron/parallel/tensor_parallel/column_linear.py diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py new file mode 100644 index 00000000..eaab5abe --- /dev/null +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -0,0 +1,62 @@ +from typing import Optional + +import torch +from torch.nn import functional as F + +import nanotron.distributed as dist +from nanotron.parallel.utils import MemoryBuffer + + +class ColumnLinearContextParallel(torch.autograd.Function): + """ + Column linear with memory_buffer for the allgather, context parallel + enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and + async communication disabled. + """ + @staticmethod + def forward(ctx, input: torch.Tensor, weight: torch.Tensor, + bias: Optional[torch.Tensor], group: dist.ProcessGroup): + + # Prepare context. + ctx.save_for_backward(input, weight, bias) + ctx.group = group + + # Do allgather. + sharded_batch_size, *rest_size = input.shape + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Get linear output. + out = F.linear(total_input, weight, bias) + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Allgather the inputs again. + input, weight, bias = ctx.saved_tensors + group = ctx.group + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Get the grad_output and total_input on the correct views to be able to transpose them below. + grad_output = grad_output.contiguous() + assert grad_output.dim() == 3 + grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) + total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) + + # Compute gradients. + grad_input = grad_output @ weight + sub_grad_input = torch.empty(input.size(), dtype=input.dtype, device=input.device, requires_grad=False) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + grad_weight = grad_output.T @ total_input + grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None + + return sub_grad_input, grad_weight, grad_bias, None + +def column_linear_context_parallel(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + group: dist.ProcessGroup): + return ColumnLinearContextParallel.apply(input, weight, bias, group) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index aa460cc6..d66826e3 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -19,7 +19,6 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup -from nanotron.parallel.utils import MemoryBuffer class DifferentiableIdentity(torch.autograd.Function): @@ -68,7 +67,13 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): group = torch_dist.distributed_c10d._get_default_group() unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + unsharded_tensor = torch.empty( + unsharded_batch_size, + *rest_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + ) # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 @@ -79,8 +84,11 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): + #print(f"{torch.distributed.get_rank()} grad_output: {grad_output}") group = ctx.group - return DifferentiableReduceScatterSum.apply(grad_output, group), None + out = DifferentiableReduceScatterSum.apply(grad_output, group) + #print(f"{torch.distributed.get_rank()} grad_grad: {out}") + return out, None, None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -103,8 +111,12 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 tensor = tensor.contiguous() - sharded_tensor = MemoryBuffer().get( - "dist", (unsharded_batch_size // group.size(), *rest_size), dtype=tensor.dtype + sharded_tensor = torch.empty( + unsharded_batch_size // group.size(), + *rest_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, ) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor @@ -112,7 +124,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableAllGather.apply(grad_output, group), None + #print(f"{torch.distributed.get_rank()} Calling AllGather because of backward of reducescatter") + return DifferentiableAllGather.apply(grad_output, group, False), None # ----------------- @@ -128,7 +141,7 @@ def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllReduceSum.apply(tensor, group) -def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): +def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None) return DifferentiableAllGather.apply(tensor, group) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index fdef48ac..b3602707 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -26,6 +26,7 @@ differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 @@ -352,7 +353,7 @@ def column_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - input = differentiable_all_gather(input, group=group) + return column_linear_context_parallel(input, weight, bias, group) else: raise ValueError(f"Got unexpected mode: {tp_mode}.") @@ -473,6 +474,7 @@ def row_linear( out = F.linear(input, weight, bias) + #print("Calling row linear") if tp_mode is TensorParallelLinearMode.ALL_REDUCE: out = differentiable_all_reduce_sum(out, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: From 59bfb6b38c4487b04e425d8540b6e44b2a7fbcf9 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 16 Jul 2024 11:42:27 +0200 Subject: [PATCH 12/35] Removed useless prints --- .../tensor_parallel/distributed_differentiable_primitives.py | 3 --- src/nanotron/parallel/tensor_parallel/functional.py | 1 - 2 files changed, 4 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index d66826e3..f1102908 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -84,10 +84,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): - #print(f"{torch.distributed.get_rank()} grad_output: {grad_output}") group = ctx.group out = DifferentiableReduceScatterSum.apply(grad_output, group) - #print(f"{torch.distributed.get_rank()} grad_grad: {out}") return out, None, None @@ -124,7 +122,6 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - #print(f"{torch.distributed.get_rank()} Calling AllGather because of backward of reducescatter") return DifferentiableAllGather.apply(grad_output, group, False), None diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index b3602707..cedbb219 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -474,7 +474,6 @@ def row_linear( out = F.linear(input, weight, bias) - #print("Calling row linear") if tp_mode is TensorParallelLinearMode.ALL_REDUCE: out = differentiable_all_reduce_sum(out, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: From 2c69e9ad2887b7e78c88c2db3209713542dad7e2 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 10:01:44 +0000 Subject: [PATCH 13/35] Minor fixes --- .../distributed_differentiable_primitives.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index f1102908..bd41347a 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -86,7 +86,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): def backward(ctx, grad_output): group = ctx.group out = DifferentiableReduceScatterSum.apply(grad_output, group) - return out, None, None + return out, None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -122,7 +122,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableAllGather.apply(grad_output, group, False), None + return DifferentiableAllGather.apply(grad_output, group), None # ----------------- @@ -138,7 +138,7 @@ def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllReduceSum.apply(tensor, group) -def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None) +def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllGather.apply(tensor, group) From 30439fdee7cac456be4a2c28798b42c931f7cf72 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 11:16:40 +0000 Subject: [PATCH 14/35] precommit --- .../parallel/tensor_parallel/column_linear.py | 12 ++++++++---- src/nanotron/parallel/tensor_parallel/functional.py | 7 +++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py index eaab5abe..21daba36 100644 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -13,9 +13,11 @@ class ColumnLinearContextParallel(torch.autograd.Function): enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and async communication disabled. """ + @staticmethod - def forward(ctx, input: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor], group: dist.ProcessGroup): + def forward( + ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup + ): # Prepare context. ctx.save_for_backward(input, weight, bias) @@ -57,6 +59,8 @@ def backward(ctx, grad_output: torch.Tensor): return sub_grad_input, grad_weight, grad_bias, None -def column_linear_context_parallel(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], - group: dist.ProcessGroup): + +def column_linear_context_parallel( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup +): return ColumnLinearContextParallel.apply(input, weight, bias, group) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index cedbb219..f4e9de30 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,14 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist +from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( - differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 @@ -90,10 +89,10 @@ def forward( @staticmethod def backward(ctx, grad_output): - # Retreive tensors from the forward path. + # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors - # All the inputs have softmax as thier gradient. + # All the inputs have softmax as their gradient. grad_input = softmax # For simplicity, work with the 2D gradient. sharded_hidden_size = softmax.size()[-1] From 1e02a9ce9c9b564f4a4274ee62e7208e3d5d9df8 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 12:47:20 +0000 Subject: [PATCH 15/35] Added tp_recompute_allgather option --- src/nanotron/config/parallelism_config.py | 2 + src/nanotron/models/llama.py | 3 ++ .../parallel/tensor_parallel/column_linear.py | 51 ++++++++++++------- .../parallel/tensor_parallel/functional.py | 4 +- src/nanotron/parallel/tensor_parallel/nn.py | 3 ++ 5 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 5912425b..e9a6f2a4 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -32,6 +32,8 @@ class ParallelismArgs: tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None + tp_recompute_allgather: bool = False + expert_parallel_size: int = 1 def __post_init__(self): diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 3319b0ef..a31ebec6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -154,6 +154,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -314,6 +315,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. self.rotary_embedding = RotaryEmbedding( @@ -742,6 +744,7 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, + "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py index 21daba36..2f743199 100644 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -16,33 +16,47 @@ class ColumnLinearContextParallel(torch.autograd.Function): @staticmethod def forward( - ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup + ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, + tp_recompute_allgather: bool ): - # Prepare context. - ctx.save_for_backward(input, weight, bias) - ctx.group = group - # Do allgather. sharded_batch_size, *rest_size = input.shape unsharded_batch_size = sharded_batch_size * group.size() - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + if tp_recompute_allgather: + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + else: + total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + # Prepare context. + ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.input_size = input.shape + if tp_recompute_allgather: + ctx.save_for_backward(input, weight, bias) + else: + ctx.save_for_backward(total_input, weight, bias) + # Get linear output. out = F.linear(total_input, weight, bias) return out @staticmethod def backward(ctx, grad_output: torch.Tensor): - # Allgather the inputs again. - input, weight, bias = ctx.saved_tensors + # Either allgather the inputs again or get them from context. group = ctx.group - sharded_batch_size, *rest_size = input.shape - total_input = sharded_batch_size * group.size() - unsharded_batch_size = sharded_batch_size * group.size() - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + tp_recompute_allgather = ctx.tp_recompute_allgather + input_size = ctx.input_size + if tp_recompute_allgather: + input, weight, bias = ctx.saved_tensors + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + else: + total_input, weight, bias = ctx.saved_tensors # Get the grad_output and total_input on the correct views to be able to transpose them below. grad_output = grad_output.contiguous() @@ -51,16 +65,17 @@ def backward(ctx, grad_output: torch.Tensor): total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) # Compute gradients. + grad_weight = grad_output.T @ total_input grad_input = grad_output @ weight - sub_grad_input = torch.empty(input.size(), dtype=input.dtype, device=input.device, requires_grad=False) + sub_grad_input = torch.empty(input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False) dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) - grad_weight = grad_output.T @ total_input grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None - return sub_grad_input, grad_weight, grad_bias, None + return sub_grad_input, grad_weight, grad_bias, None, None def column_linear_context_parallel( - input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, + tp_recompute_allgather: bool = False ): - return ColumnLinearContextParallel.apply(input, weight, bias, group) + return ColumnLinearContextParallel.apply(input, weight, bias, group, tp_recompute_allgather) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index f4e9de30..c16ae492 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -21,6 +21,7 @@ import nanotron.distributed as dist from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( + differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, @@ -345,6 +346,7 @@ def column_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, + tp_recompute_allgather: bool = True ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) @@ -352,7 +354,7 @@ def column_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return column_linear_context_parallel(input, weight, bias, group) + return column_linear_context_parallel(input, weight, bias, group, tp_recompute_allgather) else: raise ValueError(f"Got unexpected mode: {tp_mode}.") diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 40e89968..42ffc828 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -51,6 +51,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, + tp_recompute_allgather: bool = False, ): self.pg = pg self.world_size = pg.size() @@ -59,6 +60,7 @@ def __init__( self.in_features = in_features self.out_features = out_features // self.world_size + self.tp_recompute_allgather = tp_recompute_allgather super().__init__( in_features=self.in_features, @@ -91,6 +93,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + tp_recompute_allgather=self.tp_recompute_allgather, ) def extra_repr(self) -> str: From 9cc81bb6fe680b72cf6114f7258af0483886ada1 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 13:39:14 +0000 Subject: [PATCH 16/35] Changed recompute default --- src/nanotron/config/parallelism_config.py | 2 +- .../parallel/tensor_parallel/column_linear.py | 19 ++++++++++++++----- .../parallel/tensor_parallel/functional.py | 3 +-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index e9a6f2a4..cc5d406a 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -32,7 +32,7 @@ class ParallelismArgs: tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None - tp_recompute_allgather: bool = False + tp_recompute_allgather: bool = True expert_parallel_size: int = 1 diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py index 2f743199..880d5ff0 100644 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -16,8 +16,12 @@ class ColumnLinearContextParallel(torch.autograd.Function): @staticmethod def forward( - ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, - tp_recompute_allgather: bool + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool, ): # Do allgather. @@ -67,7 +71,9 @@ def backward(ctx, grad_output: torch.Tensor): # Compute gradients. grad_weight = grad_output.T @ total_input grad_input = grad_output @ weight - sub_grad_input = torch.empty(input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False) + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None @@ -75,7 +81,10 @@ def backward(ctx, grad_output: torch.Tensor): def column_linear_context_parallel( - input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, - tp_recompute_allgather: bool = False + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool = True, ): return ColumnLinearContextParallel.apply(input, weight, bias, group, tp_recompute_allgather) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index c16ae492..454cc447 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -21,7 +21,6 @@ import nanotron.distributed as dist from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( - differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, @@ -346,7 +345,7 @@ def column_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, - tp_recompute_allgather: bool = True + tp_recompute_allgather: bool = True, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) From 956fbfd09a2f0c2358fcb90be395f97ffa79632e Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 13:40:34 +0000 Subject: [PATCH 17/35] Changed recompute default --- src/nanotron/parallel/tensor_parallel/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 42ffc828..4c7325cd 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -51,7 +51,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, - tp_recompute_allgather: bool = False, + tp_recompute_allgather: bool = True, ): self.pg = pg self.world_size = pg.size() From 9992f1c5919fd4038e85cd9b3fb1dd4faa81daf1 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 14:28:57 +0000 Subject: [PATCH 18/35] Little fixes --- src/nanotron/config/config.py | 4 ++-- tools/preprocess_data.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index fe194883..2e1a98cc 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -96,10 +96,10 @@ class NanosetDatasetsArgs: dataset_folder: Union[str, dict, List[str]] def __post_init__(self): - if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file + if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder self.dataset_folder = [self.dataset_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder self.dataset_weights = None # Set to None so we consume all the samples randomly elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights tmp_dataset_folder = self.dataset_folder.copy() diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 38db67f1..f3cdab70 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -95,6 +95,7 @@ def main(args): output_folder=args.output_folder, tokenizer_name_or_path=args.tokenizer_name_or_path, eos_token=args.eos_token, + shuffle=False, max_tokens_per_file=1e9, ), ], From b9e92017614e0326acab86b7665e0c8e8718bfc3 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 15:21:04 +0000 Subject: [PATCH 19/35] Moved ColumnLinearNoAsync module for consistency --- .../parallel/tensor_parallel/column_linear.py | 90 ------------------- .../parallel/tensor_parallel/functional.py | 78 +++++++++++++++- 2 files changed, 76 insertions(+), 92 deletions(-) delete mode 100644 src/nanotron/parallel/tensor_parallel/column_linear.py diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py deleted file mode 100644 index 880d5ff0..00000000 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Optional - -import torch -from torch.nn import functional as F - -import nanotron.distributed as dist -from nanotron.parallel.utils import MemoryBuffer - - -class ColumnLinearContextParallel(torch.autograd.Function): - """ - Column linear with memory_buffer for the allgather, context parallel - enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and - async communication disabled. - """ - - @staticmethod - def forward( - ctx, - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - group: dist.ProcessGroup, - tp_recompute_allgather: bool, - ): - - # Do allgather. - sharded_batch_size, *rest_size = input.shape - unsharded_batch_size = sharded_batch_size * group.size() - if tp_recompute_allgather: - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - else: - total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - - # Prepare context. - ctx.group = group - ctx.tp_recompute_allgather = tp_recompute_allgather - ctx.input_size = input.shape - if tp_recompute_allgather: - ctx.save_for_backward(input, weight, bias) - else: - ctx.save_for_backward(total_input, weight, bias) - - # Get linear output. - out = F.linear(total_input, weight, bias) - return out - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - # Either allgather the inputs again or get them from context. - group = ctx.group - tp_recompute_allgather = ctx.tp_recompute_allgather - input_size = ctx.input_size - if tp_recompute_allgather: - input, weight, bias = ctx.saved_tensors - sharded_batch_size, *rest_size = input.shape - total_input = sharded_batch_size * group.size() - unsharded_batch_size = sharded_batch_size * group.size() - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - else: - total_input, weight, bias = ctx.saved_tensors - - # Get the grad_output and total_input on the correct views to be able to transpose them below. - grad_output = grad_output.contiguous() - assert grad_output.dim() == 3 - grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) - total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) - - # Compute gradients. - grad_weight = grad_output.T @ total_input - grad_input = grad_output @ weight - sub_grad_input = torch.empty( - input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False - ) - dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) - grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None - - return sub_grad_input, grad_weight, grad_bias, None, None - - -def column_linear_context_parallel( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - group: dist.ProcessGroup, - tp_recompute_allgather: bool = True, -): - return ColumnLinearContextParallel.apply(input, weight, bias, group, tp_recompute_allgather) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 454cc447..2b93fb02 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,7 +19,7 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel +from nanotron.parallel.utils import MemoryBuffer from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, @@ -338,6 +338,80 @@ def backward(ctx, grad_output): raise ValueError(f"Got unexpected mode: {tp_mode}.") +class _ColumnLinearContextParallelNoAsync(torch.autograd.Function): + """ + Column linear with memory_buffer for the allgather, context parallel + enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and + async communication disabled. + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool, + ): + + # Do allgather. + sharded_batch_size, *rest_size = input.shape + unsharded_batch_size = sharded_batch_size * group.size() + if tp_recompute_allgather: + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + else: + total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Prepare context. + ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.input_size = input.shape + if tp_recompute_allgather: + ctx.save_for_backward(input, weight, bias) + else: + ctx.save_for_backward(total_input, weight, bias) + + # Get linear output. + out = F.linear(total_input, weight, bias) + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Either allgather the inputs again or get them from context. + group = ctx.group + tp_recompute_allgather = ctx.tp_recompute_allgather + input_size = ctx.input_size + if tp_recompute_allgather: + input, weight, bias = ctx.saved_tensors + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + else: + total_input, weight, bias = ctx.saved_tensors + + # Get the grad_output and total_input on the correct views to be able to transpose them below. + grad_output = grad_output.contiguous() + assert grad_output.dim() == 3 + grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) + total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) + + # Compute gradients. + grad_weight = grad_output.T @ total_input + grad_input = grad_output @ weight + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None + + return sub_grad_input, grad_weight, grad_bias, None, None + + + def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -353,7 +427,7 @@ def column_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return column_linear_context_parallel(input, weight, bias, group, tp_recompute_allgather) + return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather) else: raise ValueError(f"Got unexpected mode: {tp_mode}.") From 7cc6653c69c19cd42da05e6a8712159a146407e7 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:16:42 +0000 Subject: [PATCH 20/35] memory efficient async linear --- .../parallel/tensor_parallel/functional.py | 48 ++++++------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 47c0b5a1..1a82254e 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -149,14 +149,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = torch.empty( - gathered_batch_size, - *intermediate_size, - hidden_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + gathered_tensor = MemoryBuffer().get("allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -261,7 +254,7 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias tp_mode = ctx.tp_mode - handle: Optional[dist.Work] = None + handle1: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape @@ -273,14 +266,8 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = torch.empty( - unsharded_batch_size, - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=False, - ) - handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) + unsharded_tensor = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation total_tensor = unsharded_tensor @@ -289,9 +276,6 @@ def backward(ctx, grad_output): grad_tensor = grad_output.matmul(weight) - if handle is not None: - handle.wait() - # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only # clones it if it's not contiguous: @@ -303,7 +287,7 @@ def backward(ctx, grad_output): grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) - handle: Optional[dist.Work] = None + handle2: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: if group.size() == 1: sub_grad_tensor = grad_tensor @@ -312,23 +296,27 @@ def backward(ctx, grad_output): tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter - handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) + handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # reduce scatter is scheduled before the weight gradient computation elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: # Asynchronous all-reduce - handle = dist.all_reduce(grad_tensor, group=group, async_op=True) + handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # all-reduce is scheduled before the weight gradient computation else: raise ValueError() + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if handle1 is not None: + handle1.wait() + # TODO @thomasw21: This sounds like we don't have the optimal physical layout grad_weight = grad_output.t().matmul(total_tensor) - grad_bias = grad_output.sum(dim=0) if use_bias else None - if handle is not None: - handle.wait() + if handle2 is not None: + handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return sub_grad_tensor, grad_weight, grad_bias, None, None @@ -472,13 +460,7 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = torch.empty( - unsharded_batch_size, - *rest_size, - device=grad_output.device, - dtype=grad_output.dtype, - requires_grad=False, - ) + total_grad_output = MemoryBuffer().get("allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only From cb0f2609e357747a0a3001dd9b12c649f9e6eef7 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:17:09 +0000 Subject: [PATCH 21/35] precommit --- .../parallel/tensor_parallel/functional.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 1a82254e..3821d544 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,14 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.utils import MemoryBuffer from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -149,7 +148,9 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = MemoryBuffer().get("allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype) + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -266,7 +267,9 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + unsharded_tensor = MemoryBuffer().get( + "allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype + ) handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation @@ -399,7 +402,6 @@ def backward(ctx, grad_output: torch.Tensor): return sub_grad_input, grad_weight, grad_bias, None, None - def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -460,7 +462,9 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = MemoryBuffer().get("allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + total_grad_output = MemoryBuffer().get( + "allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype + ) # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only From 6d85d038d52ffc06ec4f2ae4705deae3b05d25d8 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:46:39 +0000 Subject: [PATCH 22/35] Added no_recompute_allgather mode to async --- .../parallel/tensor_parallel/functional.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 3821d544..29480f9a 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -120,10 +120,12 @@ class _ColumnLinearAsyncCommunication(torch.autograd.Function): @staticmethod @assert_cuda_max_connections_set_to_1 - def forward(ctx, tensor, weight, bias, group, tp_mode): + def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): ctx.use_bias = bias is not None ctx.tp_mode = tp_mode ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.tensor_shape = tensor.size() if tp_mode is TensorParallelLinearMode.ALL_REDUCE: gathered_tensor = tensor @@ -140,7 +142,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 tensor = tensor.contiguous() - ctx.save_for_backward(tensor, weight) + # ctx.save_for_backward(tensor, weight) # TODO @thomasw21: gather along another dimension sharded_batch_size, *intermediate_size, hidden_size = tensor.shape @@ -148,9 +150,19 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = MemoryBuffer().get( - "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype - ) + if tp_recompute_allgather: + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) + else: + gathered_tensor = torch.empty( + gathered_batch_size, + *intermediate_size, + hidden_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -198,6 +210,10 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # Wait communication handle.wait() + if tp_recompute_allgather: + ctx.save_for_backward(tensor, weight) + else: + ctx.save_for_backward(gathered_tensor, weight) # Compute all the other shards that are obtained from AllGather # weights: w0 w1 w2 w3 @@ -256,7 +272,7 @@ def backward(ctx, grad_output): tp_mode = ctx.tp_mode handle1: Optional[dist.Work] = None - if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape if group is None: @@ -296,7 +312,7 @@ def backward(ctx, grad_output): sub_grad_tensor = grad_tensor else: sub_grad_tensor = torch.empty( - tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False + ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) @@ -322,9 +338,9 @@ def backward(ctx, grad_output): handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return sub_grad_tensor, grad_weight, grad_bias, None, None + return sub_grad_tensor, grad_weight, grad_bias, None, None, None elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: - return grad_tensor, grad_weight, grad_bias, None, None + return grad_tensor, grad_weight, grad_bias, None, None, None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") @@ -412,7 +428,7 @@ def column_linear( tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) From 2afd00769c7c5891341e2d7880492bfc80c524f6 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 23 Jul 2024 09:52:12 +0000 Subject: [PATCH 23/35] Fixed List not found --- src/nanotron/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 7ee44390..49ea86e6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List import torch from torch import nn From 7e758db3068948178edd2151232e4abd7b2d5ffd Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 23 Jul 2024 13:17:34 +0000 Subject: [PATCH 24/35] Fixed tp=1 case --- .../parallel/tensor_parallel/functional.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 2b93fb02..22c8ca3c 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -358,11 +358,14 @@ def forward( # Do allgather. sharded_batch_size, *rest_size = input.shape unsharded_batch_size = sharded_batch_size * group.size() - if tp_recompute_allgather: + if group.size() == 1: + total_input = input.contiguous() + elif tp_recompute_allgather: total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) else: total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) # Prepare context. ctx.group = group @@ -383,21 +386,22 @@ def backward(ctx, grad_output: torch.Tensor): group = ctx.group tp_recompute_allgather = ctx.tp_recompute_allgather input_size = ctx.input_size - if tp_recompute_allgather: + if group.size() == 1 or not tp_recompute_allgather: + total_input, weight, bias = ctx.saved_tensors + else: input, weight, bias = ctx.saved_tensors sharded_batch_size, *rest_size = input.shape total_input = sharded_batch_size * group.size() unsharded_batch_size = sharded_batch_size * group.size() total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - else: - total_input, weight, bias = ctx.saved_tensors - # Get the grad_output and total_input on the correct views to be able to transpose them below. + # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.contiguous() - assert grad_output.dim() == 3 - grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) - total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) + grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1] + total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1] + grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) + total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) # Compute gradients. grad_weight = grad_output.T @ total_input From 41f11f01798b88538bbc31a793435a6572b088ea Mon Sep 17 00:00:00 2001 From: Tiancheng Chen Date: Thu, 16 May 2024 18:54:55 +0200 Subject: [PATCH 25/35] fix row parallel --- .../parallel/tensor_parallel/functional.py | 74 ++++++++++++++----- tests/test_tensor_parallel.py | 29 ++++++-- 2 files changed, 76 insertions(+), 27 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index fdef48ac..1c4db5de 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -387,8 +387,7 @@ def backward(ctx, grad_output): group = ctx.group use_bias = ctx.use_bias - handle_0: Optional[dist.Work] = None - handle_1: Optional[dist.Work] = None + handle: Optional[dist.Work] = None # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = grad_output.shape @@ -412,31 +411,69 @@ def backward(ctx, grad_output): # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 grad_output = grad_output.contiguous() - handle_0 = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - - grad_tensor = grad_output.matmul(weight) + handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - # wait for the first all_gather to finish before starting the second all_gather - if handle_0 is not None: - handle_0.wait() - - # TODO @thomasw21: gather along another dimension - sharded_batch_size, *rest_size = grad_tensor.shape + # total_grad_output: [b, s, h_out] + # weight: [h_out, h_in/n] + # total_grad_tensor: [b, s, h_in/n] + # grad_output: [b/n, s, h_out] + sharded_batch_size, *rest_size_grad_output = grad_output.shape + rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]] if group.size() == 1: - total_grad_tensor = grad_tensor + total_grad_tensor = grad_output.matmul(weight) else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_tensor = torch.empty( unsharded_batch_size, - *rest_size, - device=grad_tensor.device, - dtype=grad_tensor.dtype, + *rest_size_grad_tensor, + device=grad_output.device, + dtype=grad_output.dtype, requires_grad=False, ) + before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split( + total_grad_tensor, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + # compute local shard + torch.mm( + input=grad_output.view(-1, grad_output.shape[-1]), + mat2=weight, + out=same_device_shard_grad_tensor.view(-1, weight.shape[1]), + ) - handle_1 = dist.all_gather_into_tensor(total_grad_tensor, grad_tensor, group=group, async_op=True) + if handle is not None: + handle.wait() + + before_shard_grad_output, _, after_shard_grad_output = torch.split( + total_grad_output, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + + # before shard compute + if before_shard_grad_tensor.numel() > 0: + torch.mm( + input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]), + mat2=weight, + out=before_shard_grad_tensor.view(-1, weight.shape[1]), + ) + # after shard compute + if after_shard_grad_tensor.numel() > 0: + torch.mm( + input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]), + mat2=weight, + out=after_shard_grad_tensor.view(-1, weight.shape[1]), + ) # Convert the tensor shapes to 2D for execution compatibility tensor = tensor.contiguous() @@ -454,9 +491,6 @@ def backward(ctx, grad_output): grad_weight = total_grad_output.t().matmul(tensor) grad_bias = total_grad_output.sum(dim=0) if use_bias else None - if handle_1 is not None: - handle_1.wait() - return total_grad_tensor, grad_weight, grad_bias, None, None diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 127ba2fa..f5dcaeb0 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -208,14 +208,19 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL random_input = torch.randn(batch_size, in_features, device="cuda") # synchronize random_input across tp dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) - + random_input.requires_grad = True # Row linear receives as input sharded input - random_sharded_input = random_input[ - :, - dist.get_rank(parallel_context.tp_pg) - * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) - * in_features_per_rank, - ] + random_sharded_input = ( + random_input[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ] + .detach() + .clone() + ) + random_sharded_input.requires_grad = True # Test that we get the same output after forward pass # TODO @kunhao: We may want to have our custom error type @@ -261,6 +266,16 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL else: assert row_linear.bias is None + torch.testing.assert_close( + random_sharded_input.grad, + random_input.grad[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ], + ) + parallel_context.destroy() From 793bdf3b27275a0e8cb3d919fd10028066fafce6 Mon Sep 17 00:00:00 2001 From: ischlag Date: Mon, 29 Jul 2024 16:51:42 +0200 Subject: [PATCH 26/35] fix upstream bug --- src/nanotron/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index b9ec5deb..a440c8d0 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List import torch from torch import nn From cd84d4fa7bff4017a9b1653c4b68290c00eae649 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:16:14 +0200 Subject: [PATCH 27/35] Fixed column parallel --- .../parallel/tensor_parallel/functional.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 054d41e9..468855a5 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,14 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.utils import MemoryBuffer from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -399,23 +398,25 @@ def backward(ctx, grad_output: torch.Tensor): # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.contiguous() grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1] - total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1] + total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1] grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) - total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) + total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim) # Compute gradients. grad_weight = grad_output.T @ total_input grad_input = grad_output @ weight - sub_grad_input = torch.empty( - input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False - ) - dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + if group.size() == 1: + sub_grad_input = grad_input + else: + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None return sub_grad_input, grad_weight, grad_bias, None, None - def column_linear( input: torch.Tensor, weight: torch.Tensor, From d3db06acb2e3fe235ae512861ff64ee9fbc9ac11 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:16:28 +0200 Subject: [PATCH 28/35] Added tp_recompute_allgather test --- tests/test_tensor_parallel.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index f5dcaeb0..8e73973b 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -18,17 +18,30 @@ @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_column_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather") init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)( - tp_mode=tp_mode, async_communication=async_communication + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather ) def _test_column_linear( - parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -44,6 +57,7 @@ def _test_column_linear( mode=tp_mode, device="cuda", async_communication=async_communication, + tp_recompute_allgather=tp_recompute_allgather, ) # Un-sharded @@ -86,7 +100,7 @@ def _test_column_linear( random_input = sharded_random_input else: ValueError(f"Unsupported mode: {tp_mode}") - # It's important that `random_input` and `sharded_random_input` are two seperate tensors with seperate storage + # It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage sharded_random_input = sharded_random_input.clone() random_input.requires_grad = True sharded_random_input.requires_grad = True From 4c94b99a8cafd5f8dc2b7f208341a17a4d818234 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:34:31 +0200 Subject: [PATCH 29/35] Added tp_recompute_allgather test --- tests/test_tensor_parallel.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 8e73973b..16008eaa 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -164,15 +164,32 @@ def _test_column_linear( @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_row_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_row_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather") - init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(tp_mode=tp_mode, async_communication=async_communication) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)( + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather + ) -def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool): +def _test_row_linear( + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" out_features = 3 From 3c3561158eb053176b6f148a2366ea1aa56fdc7d Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 31 Jul 2024 17:15:34 +0000 Subject: [PATCH 30/35] change to correct config_nanoset.yaml path --- docs/nanoset.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/nanoset.md b/docs/nanoset.md index 9dce21b7..61393438 100644 --- a/docs/nanoset.md +++ b/docs/nanoset.md @@ -79,7 +79,7 @@ To work with `Nanosets`, we just need to configure 1 argument: Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py). ```shell -torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml +torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml ``` ## Under the hood From 7daa186e84e03fa66fa83129d9b0acffb1a668ba Mon Sep 17 00:00:00 2001 From: AleHD Date: Fri, 2 Aug 2024 15:40:46 +0200 Subject: [PATCH 31/35] Minor restyling --- .../parallel/tensor_parallel/functional.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 468855a5..1fb86cb5 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -115,7 +115,7 @@ def sharded_cross_entropy(sharded_logits, target, group: dist.ProcessGroup, dtyp return _ShardedCrossEntropy.apply(sharded_logits, target, group) -class _ColumnLinearAsyncCommunication(torch.autograd.Function): +class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): """Adapted from https://github.com/NVIDIA/Megatron-LM/blob/e6d7e09845590d0a36bc7f29eb28db974fb8da4e/megatron/core/tensor_parallel/layers.py#L215""" @staticmethod @@ -408,6 +408,9 @@ def backward(ctx, grad_output: torch.Tensor): if group.size() == 1: sub_grad_input = grad_input else: + # Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 + # We set grad_input to be contiguous in case it isn't already. + grad_input = grad_input.contiguous() sub_grad_input = torch.empty( input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False ) @@ -427,16 +430,14 @@ def column_linear( tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(input, weight, bias, group, tp_mode) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) - elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + return F.linear(input, weight, bias) + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather) - else: - raise ValueError(f"Got unexpected mode: {tp_mode}.") - - return F.linear(input, weight, bias) + raise ValueError(f"Got unexpected mode: {tp_mode}.") class _RowLinearAsyncCommunication(torch.autograd.Function): From 31c3c5ad0a845ff6318842c663223e9621586a3d Mon Sep 17 00:00:00 2001 From: AleHD Date: Fri, 2 Aug 2024 15:54:44 +0200 Subject: [PATCH 32/35] Fixed names --- src/nanotron/parallel/tensor_parallel/functional.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 1fb86cb5..7a88aec6 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -115,7 +115,7 @@ def sharded_cross_entropy(sharded_logits, target, group: dist.ProcessGroup, dtyp return _ShardedCrossEntropy.apply(sharded_logits, target, group) -class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): +class _ColumnLinearAsyncCommunication(torch.autograd.Function): """Adapted from https://github.com/NVIDIA/Megatron-LM/blob/e6d7e09845590d0a36bc7f29eb28db974fb8da4e/megatron/core/tensor_parallel/layers.py#L215""" @staticmethod @@ -337,7 +337,7 @@ def backward(ctx, grad_output): raise ValueError(f"Got unexpected mode: {tp_mode}.") -class _ColumnLinearContextParallelNoAsync(torch.autograd.Function): +class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): """ Column linear with memory_buffer for the allgather, context parallel enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and @@ -430,13 +430,15 @@ def column_linear( tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather) + return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( + input, weight, bias, group, tp_recompute_allgather + ) raise ValueError(f"Got unexpected mode: {tp_mode}.") From 664c09aa48204b8a45756e15fac9cc6bf0b38ccf Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 4 Sep 2024 10:38:32 +0000 Subject: [PATCH 33/35] lr scheduler resume training with PP fix --- src/nanotron/serialize/main.py | 1 + src/nanotron/serialize/optimizer.py | 12 +++++------- src/nanotron/trainer.py | 9 +++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 286008ac..346ad573 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -236,6 +236,7 @@ def load( load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder) load_lr_scheduler( lr_scheduler=lr_scheduler, + parallel_context=parallel_context, root_folder=root_folder, ) return checkpoint_metadata diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index 68a3b1a0..f11210da 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -30,9 +30,9 @@ def optimizer_filename(parallel_context: ParallelContext, is_zero: bool): return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" -def lr_scheduler_filename(): +def lr_scheduler_filename(parallel_context: ParallelContext): """The lr_scheduler is the same for all processes.""" - return f"{ObjectType.LR_SCHEDULER.value}.pt" + return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" def save_optimizer( @@ -109,9 +109,6 @@ def save_lr_scheduler( root_folder: Path, ): """Saves lr scheduler states""" - if dist.get_rank(parallel_context.world_pg) > 0: - # Only WORLD-RANK 0 saves the lr scheduler state - return root_folder = root_folder / "lr_scheduler" root_folder.mkdir(exist_ok=True, parents=True) @@ -119,7 +116,7 @@ def save_lr_scheduler( # We dump the optimizer state using `torch.save` torch.save( lr_scheduler.state_dict(), - root_folder / lr_scheduler_filename(), + root_folder / lr_scheduler_filename(parallel_context), ) @@ -313,9 +310,10 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - def load_lr_scheduler( lr_scheduler, + parallel_context: ParallelContext, root_folder: Path, ): root_folder = root_folder / "lr_scheduler" - state_dict = torch.load(root_folder / lr_scheduler_filename()) + state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context)) lr_scheduler.load_state_dict(state_dict) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 214ea52f..21251a32 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -206,6 +206,7 @@ def __init__( if self.init_checkpoint_path is not None: load_lr_scheduler( lr_scheduler=self.lr_scheduler, + parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path, ) @@ -443,10 +444,10 @@ def train( self.save_checkpoint() dist.barrier() # let's wait for everyone before leaving - + if self.config.checkpoints.save_final_state: self.save_checkpoint() - + self.post_training() def training_step( @@ -865,8 +866,8 @@ def save_checkpoint(self) -> Path: ), # We only save the weights on DP==0 should_save_optimizer=True, should_save_lr_scheduler=bool( - dist.get_rank(self.parallel_context.world_pg) == 0 - ), # We only save the lr_scheduler on world_rank==0 + dist.get_rank(self.parallel_context.dp_pg) == 0 + ), # We only save the lr_scheduler on DP==0 should_save_config=bool( dist.get_rank(self.parallel_context.world_pg) == 0 ), # We only save the config on world_rank==0 From 761d253c21d101216fcf396ced495f88ad01bdb9 Mon Sep 17 00:00:00 2001 From: Kyle Matoba <22180455+kylematoba@users.noreply.github.com> Date: Sat, 14 Sep 2024 22:45:47 +0200 Subject: [PATCH 34/35] work --- lion_pytorch/__init__.py | 1 + lion_pytorch/foreach.py | 95 ++++++++++++++++++++++++++++++++++ lion_pytorch/lion_pytorch.py | 97 +++++++++++++++++++++++++++++++++++ lion_pytorch/triton.py | 98 ++++++++++++++++++++++++++++++++++++ src/nanotron/helpers.py | 14 ++++-- 5 files changed, 300 insertions(+), 5 deletions(-) create mode 100644 lion_pytorch/__init__.py create mode 100644 lion_pytorch/foreach.py create mode 100644 lion_pytorch/lion_pytorch.py create mode 100644 lion_pytorch/triton.py diff --git a/lion_pytorch/__init__.py b/lion_pytorch/__init__.py new file mode 100644 index 00000000..b3a7799d --- /dev/null +++ b/lion_pytorch/__init__.py @@ -0,0 +1 @@ +from lion_pytorch.lion_pytorch import Lion diff --git a/lion_pytorch/foreach.py b/lion_pytorch/foreach.py new file mode 100644 index 00000000..50d60518 --- /dev/null +++ b/lion_pytorch/foreach.py @@ -0,0 +1,95 @@ +from __future__ import annotations +from typing import Tuple, Callable + +import torch +from torch.optim.optimizer import Optimizer + +# functions + +def exists(val): + return val is not None + +# class + +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + decoupled_weight_decay: bool = False + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'sign', 'lerp')]), 'this version of torch does not have the prerequisite foreach functions' + + self._init_lr = lr + self.decoupled_wd = decoupled_weight_decay + + defaults = dict( + lr = lr, + betas = betas, + weight_decay = weight_decay + ) + + super().__init__(params, defaults) + + @torch.no_grad() + def step( + self, + closure: Callable | None = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + lr, wd, beta1, beta2, decoupled_wd, init_lr = group['lr'], group['weight_decay'], *group['betas'], self.decoupled_wd, self._init_lr + + # maybe decoupled weight decay + + if decoupled_wd: + wd /= init_lr + + # accumulate List[Tensor] for foreach inplace updates + + params = [] + grads = [] + exp_avgs = [] + + for p in filter(lambda p: exists(p.grad), group['params']): + + grad, state = p.grad, self.state[p] + + # init state - exponential moving average of gradient values + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + params.append(p) + grads.append(grad) + exp_avgs.append(exp_avg) + + # stepweight decay + + torch._foreach_mul_(params, 1. - lr * wd) + + # weight update + + updates = [t.clone() for t in exp_avgs] + torch._foreach_lerp_(updates, grads, 1. - beta1) + torch._foreach_sign_(updates) + + torch._foreach_add_(params, updates, alpha = -lr) + + # decay momentum running average + + torch._foreach_lerp_(exp_avgs, grads, 1. - beta2) + + return loss diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py new file mode 100644 index 00000000..b0d3a3f8 --- /dev/null +++ b/lion_pytorch/lion_pytorch.py @@ -0,0 +1,97 @@ +from __future__ import annotations +from typing import Tuple, Callable, Union + +import torch +from torch.optim.optimizer import Optimizer + + +def exists(val): + return val is not None + + +def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): + # stepweight decay + p.data.mul_(1. - lr * wd) + + # weight update + update = exp_avg.clone().mul_(beta1).add(grad, alpha=1.0 - beta1).sign_() + p.add_(update, alpha=-lr) + + # decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2) + + +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + use_triton: bool = False, + decoupled_weight_decay: bool = False, + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + + self._init_lr = lr + self.decoupled_wd = decoupled_weight_decay + + defaults = dict( + lr=lr, + betas=betas, + weight_decay=weight_decay + ) + + super().__init__(params, defaults) + self.update_fn = update_fn + + if use_triton: + from lion_pytorch.triton import update_fn as triton_update_fn + self.update_fn = triton_update_fn + + @torch.no_grad() + def step( + self, + closure: Union[Callable, None] = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group['params']): + + # grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr + grad = p.grad + lr = group['lr'] + wd = group['weight_decay'] + beta1, beta2 = group['betas'] + state= self.state[p] + decoupled_wd = self.decoupled_wd + init_lr = self._init_lr + + # maybe decoupled weight decay + + if decoupled_wd: + wd /= init_lr + + # init state - exponential moving average of gradient values + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + self.update_fn( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2 + ) + + return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py new file mode 100644 index 00000000..1dd4696b --- /dev/null +++ b/lion_pytorch/triton.py @@ -0,0 +1,98 @@ +import torch + +try: + import triton + import triton.language as tl +except ImportError as e: + print('triton is not installed, please install by running `pip install triton>=2.2.0`') + exit() + +# triton cuda kernel + +@triton.autotune(configs = [ + triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), + triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), +], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr']) +@triton.jit +def update_fn_kernel( + p_ptr, + grad_ptr, + exp_avg_ptr, + lr, + wd, + beta1, + beta2, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis = 0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + # offsetted pointers + + offset_p_ptr = p_ptr + offsets + offset_grad_ptr = grad_ptr + offsets + offset_exp_avg_ptr = exp_avg_ptr + offsets + + # load + + p = tl.load(offset_p_ptr, mask = mask) + grad = tl.load(offset_grad_ptr, mask = mask) + exp_avg = tl.load(offset_exp_avg_ptr, mask = mask) + + # stepweight decay + + p = p * (1 - lr * wd) + + # diff between momentum running average and grad + + diff = exp_avg - grad + + # weight update + + update = diff * beta1 + grad + + # torch.sign + + can_update = update != 0 + update_sign = tl.where(update > 0, -lr, lr) + + p = p + update_sign * can_update + + # decay the momentum running average coefficient + + exp_avg = diff * beta2 + grad + + # store new params and momentum running average coefficient + + tl.store(offset_p_ptr, p, mask = mask) + tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) + +def update_fn( + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, + lr: float, + wd: float, + beta1: float, + beta2: float +): + assert all([t.is_cuda for t in (p, grad, exp_avg)]) + n_elements = p.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + update_fn_kernel[grid]( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2, + n_elements + ) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index a82f0294..548c352b 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -44,6 +44,8 @@ from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod from nanotron.serialize.metadata import TrainingMetadata +from lion_pytorch import Lion + logger = logging.get_logger(__name__) @@ -327,9 +329,7 @@ def init_optimizer_and_grad_accumulator( # Basic optimizer builder def basic_optimizer_builder(named_param_groups): optimizer = None - if optimizer_args.optimizer_factory.name == "adamW": - def optimizer(param_groups): return torch.optim.AdamW( param_groups, @@ -339,16 +339,20 @@ def optimizer(param_groups): betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2), fused=optimizer_args.optimizer_factory.torch_adam_is_fused, ) - elif optimizer_args.optimizer_factory.name == "sgd": - def optimizer(param_groups): return torch.optim.SGD( param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, ) - + elif optimizer_args.optimizer_factory.name == "lion": + def optimizer(param_groups): + return Lion( + param_groups, + lr=optimizer_args.learning_rate_scheduler.learning_rate, + weight_decay=optimizer_args.weight_decay, + ) else: raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported") From 7b7ead9a4b5378a2a286f59e326a9fa5e942182e Mon Sep 17 00:00:00 2001 From: Kyle Matoba <22180455+kylematoba@users.noreply.github.com> Date: Sat, 14 Sep 2024 22:58:38 +0200 Subject: [PATCH 35/35] Revert "work" This reverts commit 761d253c21d101216fcf396ced495f88ad01bdb9. --- lion_pytorch/__init__.py | 1 - lion_pytorch/foreach.py | 95 ---------------------------------- lion_pytorch/lion_pytorch.py | 97 ----------------------------------- lion_pytorch/triton.py | 98 ------------------------------------ src/nanotron/helpers.py | 14 ++---- 5 files changed, 5 insertions(+), 300 deletions(-) delete mode 100644 lion_pytorch/__init__.py delete mode 100644 lion_pytorch/foreach.py delete mode 100644 lion_pytorch/lion_pytorch.py delete mode 100644 lion_pytorch/triton.py diff --git a/lion_pytorch/__init__.py b/lion_pytorch/__init__.py deleted file mode 100644 index b3a7799d..00000000 --- a/lion_pytorch/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from lion_pytorch.lion_pytorch import Lion diff --git a/lion_pytorch/foreach.py b/lion_pytorch/foreach.py deleted file mode 100644 index 50d60518..00000000 --- a/lion_pytorch/foreach.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import annotations -from typing import Tuple, Callable - -import torch -from torch.optim.optimizer import Optimizer - -# functions - -def exists(val): - return val is not None - -# class - -class Lion(Optimizer): - def __init__( - self, - params, - lr: float = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.0, - decoupled_weight_decay: bool = False - ): - assert lr > 0. - assert all([0. <= beta <= 1. for beta in betas]) - assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'sign', 'lerp')]), 'this version of torch does not have the prerequisite foreach functions' - - self._init_lr = lr - self.decoupled_wd = decoupled_weight_decay - - defaults = dict( - lr = lr, - betas = betas, - weight_decay = weight_decay - ) - - super().__init__(params, defaults) - - @torch.no_grad() - def step( - self, - closure: Callable | None = None - ): - - loss = None - if exists(closure): - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - - lr, wd, beta1, beta2, decoupled_wd, init_lr = group['lr'], group['weight_decay'], *group['betas'], self.decoupled_wd, self._init_lr - - # maybe decoupled weight decay - - if decoupled_wd: - wd /= init_lr - - # accumulate List[Tensor] for foreach inplace updates - - params = [] - grads = [] - exp_avgs = [] - - for p in filter(lambda p: exists(p.grad), group['params']): - - grad, state = p.grad, self.state[p] - - # init state - exponential moving average of gradient values - - if len(state) == 0: - state['exp_avg'] = torch.zeros_like(p) - - exp_avg = state['exp_avg'] - - params.append(p) - grads.append(grad) - exp_avgs.append(exp_avg) - - # stepweight decay - - torch._foreach_mul_(params, 1. - lr * wd) - - # weight update - - updates = [t.clone() for t in exp_avgs] - torch._foreach_lerp_(updates, grads, 1. - beta1) - torch._foreach_sign_(updates) - - torch._foreach_add_(params, updates, alpha = -lr) - - # decay momentum running average - - torch._foreach_lerp_(exp_avgs, grads, 1. - beta2) - - return loss diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py deleted file mode 100644 index b0d3a3f8..00000000 --- a/lion_pytorch/lion_pytorch.py +++ /dev/null @@ -1,97 +0,0 @@ -from __future__ import annotations -from typing import Tuple, Callable, Union - -import torch -from torch.optim.optimizer import Optimizer - - -def exists(val): - return val is not None - - -def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): - # stepweight decay - p.data.mul_(1. - lr * wd) - - # weight update - update = exp_avg.clone().mul_(beta1).add(grad, alpha=1.0 - beta1).sign_() - p.add_(update, alpha=-lr) - - # decay the momentum running average coefficient - exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2) - - -class Lion(Optimizer): - def __init__( - self, - params, - lr: float = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.0, - use_triton: bool = False, - decoupled_weight_decay: bool = False, - ): - assert lr > 0. - assert all([0. <= beta <= 1. for beta in betas]) - - self._init_lr = lr - self.decoupled_wd = decoupled_weight_decay - - defaults = dict( - lr=lr, - betas=betas, - weight_decay=weight_decay - ) - - super().__init__(params, defaults) - self.update_fn = update_fn - - if use_triton: - from lion_pytorch.triton import update_fn as triton_update_fn - self.update_fn = triton_update_fn - - @torch.no_grad() - def step( - self, - closure: Union[Callable, None] = None - ): - - loss = None - if exists(closure): - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in filter(lambda p: exists(p.grad), group['params']): - - # grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr - grad = p.grad - lr = group['lr'] - wd = group['weight_decay'] - beta1, beta2 = group['betas'] - state= self.state[p] - decoupled_wd = self.decoupled_wd - init_lr = self._init_lr - - # maybe decoupled weight decay - - if decoupled_wd: - wd /= init_lr - - # init state - exponential moving average of gradient values - if len(state) == 0: - state['exp_avg'] = torch.zeros_like(p) - - exp_avg = state['exp_avg'] - - self.update_fn( - p, - grad, - exp_avg, - lr, - wd, - beta1, - beta2 - ) - - return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py deleted file mode 100644 index 1dd4696b..00000000 --- a/lion_pytorch/triton.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch - -try: - import triton - import triton.language as tl -except ImportError as e: - print('triton is not installed, please install by running `pip install triton>=2.2.0`') - exit() - -# triton cuda kernel - -@triton.autotune(configs = [ - triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), - triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), -], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr']) -@triton.jit -def update_fn_kernel( - p_ptr, - grad_ptr, - exp_avg_ptr, - lr, - wd, - beta1, - beta2, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(axis = 0) - - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - mask = offsets < n_elements - - # offsetted pointers - - offset_p_ptr = p_ptr + offsets - offset_grad_ptr = grad_ptr + offsets - offset_exp_avg_ptr = exp_avg_ptr + offsets - - # load - - p = tl.load(offset_p_ptr, mask = mask) - grad = tl.load(offset_grad_ptr, mask = mask) - exp_avg = tl.load(offset_exp_avg_ptr, mask = mask) - - # stepweight decay - - p = p * (1 - lr * wd) - - # diff between momentum running average and grad - - diff = exp_avg - grad - - # weight update - - update = diff * beta1 + grad - - # torch.sign - - can_update = update != 0 - update_sign = tl.where(update > 0, -lr, lr) - - p = p + update_sign * can_update - - # decay the momentum running average coefficient - - exp_avg = diff * beta2 + grad - - # store new params and momentum running average coefficient - - tl.store(offset_p_ptr, p, mask = mask) - tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) - -def update_fn( - p: torch.Tensor, - grad: torch.Tensor, - exp_avg: torch.Tensor, - lr: float, - wd: float, - beta1: float, - beta2: float -): - assert all([t.is_cuda for t in (p, grad, exp_avg)]) - n_elements = p.numel() - - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - - update_fn_kernel[grid]( - p, - grad, - exp_avg, - lr, - wd, - beta1, - beta2, - n_elements - ) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 548c352b..a82f0294 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -44,8 +44,6 @@ from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod from nanotron.serialize.metadata import TrainingMetadata -from lion_pytorch import Lion - logger = logging.get_logger(__name__) @@ -329,7 +327,9 @@ def init_optimizer_and_grad_accumulator( # Basic optimizer builder def basic_optimizer_builder(named_param_groups): optimizer = None + if optimizer_args.optimizer_factory.name == "adamW": + def optimizer(param_groups): return torch.optim.AdamW( param_groups, @@ -339,20 +339,16 @@ def optimizer(param_groups): betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2), fused=optimizer_args.optimizer_factory.torch_adam_is_fused, ) + elif optimizer_args.optimizer_factory.name == "sgd": + def optimizer(param_groups): return torch.optim.SGD( param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, ) - elif optimizer_args.optimizer_factory.name == "lion": - def optimizer(param_groups): - return Lion( - param_groups, - lr=optimizer_args.learning_rate_scheduler.learning_rate, - weight_decay=optimizer_args.weight_decay, - ) + else: raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported")