diff --git a/docs/nanoset.md b/docs/nanoset.md index 9dce21b7..61393438 100644 --- a/docs/nanoset.md +++ b/docs/nanoset.md @@ -79,7 +79,7 @@ To work with `Nanosets`, we just need to configure 1 argument: Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py). ```shell -torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml +torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml ``` ## Under the hood diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 822be7a4..ad80b82a 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -101,12 +101,16 @@ class NanosetDatasetsArgs: dataset_folder: Union[str, dict, List[str]] def __post_init__(self): - if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file + if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder self.dataset_folder = [self.dataset_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file - self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights + elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder + self.dataset_weights = ( + None # Set to None so we consume all the samples randomly + ) + elif isinstance( + self.dataset_folder, dict + ): # Case 3: dict with > 1 dataset_folder and weights tmp_dataset_folder = self.dataset_folder.copy() self.dataset_folder = list(tmp_dataset_folder.keys()) self.dataset_weights = list(tmp_dataset_folder.values()) @@ -116,7 +120,9 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB + languages: List[ + str + ] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -124,8 +130,45 @@ def __post_init__(self): self.validation_folder = [self.validation_folder] self.dataset_weights = [1] elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder - self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.training_folder, dict): # Case 3: dict with > 1 training_folder and weights + self.dataset_weights = ( + None # Set to None so we consume all the samples randomly + ) + elif isinstance( + self.training_folder, dict + ): # Case 3: dict with > 1 training_folder and weights + tmp_training_folder = self.training_folder.copy() + self.training_folder = list(tmp_training_folder.keys()) + self.dataset_weights = list(tmp_training_folder.values()) + + assert len(self.training_folder) == len( + self.languages + ), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})" + + assert len(self.training_folder) == len( + self.validation_folder + ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" + + +@dataclass +class MultilingualNanosetDatasetsArgs: + training_folder: Union[str, dict, List[str]] + validation_folder: Union[str, List[str]] + languages: List[ + str + ] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB + + def __post_init__(self): + if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder + self.training_folder = [self.training_folder] + self.validation_folder = [self.validation_folder] + self.dataset_weights = [1] + elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder + self.dataset_weights = ( + None # Set to None so we consume all the samples randomly + ) + elif isinstance( + self.training_folder, dict + ): # Case 3: dict with > 1 training_folder and weights tmp_training_folder = self.training_folder.copy() self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) @@ -167,7 +210,9 @@ class DatasetStageArgs: def __post_init__(self): if self.start_training_step < 0: - raise ValueError(f"training_steps should be a positive integer and not {self.start_training_step}") + raise ValueError( + f"training_steps should be a positive integer and not {self.start_training_step}" + ) @dataclass @@ -182,6 +227,7 @@ class CheckpointsArgs: checkpoints_path: Path checkpoint_interval: int save_initial_state: Optional[bool] = False + save_final_state: Optional[bool] = False resume_checkpoint_path: Optional[Path] = None checkpoints_path_is_shared_file_system: Optional[bool] = False @@ -387,13 +433,19 @@ def __post_init__(self): if self.profiler is not None and self.profiler.profiler_export_path is not None: assert self.tokens.train_steps < 10 - if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None: + if ( + self.optimizer is not None + and self.optimizer.learning_rate_scheduler.lr_decay_steps is None + ): self.optimizer.learning_rate_scheduler.lr_decay_steps = ( - self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps + self.tokens.train_steps + - self.optimizer.learning_rate_scheduler.lr_warmup_steps ) if self.data_stages is not None: - self.data_stages = sorted(self.data_stages, key=lambda stage: stage.start_training_step) + self.data_stages = sorted( + self.data_stages, key=lambda stage: stage.start_training_step + ) names = [stage.name for stage in self.data_stages] training_steps = [stage.start_training_step for stage in self.data_stages] assert any( @@ -402,7 +454,9 @@ def __post_init__(self): for stage in self.data_stages: if names.count(stage.name) > 1: - raise ValueError(f"Each stage should have unique names and not {names}") + raise ValueError( + f"Each stage should have unique names and not {names}" + ) if training_steps.count(stage.start_training_step) > 1: raise ValueError( @@ -411,13 +465,29 @@ def __post_init__(self): # NOTE: must order the stages by start_training_step from lowest to highest assert all( - self.data_stages[i].start_training_step < self.data_stages[i + 1].start_training_step + self.data_stages[i].start_training_step + < self.data_stages[i + 1].start_training_step for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we # must comply with val_check_interval % iteration_step_info_interval = 0 - if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0: + if ( + not self.tokens.val_check_interval + % self.logging.iteration_step_info_interval + == 0 + ): + raise ValueError( + f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}" + ) + + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + if ( + not self.tokens.val_check_interval + % self.logging.iteration_step_info_interval + == 0 + ): raise ValueError( f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}" ) @@ -428,7 +498,11 @@ def __post_init__(self): @property def global_batch_size(self): - return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp + return ( + self.tokens.micro_batch_size + * self.tokens.batch_accumulation_per_replica + * self.parallelism.dp + ) def save_as_yaml(self, file_path: str): config_dict = serialize(self) @@ -460,12 +534,18 @@ def get_config_from_dict( if skip_unused_config_keys: logger.warning("skip_unused_config_keys set") config_dict = { - field.name: config_dict[field.name] for field in fields(config_class) if field.name in config_dict + field.name: config_dict[field.name] + for field in fields(config_class) + if field.name in config_dict } if skip_null_keys: logger.warning("Skip_null_keys set") config_dict = { - k: ({kk: vv for kk, vv in v.items() if vv is not None} if isinstance(v, dict) else v) + k: ( + {kk: vv for kk, vv in v.items() if vv is not None} + if isinstance(v, dict) + else v + ) for k, v in config_dict.items() if v is not None } diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index a1b71070..3fbcac49 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -51,6 +51,7 @@ def __post_init__(self): class LightEvalTasksArgs: """Arguments related to tasks for LightEval""" + langs: Optional[str] = None tasks: Optional[str] = None custom_tasks: Optional[str] = None max_samples: Optional[int] = None diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 321ee045..7f20ad99 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -34,6 +34,8 @@ class ParallelismArgs: tp_linear_async_communication: Optional[bool] = None recompute_layer: bool = False + tp_recompute_allgather: bool = True + expert_parallel_size: int = 1 def __post_init__(self): diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 84ef1946..9a8a6c68 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -38,7 +38,9 @@ def __init__( # Checks if isinstance(dataset_folders, str): - warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]") + warnings.warn( + "dataset_folders should be of type List[str] but str was provided. Converting to List[str]" + ) dataset_folders = [dataset_folders] # Init @@ -63,7 +65,9 @@ def __init__( # Build Nanoset Index ## To build the index we need the length of each dataset - self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets] + self.dataset_lengths = [ + len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets + ] ## Set dataset weights if ( dataset_weights is None @@ -76,10 +80,14 @@ def __init__( ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index if is_valid: # Valid MultilingualNanoset - self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(self.dataset_lengths) + self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index( + self.dataset_lengths + ) else: # Train MultilingualNanoset - self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() + self.dataset_index, self.dataset_sample_index = ( + self.build_train_nanoset_index() + ) self.print_nanoset_info() @@ -129,7 +137,9 @@ def build_train_nanoset_index(self) -> np.ndarray: numpy_random_state.shuffle(dataset_sample_index) # Concatenate num_epochs the shuffled indexes dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)]) - dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)]) + dataset_sample_index = np.concatenate( + [dataset_sample_index for _ in range(num_epochs)] + ) # Just keep the necessary samples dataset_index = dataset_index[: self.train_split_num_samples] dataset_sample_index = dataset_sample_index[: self.train_split_num_samples] @@ -152,7 +162,9 @@ def print_nanoset_info(self): ) # Print samples from each dataset + weight - dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders)) + dataset_sample_count = count_dataset_indexes( + self.dataset_index, len(self.dataset_folders) + ) for index, sample_count in enumerate(dataset_sample_count): log_rank( f"> Total number of {'validation' if self.is_valid else 'training'} samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", @@ -174,7 +186,9 @@ def build_train_nanoset_index_helper( """ # Create empty arrays for dataset indices and dataset sample indices dataset_index = np.empty((n_samples,), dtype="uint") - dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples + dataset_sample_index = np.empty( + (n_samples,), dtype="long" + ) # Supports dataset with up to 2**64 samples # Initialize buffer for number of samples used for each dataset current_samples = np.zeros((len(weights),), dtype="long") @@ -191,7 +205,9 @@ def build_train_nanoset_index_helper( # Assign the dataset index and update the sample index dataset_index[sample_idx] = max_error_index - dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index] + dataset_sample_index[sample_idx] = ( + current_samples[max_error_index] % dataset_sizes[max_error_index] + ) # Update the total samples for the selected dataset current_samples[max_error_index] += 1 @@ -211,4 +227,6 @@ def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: dataset_index.extend([i] * length) dataset_sample_index.extend(range(length)) - return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") + return np.array(dataset_index, dtype="uint"), np.array( + dataset_sample_index, dtype="long" + ) diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 382951fe..f7e57d2a 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -155,6 +155,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -164,8 +165,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - # TODO @nouamane: why can't we torch.jit.script GLUActivation? - self.split_silu_mul = GLUActivation(config.hidden_act) + self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) @@ -302,6 +302,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. self.rotary_embedding = RotaryEmbedding( @@ -739,6 +740,7 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, + "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, @@ -756,20 +758,27 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer]=None, # [batch_size, 1] lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask, lang_code=lang_code)[0] + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask, lang_code=lang_code)[0] def forward_with_hidden_states( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): # NOTE(tj.solergibert) I bring `lang_code` till the forward of LlamaModel. Remember that # to use it in the different pipeline blocks you need to also set the module_input_keys & module_output_keys # of the necessary `PipelineBlock`'s defined in the LlamaModel init! + # NOTE(tj.solergibert) I bring `lang_code` till the forward of LlamaModel. Remember that + # to use it in the different pipeline blocks you need to also set the module_input_keys & module_output_keys + # of the necessary `PipelineBlock`'s defined in the LlamaModel init! + # all tensors are optional as most ranks don't need anything from the dataloader. output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) @@ -849,6 +858,7 @@ def forward( # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 + sample_loss = sharded_cross_entropy( sample_loss = sharded_cross_entropy( sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float ).transpose(0, 1) @@ -861,6 +871,14 @@ def forward( # TODO @thomasw21: I think indexing causes a sync we don't actually want # TODO @thomasw21: loss = loss[label_mask].sum() return {"sample_loss": sample_loss} + sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float) + # NOTE(tj.solergibert) masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. + # We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort. + # WARN(tj.solergibert) Don't panic, the batch loss used to update the parameters is computed in `LlamaForTraining` + + # TODO @thomasw21: I think indexing causes a sync we don't actually want + # TODO @thomasw21: loss = loss[label_mask].sum() + return {"sample_loss": sample_loss} class LlamaForTraining(NanotronModel): diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index 873d77df..bd41347a 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -85,7 +85,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): @staticmethod def backward(ctx, grad_output): group = ctx.group - return DifferentiableReduceScatterSum.apply(grad_output, group), None + out = DifferentiableReduceScatterSum.apply(grad_output, group) + return out, None class DifferentiableReduceScatterSum(torch.autograd.Function): @@ -113,7 +114,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]): *rest_size, device=tensor.device, dtype=tensor.dtype, - requires_grad=tensor.requires_grad, + requires_grad=False, ) dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM) return sharded_tensor diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index ecccd33a..e2ee3a29 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -20,13 +20,12 @@ import nanotron.distributed as dist from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( - differentiable_all_gather, differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -121,10 +120,12 @@ class _ColumnLinearAsyncCommunication(torch.autograd.Function): @staticmethod @assert_cuda_max_connections_set_to_1 - def forward(ctx, tensor, weight, bias, group, tp_mode): + def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): ctx.use_bias = bias is not None ctx.tp_mode = tp_mode ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.tensor_shape = tensor.size() if tp_mode is TensorParallelLinearMode.ALL_REDUCE: gathered_tensor = tensor @@ -141,7 +142,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 tensor = tensor.contiguous() - ctx.save_for_backward(tensor, weight) + # ctx.save_for_backward(tensor, weight) # TODO @thomasw21: gather along another dimension sharded_batch_size, *intermediate_size, hidden_size = tensor.shape @@ -149,14 +150,19 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = torch.empty( - gathered_batch_size, - *intermediate_size, - hidden_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + if tp_recompute_allgather: + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) + else: + gathered_tensor = torch.empty( + gathered_batch_size, + *intermediate_size, + hidden_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -204,6 +210,10 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # Wait communication handle.wait() + if tp_recompute_allgather: + ctx.save_for_backward(tensor, weight) + else: + ctx.save_for_backward(gathered_tensor, weight) # Compute all the other shards that are obtained from AllGather # weights: w0 w1 w2 w3 @@ -261,8 +271,8 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias tp_mode = ctx.tp_mode - handle: Optional[dist.Work] = None - if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + handle1: Optional[dist.Work] = None + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape if group is None: @@ -273,14 +283,10 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = torch.empty( - unsharded_batch_size, - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=False, + unsharded_tensor = MemoryBuffer().get( + "allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype ) - handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) + handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation total_tensor = unsharded_tensor @@ -289,9 +295,6 @@ def backward(ctx, grad_output): grad_tensor = grad_output.matmul(weight) - if handle is not None: - handle.wait() - # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only # clones it if it's not contiguous: @@ -303,41 +306,128 @@ def backward(ctx, grad_output): grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) - handle: Optional[dist.Work] = None + handle2: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: if group.size() == 1: sub_grad_tensor = grad_tensor else: sub_grad_tensor = torch.empty( - tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False + ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter - handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) + handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # reduce scatter is scheduled before the weight gradient computation elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: # Asynchronous all-reduce - handle = dist.all_reduce(grad_tensor, group=group, async_op=True) + handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # all-reduce is scheduled before the weight gradient computation else: raise ValueError() + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if handle1 is not None: + handle1.wait() + # TODO @thomasw21: This sounds like we don't have the optimal physical layout grad_weight = grad_output.t().matmul(total_tensor) - grad_bias = grad_output.sum(dim=0) if use_bias else None - if handle is not None: - handle.wait() + if handle2 is not None: + handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return sub_grad_tensor, grad_weight, grad_bias, None, None + return sub_grad_tensor, grad_weight, grad_bias, None, None, None elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: - return grad_tensor, grad_weight, grad_bias, None, None + return grad_tensor, grad_weight, grad_bias, None, None, None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") +class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function): + """ + Column linear with memory_buffer for the allgather, context parallel + enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and + async communication disabled. + """ + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + group: dist.ProcessGroup, + tp_recompute_allgather: bool, + ): + + # Do allgather. + sharded_batch_size, *rest_size = input.shape + unsharded_batch_size = sharded_batch_size * group.size() + if group.size() == 1: + total_input = input.contiguous() + elif tp_recompute_allgather: + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + else: + total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Prepare context. + ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.input_size = input.shape + if tp_recompute_allgather: + ctx.save_for_backward(input, weight, bias) + else: + ctx.save_for_backward(total_input, weight, bias) + + # Get linear output. + out = F.linear(total_input, weight, bias) + return out + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + # Either allgather the inputs again or get them from context. + group = ctx.group + tp_recompute_allgather = ctx.tp_recompute_allgather + input_size = ctx.input_size + if group.size() == 1 or not tp_recompute_allgather: + total_input, weight, bias = ctx.saved_tensors + else: + input, weight, bias = ctx.saved_tensors + sharded_batch_size, *rest_size = input.shape + total_input = sharded_batch_size * group.size() + unsharded_batch_size = sharded_batch_size * group.size() + total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype) + dist.all_gather_into_tensor(total_input, input.contiguous(), group=group) + + # Convert the tensor shapes to 2D for execution compatibility + grad_output = grad_output.contiguous() + grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1] + total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1] + grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) + total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim) + + # Compute gradients. + grad_weight = grad_output.T @ total_input + grad_input = grad_output @ weight + if group.size() == 1: + sub_grad_input = grad_input + else: + # Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 + # We set grad_input to be contiguous in case it isn't already. + grad_input = grad_input.contiguous() + sub_grad_input = torch.empty( + input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False + ) + dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM) + grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None + + return sub_grad_input, grad_weight, grad_bias, None, None + + def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -345,18 +435,19 @@ def column_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, + tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) - elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - input = differentiable_all_gather(input, group=group) - else: - raise ValueError(f"Got unexpected mode: {tp_mode}.") - - return F.linear(input, weight, bias) + return F.linear(input, weight, bias) + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( + input, weight, bias, group, tp_recompute_allgather + ) + raise ValueError(f"Got unexpected mode: {tp_mode}.") class _RowLinearAsyncCommunication(torch.autograd.Function): @@ -397,12 +488,8 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = torch.empty( - unsharded_batch_size, - *rest_size, - device=grad_output.device, - dtype=grad_output.dtype, - requires_grad=False, + total_grad_output = MemoryBuffer().get( + "allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype ) # Doing gather + slicing during the NeMo forward pass can make this tensor diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 40e89968..4c7325cd 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -51,6 +51,7 @@ def __init__( dtype=None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, + tp_recompute_allgather: bool = True, ): self.pg = pg self.world_size = pg.size() @@ -59,6 +60,7 @@ def __init__( self.in_features = in_features self.out_features = out_features // self.world_size + self.tp_recompute_allgather = tp_recompute_allgather super().__init__( in_features=self.in_features, @@ -91,6 +93,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, + tp_recompute_allgather=self.tp_recompute_allgather, ) def extra_repr(self) -> str: diff --git a/src/nanotron/parallel/utils.py b/src/nanotron/parallel/utils.py index b9ac12ae..f694b0e6 100644 --- a/src/nanotron/parallel/utils.py +++ b/src/nanotron/parallel/utils.py @@ -1,11 +1,31 @@ import functools +import operator import os +import torch from torch import nn from nanotron import distributed as dist from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param +from nanotron.utils import Singleton + + +class MemoryBuffer(metaclass=Singleton): + """ + Global memory buffer to store intermediate activations that need not to be cached for the backward pass. + """ + + def __init__(self): + self.buffer = {} + + def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor: + required_numel = functools.reduce(operator.mul, shape, 1) + if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel: + self.buffer[name, dtype] = torch.empty( + required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + return self.buffer[name, dtype][:required_numel].view(shape) def assert_cuda_max_connections_set_to_1(func): diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 286008ac..346ad573 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -236,6 +236,7 @@ def load( load_optimizer(optimizer=optimizer, parallel_context=parallel_context, root_folder=root_folder) load_lr_scheduler( lr_scheduler=lr_scheduler, + parallel_context=parallel_context, root_folder=root_folder, ) return checkpoint_metadata diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index 68a3b1a0..f11210da 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -30,9 +30,9 @@ def optimizer_filename(parallel_context: ParallelContext, is_zero: bool): return f"{ObjectType.OPTIMIZER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" -def lr_scheduler_filename(): +def lr_scheduler_filename(parallel_context: ParallelContext): """The lr_scheduler is the same for all processes.""" - return f"{ObjectType.LR_SCHEDULER.value}.pt" + return f"{ObjectType.LR_SCHEDULER.value}_pp-{dist.get_rank(parallel_context.pp_pg)}-of-{parallel_context.pp_pg.size()}_tp-{dist.get_rank(parallel_context.tp_pg)}-of-{parallel_context.tp_pg.size()}_exp-{dist.get_rank(parallel_context.expert_pg)}-of-{parallel_context.expert_parallel_size}.pt" def save_optimizer( @@ -109,9 +109,6 @@ def save_lr_scheduler( root_folder: Path, ): """Saves lr scheduler states""" - if dist.get_rank(parallel_context.world_pg) > 0: - # Only WORLD-RANK 0 saves the lr scheduler state - return root_folder = root_folder / "lr_scheduler" root_folder.mkdir(exist_ok=True, parents=True) @@ -119,7 +116,7 @@ def save_lr_scheduler( # We dump the optimizer state using `torch.save` torch.save( lr_scheduler.state_dict(), - root_folder / lr_scheduler_filename(), + root_folder / lr_scheduler_filename(parallel_context), ) @@ -313,9 +310,10 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - def load_lr_scheduler( lr_scheduler, + parallel_context: ParallelContext, root_folder: Path, ): root_folder = root_folder / "lr_scheduler" - state_dict = torch.load(root_folder / lr_scheduler_filename()) + state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context)) lr_scheduler.load_state_dict(state_dict) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 53f3708f..683744bd 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -211,6 +211,7 @@ def __init__( if self.init_checkpoint_path is not None: load_lr_scheduler( lr_scheduler=self.lr_scheduler, + parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path, ) @@ -573,6 +574,9 @@ def train( dist.barrier() # let's wait for everyone before leaving + if self.config.checkpoints.save_final_state: + self.save_checkpoint() + self.post_training() def training_step( @@ -1107,8 +1111,8 @@ def save_checkpoint(self) -> Path: ), # We only save the weights on DP==0 should_save_optimizer=True, should_save_lr_scheduler=bool( - dist.get_rank(self.parallel_context.world_pg) == 0 - ), # We only save the lr_scheduler on world_rank==0 + dist.get_rank(self.parallel_context.dp_pg) == 0 + ), # We only save the lr_scheduler on DP==0 should_save_config=bool( dist.get_rank(self.parallel_context.world_pg) == 0 ), # We only save the config on world_rank==0 diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 14fe1ca8..b3831801 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -1,11 +1,10 @@ import functools import inspect -import math import os import random import socket from contextlib import ExitStack, contextmanager -from typing import Callable, ContextManager, List, Optional +from typing import ContextManager, List, Optional import torch from packaging import version @@ -15,6 +14,25 @@ from nanotron import distributed as dist +class Singleton(type): + """ + Singleton metaclass. + Create objects using this class as the metaclass to enable singleton behaviour. + For instance: + ``` + class Logger(metaclass=Singleton): + ... + ``` + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + class ContextManagers: """ Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers` @@ -52,7 +70,7 @@ def main_rank_first(group: dist.ProcessGroup): @contextmanager def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None): """Context manager that executes the code in the context with all the local rank zero of the group going first. - Usefull to run only once per node first (e.g. to create local files, etc) + Useful to run only once per node first (e.g. to create local files, etc) """ is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0 if is_main: @@ -123,6 +141,7 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage: else: return tensor.storage().untyped() + def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype): # TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage. device = untyped_storage.device diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index f5dcaeb0..16008eaa 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -18,17 +18,30 @@ @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_column_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather") init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)( - tp_mode=tp_mode, async_communication=async_communication + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather ) def _test_column_linear( - parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, ): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" @@ -44,6 +57,7 @@ def _test_column_linear( mode=tp_mode, device="cuda", async_communication=async_communication, + tp_recompute_allgather=tp_recompute_allgather, ) # Un-sharded @@ -86,7 +100,7 @@ def _test_column_linear( random_input = sharded_random_input else: ValueError(f"Unsupported mode: {tp_mode}") - # It's important that `random_input` and `sharded_random_input` are two seperate tensors with seperate storage + # It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage sharded_random_input = sharded_random_input.clone() random_input.requires_grad = True sharded_random_input.requires_grad = True @@ -150,15 +164,32 @@ def _test_column_linear( @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_row_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_row_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather") - init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(tp_mode=tp_mode, async_communication=async_communication) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)( + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather + ) -def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool): +def _test_row_linear( + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" out_features = 3 diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 8383ba38..23016eaf 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -112,6 +112,7 @@ def main(args): shuffle=False, tokenizer_name_or_path=args.tokenizer_name_or_path, eos_token=args.eos_token, + shuffle=False, max_tokens_per_file=1e9, ), ],