From b1872e1ed8542d9aee8e659f2336289a89588062 Mon Sep 17 00:00:00 2001 From: Angel Gonzalez Date: Tue, 7 May 2024 11:35:59 +0200 Subject: [PATCH 01/44] Adding checkpoint after traning ends --- src/nanotron/config/config.py | 1 + src/nanotron/trainer.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index d9946f26..e26fac75 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -129,6 +129,7 @@ class CheckpointsArgs: checkpoints_path: Path checkpoint_interval: int save_initial_state: Optional[bool] = False + save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[Path] = None checkpoints_path_is_shared_file_system: Optional[bool] = False diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 0eda00dc..70d023fb 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -442,7 +442,10 @@ def train( self.save_checkpoint() dist.barrier() # let's wait for everyone before leaving - + + if self.config.checkpoints.save_final_state: + self.save_checkpoint() + self.post_training() def training_step( From bcf405d9af2028773d6d76cd4ff658540b87a3f1 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 27 Jun 2024 11:56:53 +0000 Subject: [PATCH 02/44] Implemented global memory buffer to reduce activation memory of differentiable distributed operations --- src/nanotron/config/config.py | 2 +- .../distributed_differentiable_primitives.py | 17 +++-------------- src/nanotron/parallel/utils.py | 19 +++++++++++++++++++ src/nanotron/utils.py | 17 +++++++++++++++++ 4 files changed, 40 insertions(+), 15 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index d5b9976f..d72ea97f 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] + dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 873d77df..57a67c42 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -19,6 +19,7 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup +from nanotron.parallel.utils import MemoryBuffer class DifferentiableIdentity(torch.autograd.Function): @@ -67,13 +68,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): group = torch_dist.distributed_c10d._get_default_group() unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = torch.empty( - unsharded_batch_size, - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + unsharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 @@ -108,13 +103,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 tensor = tensor.contiguous() - sharded_tensor = torch.empty( - unsharded_batch_size // group.size(), - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + sharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size//group.size(), *rest_size), dtype=tensor.dtype) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor diff --git a/src/nanotron/parallel/utils.py b/src/nanotron/parallel/utils.py index b9ac12ae..eb4e441d 100644 --- a/src/nanotron/parallel/utils.py +++ b/src/nanotron/parallel/utils.py @@ -1,13 +1,32 @@ import functools +import operator import os +import torch from torch import nn from nanotron import distributed as dist +from nanotron.utils import Singleton from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param +class MemoryBuffer(metaclass=Singleton): + """ + Global memory buffer to store intermediate activations that need not to be cached for the backward pass. + """ + + def __init__(self): + self.buffer = {} + + def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: + required_numel = functools.reduce(operator.mul, shape, 1) + if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel: + self.buffer[name, dtype] = torch.empty(required_numel, dtype=dtype, device=torch.cuda.current_device(), + requires_grad=False) + return self.buffer[name, dtype][:required_numel].view(shape) + + def assert_cuda_max_connections_set_to_1(func): flag_is_set_to_1 = None diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 14fe1ca8..8065962b 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -15,6 +15,23 @@ from nanotron import distributed as dist +class Singleton(type): + """ + Singleton metaclass. + Create objects using this class as the metaclass to enable singleton behaviour. + For instance: + ``` + class Logger(metaclass=Singleton): + ... + ``` + """ + _instances = {} + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + class ContextManagers: """ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` From ed1ca7d0b55b07696d1c622d713c793f3ca53e28 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 27 Jun 2024 14:29:14 +0000 Subject: [PATCH 03/44] GLU fusion --- src/nanotron/models/llama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ca8894b9..d310fe2a 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -163,8 +163,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - # TODO @nouamane: why can't we torch.jit.script GLUActivation? - self.split_silu_mul = GLUActivation(config.hidden_act) + self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) From 9b0de5be04afb9cac631399593aef8de6aa852a6 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 27 Jun 2024 14:42:42 +0000 Subject: [PATCH 04/44] precommit --- src/nanotron/models/llama.py | 2 +- .../distributed_differentiable_primitives.py | 4 +++- src/nanotron/parallel/utils.py | 7 ++++--- src/nanotron/utils.py | 8 +++++--- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index d310fe2a..3319b0ef 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 57a67c42..aa460cc6 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -103,7 +103,9 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 tensor = tensor.contiguous() - sharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size//group.size(), *rest_size), dtype=tensor.dtype) + sharded_tensor = MemoryBuffer().get( + "dist", (unsharded_batch_size // group.size(), *rest_size), dtype=tensor.dtype + ) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor diff --git a/src/nanotron/parallel/utils.py b/src/nanotron/parallel/utils.py index eb4e441d..f694b0e6 100644 --- a/src/nanotron/parallel/utils.py +++ b/src/nanotron/parallel/utils.py @@ -6,9 +6,9 @@ from torch import nn from nanotron import distributed as dist -from nanotron.utils import Singleton from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param +from nanotron.utils import Singleton class MemoryBuffer(metaclass=Singleton): @@ -22,8 +22,9 @@ def __init__(self): def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: required_numel = functools.reduce(operator.mul, shape, 1) if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel: - self.buffer[name, dtype] = torch.empty(required_numel, dtype=dtype, device=torch.cuda.current_device(), - requires_grad=False) + self.buffer[name, dtype] = torch.empty( + required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) return self.buffer[name, dtype][:required_numel].view(shape) diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 8065962b..b3831801 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -1,11 +1,10 @@ import functools import inspect -import math import os import random import socket from contextlib import ExitStack, contextmanager -from typing import Callable, ContextManager, List, Optional +from typing import ContextManager, List, Optional import torch from packaging import version @@ -25,7 +24,9 @@ class Logger(metaclass=Singleton): ... ``` """ + _instances = {} + def __call__(cls, *args, **kwargs): if cls not in cls._instances: cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) @@ -69,7 +70,7 @@ def main_rank_first(group: dist.ProcessGroup): @contextmanager def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None): """Context manager that executes the code in the context with all the local rank zero of the group going first. - Usefull to run only once per node first (e.g. to create local files, etc) + Useful to run only once per node first (e.g. to create local files, etc) """ is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0 if is_main: @@ -140,6 +141,7 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage: else: return tensor.storage().untyped() + def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype): # TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage. device = untyped_storage.device From 803b6da3233a642a0ba7a62484310d1496db81dc Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 16 Jul 2024 11:39:32 +0200 Subject: [PATCH 05/44] Wrong backward fixed --- .../parallel/tensor_parallel/column_linear.py | 62 +++++++++++++++++++ .../distributed_differentiable_primitives.py | 27 +++++--- .../parallel/tensor_parallel/functional.py | 4 +- 3 files changed, 85 insertions(+), 8 deletions(-) create mode 100644 src/nanotron/parallel/tensor_parallel/column_linear.py diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py new file mode 100644 index 00000000..eaab5abe --- /dev/null +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -0,0 +1,62 @@ +from typing import Optional + +import torch +from torch.nn import functional as F + +import nanotron.distributed as dist +from nanotron.parallel.utils import MemoryBuffer + + +class ColumnLinearContextParallel(torch.autograd.Function): + """ + Column linear with memory_buffer for the allgather, context parallel + enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and + async communication disabled. + """ + @staticmethod + def forward(ctx, input: torch.Tensor, weight: torch.Tensor, + bias: Optional[torch.Tensor], group: dist.ProcessGroup): + + # Prepare context. + ctx.save_for_backward(input, weight, bias) + ctx.group = group + + # Do allgather. + sharded_batch_size, *rest_size = input.shape + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Get linear output. + out = F.linear(total_input, weight, bias) + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Allgather the inputs again. + input, weight, bias = ctx.saved_tensors + group = ctx.group + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Get the grad_output and total_input on the correct views to be able to transpose them below. + grad_output = grad_output.contiguous() + assert grad_output.dim() == 3 + grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) + total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) + + # Compute gradients. + grad_input = grad_output @ weight + sub_grad_input = torch.empty(input.size(), dtype=input.dtype, device=input.device, requires_grad=False) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + grad_weight = grad_output.T @ total_input + grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None + + return sub_grad_input, grad_weight, grad_bias, None + +def column_linear_context_parallel(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + group: dist.ProcessGroup): + return ColumnLinearContextParallel.apply(input, weight, bias, group) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index aa460cc6..d66826e3 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -19,7 +19,6 @@ from nanotron import distributed as dist from nanotron.distributed import ProcessGroup -from nanotron.parallel.utils import MemoryBuffer class DifferentiableIdentity(torch.autograd.Function): @@ -68,7 +67,13 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): group = torch_dist.distributed_c10d._get_default_group() unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + unsharded_tensor = torch.empty( + unsharded_batch_size, + *rest_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + ) # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 @@ -79,8 +84,11 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): + #print(f"{torch.distributed.get_rank()} grad_output: {grad_output}") group = ctx.group - return DifferentiableReduceScatterSum.apply(grad_output, group), None + out = DifferentiableReduceScatterSum.apply(grad_output, group) + #print(f"{torch.distributed.get_rank()} grad_grad: {out}") + return out, None, None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -103,8 +111,12 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 tensor = tensor.contiguous() - sharded_tensor = MemoryBuffer().get( - "dist", (unsharded_batch_size // group.size(), *rest_size), dtype=tensor.dtype + sharded_tensor = torch.empty( + unsharded_batch_size // group.size(), + *rest_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, ) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor @@ -112,7 +124,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableAllGather.apply(grad_output, group), None + #print(f"{torch.distributed.get_rank()} Calling AllGather because of backward of reducescatter") + return DifferentiableAllGather.apply(grad_output, group, False), None # ----------------- @@ -128,7 +141,7 @@ def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllReduceSum.apply(tensor, group) -def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): +def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None) return DifferentiableAllGather.apply(tensor, group) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index fdef48ac..b3602707 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -26,6 +26,7 @@ differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 @@ -352,7 +353,7 @@ def column_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - input = differentiable_all_gather(input, group=group) + return column_linear_context_parallel(input, weight, bias, group) else: raise ValueError(f"Got unexpected mode: {tp_mode}.") @@ -473,6 +474,7 @@ def row_linear( out = F.linear(input, weight, bias) + #print("Calling row linear") if tp_mode is TensorParallelLinearMode.ALL_REDUCE: out = differentiable_all_reduce_sum(out, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: From 59bfb6b38c4487b04e425d8540b6e44b2a7fbcf9 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 16 Jul 2024 11:42:27 +0200 Subject: [PATCH 06/44] Removed useless prints --- .../tensor_parallel/distributed_differentiable_primitives.py | 3 --- src/nanotron/parallel/tensor_parallel/functional.py | 1 - 2 files changed, 4 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index d66826e3..f1102908 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -84,10 +84,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): - #print(f"{torch.distributed.get_rank()} grad_output: {grad_output}") group = ctx.group out = DifferentiableReduceScatterSum.apply(grad_output, group) - #print(f"{torch.distributed.get_rank()} grad_grad: {out}") return out, None, None @@ -124,7 +122,6 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - #print(f"{torch.distributed.get_rank()} Calling AllGather because of backward of reducescatter") return DifferentiableAllGather.apply(grad_output, group, False), None diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index b3602707..cedbb219 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -474,7 +474,6 @@ def row_linear( out = F.linear(input, weight, bias) - #print("Calling row linear") if tp_mode is TensorParallelLinearMode.ALL_REDUCE: out = differentiable_all_reduce_sum(out, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: From 2c69e9ad2887b7e78c88c2db3209713542dad7e2 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 10:01:44 +0000 Subject: [PATCH 07/44] Minor fixes --- .../distributed_differentiable_primitives.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index f1102908..bd41347a 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -86,7 +86,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): def backward(ctx, grad_output): group = ctx.group out = DifferentiableReduceScatterSum.apply(grad_output, group) - return out, None, None + return out, None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -122,7 +122,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableAllGather.apply(grad_output, group, False), None + return DifferentiableAllGather.apply(grad_output, group), None # ----------------- @@ -138,7 +138,7 @@ def differentiable_all_reduce_sum(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllReduceSum.apply(tensor, group) -def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None) +def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): return DifferentiableAllGather.apply(tensor, group) From 30439fdee7cac456be4a2c28798b42c931f7cf72 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 11:16:40 +0000 Subject: [PATCH 08/44] precommit --- .../parallel/tensor_parallel/column_linear.py | 12 ++++++++---- src/nanotron/parallel/tensor_parallel/functional.py | 7 +++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py index eaab5abe..21daba36 100644 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -13,9 +13,11 @@ class ColumnLinearContextParallel(torch.autograd.Function): enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and async communication disabled. """ + @staticmethod - def forward(ctx, input: torch.Tensor, weight: torch.Tensor, - bias: Optional[torch.Tensor], group: dist.ProcessGroup): + def forward( + ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup + ): # Prepare context. ctx.save_for_backward(input, weight, bias) @@ -57,6 +59,8 @@ def backward(ctx, grad_output: torch.Tensor): return sub_grad_input, grad_weight, grad_bias, None -def column_linear_context_parallel(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], - group: dist.ProcessGroup): + +def column_linear_context_parallel( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup +): return ColumnLinearContextParallel.apply(input, weight, bias, group) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index cedbb219..f4e9de30 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,14 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist +from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( - differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 @@ -90,10 +89,10 @@ def forward( @staticmethod def backward(ctx, grad_output): - # Retreive tensors from the forward path. + # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors - # All the inputs have softmax as thier gradient. + # All the inputs have softmax as their gradient. grad_input = softmax # For simplicity, work with the 2D gradient. sharded_hidden_size = softmax.size()[-1] From 1e02a9ce9c9b564f4a4274ee62e7208e3d5d9df8 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 12:47:20 +0000 Subject: [PATCH 09/44] Added tp_recompute_allgather option --- src/nanotron/config/parallelism_config.py | 2 + src/nanotron/models/llama.py | 3 ++ .../parallel/tensor_parallel/column_linear.py | 51 ++++++++++++------- .../parallel/tensor_parallel/functional.py | 4 +- src/nanotron/parallel/tensor_parallel/nn.py | 3 ++ 5 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 5912425b..e9a6f2a4 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -32,6 +32,8 @@ class ParallelismArgs: tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None + tp_recompute_allgather: bool = False + expert_parallel_size: int = 1 def __post_init__(self): diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 3319b0ef..a31ebec6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -154,6 +154,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -314,6 +315,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. self.rotary_embedding = RotaryEmbedding( @@ -742,6 +744,7 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, + "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py index 21daba36..2f743199 100644 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -16,33 +16,47 @@ class ColumnLinearContextParallel(torch.autograd.Function): @staticmethod def forward( - ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup + ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, + tp_recompute_allgather: bool ): - # Prepare context. - ctx.save_for_backward(input, weight, bias) - ctx.group = group - # Do allgather. sharded_batch_size, *rest_size = input.shape unsharded_batch_size = sharded_batch_size * group.size() - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + if tp_recompute_allgather: + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + else: + total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + # Prepare context. + ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.input_size = input.shape + if tp_recompute_allgather: + ctx.save_for_backward(input, weight, bias) + else: + ctx.save_for_backward(total_input, weight, bias) + # Get linear output. out = F.linear(total_input, weight, bias) return out @staticmethod def backward(ctx, grad_output: torch.Tensor): - # Allgather the inputs again. - input, weight, bias = ctx.saved_tensors + # Either allgather the inputs again or get them from context. group = ctx.group - sharded_batch_size, *rest_size = input.shape - total_input = sharded_batch_size * group.size() - unsharded_batch_size = sharded_batch_size * group.size() - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + tp_recompute_allgather = ctx.tp_recompute_allgather + input_size = ctx.input_size + if tp_recompute_allgather: + input, weight, bias = ctx.saved_tensors + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + else: + total_input, weight, bias = ctx.saved_tensors # Get the grad_output and total_input on the correct views to be able to transpose them below. grad_output = grad_output.contiguous() @@ -51,16 +65,17 @@ def backward(ctx, grad_output: torch.Tensor): total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) # Compute gradients. + grad_weight = grad_output.T @ total_input grad_input = grad_output @ weight - sub_grad_input = torch.empty(input.size(), dtype=input.dtype, device=input.device, requires_grad=False) + sub_grad_input = torch.empty(input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False) dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) - grad_weight = grad_output.T @ total_input grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None - return sub_grad_input, grad_weight, grad_bias, None + return sub_grad_input, grad_weight, grad_bias, None, None def column_linear_context_parallel( - input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, + tp_recompute_allgather: bool = False ): - return ColumnLinearContextParallel.apply(input, weight, bias, group) + return ColumnLinearContextParallel.apply(input, weight, bias, group, tp_recompute_allgather) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index f4e9de30..c16ae492 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -21,6 +21,7 @@ import nanotron.distributed as dist from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( + differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, @@ -345,6 +346,7 @@ def column_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, + tp_recompute_allgather: bool = True ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) @@ -352,7 +354,7 @@ def column_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return column_linear_context_parallel(input, weight, bias, group) + return column_linear_context_parallel(input, weight, bias, group, tp_recompute_allgather) else: raise ValueError(f"Got unexpected mode: {tp_mode}.") diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 40e89968..42ffc828 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -51,6 +51,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, + tp_recompute_allgather: bool = False, ): self.pg = pg self.world_size = pg.size() @@ -59,6 +60,7 @@ def __init__( self.in_features = in_features self.out_features = out_features // self.world_size + self.tp_recompute_allgather = tp_recompute_allgather super().__init__( in_features=self.in_features, @@ -91,6 +93,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + tp_recompute_allgather=self.tp_recompute_allgather, ) def extra_repr(self) -> str: From 9cc81bb6fe680b72cf6114f7258af0483886ada1 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 13:39:14 +0000 Subject: [PATCH 10/44] Changed recompute default --- src/nanotron/config/parallelism_config.py | 2 +- .../parallel/tensor_parallel/column_linear.py | 19 ++++++++++++++----- .../parallel/tensor_parallel/functional.py | 3 +-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index e9a6f2a4..cc5d406a 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -32,7 +32,7 @@ class ParallelismArgs: tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None - tp_recompute_allgather: bool = False + tp_recompute_allgather: bool = True expert_parallel_size: int = 1 diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py index 2f743199..880d5ff0 100644 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ b/src/nanotron/parallel/tensor_parallel/column_linear.py @@ -16,8 +16,12 @@ class ColumnLinearContextParallel(torch.autograd.Function): @staticmethod def forward( - ctx, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, - tp_recompute_allgather: bool + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool, ): # Do allgather. @@ -67,7 +71,9 @@ def backward(ctx, grad_output: torch.Tensor): # Compute gradients. grad_weight = grad_output.T @ total_input grad_input = grad_output @ weight - sub_grad_input = torch.empty(input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False) + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None @@ -75,7 +81,10 @@ def backward(ctx, grad_output: torch.Tensor): def column_linear_context_parallel( - input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], group: dist.ProcessGroup, - tp_recompute_allgather: bool = False + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool = True, ): return ColumnLinearContextParallel.apply(input, weight, bias, group, tp_recompute_allgather) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index c16ae492..454cc447 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -21,7 +21,6 @@ import nanotron.distributed as dist from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( - differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, @@ -346,7 +345,7 @@ def column_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, - tp_recompute_allgather: bool = True + tp_recompute_allgather: bool = True, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) From 956fbfd09a2f0c2358fcb90be395f97ffa79632e Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 13:40:34 +0000 Subject: [PATCH 11/44] Changed recompute default --- src/nanotron/parallel/tensor_parallel/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 42ffc828..4c7325cd 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -51,7 +51,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, - tp_recompute_allgather: bool = False, + tp_recompute_allgather: bool = True, ): self.pg = pg self.world_size = pg.size() From 9992f1c5919fd4038e85cd9b3fb1dd4faa81daf1 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 14:28:57 +0000 Subject: [PATCH 12/44] Little fixes --- src/nanotron/config/config.py | 4 ++-- tools/preprocess_data.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index fe194883..2e1a98cc 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -96,10 +96,10 @@ class NanosetDatasetsArgs: dataset_folder: Union[str, dict, List[str]] def __post_init__(self): - if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file + if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder self.dataset_folder = [self.dataset_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder self.dataset_weights = None # Set to None so we consume all the samples randomly elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights tmp_dataset_folder = self.dataset_folder.copy() diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 38db67f1..f3cdab70 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -95,6 +95,7 @@ def main(args): output_folder=args.output_folder, tokenizer_name_or_path=args.tokenizer_name_or_path, eos_token=args.eos_token, + shuffle=False, max_tokens_per_file=1e9, ), ], From b9e92017614e0326acab86b7665e0c8e8718bfc3 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 17 Jul 2024 15:21:04 +0000 Subject: [PATCH 13/44] Moved ColumnLinearNoAsync module for consistency --- .../parallel/tensor_parallel/column_linear.py | 90 ------------------- .../parallel/tensor_parallel/functional.py | 78 +++++++++++++++- 2 files changed, 76 insertions(+), 92 deletions(-) delete mode 100644 src/nanotron/parallel/tensor_parallel/column_linear.py diff --git a/src/nanotron/parallel/tensor_parallel/column_linear.py b/src/nanotron/parallel/tensor_parallel/column_linear.py deleted file mode 100644 index 880d5ff0..00000000 --- a/src/nanotron/parallel/tensor_parallel/column_linear.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Optional - -import torch -from torch.nn import functional as F - -import nanotron.distributed as dist -from nanotron.parallel.utils import MemoryBuffer - - -class ColumnLinearContextParallel(torch.autograd.Function): - """ - Column linear with memory_buffer for the allgather, context parallel - enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and - async communication disabled. - """ - - @staticmethod - def forward( - ctx, - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - group: dist.ProcessGroup, - tp_recompute_allgather: bool, - ): - - # Do allgather. - sharded_batch_size, *rest_size = input.shape - unsharded_batch_size = sharded_batch_size * group.size() - if tp_recompute_allgather: - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - else: - total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - - # Prepare context. - ctx.group = group - ctx.tp_recompute_allgather = tp_recompute_allgather - ctx.input_size = input.shape - if tp_recompute_allgather: - ctx.save_for_backward(input, weight, bias) - else: - ctx.save_for_backward(total_input, weight, bias) - - # Get linear output. - out = F.linear(total_input, weight, bias) - return out - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - # Either allgather the inputs again or get them from context. - group = ctx.group - tp_recompute_allgather = ctx.tp_recompute_allgather - input_size = ctx.input_size - if tp_recompute_allgather: - input, weight, bias = ctx.saved_tensors - sharded_batch_size, *rest_size = input.shape - total_input = sharded_batch_size * group.size() - unsharded_batch_size = sharded_batch_size * group.size() - total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - else: - total_input, weight, bias = ctx.saved_tensors - - # Get the grad_output and total_input on the correct views to be able to transpose them below. - grad_output = grad_output.contiguous() - assert grad_output.dim() == 3 - grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) - total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) - - # Compute gradients. - grad_weight = grad_output.T @ total_input - grad_input = grad_output @ weight - sub_grad_input = torch.empty( - input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False - ) - dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) - grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None - - return sub_grad_input, grad_weight, grad_bias, None, None - - -def column_linear_context_parallel( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - group: dist.ProcessGroup, - tp_recompute_allgather: bool = True, -): - return ColumnLinearContextParallel.apply(input, weight, bias, group, tp_recompute_allgather) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 454cc447..2b93fb02 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,7 +19,7 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel +from nanotron.parallel.utils import MemoryBuffer from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, @@ -338,6 +338,80 @@ def backward(ctx, grad_output): raise ValueError(f"Got unexpected mode: {tp_mode}.") +class _ColumnLinearContextParallelNoAsync(torch.autograd.Function): + """ + Column linear with memory_buffer for the allgather, context parallel + enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and + async communication disabled. + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool, + ): + + # Do allgather. + sharded_batch_size, *rest_size = input.shape + unsharded_batch_size = sharded_batch_size * group.size() + if tp_recompute_allgather: + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + else: + total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Prepare context. + ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.input_size = input.shape + if tp_recompute_allgather: + ctx.save_for_backward(input, weight, bias) + else: + ctx.save_for_backward(total_input, weight, bias) + + # Get linear output. + out = F.linear(total_input, weight, bias) + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Either allgather the inputs again or get them from context. + group = ctx.group + tp_recompute_allgather = ctx.tp_recompute_allgather + input_size = ctx.input_size + if tp_recompute_allgather: + input, weight, bias = ctx.saved_tensors + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + else: + total_input, weight, bias = ctx.saved_tensors + + # Get the grad_output and total_input on the correct views to be able to transpose them below. + grad_output = grad_output.contiguous() + assert grad_output.dim() == 3 + grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) + total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) + + # Compute gradients. + grad_weight = grad_output.T @ total_input + grad_input = grad_output @ weight + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None + + return sub_grad_input, grad_weight, grad_bias, None, None + + + def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -353,7 +427,7 @@ def column_linear( if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return column_linear_context_parallel(input, weight, bias, group, tp_recompute_allgather) + return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather) else: raise ValueError(f"Got unexpected mode: {tp_mode}.") From 7cc6653c69c19cd42da05e6a8712159a146407e7 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:16:42 +0000 Subject: [PATCH 14/44] memory efficient async linear --- .../parallel/tensor_parallel/functional.py | 48 ++++++------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 47c0b5a1..1a82254e 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -149,14 +149,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = torch.empty( - gathered_batch_size, - *intermediate_size, - hidden_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + gathered_tensor = MemoryBuffer().get("allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -261,7 +254,7 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias tp_mode = ctx.tp_mode - handle: Optional[dist.Work] = None + handle1: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape @@ -273,14 +266,8 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = torch.empty( - unsharded_batch_size, - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=False, - ) - handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) + unsharded_tensor = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation total_tensor = unsharded_tensor @@ -289,9 +276,6 @@ def backward(ctx, grad_output): grad_tensor = grad_output.matmul(weight) - if handle is not None: - handle.wait() - # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only # clones it if it's not contiguous: @@ -303,7 +287,7 @@ def backward(ctx, grad_output): grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) - handle: Optional[dist.Work] = None + handle2: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: if group.size() == 1: sub_grad_tensor = grad_tensor @@ -312,23 +296,27 @@ def backward(ctx, grad_output): tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter - handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) + handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # reduce scatter is scheduled before the weight gradient computation elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: # Asynchronous all-reduce - handle = dist.all_reduce(grad_tensor, group=group, async_op=True) + handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # all-reduce is scheduled before the weight gradient computation else: raise ValueError() + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if handle1 is not None: + handle1.wait() + # TODO @thomasw21: This sounds like we don't have the optimal physical layout grad_weight = grad_output.t().matmul(total_tensor) - grad_bias = grad_output.sum(dim=0) if use_bias else None - if handle is not None: - handle.wait() + if handle2 is not None: + handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return sub_grad_tensor, grad_weight, grad_bias, None, None @@ -472,13 +460,7 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = torch.empty( - unsharded_batch_size, - *rest_size, - device=grad_output.device, - dtype=grad_output.dtype, - requires_grad=False, - ) + total_grad_output = MemoryBuffer().get("allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only From cb0f2609e357747a0a3001dd9b12c649f9e6eef7 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:17:09 +0000 Subject: [PATCH 15/44] precommit --- .../parallel/tensor_parallel/functional.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 1a82254e..3821d544 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,14 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.utils import MemoryBuffer from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -149,7 +148,9 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = MemoryBuffer().get("allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype) + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -266,7 +267,9 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + unsharded_tensor = MemoryBuffer().get( + "allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype + ) handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation @@ -399,7 +402,6 @@ def backward(ctx, grad_output: torch.Tensor): return sub_grad_input, grad_weight, grad_bias, None, None - def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -460,7 +462,9 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = MemoryBuffer().get("allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + total_grad_output = MemoryBuffer().get( + "allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype + ) # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only From 6d85d038d52ffc06ec4f2ae4705deae3b05d25d8 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:46:39 +0000 Subject: [PATCH 16/44] Added no_recompute_allgather mode to async --- .../parallel/tensor_parallel/functional.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 3821d544..29480f9a 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -120,10 +120,12 @@ class _ColumnLinearAsyncCommunication(torch.autograd.Function): @staticmethod @assert_cuda_max_connections_set_to_1 - def forward(ctx, tensor, weight, bias, group, tp_mode): + def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): ctx.use_bias = bias is not None ctx.tp_mode = tp_mode ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.tensor_shape = tensor.size() if tp_mode is TensorParallelLinearMode.ALL_REDUCE: gathered_tensor = tensor @@ -140,7 +142,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 tensor = tensor.contiguous() - ctx.save_for_backward(tensor, weight) + # ctx.save_for_backward(tensor, weight) # TODO @thomasw21: gather along another dimension sharded_batch_size, *intermediate_size, hidden_size = tensor.shape @@ -148,9 +150,19 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = MemoryBuffer().get( - "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype - ) + if tp_recompute_allgather: + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) + else: + gathered_tensor = torch.empty( + gathered_batch_size, + *intermediate_size, + hidden_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -198,6 +210,10 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # Wait communication handle.wait() + if tp_recompute_allgather: + ctx.save_for_backward(tensor, weight) + else: + ctx.save_for_backward(gathered_tensor, weight) # Compute all the other shards that are obtained from AllGather # weights: w0 w1 w2 w3 @@ -256,7 +272,7 @@ def backward(ctx, grad_output): tp_mode = ctx.tp_mode handle1: Optional[dist.Work] = None - if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape if group is None: @@ -296,7 +312,7 @@ def backward(ctx, grad_output): sub_grad_tensor = grad_tensor else: sub_grad_tensor = torch.empty( - tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False + ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) @@ -322,9 +338,9 @@ def backward(ctx, grad_output): handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return sub_grad_tensor, grad_weight, grad_bias, None, None + return sub_grad_tensor, grad_weight, grad_bias, None, None, None elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: - return grad_tensor, grad_weight, grad_bias, None, None + return grad_tensor, grad_weight, grad_bias, None, None, None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") @@ -412,7 +428,7 @@ def column_linear( tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) From 2afd00769c7c5891341e2d7880492bfc80c524f6 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 23 Jul 2024 09:52:12 +0000 Subject: [PATCH 17/44] Fixed List not found --- src/nanotron/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 7ee44390..49ea86e6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List import torch from torch import nn From 7e758db3068948178edd2151232e4abd7b2d5ffd Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 23 Jul 2024 13:17:34 +0000 Subject: [PATCH 18/44] Fixed tp=1 case --- .../parallel/tensor_parallel/functional.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 2b93fb02..22c8ca3c 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -358,11 +358,14 @@ def forward( # Do allgather. sharded_batch_size, *rest_size = input.shape unsharded_batch_size = sharded_batch_size * group.size() - if tp_recompute_allgather: + if group.size() == 1: + total_input = input.contiguous() + elif tp_recompute_allgather: total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) else: total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) - dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) # Prepare context. ctx.group = group @@ -383,21 +386,22 @@ def backward(ctx, grad_output: torch.Tensor): group = ctx.group tp_recompute_allgather = ctx.tp_recompute_allgather input_size = ctx.input_size - if tp_recompute_allgather: + if group.size() == 1 or not tp_recompute_allgather: + total_input, weight, bias = ctx.saved_tensors + else: input, weight, bias = ctx.saved_tensors sharded_batch_size, *rest_size = input.shape total_input = sharded_batch_size * group.size() unsharded_batch_size = sharded_batch_size * group.size() total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) - else: - total_input, weight, bias = ctx.saved_tensors - # Get the grad_output and total_input on the correct views to be able to transpose them below. + # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.contiguous() - assert grad_output.dim() == 3 - grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2)) - total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2)) + grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1] + total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1] + grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) + total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) # Compute gradients. grad_weight = grad_output.T @ total_input From cd84d4fa7bff4017a9b1653c4b68290c00eae649 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:16:14 +0200 Subject: [PATCH 19/44] Fixed column parallel --- .../parallel/tensor_parallel/functional.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 054d41e9..468855a5 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,14 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.utils import MemoryBuffer from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -399,23 +398,25 @@ def backward(ctx, grad_output: torch.Tensor): # Convert the tensor shapes to 2D for execution compatibility grad_output = grad_output.contiguous() grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1] - total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1] + total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1] grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) - total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) + total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim) # Compute gradients. grad_weight = grad_output.T @ total_input grad_input = grad_output @ weight - sub_grad_input = torch.empty( - input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False - ) - dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + if group.size() == 1: + sub_grad_input = grad_input + else: + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None return sub_grad_input, grad_weight, grad_bias, None, None - def column_linear( input: torch.Tensor, weight: torch.Tensor, From d3db06acb2e3fe235ae512861ff64ee9fbc9ac11 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:16:28 +0200 Subject: [PATCH 20/44] Added tp_recompute_allgather test --- tests/test_tensor_parallel.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index f5dcaeb0..8e73973b 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -18,17 +18,30 @@ @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_column_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather") init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)( - tp_mode=tp_mode, async_communication=async_communication + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather ) def _test_column_linear( - parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -44,6 +57,7 @@ def _test_column_linear( mode=tp_mode, device="cuda", async_communication=async_communication, + tp_recompute_allgather=tp_recompute_allgather, ) # Un-sharded @@ -86,7 +100,7 @@ def _test_column_linear( random_input = sharded_random_input else: ValueError(f"Unsupported mode: {tp_mode}") - # It's important that `random_input` and `sharded_random_input` are two seperate tensors with seperate storage + # It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage sharded_random_input = sharded_random_input.clone() random_input.requires_grad = True sharded_random_input.requires_grad = True From 4c94b99a8cafd5f8dc2b7f208341a17a4d818234 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:34:31 +0200 Subject: [PATCH 21/44] Added tp_recompute_allgather test --- tests/test_tensor_parallel.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 8e73973b..16008eaa 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -164,15 +164,32 @@ def _test_column_linear( @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_row_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_row_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather") - init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(tp_mode=tp_mode, async_communication=async_communication) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)( + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather + ) -def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool): +def _test_row_linear( + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" out_features = 3 From 3c3561158eb053176b6f148a2366ea1aa56fdc7d Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 31 Jul 2024 17:15:34 +0000 Subject: [PATCH 22/44] change to correct config_nanoset.yaml path --- docs/nanoset.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/nanoset.md b/docs/nanoset.md index 9dce21b7..61393438 100644 --- a/docs/nanoset.md +++ b/docs/nanoset.md @@ -79,7 +79,7 @@ To work with `Nanosets`, we just need to configure 1 argument: Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py). ```shell -torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml +torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml ``` ## Under the hood From 7daa186e84e03fa66fa83129d9b0acffb1a668ba Mon Sep 17 00:00:00 2001 From: AleHD Date: Fri, 2 Aug 2024 15:40:46 +0200 Subject: [PATCH 23/44] Minor restyling --- .../parallel/tensor_parallel/functional.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 468855a5..1fb86cb5 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -115,7 +115,7 @@ def sharded_cross_entropy(sharded_logits, target, group: dist.ProcessGroup, dtyp return _ShardedCrossEntropy.apply(sharded_logits, target, group) -class _ColumnLinearAsyncCommunication(torch.autograd.Function): +class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): """Adapted from https://github.com/NVIDIA/Megatron-LM/blob/e6d7e09845590d0a36bc7f29eb28db974fb8da4e/megatron/core/tensor_parallel/layers.py#L215""" @staticmethod @@ -408,6 +408,9 @@ def backward(ctx, grad_output: torch.Tensor): if group.size() == 1: sub_grad_input = grad_input else: + # Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 + # We set grad_input to be contiguous in case it isn't already. + grad_input = grad_input.contiguous() sub_grad_input = torch.empty( input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False ) @@ -427,16 +430,14 @@ def column_linear( tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(input, weight, bias, group, tp_mode) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) - elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + return F.linear(input, weight, bias) + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather) - else: - raise ValueError(f"Got unexpected mode: {tp_mode}.") - - return F.linear(input, weight, bias) + raise ValueError(f"Got unexpected mode: {tp_mode}.") class _RowLinearAsyncCommunication(torch.autograd.Function): From 31c3c5ad0a845ff6318842c663223e9621586a3d Mon Sep 17 00:00:00 2001 From: AleHD Date: Fri, 2 Aug 2024 15:54:44 +0200 Subject: [PATCH 24/44] Fixed names --- src/nanotron/parallel/tensor_parallel/functional.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 1fb86cb5..7a88aec6 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -115,7 +115,7 @@ def sharded_cross_entropy(sharded_logits, target, group: dist.ProcessGroup, dtyp return _ShardedCrossEntropy.apply(sharded_logits, target, group) -class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): +class _ColumnLinearAsyncCommunication(torch.autograd.Function): """Adapted from https://github.com/NVIDIA/Megatron-LM/blob/e6d7e09845590d0a36bc7f29eb28db974fb8da4e/megatron/core/tensor_parallel/layers.py#L215""" @staticmethod @@ -337,7 +337,7 @@ def backward(ctx, grad_output): raise ValueError(f"Got unexpected mode: {tp_mode}.") -class _ColumnLinearContextParallelNoAsync(torch.autograd.Function): +class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): """ Column linear with memory_buffer for the allgather, context parallel enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and @@ -430,13 +430,15 @@ def column_linear( tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) return F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather) + return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( + input, weight, bias, group, tp_recompute_allgather + ) raise ValueError(f"Got unexpected mode: {tp_mode}.") From 664c09aa48204b8a45756e15fac9cc6bf0b38ccf Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 4 Sep 2024 10:38:32 +0000 Subject: [PATCH 25/44] lr scheduler resume training with PP fix --- src/nanotron/serialize/main.py | 1 + src/nanotron/serialize/optimizer.py | 12 +++++------- src/nanotron/trainer.py | 9 +++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 286008ac..346ad573 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -236,6 +236,7 @@ def load( load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder) load_lr_scheduler( lr_scheduler=lr_scheduler, + parallel_context=parallel_context, root_folder=root_folder, ) return checkpoint_metadata diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index 68a3b1a0..f11210da 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -30,9 +30,9 @@ def optimizer_filename(parallel_context: ParallelContext, is_zero: bool): return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" -def lr_scheduler_filename(): +def lr_scheduler_filename(parallel_context: ParallelContext): """The lr_scheduler is the same for all processes.""" - return f"{ObjectType.LR_SCHEDULER.value}.pt" + return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" def save_optimizer( @@ -109,9 +109,6 @@ def save_lr_scheduler( root_folder: Path, ): """Saves lr scheduler states""" - if dist.get_rank(parallel_context.world_pg) > 0: - # Only WORLD-RANK 0 saves the lr scheduler state - return root_folder = root_folder / "lr_scheduler" root_folder.mkdir(exist_ok=True, parents=True) @@ -119,7 +116,7 @@ def save_lr_scheduler( # We dump the optimizer state using `torch.save` torch.save( lr_scheduler.state_dict(), - root_folder / lr_scheduler_filename(), + root_folder / lr_scheduler_filename(parallel_context), ) @@ -313,9 +310,10 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - def load_lr_scheduler( lr_scheduler, + parallel_context: ParallelContext, root_folder: Path, ): root_folder = root_folder / "lr_scheduler" - state_dict = torch.load(root_folder / lr_scheduler_filename()) + state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context)) lr_scheduler.load_state_dict(state_dict) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 214ea52f..21251a32 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -206,6 +206,7 @@ def __init__( if self.init_checkpoint_path is not None: load_lr_scheduler( lr_scheduler=self.lr_scheduler, + parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path, ) @@ -443,10 +444,10 @@ def train( self.save_checkpoint() dist.barrier() # let's wait for everyone before leaving - + if self.config.checkpoints.save_final_state: self.save_checkpoint() - + self.post_training() def training_step( @@ -865,8 +866,8 @@ def save_checkpoint(self) -> Path: ), # We only save the weights on DP==0 should_save_optimizer=True, should_save_lr_scheduler=bool( - dist.get_rank(self.parallel_context.world_pg) == 0 - ), # We only save the lr_scheduler on world_rank==0 + dist.get_rank(self.parallel_context.dp_pg) == 0 + ), # We only save the lr_scheduler on DP==0 should_save_config=bool( dist.get_rank(self.parallel_context.world_pg) == 0 ), # We only save the config on world_rank==0 From 761d253c21d101216fcf396ced495f88ad01bdb9 Mon Sep 17 00:00:00 2001 From: Kyle Matoba <22180455+kylematoba@users.noreply.github.com> Date: Sat, 14 Sep 2024 22:45:47 +0200 Subject: [PATCH 26/44] work --- lion_pytorch/__init__.py | 1 + lion_pytorch/foreach.py | 95 ++++++++++++++++++++++++++++++++++ lion_pytorch/lion_pytorch.py | 97 +++++++++++++++++++++++++++++++++++ lion_pytorch/triton.py | 98 ++++++++++++++++++++++++++++++++++++ src/nanotron/helpers.py | 14 ++++-- 5 files changed, 300 insertions(+), 5 deletions(-) create mode 100644 lion_pytorch/__init__.py create mode 100644 lion_pytorch/foreach.py create mode 100644 lion_pytorch/lion_pytorch.py create mode 100644 lion_pytorch/triton.py diff --git a/lion_pytorch/__init__.py b/lion_pytorch/__init__.py new file mode 100644 index 00000000..b3a7799d --- /dev/null +++ b/lion_pytorch/__init__.py @@ -0,0 +1 @@ +from lion_pytorch.lion_pytorch import Lion diff --git a/lion_pytorch/foreach.py b/lion_pytorch/foreach.py new file mode 100644 index 00000000..50d60518 --- /dev/null +++ b/lion_pytorch/foreach.py @@ -0,0 +1,95 @@ +from __future__ import annotations +from typing import Tuple, Callable + +import torch +from torch.optim.optimizer import Optimizer + +# functions + +def exists(val): + return val is not None + +# class + +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + decoupled_weight_decay: bool = False + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'sign', 'lerp')]), 'this version of torch does not have the prerequisite foreach functions' + + self._init_lr = lr + self.decoupled_wd = decoupled_weight_decay + + defaults = dict( + lr = lr, + betas = betas, + weight_decay = weight_decay + ) + + super().__init__(params, defaults) + + @torch.no_grad() + def step( + self, + closure: Callable | None = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + + lr, wd, beta1, beta2, decoupled_wd, init_lr = group['lr'], group['weight_decay'], *group['betas'], self.decoupled_wd, self._init_lr + + # maybe decoupled weight decay + + if decoupled_wd: + wd /= init_lr + + # accumulate List[Tensor] for foreach inplace updates + + params = [] + grads = [] + exp_avgs = [] + + for p in filter(lambda p: exists(p.grad), group['params']): + + grad, state = p.grad, self.state[p] + + # init state - exponential moving average of gradient values + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + params.append(p) + grads.append(grad) + exp_avgs.append(exp_avg) + + # stepweight decay + + torch._foreach_mul_(params, 1. - lr * wd) + + # weight update + + updates = [t.clone() for t in exp_avgs] + torch._foreach_lerp_(updates, grads, 1. - beta1) + torch._foreach_sign_(updates) + + torch._foreach_add_(params, updates, alpha = -lr) + + # decay momentum running average + + torch._foreach_lerp_(exp_avgs, grads, 1. - beta2) + + return loss diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py new file mode 100644 index 00000000..b0d3a3f8 --- /dev/null +++ b/lion_pytorch/lion_pytorch.py @@ -0,0 +1,97 @@ +from __future__ import annotations +from typing import Tuple, Callable, Union + +import torch +from torch.optim.optimizer import Optimizer + + +def exists(val): + return val is not None + + +def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): + # stepweight decay + p.data.mul_(1. - lr * wd) + + # weight update + update = exp_avg.clone().mul_(beta1).add(grad, alpha=1.0 - beta1).sign_() + p.add_(update, alpha=-lr) + + # decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2) + + +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + use_triton: bool = False, + decoupled_weight_decay: bool = False, + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + + self._init_lr = lr + self.decoupled_wd = decoupled_weight_decay + + defaults = dict( + lr=lr, + betas=betas, + weight_decay=weight_decay + ) + + super().__init__(params, defaults) + self.update_fn = update_fn + + if use_triton: + from lion_pytorch.triton import update_fn as triton_update_fn + self.update_fn = triton_update_fn + + @torch.no_grad() + def step( + self, + closure: Union[Callable, None] = None + ): + + loss = None + if exists(closure): + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: exists(p.grad), group['params']): + + # grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr + grad = p.grad + lr = group['lr'] + wd = group['weight_decay'] + beta1, beta2 = group['betas'] + state= self.state[p] + decoupled_wd = self.decoupled_wd + init_lr = self._init_lr + + # maybe decoupled weight decay + + if decoupled_wd: + wd /= init_lr + + # init state - exponential moving average of gradient values + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + self.update_fn( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2 + ) + + return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py new file mode 100644 index 00000000..1dd4696b --- /dev/null +++ b/lion_pytorch/triton.py @@ -0,0 +1,98 @@ +import torch + +try: + import triton + import triton.language as tl +except ImportError as e: + print('triton is not installed, please install by running `pip install triton>=2.2.0`') + exit() + +# triton cuda kernel + +@triton.autotune(configs = [ + triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), + triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), +], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr']) +@triton.jit +def update_fn_kernel( + p_ptr, + grad_ptr, + exp_avg_ptr, + lr, + wd, + beta1, + beta2, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis = 0) + + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + mask = offsets < n_elements + + # offsetted pointers + + offset_p_ptr = p_ptr + offsets + offset_grad_ptr = grad_ptr + offsets + offset_exp_avg_ptr = exp_avg_ptr + offsets + + # load + + p = tl.load(offset_p_ptr, mask = mask) + grad = tl.load(offset_grad_ptr, mask = mask) + exp_avg = tl.load(offset_exp_avg_ptr, mask = mask) + + # stepweight decay + + p = p * (1 - lr * wd) + + # diff between momentum running average and grad + + diff = exp_avg - grad + + # weight update + + update = diff * beta1 + grad + + # torch.sign + + can_update = update != 0 + update_sign = tl.where(update > 0, -lr, lr) + + p = p + update_sign * can_update + + # decay the momentum running average coefficient + + exp_avg = diff * beta2 + grad + + # store new params and momentum running average coefficient + + tl.store(offset_p_ptr, p, mask = mask) + tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) + +def update_fn( + p: torch.Tensor, + grad: torch.Tensor, + exp_avg: torch.Tensor, + lr: float, + wd: float, + beta1: float, + beta2: float +): + assert all([t.is_cuda for t in (p, grad, exp_avg)]) + n_elements = p.numel() + + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + + update_fn_kernel[grid]( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2, + n_elements + ) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index a82f0294..548c352b 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -44,6 +44,8 @@ from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod from nanotron.serialize.metadata import TrainingMetadata +from lion_pytorch import Lion + logger = logging.get_logger(__name__) @@ -327,9 +329,7 @@ def init_optimizer_and_grad_accumulator( # Basic optimizer builder def basic_optimizer_builder(named_param_groups): optimizer = None - if optimizer_args.optimizer_factory.name == "adamW": - def optimizer(param_groups): return torch.optim.AdamW( param_groups, @@ -339,16 +339,20 @@ def optimizer(param_groups): betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2), fused=optimizer_args.optimizer_factory.torch_adam_is_fused, ) - elif optimizer_args.optimizer_factory.name == "sgd": - def optimizer(param_groups): return torch.optim.SGD( param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, ) - + elif optimizer_args.optimizer_factory.name == "lion": + def optimizer(param_groups): + return Lion( + param_groups, + lr=optimizer_args.learning_rate_scheduler.learning_rate, + weight_decay=optimizer_args.weight_decay, + ) else: raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported") From 7b7ead9a4b5378a2a286f59e326a9fa5e942182e Mon Sep 17 00:00:00 2001 From: Kyle Matoba <22180455+kylematoba@users.noreply.github.com> Date: Sat, 14 Sep 2024 22:58:38 +0200 Subject: [PATCH 27/44] Revert "work" This reverts commit 761d253c21d101216fcf396ced495f88ad01bdb9. --- lion_pytorch/__init__.py | 1 - lion_pytorch/foreach.py | 95 ---------------------------------- lion_pytorch/lion_pytorch.py | 97 ----------------------------------- lion_pytorch/triton.py | 98 ------------------------------------ src/nanotron/helpers.py | 14 ++---- 5 files changed, 5 insertions(+), 300 deletions(-) delete mode 100644 lion_pytorch/__init__.py delete mode 100644 lion_pytorch/foreach.py delete mode 100644 lion_pytorch/lion_pytorch.py delete mode 100644 lion_pytorch/triton.py diff --git a/lion_pytorch/__init__.py b/lion_pytorch/__init__.py deleted file mode 100644 index b3a7799d..00000000 --- a/lion_pytorch/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from lion_pytorch.lion_pytorch import Lion diff --git a/lion_pytorch/foreach.py b/lion_pytorch/foreach.py deleted file mode 100644 index 50d60518..00000000 --- a/lion_pytorch/foreach.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import annotations -from typing import Tuple, Callable - -import torch -from torch.optim.optimizer import Optimizer - -# functions - -def exists(val): - return val is not None - -# class - -class Lion(Optimizer): - def __init__( - self, - params, - lr: float = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.0, - decoupled_weight_decay: bool = False - ): - assert lr > 0. - assert all([0. <= beta <= 1. for beta in betas]) - assert all([hasattr(torch, f'_foreach_{attr}_') for attr in ('mul', 'add', 'sign', 'lerp')]), 'this version of torch does not have the prerequisite foreach functions' - - self._init_lr = lr - self.decoupled_wd = decoupled_weight_decay - - defaults = dict( - lr = lr, - betas = betas, - weight_decay = weight_decay - ) - - super().__init__(params, defaults) - - @torch.no_grad() - def step( - self, - closure: Callable | None = None - ): - - loss = None - if exists(closure): - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - - lr, wd, beta1, beta2, decoupled_wd, init_lr = group['lr'], group['weight_decay'], *group['betas'], self.decoupled_wd, self._init_lr - - # maybe decoupled weight decay - - if decoupled_wd: - wd /= init_lr - - # accumulate List[Tensor] for foreach inplace updates - - params = [] - grads = [] - exp_avgs = [] - - for p in filter(lambda p: exists(p.grad), group['params']): - - grad, state = p.grad, self.state[p] - - # init state - exponential moving average of gradient values - - if len(state) == 0: - state['exp_avg'] = torch.zeros_like(p) - - exp_avg = state['exp_avg'] - - params.append(p) - grads.append(grad) - exp_avgs.append(exp_avg) - - # stepweight decay - - torch._foreach_mul_(params, 1. - lr * wd) - - # weight update - - updates = [t.clone() for t in exp_avgs] - torch._foreach_lerp_(updates, grads, 1. - beta1) - torch._foreach_sign_(updates) - - torch._foreach_add_(params, updates, alpha = -lr) - - # decay momentum running average - - torch._foreach_lerp_(exp_avgs, grads, 1. - beta2) - - return loss diff --git a/lion_pytorch/lion_pytorch.py b/lion_pytorch/lion_pytorch.py deleted file mode 100644 index b0d3a3f8..00000000 --- a/lion_pytorch/lion_pytorch.py +++ /dev/null @@ -1,97 +0,0 @@ -from __future__ import annotations -from typing import Tuple, Callable, Union - -import torch -from torch.optim.optimizer import Optimizer - - -def exists(val): - return val is not None - - -def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): - # stepweight decay - p.data.mul_(1. - lr * wd) - - # weight update - update = exp_avg.clone().mul_(beta1).add(grad, alpha=1.0 - beta1).sign_() - p.add_(update, alpha=-lr) - - # decay the momentum running average coefficient - exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2) - - -class Lion(Optimizer): - def __init__( - self, - params, - lr: float = 1e-4, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.0, - use_triton: bool = False, - decoupled_weight_decay: bool = False, - ): - assert lr > 0. - assert all([0. <= beta <= 1. for beta in betas]) - - self._init_lr = lr - self.decoupled_wd = decoupled_weight_decay - - defaults = dict( - lr=lr, - betas=betas, - weight_decay=weight_decay - ) - - super().__init__(params, defaults) - self.update_fn = update_fn - - if use_triton: - from lion_pytorch.triton import update_fn as triton_update_fn - self.update_fn = triton_update_fn - - @torch.no_grad() - def step( - self, - closure: Union[Callable, None] = None - ): - - loss = None - if exists(closure): - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in filter(lambda p: exists(p.grad), group['params']): - - # grad, lr, wd, beta1, beta2, state, decoupled_wd, init_lr = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p], self.decoupled_wd, self._init_lr - grad = p.grad - lr = group['lr'] - wd = group['weight_decay'] - beta1, beta2 = group['betas'] - state= self.state[p] - decoupled_wd = self.decoupled_wd - init_lr = self._init_lr - - # maybe decoupled weight decay - - if decoupled_wd: - wd /= init_lr - - # init state - exponential moving average of gradient values - if len(state) == 0: - state['exp_avg'] = torch.zeros_like(p) - - exp_avg = state['exp_avg'] - - self.update_fn( - p, - grad, - exp_avg, - lr, - wd, - beta1, - beta2 - ) - - return loss diff --git a/lion_pytorch/triton.py b/lion_pytorch/triton.py deleted file mode 100644 index 1dd4696b..00000000 --- a/lion_pytorch/triton.py +++ /dev/null @@ -1,98 +0,0 @@ -import torch - -try: - import triton - import triton.language as tl -except ImportError as e: - print('triton is not installed, please install by running `pip install triton>=2.2.0`') - exit() - -# triton cuda kernel - -@triton.autotune(configs = [ - triton.Config({'BLOCK_SIZE': 128}, num_warps = 4), - triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8), -], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr']) -@triton.jit -def update_fn_kernel( - p_ptr, - grad_ptr, - exp_avg_ptr, - lr, - wd, - beta1, - beta2, - n_elements, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(axis = 0) - - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - - mask = offsets < n_elements - - # offsetted pointers - - offset_p_ptr = p_ptr + offsets - offset_grad_ptr = grad_ptr + offsets - offset_exp_avg_ptr = exp_avg_ptr + offsets - - # load - - p = tl.load(offset_p_ptr, mask = mask) - grad = tl.load(offset_grad_ptr, mask = mask) - exp_avg = tl.load(offset_exp_avg_ptr, mask = mask) - - # stepweight decay - - p = p * (1 - lr * wd) - - # diff between momentum running average and grad - - diff = exp_avg - grad - - # weight update - - update = diff * beta1 + grad - - # torch.sign - - can_update = update != 0 - update_sign = tl.where(update > 0, -lr, lr) - - p = p + update_sign * can_update - - # decay the momentum running average coefficient - - exp_avg = diff * beta2 + grad - - # store new params and momentum running average coefficient - - tl.store(offset_p_ptr, p, mask = mask) - tl.store(offset_exp_avg_ptr, exp_avg, mask = mask) - -def update_fn( - p: torch.Tensor, - grad: torch.Tensor, - exp_avg: torch.Tensor, - lr: float, - wd: float, - beta1: float, - beta2: float -): - assert all([t.is_cuda for t in (p, grad, exp_avg)]) - n_elements = p.numel() - - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) - - update_fn_kernel[grid]( - p, - grad, - exp_avg, - lr, - wd, - beta1, - beta2, - n_elements - ) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 548c352b..a82f0294 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -44,8 +44,6 @@ from nanotron.scaling.parametrization import LearningRateForSP, LearningRateForSpectralMup, ParametrizationMethod from nanotron.serialize.metadata import TrainingMetadata -from lion_pytorch import Lion - logger = logging.get_logger(__name__) @@ -329,7 +327,9 @@ def init_optimizer_and_grad_accumulator( # Basic optimizer builder def basic_optimizer_builder(named_param_groups): optimizer = None + if optimizer_args.optimizer_factory.name == "adamW": + def optimizer(param_groups): return torch.optim.AdamW( param_groups, @@ -339,20 +339,16 @@ def optimizer(param_groups): betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2), fused=optimizer_args.optimizer_factory.torch_adam_is_fused, ) + elif optimizer_args.optimizer_factory.name == "sgd": + def optimizer(param_groups): return torch.optim.SGD( param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, ) - elif optimizer_args.optimizer_factory.name == "lion": - def optimizer(param_groups): - return Lion( - param_groups, - lr=optimizer_args.learning_rate_scheduler.learning_rate, - weight_decay=optimizer_args.weight_decay, - ) + else: raise ValueError(f"Optimizer {optimizer_args.optimizer_factory.name} is not supported") From d6ef5eafe4ccaa8f92d97406599ead2542997847 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 26 Jun 2024 12:38:55 +0000 Subject: [PATCH 28/44] XGLM work in progress: Causal Attention and Positional Embeddings work --- examples/xglm/__init__.py | 0 examples/xglm/convert_hf2nt.py | 28 ++ examples/xglm/tests/test_attn.py | 74 +++++ examples/xglm/tests/test_implementation.py | 90 ++++++ src/nanotron/config/models_config.py | 36 +++ src/nanotron/models/gpt3.py | 358 +++++++++++++++++++++ 6 files changed, 586 insertions(+) create mode 100644 examples/xglm/__init__.py create mode 100644 examples/xglm/convert_hf2nt.py create mode 100644 examples/xglm/tests/test_attn.py create mode 100644 examples/xglm/tests/test_implementation.py create mode 100644 src/nanotron/models/gpt3.py diff --git a/examples/xglm/__init__.py b/examples/xglm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py new file mode 100644 index 00000000..e008f859 --- /dev/null +++ b/examples/xglm/convert_hf2nt.py @@ -0,0 +1,28 @@ +import torch + +from transformers.models.xglm.modeling_xglm import XGLMAttention +from nanotron.models.gpt3 import CausalSelfAttention + + +def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): + q_ws = torch.chunk(attn_hf.q_proj.weight, attn_hf.num_heads) + k_ws = torch.chunk(attn_hf.k_proj.weight, attn_hf.num_heads) + v_ws = torch.chunk(attn_hf.v_proj.weight, attn_hf.num_heads) + + q_bs = torch.chunk(attn_hf.q_proj.bias, attn_hf.num_heads) + k_bs = torch.chunk(attn_hf.k_proj.bias, attn_hf.num_heads) + v_bs = torch.chunk(attn_hf.v_proj.bias, attn_hf.num_heads) + + qkv_w = [] + qkv_b = [] + for q_w, k_w, v_w, q_b, k_b, v_b in zip(q_ws, k_ws, v_ws, q_bs, k_bs, v_bs): + qkv_w += [q_w, k_w, v_w] + qkv_b += [q_b, k_b, v_b] + qkv_w = torch.cat(qkv_w) + qkv_b = torch.cat(qkv_b) + + with torch.no_grad(): + attn_nt.query_key_value.weight.data = qkv_w.clone() + attn_nt.query_key_value.bias.data = qkv_b.clone() + attn_nt.dense.weight.data = attn_hf.out_proj.weight.clone() + attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() diff --git a/examples/xglm/tests/test_attn.py b/examples/xglm/tests/test_attn.py new file mode 100644 index 00000000..2fcdb3a8 --- /dev/null +++ b/examples/xglm/tests/test_attn.py @@ -0,0 +1,74 @@ +import torch +from torch.nn import functional as F +#torch.Size([4, 2048, 16, 64]), torch.Size([2048, 4, 1024]) + +# inputs = (batchsize * qlen, heads, head_dim) +# outputs = (batchsize*qlen, heads, head_dim) +def sdpa(query, key, value, batchsize: int): + def reshape(tensor): # output = (batchsize, heads, qlen, head_dim) + return tensor.view(batchsize, qlen, heads, head_dim).permute(0, 2, 1, 3) + + batchsize_x_qlen, heads, head_dim = query.size() + qlen = batchsize_x_qlen//batchsize + out = F.scaled_dot_product_attention(reshape(query), reshape(key), reshape(value), is_causal=True) # (b,h,q,d) + return out.permute(0, 2, 1, 3).reshape(batchsize*qlen, heads, head_dim) + + +# inputs = (batchsize * qlen, heads, head_dim) +# outputs = (batchsize*qlen, heads, head_dim) +def fa(query_states, key_states, value_states, batchsize: int): + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + batchsize_x_qlen, heads, head_dim = query_states.size() + qlen = batchsize_x_qlen//batchsize + + q_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") + kv_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") + + # TODO @thomasw21: Compute once, instead of computing for each layers. + cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) + cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) + torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) + torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) + + # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not + # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. + causal = False if q_sequence_mask.shape[1] == 1 else True + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_sequence_mask.shape[1], + max_seqlen_k=kv_sequence_mask.shape[1], + dropout_p=0.0, + softmax_scale=None, # defaults to 1/sqrt(d_qk) + causal=causal, + window_size=(-1, -1), + return_attn_probs=False, + ) + return attn_output + + +def main(): + batchsize = 5 + qlen = 6 + heads = 2 + head_dim = 16 + + query = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + key = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + value = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + + out_pt = sdpa(query, key, value, batchsize) + out_fa = fa(query, key, value, batchsize) + + assert out_pt.size() == out_fa.size() + + torch.testing.assert_close(out_pt, out_fa) + + + +if __name__ == "__main__": + main() diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py new file mode 100644 index 00000000..10f0302a --- /dev/null +++ b/examples/xglm/tests/test_implementation.py @@ -0,0 +1,90 @@ +import numpy as np +import torch +import pytest + +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMSinusoidalPositionalEmbedding + +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import CausalSelfAttention, PositionEmbedding +from nanotron.parallel import ParallelContext + +from tests.helpers.utils import init_distributed + +from examples.xglm.convert_hf2nt import convert_attention + + +SEQUENCE_LENGTH = 2048 +BATCH_SIZE = 4 +HIDDEN_SIZE = 1024 +DTYPE = torch.float64 + +CONFIG = GPT3Config( + attn_pdrop=0.0, + embd_pdrop=0.0, + resid_pdrop=0.0, + eos_token_id=2, + hidden_size=HIDDEN_SIZE, + intermediate_size=4096, + layer_norm_epsilon=1e-05, + max_position_embeddings=SEQUENCE_LENGTH, + num_attention_heads=16, + num_hidden_layers=24, + scale_attn_weights=True, + vocab_size=256008, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=True +) + + +@pytest.fixture +def hidden_states() -> torch.Tensor: + return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, + dtype=DTYPE) + + +@pytest.fixture +def input_mask() -> torch.Tensor: + return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) + + +def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + attn_nt = CausalSelfAttention(CONFIG, None, parallel_context.tp_pg, 0).cuda().eval().to(DTYPE) + attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) + assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) + + # Build xglm mask. + mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) + mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) + + convert_attention(attn_nt, attn_hf) + out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=mask)[0].permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_attention)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_position_embeddings(parallel_context: ParallelContext): + position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, SEQUENCE_LENGTH) + + emb_nt = PositionEmbedding(parallel_context.tp_pg, CONFIG, None).cuda() + emb_hf = XGLMSinusoidalPositionalEmbedding(SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() + + assert emb_nt.position_embedding.weight.size() == emb_hf.weights.size() + torch.testing.assert_close(emb_nt.position_embedding.weight, emb_hf.weights) + + out_nt = emb_nt(position_ids)["position_embeds"] + out_hf = emb_hf(position_ids).permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + +def test_position_embeddings(): + init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 2630e1d6..f214b357 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -136,4 +136,40 @@ def n_inner(self): return self.intermediate_size +@dataclass +class GPT3Config: + """Configuration for a GPT3 model""" + + activation_function: str = "gelu" + attn_pdrop: float = 0.1 + embd_pdrop: float = 0.1 + eos_token_id: int = 49152 + hidden_size: int = 2048 + intermediate_size: Optional[int] = None + layer_norm_epsilon: float = 1e-05 + max_position_embeddings: int = 4096 + num_attention_heads: int = 16 + num_hidden_layers: int = 24 + resid_pdrop: float = 0.1 + scale_attention_softmax_in_fp32: bool = True + scale_attn_weights: bool = True + vocab_size: int = 49280 + sinusoidal_position_embedding: bool = True + position_embedding_offset: int = 2 + use_spda: bool = False + + def as_starcoder2(self) -> Starcoder2Config: + config = dict(**vars(self)) + del config["sinusoidal_position_embedding"] + del config["use_spda"] + del config["position_embedding_offset"] + return Starcoder2Config( + grouped_query=True, + num_kv_heads=self.num_attention_heads, + use_rotary_embeddings=False, + **config + ) + + NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any] + diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py new file mode 100644 index 00000000..8cea58c4 --- /dev/null +++ b/src/nanotron/models/gpt3.py @@ -0,0 +1,358 @@ +"""PyTorch GPT-3 model.""" + +import math +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F + +from nanotron import distributed as dist +from nanotron.parallel import ParallelContext +from nanotron.config import Config, GPT3Config, ParallelismArgs +from nanotron.generation.generate_store import AttachableStore +from nanotron.models.starcoder2 import MLP as Starcoder2MLP +from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention +from nanotron.models.starcoder2 import CausalSelfGQA +from nanotron.random import RandomStates, branch_random_state +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding +from nanotron.parallel.tied_parameters import tie_parameters + +# NOTES: +# - tie head_weight with embeddings I think. + +# TODO: +# - class GPT3Config: config lol +# - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. +# - from starcoder import Embedding +# - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding +# - class GPTBLock: very similar to starcoder2 but make it so it support non-GQA or MQA +# - from starcoder import Loss + + +class CoreAttention(Starcoder2CoreAttention): + def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + super().__init__(config.as_starcoder2(), parallel_config, layer_idx) + self.gpt3config = config + + def forward(self, + query_states: torch.Tensor, # [batch_size * q_length, q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] + q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) + kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) + ): + + if self.gpt3config.use_spda: + assert torch.all(q_sequence_mask) + assert torch.all(kv_sequence_mask) + + batch_size, q_length = q_sequence_mask.size() + kv_length = kv_sequence_mask.size(1) + _, q_heads, head_dim = query_states.size() + kv_heads = key_states.size(1) + + attention_output = F.scaled_dot_product_attention( + query_states.view(batch_size, q_length, q_heads, head_dim).permute(0, 2, 1, 3), + key_states.view(batch_size, kv_length, kv_heads, head_dim).permute(0, 2, 1, 3), + value_states.view(batch_size, kv_length, kv_heads, head_dim).permute(0, 2, 1, 3), + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + ) # [batch, q_length, q_heads, head_dim] + attention_output = attention_output.permute(0, 2, 1, 3) + attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) + return attention_output + + assert query_states.dtype in {torch.bfloat16, torch.float16} + return super().forward(query_states, key_states, value_states, q_sequence_mask, kv_sequence_mask) + + +class CausalSelfAttention(CausalSelfGQA): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int, + ): + super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) + self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. + self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. + + +class MLP(Starcoder2MLP): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + # TODO: GPT3Config -> Starcoder2Config. + super().__init__(config, parallel_config, tp_pg) + self.dropout = nn.Dropout(p=config.dropout) # TODO: correct config.dropout name + + def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = self.dropout(input=hidden_states) + hidden_states = self.c_proj(hidden_states) + return {"hidden_states": hidden_states} + + +class GPTBlock(nn.Module): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + random_states: RandomStates, + layer_idx: int, + ): + super(GPTBlock, self).__init__() + self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.attn = CausalSelfAttention( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx + ) + self.attn_dropout = config.attn_pdrop + + self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.ff_dropout = config.resid_pdrop + + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + def forward( + self, + hidden_states: torch.Tensor | TensorPointer, + sequence_mask: torch.Tensor | TensorPointer, + ) -> dict[str, torch.Tensor | TensorPointer]: + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) + else: + # No need for random state context manager + # TODO: add dropout scaling? + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = self.ff(hidden_states=hidden_states)["hidden_states"] + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) + else: + # No need for random state context manager + # TODO: add dropout scaling? + hidden_states = hidden_states + residual + + return { + "hidden_states": hidden_states, + "sequence_mask": output["sequence_mask"], + } + + +class PositionEmbedding(nn.Module, AttachableStore): + def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config: Optional[ParallelismArgs]): + super().__init__() + + self.config = config + if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: + dummy_pos = 0 + else: + dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % k) + true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos + + if config.sinusoidal_position_embedding: + weight = self._make_weights(tp_pg, true_max_size, config.hidden_size) + else: + weight = None + + position_embedding = TensorParallelEmbedding( + num_embeddings=true_max_size, + embedding_dim=config.hidden_size, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + _weight=weight + ) + self.pg = tp_pg + + # Sinusoidal position embeddings are usually not trainable. + # We adjust that by setting the module self.position_embedding without gradient. + if config.sinusoidal_position_embedding: + with torch.no_grad(): + self.position_embedding = position_embedding.requires_grad_(False) + else: + self.position_embedding = position_embedding + + def forward(self, position_ids: torch.Tensor): # [batch_size, seq_length] + position_ids = position_ids.transpose(0, 1) + position_embeds = self.position_embedding(position_ids + self.config.position_embedding_offset) + return {"position_embeds": position_embeds} + + def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, + embedding_dim: int) -> torch.Tensor: + rank = dist.get_rank(group=tp_pg) + tp_size = tp_pg.size() + + assert 0 <= rank < tp_size + assert num_embeddings % tp_size == 0 + assert embedding_dim % 2 == 0 + block_size = num_embeddings//tp_size + + half_dim = embedding_dim//2 + emb = math.log(10_000)/(half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = (rank*block_size + torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) + return emb + + +class GPT3Model(nn.Module): + def __init__( + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + self.token_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"input_ids"}, + module_output_keys={"input_embeds"}, + ) + self.position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=PositionEmbedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"position_ids"}, + module_output_keys={"position_embeds"}, + ) + + self.embeds_dropout = PipelineBlock( + p2p=self.p2p, + module_builder=nn.Dropout, + module_kwargs={"p": config.embd_pdrop}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=GPTBlock, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "random_states": random_states, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "sequence_mask"}, + module_output_keys={"hidden_states", "sequence_mask"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=TritonLayerNorm, + module_kwargs={"normalized_shape": config.hidden_size, "eps": config.layer_norm_epsilon}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.hidden_size, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": parallel_config.tp_linear_async_communication + if parallel_config is not None + else False, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + self.cast_to_fp32 = PipelineBlock( + p2p=self.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) + + + def forward( + self, + input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length] + input_mask: torch.Tensor | TensorPointer, # [batch_size, seq_length] + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) + input_embeds = self.token_embeddings(input_ids=input_ids)["input_embeds"] + position_embeds = self.position_embeds(position_ids=position_ids)["position_embeds"] + hidden_states = input_embeds + position_embeds + + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode == TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = self.embeds_dropout(input=hidden_states)["hidden_states"] + + hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} + for encoder_block in self.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states) + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits From aceac6100f43301a5cac3c456dacfe9a6a95ca1e Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 26 Jun 2024 17:24:53 +0000 Subject: [PATCH 29/44] WIP: GPT arch almost done, hf->nt converters working perfectly for non-distributed inference --- examples/xglm/convert_hf2nt.py | 70 +++++++- examples/xglm/tests/test_attn.py | 74 --------- examples/xglm/tests/test_implementation.py | 135 +++++++++++++-- src/nanotron/config/models_config.py | 4 + src/nanotron/models/gpt3.py | 184 ++++++++++----------- 5 files changed, 287 insertions(+), 180 deletions(-) delete mode 100644 examples/xglm/tests/test_attn.py diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index e008f859..6e6ddff1 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -1,7 +1,44 @@ import torch +from torch import nn -from transformers.models.xglm.modeling_xglm import XGLMAttention -from nanotron.models.gpt3 import CausalSelfAttention +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM +from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining +from nanotron.config.models_config import GPT3Config + + +def convert_config(config: XGLMConfig) -> GPT3Config: + # TODOs: + # dropout=0.1, + # layerdrop=0.0, + # init_std=0.02, + # use_cache=True, + # decoder_start_token_id=2, + # pad_token_id=1, + # bos_token_id=0, + + # TODO: when going gpt3->xglm: + # - assert layernorm is 1e-05 + return GPT3Config( + activation_function=config.activation_function, + attn_pdrop=config.attention_dropout, + embd_pdrop=0.0, # TODO + eos_token_id=config.eos_token_id, + hidden_size=config.d_model, + intermediate_size=config.ffn_dim, + layer_norm_epsilon=1e-05, + max_position_embeddings=config.max_position_embeddings, + num_attention_heads=config.attention_heads, + num_hidden_layers=config.num_layers, + resid_pdrop=0.0, # TODO + scale_attention_softmax_in_fp32=True, + scale_attn_weights=True, + vocab_size=config.vocab_size, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=False, + act_pdrop=config.activation_dropout, + scale_embedding=config.scale_embedding, + ) def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): @@ -26,3 +63,32 @@ def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): attn_nt.query_key_value.bias.data = qkv_b.clone() attn_nt.dense.weight.data = attn_hf.out_proj.weight.clone() attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() + + +def convert_generic(module1: nn.Module, module2: nn.Module): + names1 = {name for name, _ in module1.named_parameters()} + names2 = {name for name, _ in module2.named_parameters()} + assert names1 == names2, f"{names1} != {names2}" + params2 = dict(module2.named_parameters()) + for name, param in module1.named_parameters(): + param.data = params2[name].clone() + + +def convert_mlp(mlp_nt: MLP, block_hf: XGLMDecoderLayer): + convert_generic(mlp_nt.c_fc, block_hf.fc1) + convert_generic(mlp_nt.c_proj, block_hf.fc2) + + +def convert_decoder(block_nt: GPTBlock, block_hf: XGLMDecoderLayer): + convert_generic(block_nt.ln_1, block_hf.self_attn_layer_norm) + convert_attention(block_nt.attn, block_hf.self_attn) + convert_generic(block_nt.ln_2, block_hf.final_layer_norm) + convert_mlp(block_nt.ff, block_hf) + + +def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): + convert_generic(model_nt.model.token_embeddings.pp_block.token_embedding, model_hf.model.embed_tokens) + for layer_nt, layer_hf in zip(model_nt.model.decoder, model_hf.model.layers): + convert_decoder(layer_nt.pp_block, layer_hf) + convert_generic(model_nt.model.final_layer_norm.pp_block, model_hf.model.layer_norm) + convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) diff --git a/examples/xglm/tests/test_attn.py b/examples/xglm/tests/test_attn.py deleted file mode 100644 index 2fcdb3a8..00000000 --- a/examples/xglm/tests/test_attn.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -from torch.nn import functional as F -#torch.Size([4, 2048, 16, 64]), torch.Size([2048, 4, 1024]) - -# inputs = (batchsize * qlen, heads, head_dim) -# outputs = (batchsize*qlen, heads, head_dim) -def sdpa(query, key, value, batchsize: int): - def reshape(tensor): # output = (batchsize, heads, qlen, head_dim) - return tensor.view(batchsize, qlen, heads, head_dim).permute(0, 2, 1, 3) - - batchsize_x_qlen, heads, head_dim = query.size() - qlen = batchsize_x_qlen//batchsize - out = F.scaled_dot_product_attention(reshape(query), reshape(key), reshape(value), is_causal=True) # (b,h,q,d) - return out.permute(0, 2, 1, 3).reshape(batchsize*qlen, heads, head_dim) - - -# inputs = (batchsize * qlen, heads, head_dim) -# outputs = (batchsize*qlen, heads, head_dim) -def fa(query_states, key_states, value_states, batchsize: int): - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - batchsize_x_qlen, heads, head_dim = query_states.size() - qlen = batchsize_x_qlen//batchsize - - q_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") - kv_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") - - # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) - - # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not - # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. - causal = False if q_sequence_mask.shape[1] == 1 else True - attn_output = flash_attn_varlen_func( - q=query_states, - k=key_states, - v=value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_sequence_mask.shape[1], - max_seqlen_k=kv_sequence_mask.shape[1], - dropout_p=0.0, - softmax_scale=None, # defaults to 1/sqrt(d_qk) - causal=causal, - window_size=(-1, -1), - return_attn_probs=False, - ) - return attn_output - - -def main(): - batchsize = 5 - qlen = 6 - heads = 2 - head_dim = 16 - - query = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - key = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - value = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - - out_pt = sdpa(query, key, value, batchsize) - out_fa = fa(query, key, value, batchsize) - - assert out_pt.size() == out_fa.size() - - torch.testing.assert_close(out_pt, out_fa) - - - -if __name__ == "__main__": - main() diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index 10f0302a..3636415b 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -1,27 +1,33 @@ +from typing import Optional + import numpy as np import torch import pytest -from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMSinusoidalPositionalEmbedding +from transformers import XGLMTokenizer +from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM +import nanotron from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import CausalSelfAttention, PositionEmbedding +from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock from nanotron.parallel import ParallelContext from tests.helpers.utils import init_distributed -from examples.xglm.convert_hf2nt import convert_attention +from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert SEQUENCE_LENGTH = 2048 BATCH_SIZE = 4 HIDDEN_SIZE = 1024 -DTYPE = torch.float64 +DTYPE = torch.bfloat16 +TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" CONFIG = GPT3Config( attn_pdrop=0.0, embd_pdrop=0.0, resid_pdrop=0.0, + act_pdrop=0.0, eos_token_id=2, hidden_size=HIDDEN_SIZE, intermediate_size=4096, @@ -42,11 +48,22 @@ def hidden_states() -> torch.Tensor: return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) - @pytest.fixture def input_mask() -> torch.Tensor: return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) +@pytest.fixture +def input_ids() -> torch.Tensor: + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH)) + + +def attention_mask() -> torch.Tensor: + # XGLM causal attention mask. + mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) + mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) + return mask + def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() @@ -56,14 +73,9 @@ def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tens attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) - # Build xglm mask. - mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) - mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) - mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) - convert_attention(attn_nt, attn_hf) out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] - out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=mask)[0].permute(1, 0, 2) + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" torch.testing.assert_close(out_nt, out_hf) @@ -88,3 +100,104 @@ def _test_position_embeddings(parallel_context: ParallelContext): def test_position_embeddings(): init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() + + +def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + config_hf = XGLMConfig() + decoder_hf = XGLMDecoderLayer(config_hf).cuda().to(DTYPE).eval() + config_nt = convert_config(config_hf) + if DTYPE not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + decoder_nt = GPTBlock(config_nt, None, parallel_context.tp_pg, random_states, 0).cuda().to(DTYPE).eval() + + convert_decoder(decoder_nt, decoder_hf) + + out_nt = decoder_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, + input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # Get hf model. + if model_hf is None: + config_hf = XGLMConfig() + model_hf = XGLMForCausalLM(config_hf).cuda().to(DTYPE).eval() + else: + model_hf = model_hf.cuda().to(DTYPE).eval() + config_hf = model_hf.config + + # Get nanotron model and make the conversion. + config_nt = convert_config(config_hf) + if DTYPE not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=DTYPE, + device="cuda", + ).eval() + convert(model_nt, model_hf) + + print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters()))/1000/1000) + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask).to(DTYPE) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt.cpu(), out_hf.cpu()) + +def _test_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + _test_model(None, parallel_context, input_ids, input_mask) + + +def test_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) + + +def _test_xglm7B(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") + _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + + +def test_xglm7B(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm7B)() + + +def _test_xglm500M(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") + _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + + +def test_xglm500M(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index f214b357..56d6411f 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -157,12 +157,16 @@ class GPT3Config: sinusoidal_position_embedding: bool = True position_embedding_offset: int = 2 use_spda: bool = False + act_pdrop: float = 0.0 + scale_embedding: bool = True def as_starcoder2(self) -> Starcoder2Config: config = dict(**vars(self)) del config["sinusoidal_position_embedding"] del config["use_spda"] del config["position_embedding_offset"] + del config["act_pdrop"] + del config["scale_embedding"] return Starcoder2Config( grouped_query=True, num_kv_heads=self.num_attention_heads, diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 8cea58c4..99f6ea85 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -2,6 +2,7 @@ import math from typing import Optional +from contextlib import contextmanager import torch from torch import nn @@ -9,11 +10,15 @@ from nanotron import distributed as dist from nanotron.parallel import ParallelContext -from nanotron.config import Config, GPT3Config, ParallelismArgs +from nanotron.config import Config, GPT3Config, ParallelismArgs, Starcoder2Config from nanotron.generation.generate_store import AttachableStore +from nanotron.models import starcoder2 +from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.models.starcoder2 import MLP as Starcoder2MLP +from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention -from nanotron.models.starcoder2 import CausalSelfGQA +from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock +from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel from nanotron.random import RandomStates, branch_random_state from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode @@ -28,10 +33,55 @@ # - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. # - from starcoder import Embedding # - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding -# - class GPTBLock: very similar to starcoder2 but make it so it support non-GQA or MQA +# - class GPTBlock: very similar to starcoder2 but make it so it support non-GQA or MQA # - from starcoder import Loss +@contextmanager +def replace_coreattention(gpt3config: GPT3Config): + orig = starcoder2.CoreAttention + try: + def create_core_attention(config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + return CoreAttention(gpt3config, parallel_config, layer_idx) + starcoder2.CoreAttention = create_core_attention + yield + finally: + starcoder2.CoreAttention = orig + + +@contextmanager +def replace_decoder(gpt3config: GPT3Config): + orig = starcoder2.PipelineBlock + try: + def create_pp_block(module_builder, module_kwargs, **kwargs): + if module_builder is Starcoder2GPTBlock: + # Starcoder2's GPT module is trying to instantiate a Starcoder2 GPTBlock. + # Let's return a PipelineBlock with a GPT3Block instead. + # This also requires to replace starcoders2's config with gpt3's config. + module_kwargs["config"] = gpt3config + return orig(module_builder=GPTBlock, module_kwargs=module_kwargs, **kwargs) + # Else, they are setting up other modules, which we also want unchanged. + return orig(module_builder=module_builder, module_kwargs=module_kwargs, **kwargs) + + starcoder2.PipelineBlock = create_pp_block + yield + finally: + starcoder2.PipelineBlock = orig + + +@contextmanager +def replace_gpt3model(gpt3config: GPT3Config): + orig = starcoder2.GPTModel + try: + def create_gptmodel(config: Starcoder2Config, parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], random_states: RandomStates): + return GPT3Model(gpt3config, parallel_context, parallel_config, random_states) + starcoder2.GPTModel = create_gptmodel + yield + finally: + starcoder2.GPTModel = orig + + class CoreAttention(Starcoder2CoreAttention): def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): super().__init__(config.as_starcoder2(), parallel_config, layer_idx) @@ -63,7 +113,7 @@ def forward(self, ) # [batch, q_length, q_heads, head_dim] attention_output = attention_output.permute(0, 2, 1, 3) attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) - return attention_output + return attention_output.contiguous() assert query_states.dtype in {torch.bfloat16, torch.float16} return super().forward(query_states, key_states, value_states, q_sequence_mask, kv_sequence_mask) @@ -77,9 +127,10 @@ def __init__( tp_pg: dist.ProcessGroup, layer_idx: int, ): - super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) + with replace_coreattention(config): + super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. - self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. + #self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. class MLP(Starcoder2MLP): @@ -88,10 +139,12 @@ def __init__( config: GPT3Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + random_states: RandomStates ): - # TODO: GPT3Config -> Starcoder2Config. - super().__init__(config, parallel_config, tp_pg) - self.dropout = nn.Dropout(p=config.dropout) # TODO: correct config.dropout name + super().__init__(config.as_starcoder2(), parallel_config, tp_pg) + self.dropout = nn.Dropout(p=config.act_pdrop) + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] hidden_states = self.c_fc(hidden_states) @@ -113,6 +166,7 @@ def __init__( random_states: RandomStates, layer_idx: int, ): + #print("New gpt block created :D") super(GPTBlock, self).__init__() self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention( @@ -124,7 +178,7 @@ def __init__( self.attn_dropout = config.attn_pdrop self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, random_states=random_states) self.ff_dropout = config.resid_pdrop self.random_states = random_states @@ -138,8 +192,10 @@ def forward( residual = hidden_states hidden_states = self.ln_1(hidden_states) + #hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] + #return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} if self.training: with branch_random_state( @@ -227,7 +283,7 @@ def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, return emb -class GPT3Model(nn.Module): +class GPT3Model(GPTModel): def __init__( self, config: GPT3Config, @@ -235,24 +291,9 @@ def __init__( parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): - super().__init__() + with replace_decoder(config): + super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) - # Declare all the nodes - self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) - self.random_states = random_states - self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - - self.token_embeddings = PipelineBlock( - p2p=self.p2p, - module_builder=Embedding, - module_kwargs={ - "tp_pg": parallel_context.tp_pg, - "config": config, - "parallel_config": parallel_config, - }, - module_input_keys={"input_ids"}, - module_output_keys={"input_embeds"}, - ) self.position_embeddings = PipelineBlock( p2p=self.p2p, module_builder=PositionEmbedding, @@ -264,69 +305,7 @@ def __init__( module_input_keys={"position_ids"}, module_output_keys={"position_embeds"}, ) - - self.embeds_dropout = PipelineBlock( - p2p=self.p2p, - module_builder=nn.Dropout, - module_kwargs={"p": config.embd_pdrop}, - module_input_keys={"input"}, - module_output_keys={"hidden_states"}, - ) - - self.decoder = nn.ModuleList( - [ - PipelineBlock( - p2p=self.p2p, - module_builder=GPTBlock, - module_kwargs={ - "config": config, - "parallel_config": parallel_config, - "tp_pg": parallel_context.tp_pg, - "random_states": random_states, - "layer_idx": layer_idx, - }, - module_input_keys={"hidden_states", "sequence_mask"}, - module_output_keys={"hidden_states", "sequence_mask"}, - ) - for layer_idx in range(config.num_hidden_layers) - ] - ) - - self.final_layer_norm = PipelineBlock( - p2p=self.p2p, - module_builder=TritonLayerNorm, - module_kwargs={"normalized_shape": config.hidden_size, "eps": config.layer_norm_epsilon}, - module_input_keys={"input"}, - module_output_keys={"hidden_states"}, - ) - - self.lm_head = PipelineBlock( - p2p=self.p2p, - # Understand that this means that we return sharded logits that are going to need to be gathered - module_builder=TensorParallelColumnLinear, - module_kwargs={ - "in_features": config.hidden_size, - "out_features": config.vocab_size, - "pg": parallel_context.tp_pg, - "bias": False, - # TODO: refactor so that we store that default in a single place. - "mode": self.tp_mode, - "async_communication": parallel_config.tp_linear_async_communication - if parallel_config is not None - else False, - }, - module_input_keys={"x"}, - module_output_keys={"logits"}, - ) - - self.cast_to_fp32 = PipelineBlock( - p2p=self.p2p, - module_builder=lambda: lambda x: x.float(), - module_kwargs={}, - module_input_keys={"x"}, - module_output_keys={"output"}, - ) - + self.embed_scale = config.hidden_size**0.5 if config.scale_embedding else 1.0 def forward( self, @@ -335,9 +314,9 @@ def forward( ): # all tensors are optional as most ranks don't need anything from the dataloader. + input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) - input_embeds = self.token_embeddings(input_ids=input_ids)["input_embeds"] - position_embeds = self.position_embeds(position_ids=position_ids)["position_embeds"] + position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] hidden_states = input_embeds + position_embeds with branch_random_state( @@ -348,6 +327,7 @@ def forward( hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) + #return hidden_encoder_states["hidden_states"] hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] @@ -356,3 +336,21 @@ def forward( fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] return fp32_sharded_logits + + +# TODO: maybe reimplement: +# - tie_custom_params +# - get_embeddings_lm_head_tied_names +# - get_block_compute_costs +# - get_flops_per_sec +class GPT3ForTraining(Starcoder2ForTraining): + def __init__( + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + with replace_gpt3model(config): + super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) + From fc9d062f65123aaf10779be3b6600394c990355f Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 9 Jul 2024 16:46:55 +0200 Subject: [PATCH 30/44] Added hf2nt frontend + tested training --- examples/xglm/README.md | 13 +++ examples/xglm/convert_hf2nt.py | 86 ++++++++++++++-- examples/xglm/example_config.yaml | 98 +++++++++++++++++++ src/nanotron/config/models_config.py | 6 +- src/nanotron/models/gpt3.py | 23 +---- .../optimizer_from_gradient_accumulator.py | 3 +- src/nanotron/trainer.py | 2 + 7 files changed, 199 insertions(+), 32 deletions(-) create mode 100644 examples/xglm/README.md create mode 100644 examples/xglm/example_config.yaml diff --git a/examples/xglm/README.md b/examples/xglm/README.md new file mode 100644 index 00000000..abc50f95 --- /dev/null +++ b/examples/xglm/README.md @@ -0,0 +1,13 @@ +# How to use XGLM? + +1. First, make sure to convert the weights from huggingface, for instance: + ``` + torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M + ``` + +1. Now you are ready to use XGLM. + Make sure you use a .yaml configuration with proper GPT3 config and then run for instance: + ``` + torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml + ``` + If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 6e6ddff1..9db5ed93 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -1,27 +1,42 @@ +""" +Converts a HF model to nanotron format +Command: + torchrun --nproc-per-node=1 convert_hf2nt.py --checkpoint-path=hf_weights --save-path=nanotron_weights +""" + +import json +import warnings +import dataclasses +from argparse import ArgumentParser +from pathlib import Path + import torch from torch import nn - from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM + +import nanotron from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from nanotron.config.models_config import GPT3Config +from nanotron.trainer import mark_tied_parameters + def convert_config(config: XGLMConfig) -> GPT3Config: # TODOs: - # dropout=0.1, # layerdrop=0.0, # init_std=0.02, # use_cache=True, - # decoder_start_token_id=2, # pad_token_id=1, # bos_token_id=0, - - # TODO: when going gpt3->xglm: - # - assert layernorm is 1e-05 + if config.dropout != config.attention_dropout: + warnings.warn(f"huggingface.dropout = {config.dropout} does not match with " + f"huggingface.attention_dropout = {config.attention_dropout}. " + "Nanotron implementation needs these two values to be equal " + "for correct conversion.") return GPT3Config( activation_function=config.activation_function, attn_pdrop=config.attention_dropout, - embd_pdrop=0.0, # TODO + embd_pdrop=config.dropout, eos_token_id=config.eos_token_id, hidden_size=config.d_model, intermediate_size=config.ffn_dim, @@ -29,12 +44,12 @@ def convert_config(config: XGLMConfig) -> GPT3Config: max_position_embeddings=config.max_position_embeddings, num_attention_heads=config.attention_heads, num_hidden_layers=config.num_layers, - resid_pdrop=0.0, # TODO + resid_pdrop=config.dropout, scale_attention_softmax_in_fp32=True, scale_attn_weights=True, vocab_size=config.vocab_size, sinusoidal_position_embedding=True, - position_embedding_offset=2, + position_embedding_offset=config.decoder_start_token_id, use_spda=False, act_pdrop=config.activation_dropout, scale_embedding=config.scale_embedding, @@ -92,3 +107,56 @@ def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): convert_decoder(layer_nt.pp_block, layer_hf) convert_generic(model_nt.model.final_layer_norm.pp_block, model_hf.model.layer_norm) convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) + + +def create_nt_model(model_config: GPT3Config, device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16) -> GPT3ForTraining: + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + #random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + return model_nt + + +def main(hf_path: str, save_path: Path): + # Load hf. + print("Loading hf...") + model_hf = XGLMForCausalLM.from_pretrained(hf_path) + + # Init nanotron. + print("Initializing nt...") + config_nt = convert_config(model_hf.config) + model_nt = create_nt_model(config_nt) + + # Copy weights and save model. + print("Copying weights...") + convert(model_nt, model_hf) + nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, + root_folder=save_path) + with open(save_path/"model_config.json", "w+") as f: + json.dump(dataclasses.asdict(config_nt), f) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument("--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint") + parser.add_argument("--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model") + args = parser.parse_args() + main(args.checkpoint_path, args.save_path) diff --git a/examples/xglm/example_config.yaml b/examples/xglm/example_config.yaml new file mode 100644 index 00000000..2d7e9926 --- /dev/null +++ b/examples/xglm/example_config.yaml @@ -0,0 +1,98 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints/xglm + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 64 + hf_dataset_config_name: null + hf_dataset_or_datasets: DKYoon/SlimPajama-6B + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Finetuning + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: xglm-test + run: xglm-dp4tp1pp1 + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + path: /capstor/scratch/cscs/ahernnde/checkpoints/xglm-564M + make_vocab_size_divisible_by: 1 + model_config: + activation_function: gelu + attn_pdrop: 0.1 + embd_pdrop: 0.1 + scale_embedding: true + eos_token_id: 2 + hidden_size: 1024 + intermediate_size: 4096 + layer_norm_epsilon: 0.00001 + max_position_embeddings: 2048 + num_attention_heads: 16 + num_hidden_layers: 24 + resid_pdrop: 0.1 + scale_attention_softmax_in_fp32: true + scale_attn_weights: true + vocab_size: 256008 + sinusoidal_position_embedding: true + position_embedding_offset: 2 + use_spda: false + act_pdrop: 0.0 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 900 + lr_decay_style: cosine + lr_warmup_steps: 100 + lr_warmup_style: linear + min_decay_lr: 1.0e-04 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 4 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: facebook/xglm-564M + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 4 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 8 + sequence_length: 2048 + train_steps: 1000 + val_check_interval: -1 diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 56d6411f..20a92126 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -167,6 +167,8 @@ def as_starcoder2(self) -> Starcoder2Config: del config["position_embedding_offset"] del config["act_pdrop"] del config["scale_embedding"] + if "_is_using_mup" in config: + del config["_is_using_mup"] return Starcoder2Config( grouped_query=True, num_kv_heads=self.num_attention_heads, @@ -174,6 +176,4 @@ def as_starcoder2(self) -> Starcoder2Config: **config ) - -NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any] - +NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 99f6ea85..33661c8b 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -18,24 +18,13 @@ from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock -from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel +from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel, dropout_add_fused_train from nanotron.random import RandomStates, branch_random_state from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding from nanotron.parallel.tied_parameters import tie_parameters -# NOTES: -# - tie head_weight with embeddings I think. - -# TODO: -# - class GPT3Config: config lol -# - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. -# - from starcoder import Embedding -# - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding -# - class GPTBlock: very similar to starcoder2 but make it so it support non-GQA or MQA -# - from starcoder import Loss - @contextmanager def replace_coreattention(gpt3config: GPT3Config): @@ -130,7 +119,6 @@ def __init__( with replace_coreattention(config): super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. - #self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. class MLP(Starcoder2MLP): @@ -204,7 +192,6 @@ def forward( hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) else: # No need for random state context manager - # TODO: add dropout scaling? hidden_states = hidden_states + residual residual = hidden_states @@ -218,7 +205,6 @@ def forward( hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) else: # No need for random state context manager - # TODO: add dropout scaling? hidden_states = hidden_states + residual return { @@ -235,7 +221,7 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: dummy_pos = 0 else: - dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % k) + dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size()) true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos if config.sinusoidal_position_embedding: @@ -278,7 +264,7 @@ def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, half_dim = embedding_dim//2 emb = math.log(10_000)/(half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) - emb = (rank*block_size + torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = (rank*block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) return emb @@ -315,6 +301,7 @@ def forward( # all tensors are optional as most ranks don't need anything from the dataloader. input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale + # TODO: position_ids could be cached. position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] hidden_states = input_embeds + position_embeds @@ -339,8 +326,6 @@ def forward( # TODO: maybe reimplement: -# - tie_custom_params -# - get_embeddings_lm_head_tied_names # - get_block_compute_costs # - get_flops_per_sec class GPT3ForTraining(Starcoder2ForTraining): diff --git a/src/nanotron/optim/optimizer_from_gradient_accumulator.py b/src/nanotron/optim/optimizer_from_gradient_accumulator.py index 01be7cb5..9883c720 100644 --- a/src/nanotron/optim/optimizer_from_gradient_accumulator.py +++ b/src/nanotron/optim/optimizer_from_gradient_accumulator.py @@ -38,7 +38,8 @@ def __init__( **{k: v for k, v in named_param_group.items() if k != "named_params"}, "named_params": [ (name, gradient_accumulator.get_parameter_for_optimizer(name)) - for name, _ in named_param_group["named_params"] + for name, param in named_param_group["named_params"] + if param.requires_grad ], } for named_param_group in named_param_groups diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 21251a32..f8022c52 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -58,6 +58,7 @@ from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp @@ -103,6 +104,7 @@ CONFIG_TO_MODEL_CLASS = { "LlamaConfig": LlamaForTraining, "Starcoder2Config": Starcoder2ForTraining, + "GPT3Config": GPT3ForTraining, } try: From abdf9c7733ec0010261a2af0acb35f69bb3bfc6b Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 11 Jul 2024 13:38:52 +0200 Subject: [PATCH 31/44] Added nt2hf conversion + tests :) --- examples/xglm/README.md | 5 + examples/xglm/convert_hf2nt.py | 38 +---- examples/xglm/convert_nt2hf.py | 126 +++++++++++++++ examples/xglm/convert_utils.py | 59 +++++++ examples/xglm/tests/test_implementation.py | 177 +++++++++++++++++---- src/nanotron/config/models_config.py | 4 + src/nanotron/models/gpt3.py | 2 +- 7 files changed, 347 insertions(+), 64 deletions(-) create mode 100644 examples/xglm/convert_nt2hf.py create mode 100644 examples/xglm/convert_utils.py diff --git a/examples/xglm/README.md b/examples/xglm/README.md index abc50f95..22765f52 100644 --- a/examples/xglm/README.md +++ b/examples/xglm/README.md @@ -11,3 +11,8 @@ torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml ``` If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. + +1. If you want to convert your finetuned checkpoint back to huggingface use: + ``` + torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M + ``` diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 9db5ed93..0efcceca 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -18,11 +18,11 @@ from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from nanotron.config.models_config import GPT3Config from nanotron.trainer import mark_tied_parameters - +from examples.xglm.convert_utils import convert_generic, create_nt_model def convert_config(config: XGLMConfig) -> GPT3Config: - # TODOs: + # These settings seem to be unused: # layerdrop=0.0, # init_std=0.02, # use_cache=True, @@ -80,15 +80,6 @@ def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() -def convert_generic(module1: nn.Module, module2: nn.Module): - names1 = {name for name, _ in module1.named_parameters()} - names2 = {name for name, _ in module2.named_parameters()} - assert names1 == names2, f"{names1} != {names2}" - params2 = dict(module2.named_parameters()) - for name, param in module1.named_parameters(): - param.data = params2[name].clone() - - def convert_mlp(mlp_nt: MLP, block_hf: XGLMDecoderLayer): convert_generic(mlp_nt.c_fc, block_hf.fc1) convert_generic(mlp_nt.c_proj, block_hf.fc2) @@ -109,31 +100,6 @@ def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) -def create_nt_model(model_config: GPT3Config, device: torch.device = torch.device("cuda"), - dtype: torch.dtype = torch.bfloat16) -> GPT3ForTraining: - - parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) - parallel_context = nanotron.parallel.ParallelContext( - data_parallel_size=parallel_config.dp, - pipeline_parallel_size=parallel_config.pp, - tensor_parallel_size=parallel_config.tp, - ) - #random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) - model_nt = nanotron.models.build_model( - model_builder=lambda: GPT3ForTraining( - config=model_config, - parallel_context=parallel_context, - parallel_config=parallel_config, - random_states=None, - ), - parallel_context=parallel_context, - dtype=dtype, - device=device, - ) - mark_tied_parameters(model=model_nt, parallel_context=parallel_context) - return model_nt - - def main(hf_path: str, save_path: Path): # Load hf. print("Loading hf...") diff --git a/examples/xglm/convert_nt2hf.py b/examples/xglm/convert_nt2hf.py new file mode 100644 index 00000000..422695a1 --- /dev/null +++ b/examples/xglm/convert_nt2hf.py @@ -0,0 +1,126 @@ +""" +Converts a nanotron model to HF format +Command: + torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights +""" + +from argparse import ArgumentParser +from typing import Optional +from pathlib import Path + +import torch +from transformers import AutoTokenizer +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM + +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining +from examples.xglm.convert_utils import convert_generic, create_nt_model + + +def convert_config(config: GPT3Config) -> XGLMConfig: + if config.embd_pdrop != config.resid_pdrop: + warnings.warn(f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion.") + if config.layer_norm_epsilon != 1e-5: + warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") + return XGLMConfig( + activation_function=config.activation_function, + attention_dropout=config.attn_pdrop, + dropout=config.embd_pdrop, + eos_token_id=config.eos_token_id, + d_model=config.hidden_size, + ffn_dim=config.intermediate_size, + max_position_embeddings=config.max_position_embeddings, + attention_heads=config.num_attention_heads, + num_layers=config.num_hidden_layers, + vocab_size=config.vocab_size, + decoder_start_token_id=config.position_embedding_offset, + activation_dropout=config.act_pdrop, + scale_embedding=config.scale_embedding, + ) + + +def convert_attention(attn_hf: XGLMAttention, attn_nt: XGLMAttention): + qs_w = [] + ks_w = [] + vs_w = [] + qs_b = [] + ks_b = [] + vs_b = [] + + head_dim = attn_hf.head_dim + qkv_ws = list(attn_nt.query_key_value.weight.split(head_dim)) + qkv_bs = list(attn_nt.query_key_value.bias.split(head_dim)) + for i, (w, b) in enumerate(zip(qkv_ws, qkv_bs)): + if i % 3 == 0: + qs_w.append(w) + qs_b.append(b) + elif i % 3 == 1: + ks_w.append(w) + ks_b.append(b) + else: + vs_w.append(w) + vs_b.append(b) + + q_w = torch.cat(qs_w) + k_w = torch.cat(ks_w) + v_w = torch.cat(vs_w) + q_b = torch.cat(qs_b) + k_b = torch.cat(ks_b) + v_b = torch.cat(vs_b) + + with torch.no_grad(): + attn_hf.q_proj.weight.data = q_w.clone() + attn_hf.k_proj.weight.data = k_w.clone() + attn_hf.v_proj.weight.data = v_w.clone() + attn_hf.q_proj.bias.data = q_b.clone() + attn_hf.k_proj.bias.data = k_b.clone() + attn_hf.v_proj.bias.data = v_b.clone() + + attn_hf.out_proj.weight.data = attn_nt.dense.weight.data.clone() + attn_hf.out_proj.bias.data = attn_nt.dense.bias.data.clone() + + +def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPTBlock): + convert_generic(block_hf.self_attn_layer_norm, block_nt.ln_1) + convert_attention(block_hf.self_attn, block_nt.attn) + convert_generic(block_hf.final_layer_norm, block_nt.ln_2) + convert_generic(block_hf.fc1, block_nt.ff.c_fc) + convert_generic(block_hf.fc2, block_nt.ff.c_proj) + + +def convert(model_hf: XGLMForCausalLM, model_nt: GPT3ForTraining): + convert_generic(model_hf.model.embed_tokens, model_nt.model.token_embeddings.pp_block.token_embedding) + for layer_hf, layer_nt in zip(model_hf.model.layers, model_nt.model.decoder): + convert_decoder(layer_hf, layer_nt.pp_block) + convert_generic(model_hf.model.layer_norm, model_nt.model.final_layer_norm.pp_block) + convert_generic(model_hf.lm_head, model_nt.model.lm_head.pp_block) + + +def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): + # Load nanotron model. + model_nt = create_nt_model(checkpoint_path=checkpoint_path) + + # Init huggingface model. + model_config_hf = convert_config(model_nt.config) + model_hf = XGLMForCausalLM._from_config(model_config_hf) + + # Copy weights, initialize tokenizer and save model. + if tokenizer_name is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.save_pretrained(save_path) + convert(model_hf, model_nt) + model_hf.save_pretrained(save_path) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument("--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint") + parser.add_argument("--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model") + parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") + args = parser.parse_args() + main(args.checkpoint_path, args.save_path, args.tokenizer_name) + diff --git a/examples/xglm/convert_utils.py b/examples/xglm/convert_utils.py new file mode 100644 index 00000000..88a731a1 --- /dev/null +++ b/examples/xglm/convert_utils.py @@ -0,0 +1,59 @@ +import json +from pathlib import Path +from typing import Optional + +import torch +from torch import nn + +import nanotron +from nanotron.models.gpt3 import GPT3ForTraining +from nanotron.config.models_config import GPT3Config +from nanotron.trainer import mark_tied_parameters + + +def convert_generic(module1: nn.Module, module2: nn.Module): + names1 = {name for name, _ in module1.named_parameters()} + names2 = {name for name, _ in module2.named_parameters()} + assert names1 == names2, f"{names1} != {names2}" + params2 = dict(module2.named_parameters()) + for name, param in module1.named_parameters(): + param.data = params2[name].clone() + + +def create_nt_model( + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None + ): + + if model_config is None: + assert checkpoint_path is not None + with open(checkpoint_path / "model_config.json") as f: + model_config = GPT3Config(**json.load(f)) + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + if checkpoint_path is not None: + nanotron.serialize.load_weights( + model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path + ) + + return model_nt diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index 3636415b..d9dc0f85 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -8,6 +8,7 @@ from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM import nanotron +from nanotron.trainer import mark_tied_parameters from nanotron.config.models_config import GPT3Config from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock from nanotron.parallel import ParallelContext @@ -15,12 +16,17 @@ from tests.helpers.utils import init_distributed from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert +from examples.xglm.convert_nt2hf import convert_attention as convert_attention_nt2hf +from examples.xglm.convert_nt2hf import convert_config as convert_config_nt2hf +from examples.xglm.convert_nt2hf import convert_decoder as convert_decoder_nt2hf +from examples.xglm.convert_nt2hf import convert as convert_nt2hf -SEQUENCE_LENGTH = 2048 +MAX_SEQUENCE_LENGTH = 2048 +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. BATCH_SIZE = 4 HIDDEN_SIZE = 1024 -DTYPE = torch.bfloat16 +DTYPE = torch.float64 TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" CONFIG = GPT3Config( @@ -32,7 +38,7 @@ hidden_size=HIDDEN_SIZE, intermediate_size=4096, layer_norm_epsilon=1e-05, - max_position_embeddings=SEQUENCE_LENGTH, + max_position_embeddings=MAX_SEQUENCE_LENGTH, num_attention_heads=16, num_hidden_layers=24, scale_attn_weights=True, @@ -45,25 +51,39 @@ @pytest.fixture def hidden_states() -> torch.Tensor: - return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) @pytest.fixture def input_mask() -> torch.Tensor: - return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) + return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) @pytest.fixture def input_ids() -> torch.Tensor: - return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH)) + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) + + +def almost_close(t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-5, rtol: float = 0.016, + max_far: float = 0.0, far_atol: float = 0.01): + very_close = torch.abs(t1 - t2) <= atol + rtol*torch.abs(t2) + not_very_close = ~very_close + + if torch.all(very_close): + return + assert torch.mean(not_very_close.float()) <= max_far, f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" + assert torch.all(torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" def attention_mask() -> torch.Tensor: # XGLM causal attention mask. - mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.ones(TEST_SEQUENCE_LENGTH, TEST_SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) return mask +## +# FROM HERE DOWN (until next comment), all tests are hf->nt +## def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() @@ -85,10 +105,10 @@ def test_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): def _test_position_embeddings(parallel_context: ParallelContext): - position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, SEQUENCE_LENGTH) + position_ids = torch.arange(TEST_SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, TEST_SEQUENCE_LENGTH) emb_nt = PositionEmbedding(parallel_context.tp_pg, CONFIG, None).cuda() - emb_hf = XGLMSinusoidalPositionalEmbedding(SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() + emb_hf = XGLMSinusoidalPositionalEmbedding(MAX_SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() assert emb_nt.position_embedding.weight.size() == emb_hf.weights.size() torch.testing.assert_close(emb_nt.position_embedding.weight, emb_hf.weights) @@ -120,7 +140,7 @@ def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt, out_hf) + torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): @@ -129,21 +149,25 @@ def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) input_ids = input_ids.cuda() input_mask = input_mask.cuda() + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + # Get hf model. if model_hf is None: config_hf = XGLMConfig() - model_hf = XGLMForCausalLM(config_hf).cuda().to(DTYPE).eval() + model_hf = XGLMForCausalLM(config_hf).cuda().to(new_dtype).eval() else: - model_hf = model_hf.cuda().to(DTYPE).eval() + model_hf = model_hf.cuda().to(new_dtype).eval() config_hf = model_hf.config # Get nanotron model and make the conversion. config_nt = convert_config(config_hf) - if DTYPE not in {torch.bfloat16, torch.float16}: + if new_dtype not in {torch.bfloat16, torch.float16}: config_nt.use_spda = True model_nt = nanotron.models.build_model( model_builder=lambda: GPT3ForTraining( @@ -153,7 +177,7 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC random_states=random_states, ), parallel_context=parallel_context, - dtype=DTYPE, + dtype=new_dtype, device="cuda", ).eval() convert(model_nt, model_hf) @@ -162,42 +186,141 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC # Get outputs and assert. with torch.no_grad(): - out_nt = model_nt.model(input_ids, input_mask).to(DTYPE) + out_nt = model_nt.model(input_ids, input_mask).to(new_dtype) del model_nt torch.cuda.empty_cache() out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) del model_hf torch.cuda.empty_cache() assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.cpu(), out_hf.cpu()) + return out_nt.cpu(), out_hf.cpu() + def _test_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): - _test_model(None, parallel_context, input_ids, input_mask) + out_nt, out_hf = _test_model(None, parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.05) def test_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): init_distributed(tp=1, dp=1, pp=1)(_test_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) +def _test_xglm500M(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") + out_nt, out_hf = _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.05) + + +def test_xglm500M(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() + + def _test_xglm7B(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") - _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + almost_close(out_nt, out_hf, max_far=0.15, far_atol=0.1) def test_xglm7B(): init_distributed(tp=1, dp=1, pp=1)(_test_xglm7B)() -def _test_xglm500M(parallel_context: ParallelContext): - tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") - tokenized = tok(TEXT) - model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") - _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) +## +# From here down we test nt->hf converters +## +def _test_nt2hf_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() -def test_xglm500M(): - init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() + attn_nt = CausalSelfAttention(CONFIG, None, parallel_context.tp_pg, 0).cuda().eval().to(DTYPE) + attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) + assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) + + convert_attention_nt2hf(attn_hf, attn_nt) + out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_nt2hf_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_attention)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_nt2hf_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + config_hf = convert_config_nt2hf(CONFIG) + decoder_nt = GPTBlock(CONFIG, None, parallel_context.tp_pg, random_states, 0).cuda().to(DTYPE).eval() + decoder_hf = XGLMDecoderLayer(config_hf).cuda().to(DTYPE).eval() + + convert_decoder_nt2hf(decoder_hf, decoder_nt) + + out_nt = decoder_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + + +def test_nt2hf_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + + # Get nanotron model. + config_nt = GPT3Config(**vars(CONFIG)) + if new_dtype not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=new_dtype, + device="cuda", + ).eval() + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + # Create empty model_hf and make conversion. + model_hf = XGLMForCausalLM(convert_config_nt2hf(config_nt)).cuda().to(new_dtype).eval() + convert_nt2hf(model_hf, model_nt) + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask).to(new_dtype) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + return out_nt.cpu(), out_hf.cpu() + + +def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.01, far_atol=0.02) + + +def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 20a92126..5f59a439 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -176,4 +176,8 @@ def as_starcoder2(self) -> Starcoder2Config: **config ) + @property + def n_inner(self): + return self.intermediate_size + NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 33661c8b..7d4e6f82 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -338,4 +338,4 @@ def __init__( ): with replace_gpt3model(config): super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) - + self.config = config From 819fdd54392b87083194ee369e24f33db2b4e6de Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 11 Jul 2024 14:32:44 +0200 Subject: [PATCH 32/44] precommit --- examples/xglm/convert_hf2nt.py | 33 ++++---- examples/xglm/convert_nt2hf.py | 28 ++++--- examples/xglm/convert_utils.py | 21 +++-- examples/xglm/tests/test_implementation.py | 89 ++++++++++++++-------- src/nanotron/config/models_config.py | 8 +- src/nanotron/models/gpt3.py | 85 ++++++++++++--------- src/nanotron/trainer.py | 2 +- 7 files changed, 154 insertions(+), 112 deletions(-) diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 0efcceca..c18a1ab8 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -4,20 +4,18 @@ torchrun --nproc-per-node=1 convert_hf2nt.py --checkpoint-path=hf_weights --save-path=nanotron_weights """ +import dataclasses import json import warnings -import dataclasses from argparse import ArgumentParser from pathlib import Path +import nanotron import torch -from torch import nn +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import MLP, CausalSelfAttention, GPT3ForTraining, GPTBlock from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM -import nanotron -from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining -from nanotron.config.models_config import GPT3Config -from nanotron.trainer import mark_tied_parameters from examples.xglm.convert_utils import convert_generic, create_nt_model @@ -29,10 +27,12 @@ def convert_config(config: XGLMConfig) -> GPT3Config: # pad_token_id=1, # bos_token_id=0, if config.dropout != config.attention_dropout: - warnings.warn(f"huggingface.dropout = {config.dropout} does not match with " - f"huggingface.attention_dropout = {config.attention_dropout}. " - "Nanotron implementation needs these two values to be equal " - "for correct conversion.") + warnings.warn( + f"huggingface.dropout = {config.dropout} does not match with " + f"huggingface.attention_dropout = {config.attention_dropout}. " + "Nanotron implementation needs these two values to be equal " + "for correct conversion." + ) return GPT3Config( activation_function=config.activation_function, attn_pdrop=config.attention_dropout, @@ -113,16 +113,19 @@ def main(hf_path: str, save_path: Path): # Copy weights and save model. print("Copying weights...") convert(model_nt, model_hf) - nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, - root_folder=save_path) - with open(save_path/"model_config.json", "w+") as f: + nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, root_folder=save_path) + with open(save_path / "model_config.json", "w+") as f: json.dump(dataclasses.asdict(config_nt), f) print(f"Model saved to {save_path}") if __name__ == "__main__": parser = ArgumentParser(description="Convert HF weights to nanotron format") - parser.add_argument("--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint") - parser.add_argument("--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model") + parser.add_argument( + "--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model" + ) args = parser.parse_args() main(args.checkpoint_path, args.save_path) diff --git a/examples/xglm/convert_nt2hf.py b/examples/xglm/convert_nt2hf.py index 422695a1..81816aa9 100644 --- a/examples/xglm/convert_nt2hf.py +++ b/examples/xglm/convert_nt2hf.py @@ -4,25 +4,28 @@ torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights """ +import warnings from argparse import ArgumentParser -from typing import Optional from pathlib import Path +from typing import Optional import torch +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import GPT3ForTraining, GPTBlock from transformers import AutoTokenizer from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM -from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from examples.xglm.convert_utils import convert_generic, create_nt_model def convert_config(config: GPT3Config) -> XGLMConfig: if config.embd_pdrop != config.resid_pdrop: - warnings.warn(f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " - f"nanotron.resid_pdrop = {config.resid_pdrop}. " - "XGLM implementation needs these two values to be equal " - "for correct conversion.") + warnings.warn( + f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion." + ) if config.layer_norm_epsilon != 1e-5: warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") return XGLMConfig( @@ -70,7 +73,7 @@ def convert_attention(attn_hf: XGLMAttention, attn_nt: XGLMAttention): q_b = torch.cat(qs_b) k_b = torch.cat(ks_b) v_b = torch.cat(vs_b) - + with torch.no_grad(): attn_hf.q_proj.weight.data = q_w.clone() attn_hf.k_proj.weight.data = k_w.clone() @@ -118,9 +121,12 @@ def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): if __name__ == "__main__": parser = ArgumentParser(description="Convert HF weights to nanotron format") - parser.add_argument("--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint") - parser.add_argument("--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model") + parser.add_argument( + "--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model" + ) parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") args = parser.parse_args() main(args.checkpoint_path, args.save_path, args.tokenizer_name) - diff --git a/examples/xglm/convert_utils.py b/examples/xglm/convert_utils.py index 88a731a1..75d67782 100644 --- a/examples/xglm/convert_utils.py +++ b/examples/xglm/convert_utils.py @@ -2,13 +2,12 @@ from pathlib import Path from typing import Optional -import torch -from torch import nn - import nanotron -from nanotron.models.gpt3 import GPT3ForTraining +import torch from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.trainer import mark_tied_parameters +from torch import nn def convert_generic(module1: nn.Module, module2: nn.Module): @@ -21,11 +20,11 @@ def convert_generic(module1: nn.Module, module2: nn.Module): def create_nt_model( - model_config: Optional[GPT3Config] = None, - device: torch.device = torch.device("cuda"), - dtype: torch.dtype = torch.bfloat16, - checkpoint_path: Optional[Path] = None - ): + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None, +): if model_config is None: assert checkpoint_path is not None @@ -52,8 +51,6 @@ def create_nt_model( mark_tied_parameters(model=model_nt, parallel_context=parallel_context) if checkpoint_path is not None: - nanotron.serialize.load_weights( - model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path - ) + nanotron.serialize.load_weights(model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path) return model_nt diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index d9dc0f85..a25d7881 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -1,29 +1,31 @@ from typing import Optional +import nanotron import numpy as np -import torch import pytest - -from transformers import XGLMTokenizer -from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM - -import nanotron -from nanotron.trainer import mark_tied_parameters +import torch from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock +from nanotron.models.gpt3 import CausalSelfAttention, GPT3ForTraining, GPTBlock, PositionEmbedding from nanotron.parallel import ParallelContext +from nanotron.trainer import mark_tied_parameters +from transformers import XGLMTokenizer +from transformers.models.xglm.modeling_xglm import ( + XGLMAttention, + XGLMConfig, + XGLMDecoderLayer, + XGLMForCausalLM, + XGLMSinusoidalPositionalEmbedding, +) -from tests.helpers.utils import init_distributed - -from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert +from examples.xglm.convert_hf2nt import convert, convert_attention, convert_config, convert_decoder +from examples.xglm.convert_nt2hf import convert as convert_nt2hf from examples.xglm.convert_nt2hf import convert_attention as convert_attention_nt2hf from examples.xglm.convert_nt2hf import convert_config as convert_config_nt2hf from examples.xglm.convert_nt2hf import convert_decoder as convert_decoder_nt2hf -from examples.xglm.convert_nt2hf import convert as convert_nt2hf - +from tests.helpers.utils import init_distributed MAX_SEQUENCE_LENGTH = 2048 -TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. BATCH_SIZE = 4 HIDDEN_SIZE = 1024 DTYPE = torch.float64 @@ -45,33 +47,44 @@ vocab_size=256008, sinusoidal_position_embedding=True, position_embedding_offset=2, - use_spda=True + use_spda=True, ) @pytest.fixture def hidden_states() -> torch.Tensor: - return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, - dtype=DTYPE) + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) + @pytest.fixture def input_mask() -> torch.Tensor: return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) + @pytest.fixture def input_ids() -> torch.Tensor: return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) -def almost_close(t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-5, rtol: float = 0.016, - max_far: float = 0.0, far_atol: float = 0.01): - very_close = torch.abs(t1 - t2) <= atol + rtol*torch.abs(t2) +def almost_close( + t1: torch.Tensor, + t2: torch.Tensor, + atol: float = 1e-5, + rtol: float = 0.016, + max_far: float = 0.0, + far_atol: float = 0.01, +): + very_close = torch.abs(t1 - t2) <= atol + rtol * torch.abs(t2) not_very_close = ~very_close if torch.all(very_close): return - assert torch.mean(not_very_close.float()) <= max_far, f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" - assert torch.all(torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" + assert ( + torch.mean(not_very_close.float()) <= max_far + ), f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" + assert torch.all( + torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol + ), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" def attention_mask() -> torch.Tensor: @@ -81,10 +94,12 @@ def attention_mask() -> torch.Tensor: mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) return mask + ## # FROM HERE DOWN (until next comment), all tests are hf->nt ## + def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() sequence_mask = sequence_mask.cuda() @@ -118,6 +133,7 @@ def _test_position_embeddings(parallel_context: ParallelContext): assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" torch.testing.assert_close(out_nt, out_hf) + def test_position_embeddings(): init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() @@ -140,15 +156,21 @@ def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + torch.testing.assert_close( + out_nt.bfloat16(), out_hf.bfloat16() + ) # We cast to bf16 to get more relaxed constraints. def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): init_distributed(tp=1, dp=1, pp=1)(_test_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) -def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, - input_ids: torch.Tensor, input_mask: torch.Tensor): +def _test_model( + model_hf: Optional[XGLMForCausalLM], + parallel_context: ParallelContext, + input_ids: torch.Tensor, + input_mask: torch.Tensor, +): random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) input_ids = input_ids.cuda() @@ -182,7 +204,7 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC ).eval() convert(model_nt, model_hf) - print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters()))/1000/1000) + print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters())) / 1000 / 1000) # Get outputs and assert. with torch.no_grad(): @@ -209,8 +231,9 @@ def _test_xglm500M(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") - out_nt, out_hf = _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model( + model_hf, parallel_context, torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]]) + ) almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.05) @@ -222,8 +245,9 @@ def _test_xglm7B(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") - out_nt, out_hf = _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model( + model_hf, parallel_context, torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]]) + ) almost_close(out_nt, out_hf, max_far=0.15, far_atol=0.1) @@ -235,6 +259,7 @@ def test_xglm7B(): # From here down we test nt->hf converters ## + def _test_nt2hf_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() sequence_mask = sequence_mask.cuda() @@ -269,7 +294,9 @@ def _test_nt2hf_decoder(parallel_context: ParallelContext, hidden_states: torch. out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + torch.testing.assert_close( + out_nt.bfloat16(), out_hf.bfloat16() + ) # We cast to bf16 to get more relaxed constraints. def test_nt2hf_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 5f59a439..af7db5cc 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, List, Optional, Union +from typing import List, Optional @dataclass @@ -170,14 +170,12 @@ def as_starcoder2(self) -> Starcoder2Config: if "_is_using_mup" in config: del config["_is_using_mup"] return Starcoder2Config( - grouped_query=True, - num_kv_heads=self.num_attention_heads, - use_rotary_embeddings=False, - **config + grouped_query=True, num_kv_heads=self.num_attention_heads, use_rotary_embeddings=False, **config ) @property def n_inner(self): return self.intermediate_size + NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 7d4e6f82..25e5f78b 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -1,37 +1,40 @@ """PyTorch GPT-3 model.""" import math -from typing import Optional from contextlib import contextmanager +from typing import Optional import torch from torch import nn from torch.nn import functional as F from nanotron import distributed as dist -from nanotron.parallel import ParallelContext -from nanotron.config import Config, GPT3Config, ParallelismArgs, Starcoder2Config +from nanotron.config import GPT3Config, ParallelismArgs, Starcoder2Config from nanotron.generation.generate_store import AttachableStore from nanotron.models import starcoder2 -from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.models.starcoder2 import MLP as Starcoder2MLP -from nanotron.parallel.pipeline_parallel.block import PipelineBlock +from nanotron.models.starcoder2 import CausalSelfGQA, GPTModel, Starcoder2ForTraining, dropout_add_fused_train from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock -from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel, dropout_add_fused_train -from nanotron.random import RandomStates, branch_random_state +from nanotron.nn.layer_norm import TritonLayerNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding -from nanotron.parallel.tied_parameters import tie_parameters +from nanotron.random import RandomStates, branch_random_state @contextmanager def replace_coreattention(gpt3config: GPT3Config): orig = starcoder2.CoreAttention try: - def create_core_attention(config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + + def create_core_attention( + config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int + ): return CoreAttention(gpt3config, parallel_config, layer_idx) + starcoder2.CoreAttention = create_core_attention yield finally: @@ -42,6 +45,7 @@ def create_core_attention(config: Starcoder2Config, parallel_config: Optional[Pa def replace_decoder(gpt3config: GPT3Config): orig = starcoder2.PipelineBlock try: + def create_pp_block(module_builder, module_kwargs, **kwargs): if module_builder is Starcoder2GPTBlock: # Starcoder2's GPT module is trying to instantiate a Starcoder2 GPTBlock. @@ -62,9 +66,15 @@ def create_pp_block(module_builder, module_kwargs, **kwargs): def replace_gpt3model(gpt3config: GPT3Config): orig = starcoder2.GPTModel try: - def create_gptmodel(config: Starcoder2Config, parallel_context: ParallelContext, - parallel_config: Optional[ParallelismArgs], random_states: RandomStates): + + def create_gptmodel( + config: Starcoder2Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): return GPT3Model(gpt3config, parallel_context, parallel_config, random_states) + starcoder2.GPTModel = create_gptmodel yield finally: @@ -76,7 +86,8 @@ def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs super().__init__(config.as_starcoder2(), parallel_config, layer_idx) self.gpt3config = config - def forward(self, + def forward( + self, query_states: torch.Tensor, # [batch_size * q_length, q_heads, inner_dim] key_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] value_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] @@ -101,7 +112,7 @@ def forward(self, is_causal=True, ) # [batch, q_length, q_heads, head_dim] attention_output = attention_output.permute(0, 2, 1, 3) - attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) + attention_output = attention_output.reshape(batch_size * q_length, q_heads, head_dim) return attention_output.contiguous() assert query_states.dtype in {torch.bfloat16, torch.float16} @@ -127,7 +138,7 @@ def __init__( config: GPT3Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, - random_states: RandomStates + random_states: RandomStates, ): super().__init__(config.as_starcoder2(), parallel_config, tp_pg) self.dropout = nn.Dropout(p=config.act_pdrop) @@ -154,14 +165,11 @@ def __init__( random_states: RandomStates, layer_idx: int, ): - #print("New gpt block created :D") + # print("New gpt block created :D") super(GPTBlock, self).__init__() self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - layer_idx=layer_idx + config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx ) self.attn_dropout = config.attn_pdrop @@ -180,10 +188,10 @@ def forward( residual = hidden_states hidden_states = self.ln_1(hidden_states) - #hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) + # hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] - #return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} + # return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} if self.training: with branch_random_state( @@ -221,7 +229,9 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: dummy_pos = 0 else: - dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size()) + dummy_pos = tp_pg.size() - ( + (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() + ) true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos if config.sinusoidal_position_embedding: @@ -234,7 +244,7 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config embedding_dim=config.hidden_size, pg=tp_pg, mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, - _weight=weight + _weight=weight, ) self.pg = tp_pg @@ -251,32 +261,31 @@ def forward(self, position_ids: torch.Tensor): # [batch_size, seq_length] position_embeds = self.position_embedding(position_ids + self.config.position_embedding_offset) return {"position_embeds": position_embeds} - def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, - embedding_dim: int) -> torch.Tensor: + def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, embedding_dim: int) -> torch.Tensor: rank = dist.get_rank(group=tp_pg) tp_size = tp_pg.size() assert 0 <= rank < tp_size assert num_embeddings % tp_size == 0 assert embedding_dim % 2 == 0 - block_size = num_embeddings//tp_size + block_size = num_embeddings // tp_size - half_dim = embedding_dim//2 - emb = math.log(10_000)/(half_dim - 1) + half_dim = embedding_dim // 2 + emb = math.log(10_000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) - emb = (rank*block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = (rank * block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) return emb class GPT3Model(GPTModel): def __init__( - self, - config: GPT3Config, - parallel_context: ParallelContext, - parallel_config: Optional[ParallelismArgs], - random_states: RandomStates, - ): + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): with replace_decoder(config): super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) @@ -300,7 +309,9 @@ def forward( ): # all tensors are optional as most ranks don't need anything from the dataloader. - input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale + input_embeds = ( + self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"] * self.embed_scale + ) # TODO: position_ids could be cached. position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] @@ -314,7 +325,7 @@ def forward( hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) - #return hidden_encoder_states["hidden_states"] + # return hidden_encoder_states["hidden_states"] hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index f8022c52..e8fbb8cc 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -56,9 +56,9 @@ ) from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining -from nanotron.models.gpt3 import GPT3ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp From 0914292ebc6f6d7d257aa90e557d6340bf8c356f Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 11:45:28 +0000 Subject: [PATCH 33/44] Added MultilingualNanoset Config --- src/nanotron/config/config.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index adc1eafd..b30dea8e 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -107,6 +107,27 @@ def __post_init__(self): self.dataset_weights = list(tmp_dataset_folder.values()) +@dataclass +class MultilingualNanosetDatasetsArgs: + dataset_folder: Union[str, dict, List[str]] + dataset_tokens: List[ + int + ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) + + def __post_init__(self): + if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file + self.dataset_folder = [self.dataset_folder] + self.dataset_weights = [1] + elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + self.dataset_weights = None # Set to None so we consume all the samples randomly + elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights + tmp_dataset_folder = self.dataset_folder.copy() + self.dataset_folder = list(tmp_dataset_folder.keys()) + self.dataset_weights = list(tmp_dataset_folder.values()) + + assert len(self.dataset_folder) == len(self.dataset_tokens) + + @dataclass class DataArgs: """Arguments related to the data and data files processing""" From 659a0a0b40ac069166c72ef9985a1bf19bfe4df4 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 11:48:51 +0000 Subject: [PATCH 34/44] Added MultilingualNanoset --- run_train.py | 125 +++++++++++- src/nanotron/data/multilingual_nanoset.py | 221 ++++++++++++++++++++++ 2 files changed, 343 insertions(+), 3 deletions(-) create mode 100644 src/nanotron/data/multilingual_nanoset.py diff --git a/run_train.py b/run_train.py index 021d955d..649784ca 100644 --- a/run_train.py +++ b/run_train.py @@ -12,7 +12,13 @@ import numpy as np from nanotron import logging -from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs +from nanotron.config import ( + DataArgs, + DatasetStageArgs, + MultilingualNanosetDatasetsArgs, + NanosetDatasetsArgs, + PretrainDatasetsArgs, +) from nanotron.data.dataloader_builder import build_nanoset_dataloader from nanotron.dataloader import ( clm_process, @@ -171,6 +177,40 @@ def get_dataloader_from_data_stage( dataloader_drop_last=True, ) + return train_dataloader + # Case 4: MultilingualNanosets + elif isinstance(data.dataset, MultilingualNanosetDatasetsArgs): + # Get tokenizer cardinality + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + # Create Nanoset + from nanotron.data.multilingual_nanoset import MultilingualNanoset + + with main_rank_first(trainer.parallel_context.world_pg): + train_dataset = MultilingualNanoset( + dataset_folders=data.dataset.dataset_folder, + dataset_weights=data.dataset.dataset_weights, + sequence_length=trainer.sequence_length, + token_size=token_size, + train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, + random_seed=data.seed, + ) + + # Prepare dataloader + train_dataloader = build_nanoset_dataloader( + train_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=consumed_train_samples, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + return train_dataloader else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") @@ -178,6 +218,57 @@ def get_dataloader_from_data_stage( return dataloader +def get_valid_dataloader_from_data_stage( + trainer: DistributedTrainer, + data: DataArgs, + valid_split_num_samples: int, + # consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples +): + + # First, we need to know which ranks to feed the dataloader to + input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) + + # Only support Validation with MultilingualNanosets + if isinstance(data.dataset, NanosetDatasetsArgs): + # Get tokenizer cardinality + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + # Create Multilingual Nanoset + from nanotron.data.multilingual_nanoset import MultilingualNanoset + + with main_rank_first(trainer.parallel_context.world_pg): + valid_dataset = MultilingualNanoset( + dataset_folders=data.dataset.dataset_folder, + dataset_weights=data.dataset.dataset_weights, + sequence_length=trainer.sequence_length, + token_size=token_size, + train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + valid_split_num_samples=valid_split_num_samples, + is_valid=True, + random_seed=data.seed, + ) + + # Prepare dataloader + valid_dataloader = build_nanoset_dataloader( + valid_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=0, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + + return valid_dataloader + else: + raise ValueError( + f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}. Validation is currently just supported for MultilingualNanoset" + ) + + def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloaders = {} @@ -219,6 +310,33 @@ def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: return dataloaders +def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: + dataloaders = {} + + for stage_idx, stage in enumerate(trainer.config.data_stages): + # NOTE: we only create the dataloader for the first stage, + # then we lazy initialize the dataloader for the other stages + stage = cast(DatasetStageArgs, stage) + valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size + + log_rank( + f"[Training Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = ( + get_valid_dataloader_from_data_stage(trainer, stage.data, valid_split_num_samples=valid_split_num_samples) + if stage_idx == 0 + else lambda stage=stage: get_dataloader_from_data_stage( + trainer, stage.data, valid_split_num_samples=valid_split_num_samples + ) + ) + dataloaders[stage.name] = dataloader + return dataloaders + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") @@ -231,7 +349,8 @@ def get_args(): # Load trainer and data trainer = DistributedTrainer(config_file) - dataloader = get_dataloader(trainer) + train_dataloader = get_dataloader(trainer) + valid_dataloader = get_valid_dataloader(trainer) # Train - trainer.train(dataloader) + trainer.train(train_dataloader, valid_dataloader) diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py new file mode 100644 index 00000000..40e06b87 --- /dev/null +++ b/src/nanotron/data/multilingual_nanoset.py @@ -0,0 +1,221 @@ +import os +import warnings +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +from datatrove.utils.dataset import DatatroveFolderDataset +from nanotron import logging +from nanotron.data.utils import count_dataset_indexes, normalize +from nanotron.logging import log_rank +from numba import jit + +logger = logging.get_logger(__name__) + + +class MultilingualNanoset(torch.utils.data.Dataset): + """ + The Nanoset dataset + + Args: + dataset_folders (List[str]): List of folders with tokenized datasets + dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ + sequence_length (int): Sequence length of the built samples + token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise + train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size + """ + + def __init__( + self, + dataset_folders: List[str], + sequence_length: int, + token_size: int, + train_split_num_samples: int, + valid_split_num_samples: int, + is_valid: bool = False, + dataset_weights: Union[List[float], None] = None, + random_seed: int = 1234, + ) -> None: + + # Checks + if isinstance(dataset_folders, str): + warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]") + dataset_folders = [dataset_folders] + + # Init + self.dataset_folders = dataset_folders + self.sequence_length = sequence_length + self.token_size = token_size + self.train_split_num_samples = train_split_num_samples + self.valid_split_num_samples = valid_split_num_samples + self.is_valid = is_valid + self.random_seed = random_seed + self.datatrove_datasets = [] + for dataset_folder in self.dataset_folders: + self.datatrove_datasets.append( + DatatroveFolderDataset( + folder_path=dataset_folder, + filename_pattern=os.path.join(dataset_folder, "*.ds"), + seq_len=sequence_length, + recursive=False, + token_size=token_size, + shuffle=True, + ) + ) + + # Build Nanoset Index + ## To build the index we need the length of each dataset + self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets] + ## Set dataset weights + if ( + dataset_weights is None + ): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch + self.dataset_weights = normalize(self.dataset_lengths) + else: + self.dataset_weights = normalize(dataset_weights) + assert len(dataset_folders) == len( + self.dataset_weights + ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." + ## Build dataset index and dataset sample index + ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts + self.valid_dataset_lenghts = self.dataset_weights * valid_split_num_samples + # Assert that we have sufficient samples to build the valid split + for ds_index in range(len(self.dataset_lengths)): + assert ( + self.valid_dataset_lenghts[ds_index] > self.dataset_lengths[ds_index] + ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." + self.train_dataset_lenghts = [ + a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) + ] # Subtract the valid samples from the training dataset + + if is_valid: # Valid MultilingualNanoset + self.split_num_samples = valid_split_num_samples + self.split_samples_per_epoch = valid_split_num_samples + self.num_epochs = 1 + self.split_dataset_lenghts = self.valid_dataset_lenghts + self.split_dataset_offsets = self.train_dataset_lenghts + + else: # Train MultilingualNanoset + self.split_num_samples = train_split_num_samples + self.split_samples_per_epoch = sum(self.train_dataset_lenghts) + self.num_epochs = int(self.split_num_samples / self.split_samples_per_epoch) + 1 + self.split_dataset_lenghts = self.train_dataset_lenghts + self.split_dataset_offsets = [ + 0 for _ in range(len(self.dataset_lengths)) + ] # For training there is NO offset + + self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() + + self.print_nanoset_info() + + def __len__(self) -> int: + """ + Returns: + int: The number of samples of the Nanoset + """ + + return len(self.dataset_index) + + def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: + """ + Returns sequence_length + 1 tokens from the memmap dataset + + Args: + idx (int): The index into the dataset + + Returns: + Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary + """ + dataset = self.dataset_index[idx] + dataset_sample = self.dataset_sample_index[idx] + + return self.datatrove_datasets[dataset][dataset_sample] + + def build_nanoset_index(self) -> np.ndarray: + """ + Build dataset index and dataset sample index + """ + # Build the dataset indexes for 1 epoch + dataset_index, dataset_sample_index = build_nanoset_index_helper( + n_samples=self.split_samples_per_epoch, + weights=self.dataset_weights, + dataset_sizes=self.split_dataset_lengths, + offsets=self.split_dataset_offsets, + ) + # Shuffle the indexes the same way + numpy_random_state = np.random.RandomState(self.random_seed) + numpy_random_state.shuffle(dataset_index) + numpy_random_state = np.random.RandomState(self.random_seed) + numpy_random_state.shuffle(dataset_sample_index) + # Concatenate num_epochs the shuffled indexes + dataset_index = np.concatenate([dataset_index for _ in range(self.num_epochs)]) + dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(self.num_epochs)]) + # Just keep the necessary samples + dataset_index = dataset_index[: self.split_num_samples] + dataset_sample_index = dataset_sample_index[: self.split_num_samples] + + return dataset_index, dataset_sample_index + + def print_nanoset_info(self): + + log_rank( + f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of samples: {len(self)}", + logger=logger, + level=logging.INFO, + rank=0, + ) + log_rank( + f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of tokens: {len(self) * self.sequence_length}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + # Print samples from each dataset + weight + dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders)) + for index, sample_count in enumerate(dataset_sample_count): + log_rank( + f"> Total number of {'validation' if self.is_valid else 'training'} samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +@jit(nopython=True, cache=True) +def build_nanoset_index_helper( + n_samples: int, weights: np.ndarray, dataset_sizes: List[int], offsets: List[int] +) -> Tuple[np.ndarray, np.ndarray]: + """ + Given multiple datasets and a weighting array, build samples indexes + such that it follows those weights. + For train and valid splits we split each dataset_folder in train (first part) and valid splits. We set the offsets to the train lengths + for generating the valid split + """ + # Create empty arrays for dataset indices and dataset sample indices + dataset_index = np.empty((n_samples,), dtype="uint") + dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples + + # Initialize buffer for number of samples used for each dataset + current_samples = np.zeros((len(weights),), dtype="long") + + # Iterate over all samples + for sample_idx in range(n_samples): + + # Convert sample index to float for comparison against weights + sample_idx_float = max(sample_idx, 1.0) + + # Find the dataset with the highest error + errors = weights * sample_idx_float - current_samples + max_error_index = np.argmax(errors) + + # Assign the dataset index and update the sample index + dataset_index[sample_idx] = max_error_index + dataset_sample_index[sample_idx] = ( + current_samples[max_error_index] % dataset_sizes[max_error_index] + ) + offsets[max_error_index] + + # Update the total samples for the selected dataset + current_samples[max_error_index] += 1 + + return dataset_index, dataset_sample_index From b7fa97d8115c5c5d5dcaf2e031c2779a9102bdde Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 12:25:17 +0000 Subject: [PATCH 35/44] Added Language token --- examples/config_multilingual_nanoset.yaml | 120 ++++++++++++++++++++++ src/nanotron/data/multilingual_nanoset.py | 7 +- 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 examples/config_multilingual_nanoset.yaml diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml new file mode 100644 index 00000000..00ae6570 --- /dev/null +++ b/examples/config_multilingual_nanoset.yaml @@ -0,0 +1,120 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints/ + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_folder: datasets/c4-es/tokenized + dataset_tokens: + - 15 + num_loading_workers: 1 + seed: 42 + name: General purpose training (Single dataset) + start_training_step: 1 +- data: + dataset: + dataset_folder: + - datasets/SlimPajama-6B/tokenized + - datasets/c4-es/tokenized + dataset_tokens: + - 16 + - 15 + num_loading_workers: 1 + seed: 42 + name: Second purpose training (> 1 dataset) + start_training_step: 15 +- data: + dataset: + dataset_folder: + datasets/SlimPajama-6B/tokenized: 0.8 + datasets/c4-es/tokenized: 0.2 + dataset_tokens: + - 16 + - 15 + num_loading_workers: 1 + seed: 42 + name: Third purpose training (Blended dataset) + start_training_step: 25 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: Nanoset + run: llama + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 11008 + is_llama_config: true + max_position_embeddings: 4096 + num_hidden_layers: 32 + num_attention_heads: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rope_interleaved: false + rope_theta: 500000.0 + rms_norm_eps: 1.0e-06 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 98 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 10 + micro_batch_size: 2 + sequence_length: 1024 + train_steps: 200 + val_check_interval: -1 diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 40e06b87..6526659d 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -32,6 +32,7 @@ def __init__( token_size: int, train_split_num_samples: int, valid_split_num_samples: int, + dataset_tokens: List[int], is_valid: bool = False, dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, @@ -48,6 +49,7 @@ def __init__( self.token_size = token_size self.train_split_num_samples = train_split_num_samples self.valid_split_num_samples = valid_split_num_samples + self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed self.datatrove_datasets = [] @@ -129,7 +131,10 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset = self.dataset_index[idx] dataset_sample = self.dataset_sample_index[idx] - return self.datatrove_datasets[dataset][dataset_sample] + tokens = self.datatrove_datasets[dataset][dataset_sample] + tokens[0] = self.dataset_tokens[dataset] # Prepend language token + + return tokens def build_nanoset_index(self) -> np.ndarray: """ From 5a42c743718d4e4a36486919b4cd8eeccf56e320 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 12:51:42 +0000 Subject: [PATCH 36/44] Forgot the trainer ups --- src/nanotron/trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index e8fbb8cc..61c0aabc 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -394,7 +394,10 @@ def find_stage_idx_to_resume(): def train( self, - dataloader_or_dls: Dict[ + train_dataloader_or_dls: Dict[ + str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] + ], + valid_dataloader_or_dls: Dict[ str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] ], **kwargs, @@ -425,7 +428,7 @@ def train( prof.step() self.iteration_start_time = time.time() - self._update_dataloader_based_on_training_stages(dataloader_or_dls) + self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) From 85be6a70e7d380a47153a6669cf441ed219ca18d Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 14:12:57 +0000 Subject: [PATCH 37/44] Fix minor errors. Everything works --- run_train.py | 6 ++++-- src/nanotron/config/config.py | 2 +- src/nanotron/data/multilingual_nanoset.py | 11 +++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/run_train.py b/run_train.py index 649784ca..9b77da77 100644 --- a/run_train.py +++ b/run_train.py @@ -195,6 +195,7 @@ def get_dataloader_from_data_stage( token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, + dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -229,7 +230,7 @@ def get_valid_dataloader_from_data_stage( input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) # Only support Validation with MultilingualNanosets - if isinstance(data.dataset, NanosetDatasetsArgs): + if isinstance(data.dataset, MultilingualNanosetDatasetsArgs): # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 @@ -245,6 +246,7 @@ def get_valid_dataloader_from_data_stage( token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, valid_split_num_samples=valid_split_num_samples, + dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -320,7 +322,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size log_rank( - f"[Training Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index b30dea8e..0083a4a6 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -132,7 +132,7 @@ def __post_init__(self): class DataArgs: """Arguments related to the data and data files processing""" - dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]] + dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, MultilingualNanosetDatasetsArgs] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 6526659d..cd8be195 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -1,5 +1,6 @@ import os import warnings +from math import ceil from typing import Dict, List, Tuple, Union import numpy as np @@ -80,11 +81,13 @@ def __init__( ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts - self.valid_dataset_lenghts = self.dataset_weights * valid_split_num_samples + self.valid_dataset_lenghts = [ + ceil(weight * valid_split_num_samples) for weight in self.dataset_weights + ] # Better not tu use numpy so we don't get overflow issues # Assert that we have sufficient samples to build the valid split for ds_index in range(len(self.dataset_lengths)): assert ( - self.valid_dataset_lenghts[ds_index] > self.dataset_lengths[ds_index] + self.dataset_lengths[ds_index] > self.valid_dataset_lenghts[ds_index] ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." self.train_dataset_lenghts = [ a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) @@ -132,7 +135,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset_sample = self.dataset_sample_index[idx] tokens = self.datatrove_datasets[dataset][dataset_sample] - tokens[0] = self.dataset_tokens[dataset] # Prepend language token + tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token return tokens @@ -144,7 +147,7 @@ def build_nanoset_index(self) -> np.ndarray: dataset_index, dataset_sample_index = build_nanoset_index_helper( n_samples=self.split_samples_per_epoch, weights=self.dataset_weights, - dataset_sizes=self.split_dataset_lengths, + dataset_sizes=self.split_dataset_lenghts, offsets=self.split_dataset_offsets, ) # Shuffle the indexes the same way From 526b929049cea216296c11eb06e7fea8fe1fd9c1 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 14:13:33 +0000 Subject: [PATCH 38/44] Updated config file with GPT2 tokenized datasets in RCP --- examples/config_multilingual_nanoset.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 00ae6570..3c4476a0 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -7,7 +7,7 @@ checkpoints: data_stages: - data: dataset: - dataset_folder: datasets/c4-es/tokenized + dataset_folder: /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized dataset_tokens: - 15 num_loading_workers: 1 @@ -17,8 +17,8 @@ data_stages: - data: dataset: dataset_folder: - - datasets/SlimPajama-6B/tokenized - - datasets/c4-es/tokenized + - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized + - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized dataset_tokens: - 16 - 15 @@ -29,8 +29,8 @@ data_stages: - data: dataset: dataset_folder: - datasets/SlimPajama-6B/tokenized: 0.8 - datasets/c4-es/tokenized: 0.2 + /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized: 0.8 + /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized: 0.2 dataset_tokens: - 16 - 15 @@ -65,7 +65,7 @@ model: initializer_range: 0.02 intermediate_size: 11008 is_llama_config: true - max_position_embeddings: 4096 + max_position_embeddings: 1024 num_hidden_layers: 32 num_attention_heads: 32 num_key_value_heads: 8 @@ -108,7 +108,7 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_name_or_path: gpt2 tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 From d247c5531cf20b88db7e7cfc58f92f0cec01fd80 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 10:13:57 +0000 Subject: [PATCH 39/44] Before lunch --- run_train.py | 13 +--- src/nanotron/config/config.py | 6 +- src/nanotron/data/multilingual_nanoset.py | 76 +++++++++-------------- 3 files changed, 37 insertions(+), 58 deletions(-) diff --git a/run_train.py b/run_train.py index 9b77da77..57e0ec25 100644 --- a/run_train.py +++ b/run_train.py @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -222,7 +221,6 @@ def get_dataloader_from_data_stage( def get_valid_dataloader_from_data_stage( trainer: DistributedTrainer, data: DataArgs, - valid_split_num_samples: int, # consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples ): @@ -245,7 +243,6 @@ def get_valid_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - valid_split_num_samples=valid_split_num_samples, dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, @@ -259,7 +256,6 @@ def get_valid_dataloader_from_data_stage( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, micro_batch_size=trainer.micro_batch_size, - consumed_train_samples=0, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, ) @@ -319,21 +315,18 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: # NOTE: we only create the dataloader for the first stage, # then we lazy initialize the dataloader for the other stages stage = cast(DatasetStageArgs, stage) - valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size log_rank( - f"[Validation Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", logger=logger, level=logging.INFO, rank=0, ) dataloader = ( - get_valid_dataloader_from_data_stage(trainer, stage.data, valid_split_num_samples=valid_split_num_samples) + get_valid_dataloader_from_data_stage(trainer, stage.data) if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage( - trainer, stage.data, valid_split_num_samples=valid_split_num_samples - ) + else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) ) dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 0083a4a6..eb44792c 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -109,7 +109,8 @@ def __post_init__(self): @dataclass class MultilingualNanosetDatasetsArgs: - dataset_folder: Union[str, dict, List[str]] + training_folder: Union[str, dict, List[str]] + validation_folder: Union[str, dict, List[str]] dataset_tokens: List[ int ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) @@ -125,7 +126,8 @@ def __post_init__(self): self.dataset_folder = list(tmp_dataset_folder.keys()) self.dataset_weights = list(tmp_dataset_folder.values()) - assert len(self.dataset_folder) == len(self.dataset_tokens) + assert len(self.training_folder) == len(self.validation_folder) + assert len(self.training_folder) == len(self.dataset_tokens) @dataclass diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index cd8be195..f634fd98 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -1,6 +1,5 @@ import os import warnings -from math import ceil from typing import Dict, List, Tuple, Union import numpy as np @@ -32,7 +31,6 @@ def __init__( sequence_length: int, token_size: int, train_split_num_samples: int, - valid_split_num_samples: int, dataset_tokens: List[int], is_valid: bool = False, dataset_weights: Union[List[float], None] = None, @@ -49,7 +47,6 @@ def __init__( self.sequence_length = sequence_length self.token_size = token_size self.train_split_num_samples = train_split_num_samples - self.valid_split_num_samples = valid_split_num_samples self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed @@ -80,36 +77,11 @@ def __init__( self.dataset_weights ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index - ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts - self.valid_dataset_lenghts = [ - ceil(weight * valid_split_num_samples) for weight in self.dataset_weights - ] # Better not tu use numpy so we don't get overflow issues - # Assert that we have sufficient samples to build the valid split - for ds_index in range(len(self.dataset_lengths)): - assert ( - self.dataset_lengths[ds_index] > self.valid_dataset_lenghts[ds_index] - ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." - self.train_dataset_lenghts = [ - a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) - ] # Subtract the valid samples from the training dataset - if is_valid: # Valid MultilingualNanoset - self.split_num_samples = valid_split_num_samples - self.split_samples_per_epoch = valid_split_num_samples - self.num_epochs = 1 - self.split_dataset_lenghts = self.valid_dataset_lenghts - self.split_dataset_offsets = self.train_dataset_lenghts + self.dataset_index, self.dataset_sample_index = self.build_valid_nanoset_index(self.dataset_lengths) else: # Train MultilingualNanoset - self.split_num_samples = train_split_num_samples - self.split_samples_per_epoch = sum(self.train_dataset_lenghts) - self.num_epochs = int(self.split_num_samples / self.split_samples_per_epoch) + 1 - self.split_dataset_lenghts = self.train_dataset_lenghts - self.split_dataset_offsets = [ - 0 for _ in range(len(self.dataset_lengths)) - ] # For training there is NO offset - - self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() + self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() self.print_nanoset_info() @@ -139,16 +111,16 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: return tokens - def build_nanoset_index(self) -> np.ndarray: + def build_train_nanoset_index(self) -> np.ndarray: """ - Build dataset index and dataset sample index + Build train dataset index and dataset sample index """ + # Compute samples per epoch and number of epochs + samples_per_epoch = sum(self.dataset_lengths) + num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1 # Build the dataset indexes for 1 epoch - dataset_index, dataset_sample_index = build_nanoset_index_helper( - n_samples=self.split_samples_per_epoch, - weights=self.dataset_weights, - dataset_sizes=self.split_dataset_lenghts, - offsets=self.split_dataset_offsets, + dataset_index, dataset_sample_index = build_train_nanoset_index_helper( + n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths ) # Shuffle the indexes the same way numpy_random_state = np.random.RandomState(self.random_seed) @@ -156,14 +128,28 @@ def build_nanoset_index(self) -> np.ndarray: numpy_random_state = np.random.RandomState(self.random_seed) numpy_random_state.shuffle(dataset_sample_index) # Concatenate num_epochs the shuffled indexes - dataset_index = np.concatenate([dataset_index for _ in range(self.num_epochs)]) - dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(self.num_epochs)]) + dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)]) + dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)]) # Just keep the necessary samples - dataset_index = dataset_index[: self.split_num_samples] - dataset_sample_index = dataset_sample_index[: self.split_num_samples] + dataset_index = dataset_index[: self.train_split_num_samples] + dataset_sample_index = dataset_sample_index[: self.train_split_num_samples] return dataset_index, dataset_sample_index + @jit(nopython=True, cache=True) + def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: + """ + Build valid dataset index and dataset sample index + """ + dataset_index = [] + dataset_sample_index = [] + + for i, length in enumerate(dataset_lengths): + dataset_index.extend([i] * length) + dataset_sample_index.extend(range(length)) + + return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") + def print_nanoset_info(self): log_rank( @@ -191,8 +177,8 @@ def print_nanoset_info(self): @jit(nopython=True, cache=True) -def build_nanoset_index_helper( - n_samples: int, weights: np.ndarray, dataset_sizes: List[int], offsets: List[int] +def build_train_nanoset_index_helper( + n_samples: int, weights: np.ndarray, dataset_sizes: List[int] ) -> Tuple[np.ndarray, np.ndarray]: """ Given multiple datasets and a weighting array, build samples indexes @@ -219,9 +205,7 @@ def build_nanoset_index_helper( # Assign the dataset index and update the sample index dataset_index[sample_idx] = max_error_index - dataset_sample_index[sample_idx] = ( - current_samples[max_error_index] % dataset_sizes[max_error_index] - ) + offsets[max_error_index] + dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index] # Update the total samples for the selected dataset current_samples[max_error_index] += 1 From f7d72dffb1b5fe48360e43214ee129c897b2f8d4 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 14:10:03 +0000 Subject: [PATCH 40/44] After lunch --- examples/config_multilingual_nanoset.yaml | 42 +++++++++++++++-------- run_train.py | 6 ++-- src/nanotron/config/config.py | 21 ++++++------ src/nanotron/data/multilingual_nanoset.py | 33 +++++++++--------- tools/preprocess_data.py | 5 ++- 5 files changed, 61 insertions(+), 46 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 3c4476a0..238f8269 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -7,7 +7,8 @@ checkpoints: data_stages: - data: dataset: - dataset_folder: /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized + training_folder: datasets/c4-es/train + validation_folder: datasets/c4-es/validation dataset_tokens: - 15 num_loading_workers: 1 @@ -16,24 +17,37 @@ data_stages: start_training_step: 1 - data: dataset: - dataset_folder: - - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized - - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation dataset_tokens: - - 16 - 15 + - 16 + - 17 num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) start_training_step: 15 - data: dataset: - dataset_folder: - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized: 0.8 - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized: 0.2 + training_folder: + datasets/c4-es/train: 0.6 + datasets/c4-en/train: 0.3 + datasets/c4-fr/train: 0.1 + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation dataset_tokens: - - 16 - 15 + - 16 + - 17 + num_loading_workers: 1 seed: 42 name: Third purpose training (Blended dataset) @@ -61,12 +75,12 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 4096 + hidden_size: 512 initializer_range: 0.02 - intermediate_size: 11008 + intermediate_size: 512 is_llama_config: true max_position_embeddings: 1024 - num_hidden_layers: 32 + num_hidden_layers: 2 num_attention_heads: 32 num_key_value_heads: 8 pad_token_id: null @@ -108,13 +122,13 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: gpt2 + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 10 - micro_batch_size: 2 + micro_batch_size: 4 sequence_length: 1024 train_steps: 200 val_check_interval: -1 diff --git a/run_train.py b/run_train.py index 57e0ec25..39cda23b 100644 --- a/run_train.py +++ b/run_train.py @@ -189,7 +189,7 @@ def get_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): train_dataset = MultilingualNanoset( - dataset_folders=data.dataset.dataset_folder, + dataset_folders=data.dataset.training_folder, dataset_weights=data.dataset.dataset_weights, sequence_length=trainer.sequence_length, token_size=token_size, @@ -238,11 +238,9 @@ def get_valid_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): valid_dataset = MultilingualNanoset( - dataset_folders=data.dataset.dataset_folder, - dataset_weights=data.dataset.dataset_weights, + dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index eb44792c..f1881faa 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -110,21 +110,20 @@ def __post_init__(self): @dataclass class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] - validation_folder: Union[str, dict, List[str]] - dataset_tokens: List[ - int - ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) + validation_folder: Union[str, List[str]] + dataset_tokens: List[int] # Set token for each language previously defined def __post_init__(self): - if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file - self.dataset_folder = [self.dataset_folder] + if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder + self.training_folder = [self.training_folder] + self.validation_folder = [self.validation_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights - tmp_dataset_folder = self.dataset_folder.copy() - self.dataset_folder = list(tmp_dataset_folder.keys()) - self.dataset_weights = list(tmp_dataset_folder.values()) + elif isinstance(self.training_folder, dict): # Case 3: dict with > 1 training_folder and weights + tmp_training_folder = self.training_folder.copy() + self.training_folder = list(tmp_training_folder.keys()) + self.dataset_weights = list(tmp_training_folder.values()) assert len(self.training_folder) == len(self.validation_folder) assert len(self.training_folder) == len(self.dataset_tokens) diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index f634fd98..7af57448 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -30,8 +30,8 @@ def __init__( dataset_folders: List[str], sequence_length: int, token_size: int, - train_split_num_samples: int, dataset_tokens: List[int], + train_split_num_samples: int = None, is_valid: bool = False, dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, @@ -78,7 +78,7 @@ def __init__( ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index if is_valid: # Valid MultilingualNanoset - self.dataset_index, self.dataset_sample_index = self.build_valid_nanoset_index(self.dataset_lengths) + self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(self.dataset_lengths) else: # Train MultilingualNanoset self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() @@ -136,20 +136,6 @@ def build_train_nanoset_index(self) -> np.ndarray: return dataset_index, dataset_sample_index - @jit(nopython=True, cache=True) - def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: - """ - Build valid dataset index and dataset sample index - """ - dataset_index = [] - dataset_sample_index = [] - - for i, length in enumerate(dataset_lengths): - dataset_index.extend([i] * length) - dataset_sample_index.extend(range(length)) - - return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") - def print_nanoset_info(self): log_rank( @@ -211,3 +197,18 @@ def build_train_nanoset_index_helper( current_samples[max_error_index] += 1 return dataset_index, dataset_sample_index + + +@jit(nopython=True, cache=True) +def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: + """ + Build valid dataset index and dataset sample index + """ + dataset_index = [] + dataset_sample_index = [] + + for i, length in enumerate(dataset_lengths): + dataset_index.extend([i] * length) + dataset_sample_index.extend(range(length)) + + return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index dc42a3c0..23016eaf 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -98,7 +98,9 @@ def main(args): dataset_options={"split": args.split}, ) elif args.readers == "parquet": - datatrove_reader = ParquetReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) + datatrove_reader = ParquetReader( + data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern + ) else: datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) @@ -107,6 +109,7 @@ def main(args): datatrove_reader, DocumentTokenizer( output_folder=args.output_folder, + shuffle=False, tokenizer_name_or_path=args.tokenizer_name_or_path, eos_token=args.eos_token, shuffle=False, From f1afcfa9effeac92914d3b918cfdd3329d667fac Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Thu, 18 Jul 2024 10:48:00 +0000 Subject: [PATCH 41/44] Ready --- examples/config_multilingual_nanoset.yaml | 20 ++++++++++---------- src/nanotron/config/config.py | 11 ++++++++--- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 238f8269..599bff6c 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -9,8 +9,8 @@ data_stages: dataset: training_folder: datasets/c4-es/train validation_folder: datasets/c4-es/validation - dataset_tokens: - - 15 + lang_to_ids: + es: 128002 num_loading_workers: 1 seed: 42 name: General purpose training (Single dataset) @@ -25,10 +25,10 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - dataset_tokens: - - 15 - - 16 - - 17 + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) @@ -43,10 +43,10 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - dataset_tokens: - - 15 - - 16 - - 17 + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index f1881faa..d90f13fb 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - dataset_tokens: List[int] # Set token for each language previously defined + lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -125,8 +125,13 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - assert len(self.training_folder) == len(self.validation_folder) - assert len(self.training_folder) == len(self.dataset_tokens) + self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.validation_folder + ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" + assert len(self.training_folder) == len( + self.dataset_tokens + ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" @dataclass From c65b349ba5d40f37d0e037f0a3762ac7cb6e7f28 Mon Sep 17 00:00:00 2001 From: Antoni-Joan Solergibert <74564958+TJ-Solergibert@users.noreply.github.com> Date: Thu, 15 Aug 2024 09:42:01 +0200 Subject: [PATCH 42/44] Add multilingual validation (#3) Add multilingual validation step. --- examples/config_multilingual_nanoset.yaml | 77 +++--- run_train.py | 19 +- src/nanotron/config/config.py | 17 +- src/nanotron/data/collator.py | 73 +++++ src/nanotron/data/dataloader_builder.py | 14 +- src/nanotron/data/multilingual_nanoset.py | 4 +- src/nanotron/distributed.py | 4 - src/nanotron/models/llama.py | 37 ++- .../parallel/pipeline_parallel/engine.py | 25 +- .../parallel/pipeline_parallel/state.py | 4 + src/nanotron/serialize/metadata.py | 2 + src/nanotron/trainer.py | 249 +++++++++++++++++- 12 files changed, 438 insertions(+), 87 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 599bff6c..cc66cd70 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 1000 + checkpoint_interval: 1000000 checkpoints_path: checkpoints/ checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null @@ -7,56 +7,57 @@ checkpoints: data_stages: - data: dataset: - training_folder: datasets/c4-es/train - validation_folder: datasets/c4-es/validation - lang_to_ids: - es: 128002 + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation + languages: + - es + - en + - fr num_loading_workers: 1 seed: 42 - name: General purpose training (Single dataset) + name: General purpose training (Blended dataset) start_training_step: 1 - data: dataset: training_folder: - datasets/c4-es/train - - datasets/c4-en/train - - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - - datasets/c4-en/validation - - datasets/c4-fr/validation - lang_to_ids: - es: 128002 - en: 128003 - fr: 128004 + languages: + - es num_loading_workers: 1 seed: 42 - name: Second purpose training (> 1 dataset) - start_training_step: 15 + name: Second purpose training (Single dataset) + start_training_step: 1000 - data: dataset: training_folder: - datasets/c4-es/train: 0.6 - datasets/c4-en/train: 0.3 - datasets/c4-fr/train: 0.1 + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - lang_to_ids: - es: 128002 - en: 128003 - fr: 128004 - + languages: + - es + - en + - fr num_loading_workers: 1 seed: 42 - name: Third purpose training (Blended dataset) - start_training_step: 25 + name: Third purpose training (>1 dataset) + start_training_step: 2000 general: benchmark_csv_path: null consumed_train_samples: null ignore_sanity_checks: true - project: Nanoset + project: MultilingualV2 run: llama seed: 42 step: null @@ -75,12 +76,12 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 512 + hidden_size: 4096 initializer_range: 0.02 - intermediate_size: 512 + intermediate_size: 14336 is_llama_config: true - max_position_embeddings: 1024 - num_hidden_layers: 2 + max_position_embeddings: 4096 + num_hidden_layers: 32 num_attention_heads: 32 num_key_value_heads: 8 pad_token_id: null @@ -89,7 +90,7 @@ model: rope_theta: 500000.0 rms_norm_eps: 1.0e-06 rope_scaling: null - tie_word_embeddings: true + tie_word_embeddings: false use_cache: true vocab_size: 128256 optimizer: @@ -112,11 +113,11 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 1 + dp: 2 expert_parallel_size: 1 pp: 1 pp_engine: 1f1b - tp: 1 + tp: 4 tp_linear_async_communication: false tp_mode: REDUCE_SCATTER profiler: null @@ -128,7 +129,7 @@ tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 10 - micro_batch_size: 4 - sequence_length: 1024 - train_steps: 200 - val_check_interval: -1 + micro_batch_size: 3 + sequence_length: 4096 + train_steps: 500 + val_check_interval: 100 diff --git a/run_train.py b/run_train.py index 39cda23b..809d8d41 100644 --- a/run_train.py +++ b/run_train.py @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -209,6 +208,7 @@ def get_dataloader_from_data_stage( consumed_train_samples=consumed_train_samples, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + is_multilingual=True, ) return train_dataloader @@ -241,7 +241,6 @@ def get_valid_dataloader_from_data_stage( dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -256,6 +255,8 @@ def get_valid_dataloader_from_data_stage( micro_batch_size=trainer.micro_batch_size, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + shuffle=True, + is_multilingual=True, ) return valid_dataloader @@ -315,7 +316,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: stage = cast(DatasetStageArgs, stage) log_rank( - f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples for the validation set", logger=logger, level=logging.INFO, rank=0, @@ -324,8 +325,18 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloader = ( get_valid_dataloader_from_data_stage(trainer, stage.data) if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) + else lambda stage=stage: get_valid_dataloader_from_data_stage(trainer, stage.data) ) + # TODO(tj.solergibert) As we are creating again the valid dataloader in every validation stage, we print multiple times + # the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda + # funcs and directly create all dataloaders. + # + # This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead + # of creating multiple DataLoaders 2. Consume less memory as the lambda func is lighter that the DataLoader object with + # the Dataset, collator, etc. + # BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling + # from the Nanoset. Also they later transform the DataLoader into a Iterator object so it's impossible to retrieve + # the DataLoader object again to delete it (More comments in trainer.py) dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index d90f13fb..d2b39441 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order + languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -125,13 +125,13 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.languages + ), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})" + assert len(self.training_folder) == len( self.validation_folder ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" - assert len(self.training_folder) == len( - self.dataset_tokens - ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" @dataclass @@ -406,6 +406,13 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0: + raise ValueError( + f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}" + ) + # # if lighteval, we need tokenizer to be defined # if self.checkpoints.lighteval is not None: # assert self.tokenizer.tokenizer_name_or_path is not None diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py index 199527e1..fd217b1a 100644 --- a/src/nanotron/data/collator.py +++ b/src/nanotron/data/collator.py @@ -78,3 +78,76 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni ) return result + + +@dataclasses.dataclass +class MultilingualNanosetDataCollatorForCLM: + """ + Data collator used for causal language modeling with Nanosets dataset. + + - input_pp_rank: Discards last input id token + - output_pp_rank: Discards first label id token + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + + sequence_length: int + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + + def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "lang_code": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + } + + # TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor? + input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) + lang_code = torch.vstack([examples[i]["lang_code"] for i in range(len(examples))]) # (b, 1) + batch_size, expanded_input_length = input_ids.shape + + result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["lang_code"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + + assert ( + expanded_input_length == self.sequence_length + 1 + ), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}" + + # Process inputs: last token is the label + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = input_ids[:, :-1] + result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + result["lang_code"] = lang_code + + # Process labels: shift them to the left + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = input_ids[:, 1:] + result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + + if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + + return result diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 9d3285f6..f9480029 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,6 +1,6 @@ import nanotron.distributed as dist from nanotron import logging -from nanotron.data.collator import NanosetDataCollatorForCLM +from nanotron.data.collator import MultilingualNanosetDataCollatorForCLM, NanosetDataCollatorForCLM from nanotron.dataloader import ( EmptyInfiniteDataset, get_dataloader_worker_init, @@ -20,9 +20,11 @@ def build_nanoset_dataloader( output_pp_rank: int, micro_batch_size: int, dataloader_num_workers: int, + is_multilingual: bool = False, consumed_train_samples: int = 0, dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, + shuffle: bool = False, ) -> DataLoader: # Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job @@ -39,6 +41,14 @@ def build_nanoset_dataloader( parallel_context=parallel_context, ) + if is_multilingual: + data_collator = MultilingualNanosetDataCollatorForCLM( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) + # Compute size and rank of dataloader workers dp_ranks_size = parallel_context.dp_pg.size() dp_rank = parallel_context.dp_pg.rank() @@ -49,7 +59,7 @@ def build_nanoset_dataloader( dl_rank=dp_rank, drop_last=dataloader_drop_last, consumed_train_samples=consumed_train_samples, - shuffle=False, + shuffle=shuffle, ) return DataLoader( diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 7af57448..8eec5549 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -30,7 +30,6 @@ def __init__( dataset_folders: List[str], sequence_length: int, token_size: int, - dataset_tokens: List[int], train_split_num_samples: int = None, is_valid: bool = False, dataset_weights: Union[List[float], None] = None, @@ -47,7 +46,6 @@ def __init__( self.sequence_length = sequence_length self.token_size = token_size self.train_split_num_samples = train_split_num_samples - self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed self.datatrove_datasets = [] @@ -107,7 +105,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset_sample = self.dataset_sample_index[idx] tokens = self.datatrove_datasets[dataset][dataset_sample] - tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token + tokens["lang_code"] = torch.tensor(dataset, dtype=torch.long) return tokens diff --git a/src/nanotron/distributed.py b/src/nanotron/distributed.py index 0156b1bb..0bc54f3e 100644 --- a/src/nanotron/distributed.py +++ b/src/nanotron/distributed.py @@ -52,10 +52,6 @@ def all_gather_into_tensor( # pylint: disable=function-redefined if group is None: group = dist.torch_dist.distributed_c10d._get_default_group() - assert ( - group.size() > 1 - ), "You should probably not call `all_gather_into_tensor` with a single rank, as it copies data over" - if torch_version_above_1_13: return dist.all_gather_into_tensor( output_tensor=output_tensor, input_tensor=input_tensor, group=group, async_op=async_op diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 28a2e30f..ecb26fd2 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -757,14 +757,20 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): - return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask, lang_code=lang_code)[0] def forward_with_hidden_states( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): + # NOTE(tj.solergibert) I bring `lang_code` till the forward of LlamaModel. Remember that + # to use it in the different pipeline blocks you need to also set the module_input_keys & module_output_keys + # of the necessary `PipelineBlock`'s defined in the LlamaModel init! + # all tensors are optional as most ranks don't need anything from the dataloader. output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) @@ -825,7 +831,9 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch @torch.jit.script def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor - return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( + dim=1 + ) # NOTE(tj.solergibert) Added dim=1 to return a tensor with shape [Batch size, 1] instead of [1] class Loss(nn.Module): @@ -842,14 +850,18 @@ def forward( # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 - loss = sharded_cross_entropy( + sample_loss = sharded_cross_entropy( sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float ).transpose(0, 1) # TODO @thomasw21: It's unclear what kind of normalization we want to do. - loss = masked_mean(loss, label_mask, dtype=torch.float) - # I think indexing causes a sync we don't actually want - # loss = loss[label_mask].sum() - return {"loss": loss} + sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float) + # NOTE(tj.solergibert) masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. + # We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort. + # WARN(tj.solergibert) Don't panic, the batch loss used to update the parameters is computed in `LlamaForTraining` + + # TODO @thomasw21: I think indexing causes a sync we don't actually want + # TODO @thomasw21: loss = loss[label_mask].sum() + return {"sample_loss": sample_loss} class LlamaForTraining(NanotronModel): @@ -871,7 +883,7 @@ def __init__( "label_ids", "label_mask", }, - module_output_keys={"loss"}, + module_output_keys={"sample_loss"}, ) self.parallel_context = parallel_context self.config = config @@ -881,19 +893,22 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], input_mask: Union[torch.Tensor, TensorPointer], + lang_code: Union[torch.Tensor, TensorPointer], label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: sharded_logits = self.model( input_ids=input_ids, input_mask=input_mask, + lang_code=lang_code, ) - loss = self.loss( + outputs = self.loss( sharded_logits=sharded_logits, label_ids=label_ids, label_mask=label_mask, - )["loss"] - return {"loss": loss} + ) + outputs["loss"] = torch.mean(outputs["sample_loss"]) + return outputs @torch.no_grad() def init_model_randomly(self, config: Config): diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..9b548e35 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,6 +2,9 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup @@ -9,11 +12,9 @@ from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model -from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState +from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState, PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) @@ -29,6 +30,7 @@ def forward( state: PipelineTrainBatchState, micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]], model: torch_nn.Module, + is_validation: bool = False, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Increment the number of backwards state.nb_forwards += 1 @@ -52,7 +54,7 @@ def forward( output["loss"] = output["loss"] / self.nb_microbatches # Add output as activations that require backward pass - if not isinstance(output["loss"], TensorPointer): + if not isinstance(output["loss"], TensorPointer) and not is_validation: assert output["loss"].requires_grad state.register_activation_requiring_backward(output["loss"]) return output @@ -134,16 +136,19 @@ def validate_batch_iter( nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: # Assign a new state for the current batch - state = PipelineTrainBatchState() # TODO: do i need state? + state = PipelineEvalBatchState() self.nb_microbatches = nb_microbatches outputs = [] + lang_codes = [] with attach_pipeline_state_to_model(model=model, pipeline_state=state): # All forward for micro_batch in batch: context = self._get_fwd_context(model=model) - output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model) + output = self.forward( + context=context, state=state, micro_batch=micro_batch, model=model, is_validation=True + ) # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage" for _ in range(len(state.microbatches_activations_to_send)): send_activation = state.microbatches_activations_to_send.popleft() @@ -157,9 +162,13 @@ def validate_batch_iter( # Store the loss for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} - outputs.append(output) - return outputs + outputs.extend( + list(output["sample_loss"]) + ) # NOTE(tj.solergibert) Yes, it might look useless to do list + extend but it's necessary to split the output["sample_loss"] tensor into multiple tensors + lang_codes.extend(micro_batch["lang_code"].flatten().tolist()) + + return outputs, lang_codes class AllForwardAllBackwardPipelineEngine(PipelineEngine): diff --git a/src/nanotron/parallel/pipeline_parallel/state.py b/src/nanotron/parallel/pipeline_parallel/state.py index e07cc89a..f22d6571 100644 --- a/src/nanotron/parallel/pipeline_parallel/state.py +++ b/src/nanotron/parallel/pipeline_parallel/state.py @@ -4,6 +4,7 @@ from typing import List import torch + from nanotron import distributed as dist from nanotron import logging from nanotron.logging import log_rank @@ -203,6 +204,9 @@ class PipelineEvalBatchState(PipelineBatchState): microbatches_activations_to_recv = collections.deque() activations_buffer = collections.deque() + # Reinitialise counter + nb_forwards = 0 + def register_activation_requiring_backward(self, activation: torch.Tensor): pass diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 0d8708f9..4bd36c19 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -46,6 +46,8 @@ class TrainingMetadata: last_stage_idx: Optional[int] = None data_stages: Optional[List[DataStageMetadata]] = None + last_validation_stage_idx: Optional[int] = None + def __post_init__(self): # NOTE: this is a sanity check after loading a trained checkpoint total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 61c0aabc..25c4d315 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -80,6 +80,7 @@ from nanotron.sanity_checks import ( after_optim_step_sanity_checks, after_tbi_sanity_checks, + assert_tensor_synced_across_pg, before_optim_step_sanity_checks, before_tbi_sanity_checks, ) @@ -232,7 +233,11 @@ def __init__( for stage in self.config.data_stages ] self.metadata: TrainingMetadata = TrainingMetadata( - consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages + consumed_train_samples=0, + last_train_step=0, + last_stage_idx=0, + data_stages=data_stages, + last_validation_stage_idx=0, ) # Setup tensorboard write and log writers on output rank @@ -254,6 +259,8 @@ def __init__( self.limit_val_batches = self.config.tokens.limit_val_batches # NOTE: the dataloader currently in use for the current training stage self.current_dataloader: Optional[DataLoader] = None + # NOTE: the dataloader currently in use for the current validation stage + self.current_validation_dataloader: Optional[DataLoader] = None self.post_init() @@ -301,6 +308,106 @@ def _print_training_plan(self): ) log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) + def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataLoader], DataLoader]): + # NOTE(tj.solergibert) Similar to _update_dataloader_based_on_training_stages BUT: + # 1. We call this function EVERY TIME we run the validation loop + # 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset + # in the first iteration and subsequent validations will fail + # `dataloaders` are either torch DataLoaders (the very first stage) OR functions that we call later that provide torch DataLoaders (subsequent stages) + # From this torch DataLoaders objects we then call `sanity_check_dataloader` that will return a iterator. + # In short, `sanity_check_dataloader` just places the input tensors in the GPU when necessary (TensorPointers stay in the CPU) + # + # TBH, the for loop below it's just for deleting the DataLoaders of previous stages, which is not so problematic. The important part is returning the + # DataLoader iterator every time we call this function from the current training stage, which is tracked during training + # + # Also, keep in mind that if val_check_interval = 5 & data.start_training_step = 10 we will already perform the evaluation with the SECOND data stage + # after just training for the current iteration, so it might not be a good idea to set evals during the stage in which we change of data stage + # + # NOTE(tj.solergibert) Further investigation should be done, but there is a extrange behaiviour when deleting the DataLoaders////lambda functs. As they + # are converted into Iterators with `sanity_check_dataloader` we can't access anymore the DataLoader object to del the dataset (After first stage, + # in this function we locally create the DataLoder from the lambda func --> Return Iterator) + # + # Also when the gc deletes the first stage dataloader, all the `DatatroveFileDataset._f` are already None AND the `del` thing are deleting a copy of the + # object, not the object itself + # + # FINAL NOTE(tj.solergibert) I will open a Issue in nanotron to check with them if they are aware of this useless deletitions + # + # TODO(tj.solergibert) Check the tuple case below + from collections.abc import Generator + + if not hasattr(self.config, "data_stages") or self.config.data_stages is None: + + if isinstance(dataloaders, tuple): # TODO(tj.solergibert) Check this tuple case + dataloader = dataloaders[0] + else: + dataloader = dataloaders + + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + + return + elif isinstance(dataloaders, Generator): + # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader + # remove this in the next PR + self.current_validation_dataloader = dataloaders + return + + assert len(dataloaders) > 0, "No dataloaders provided" + assert len(dataloaders) == len( + self.config.data_stages + ), "Number of dataloaders should match the number of dataset stages" + + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): + import gc + + log_rank( + f"[Validation Stage: {stage_name}] Clearing the previous validation stage's ({prev_stage_name}) dataloader and dataset from memory", + logger=logger, + level=logging.INFO, + ) + + # NOTE: Clear dataloader from memory + del dataloader.dataset + del dataloader.sampler + del dataloader.batch_sampler + + gc.collect() + + for stage_idx, stage in enumerate(self.config.data_stages): + if stage_idx < self.metadata.last_stage_idx: + continue + # NOTE(tj.solergibert) From this point stage_idx = self.metadata.last_stage_idx. We update self.metadata.last_stage_idx (which keeps track of the training stage) + # in each and every training step. + + if ( + stage_idx is not self.metadata.last_validation_stage_idx + ): # When stage_idx (= self.metadata.last_stage_idx, the training stage index) is different than the last validation stage index + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + # Delete previous stage DataLoader + prev_stage_name = self.config.data_stages[stage_idx - 1].name + prev_dataloader = dataloaders[prev_stage_name] + + if isinstance(prev_dataloader, DataLoader): + # NOTE: we don't need to clear dummy data generator from memory + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) + + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + + # NOTE(tj.solergibert) Create AGAIN the DataLoader + dataloader = dataloaders[stage.name] + # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it + dataloader = dataloader() if callable(dataloader) else dataloader + break + + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) # NOTE(tj.solergibert) Create a Iterator from the DataLoader + def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): from collections.abc import Generator @@ -325,11 +432,11 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da self.config.data_stages ), "Number of dataloaders should match the number of dataset stages" - def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): import gc log_rank( - f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", + f"[Training Stage: {stage_name}] Clearing the previous training stage's ({prev_stage_name}) dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -366,7 +473,9 @@ def find_stage_idx_to_resume(): if isinstance(prev_dataloader, DataLoader): # NOTE: we don't need to clear dummy data generator from memory - clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) self.metadata.last_stage_idx = stage_idx @@ -432,6 +541,19 @@ def train( # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) + self.training_step_time = time.time() + + # Validation stage + if self.iteration_step % self.config.tokens.val_check_interval == 0: + self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) + val_global_loss, val_lang_losses = self.validation_step( + dataloader=self.current_validation_dataloader + ) + self.validation_step_time = time.time() + else: + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + val_global_loss, val_lang_losses = None, None # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -442,7 +564,7 @@ def train( ].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs(outputs=outputs, loss_avg=loss_avg) + self.train_step_logs(loss_avg=loss_avg, global_loss=val_global_loss, lang_losses=val_lang_losses) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -550,22 +672,71 @@ def training_step( return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: - outputs = self.pipeline_engine.validate_batch_iter( + outputs, lang_codes = self.pipeline_engine.validate_batch_iter( model=self.model, - batch=(next(dataloader) for _ in range(self.limit_val_batches)), - nb_microbatches=self.limit_val_batches, + batch=(next(dataloader) for _ in range(self.current_validation_dataloader_lenght)), + nb_microbatches=self.current_validation_dataloader_lenght, ) - return outputs + + lang_losses = { + lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.languages + } + lang_losses_list = list(lang_losses.keys()) + + # Compute losses + if isinstance(outputs[0], torch.Tensor): + # Multilingual losses + for loss, lang_code in zip(outputs, lang_codes): + lang_losses[lang_losses_list[lang_code]].append(loss) + # Global loss + global_loss_avg = torch.mean(torch.stack(outputs)) + # Sync multilingual losses across DP + for lang in lang_losses.keys(): + if not lang_losses[ + lang + ]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation + lang_losses[lang] = torch.tensor(-1, dtype=torch.float32) + else: # If we have at least 1 loss from a given language --> compute local language loss mean + lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) + + # NOTE(tj.solergibert) We create a (DP SIZE, LANGS) tensor to aggregate ALL local losses across DP groups. + # Then we compute the mean of each lang in each and every rank and finally copy back the result to the + # `lang_losses` dict for logging + lang_losses_tensor_out = torch.zeros( + (self.parallel_context.dp_pg.size(), len(lang_losses.keys())), dtype=torch.float, device="cuda" + ) # (DP SIZE, LANGS) + lang_losses_tensor_local = torch.stack(list(lang_losses.values())).unsqueeze(0) # (1, LANGS) + dist.all_gather_into_tensor(lang_losses_tensor_out, lang_losses_tensor_local, self.parallel_context.dp_pg) + mask = lang_losses_tensor_out != -1 + lang_losses_tensor_local = (lang_losses_tensor_out * mask).sum(dim=0) / mask.sum(dim=0) # (1, LANGS) + for idx, lang in enumerate(lang_losses.keys()): + lang_losses[lang] = lang_losses_tensor_local[idx] + + # Sync global losses across DP + dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) + + # TODO(tj.solergibert) Delete this testing assertions + for lang in lang_losses.keys(): + assert_tensor_synced_across_pg(tensor=lang_losses[lang], pg=self.parallel_context.dp_pg) + assert_tensor_synced_across_pg(tensor=global_loss_avg, pg=self.parallel_context.dp_pg) + + else: + global_loss_avg = None + lang_losses = None + + return global_loss_avg, lang_losses def train_step_logs( self, - outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], loss_avg: Optional[torch.Tensor], + global_loss: torch.Tensor, + lang_losses: torch.Tensor, ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() torch.cuda.synchronize() - elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 + # Training metrics + elapsed_time_per_iteration_ms = (self.training_step_time - self.iteration_start_time) * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) ) # tokens_per_sec is calculated using sequence_length @@ -575,13 +746,27 @@ def train_step_logs( global_batch_size=self.global_batch_size, ) + # Validation metrics + if global_loss is not None: + validation_total_samples = self.current_validation_dataloader_lenght * self.micro_batch_size + validation_elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 + validation_tokens_per_sec = ( + validation_total_samples * self.sequence_length / (validation_elapsed_time_per_iteration_ms / 1000) + ) + + validation_model_tflops, validation_hardware_tflops = self.unwrapped_model.get_flops_per_sec( + iteration_time_in_sec=validation_elapsed_time_per_iteration_ms / 1000, + sequence_length=self.sequence_length, + global_batch_size=validation_total_samples, + ) + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" + # Training metrics lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "consumed_tokens", self.metadata.consumed_train_samples * self.config.tokens.sequence_length, @@ -602,6 +787,46 @@ def train_step_logs( if self.config.optimizer.clip_grad is not None: log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f")) + # Validation metrics + if global_loss is not None: + log_entries.extend( + [ + LogItem( + "validation_consumed_tokens", + validation_total_samples * self.sequence_length, + "human_format", + ), # , "12d"), + LogItem( + "validation_elapsed_time_per_iteration_ms", + validation_elapsed_time_per_iteration_ms, + "human_format", + ), # , ".1f"), + LogItem("validation_tokens_per_sec", validation_tokens_per_sec, "human_format"), # , "1.6E"), + LogItem( + "validation_tokens_per_sec_per_gpu", + validation_tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", + ), # , "1.6E"), + LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), + LogItem( + "validation_model_tflops_per_gpu", validation_model_tflops / 3, "human_format" + ), # , ".2f"), # NOTE(tj.solergibert) Check llama.py --> def get_flops() --> model_flops for explanation of the / 3 factor + LogItem( + "validation_hardware_tflops_per_gpu", validation_hardware_tflops / 3, "human_format" + ), # , ".2f"), # NOTE(tj.solergibert) Check llama.py --> def get_flops() --> model_flops for explanation of the / 3 factor + ] + ) + + # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. + # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 + # GitHub complains: https://github.com/wandb/wandb/issues/3035 + log_entries.extend( + [ + LogItem(f"{lang}_validation_loss", loss.item(), "human_format") + for lang, loss in lang_losses.items() + ] + ) + # Log not too often the memory if self.iteration_step < 5 or (self.iteration_step - 1) % self.config.checkpoints.checkpoint_interval == 0: total, used, free = shutil.disk_usage("/") From 94f3da5a51781a3984ea5fc0e81923af3ce56020 Mon Sep 17 00:00:00 2001 From: Negar Foroutan Eghlidi Date: Thu, 22 Aug 2024 15:17:58 +0200 Subject: [PATCH 43/44] Fix a device issue. --- src/nanotron/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 25c4d315..47d83fbe 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -695,7 +695,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten if not lang_losses[ lang ]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation - lang_losses[lang] = torch.tensor(-1, dtype=torch.float32) + lang_losses[lang] = torch.tensor(-1, dtype=torch.float32, device="cuda") else: # If we have at least 1 loss from a given language --> compute local language loss mean lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) From 4d6c1d318ab7e16024b10b868d056cdeaf7a3c75 Mon Sep 17 00:00:00 2001 From: Negar Foroutan Eghlidi Date: Sun, 8 Sep 2024 11:44:09 +0200 Subject: [PATCH 44/44] Make it compatible with multilingual-lighteval. --- src/nanotron/config/lighteval_config.py | 1 + src/nanotron/models/llama.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index b5f12059..208091c9 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -51,6 +51,7 @@ def __post_init__(self): class LightEvalTasksArgs: """Arguments related to tasks for LightEval""" + langs: Optional[str] = None tasks: Optional[str] = None custom_tasks: Optional[str] = None max_samples: Optional[int] = None diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index ecb26fd2..2c6ddc01 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -757,7 +757,7 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] - lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] + lang_code: Union[torch.Tensor, TensorPointer]=None, # [batch_size, 1] ): return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask, lang_code=lang_code)[0]