From 27133e1e4b433c07f7e423b66bd1eb9845dc948c Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Mon, 22 Jul 2024 16:39:47 +0000 Subject: [PATCH 1/8] Just in case --- examples/config_multilingual_nanoset.yaml | 2 +- run_train.py | 4 +- src/nanotron/config/config.py | 1 + src/nanotron/models/llama.py | 30 ++-- .../parallel/pipeline_parallel/engine.py | 27 +++- src/nanotron/trainer.py | 143 ++++++++++++++++-- 6 files changed, 173 insertions(+), 34 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 599bff6c..33f9db41 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -131,4 +131,4 @@ tokens: micro_batch_size: 4 sequence_length: 1024 train_steps: 200 - val_check_interval: -1 + val_check_interval: 3 diff --git a/run_train.py b/run_train.py index 39cda23b..ed9b5607 100644 --- a/run_train.py +++ b/run_train.py @@ -238,10 +238,10 @@ def get_valid_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): valid_dataset = MultilingualNanoset( - dataset_folders=data.dataset.validation_folder, + dataset_folders=data.dataset.validation_folder, # TODO Just 1 folder sequence_length=trainer.sequence_length, token_size=token_size, - dataset_tokens=data.dataset.dataset_tokens, + dataset_tokens=data.dataset.dataset_tokens, # TODO Just 1 lang is_valid=True, random_seed=data.seed, ) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index dd2c157d..e5ea3ec1 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -125,6 +125,7 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) + self.ids_to_lang = {v: k for k, v in self.lang_to_ids.items()} self.dataset_tokens = list(self.lang_to_ids.values()) assert len(self.training_folder) == len( self.validation_folder diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 2411e5fa..7ae34dd5 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -801,7 +801,14 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch @torch.jit.script def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor - return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( + dim=1 + ) # TODO esto de entrada da float/float = float + + +# TODO la loss de cada uno !!!! ((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1)) +# Y pasa el assert close!! +# assert_close(((loss * label_mask).sum(dtype=dtype) / label_mask.sum()), torch.mean((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1))) class Loss(nn.Module): @@ -818,14 +825,16 @@ def forward( # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 - loss = sharded_cross_entropy( + sample_loss = sharded_cross_entropy( sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float ).transpose(0, 1) # TODO @thomasw21: It's unclear what kind of normalization we want to do. - loss = masked_mean(loss, label_mask, dtype=torch.float) - # I think indexing causes a sync we don't actually want - # loss = loss[label_mask].sum() - return {"loss": loss} + sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float) + # NOTE @tj.solergibert: masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. + # We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort. + # TODO @thomasw21: I think indexing causes a sync we don't actually want + # TODO @thomasw21: loss = loss[label_mask].sum() + return {"sample_loss": sample_loss} class LlamaForTraining(NanotronModel): @@ -847,7 +856,7 @@ def __init__( "label_ids", "label_mask", }, - module_output_keys={"loss"}, + module_output_keys={"sample_loss"}, ) self.parallel_context = parallel_context self.config = config @@ -864,12 +873,13 @@ def forward( input_ids=input_ids, input_mask=input_mask, ) - loss = self.loss( + outputs = self.loss( sharded_logits=sharded_logits, label_ids=label_ids, label_mask=label_mask, - )["loss"] - return {"loss": loss} + ) + outputs["loss"] = torch.mean(outputs["sample_loss"]) + return outputs @torch.no_grad() def init_model_randomly(self, config: Config): diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..bf690bd0 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,6 +2,9 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup @@ -12,8 +15,6 @@ from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) @@ -29,6 +30,7 @@ def forward( state: PipelineTrainBatchState, micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]], model: torch_nn.Module, + is_validation: bool = False, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Increment the number of backwards state.nb_forwards += 1 @@ -52,7 +54,7 @@ def forward( output["loss"] = output["loss"] / self.nb_microbatches # Add output as activations that require backward pass - if not isinstance(output["loss"], TensorPointer): + if not isinstance(output["loss"], TensorPointer) and not is_validation: assert output["loss"].requires_grad state.register_activation_requiring_backward(output["loss"]) return output @@ -138,12 +140,15 @@ def validate_batch_iter( self.nb_microbatches = nb_microbatches outputs = [] + lang_ids = [] with attach_pipeline_state_to_model(model=model, pipeline_state=state): # All forward for micro_batch in batch: context = self._get_fwd_context(model=model) - output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model) + output = self.forward( + context=context, state=state, micro_batch=micro_batch, model=model, is_validation=True + ) # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage" for _ in range(len(state.microbatches_activations_to_send)): send_activation = state.microbatches_activations_to_send.popleft() @@ -151,15 +156,23 @@ def validate_batch_iter( send_activation() # We make `output` a dict + # TODO convert to dict other items returned by the model (MoE aux loss for example) + # But in next if statement be careful if we return other items in all of the pp processes + # This conversion to dicts is kind of useless as the model already returns a dict with loss key. Maybe the PP ranks return TensorPointer Objects? if not isinstance(output, dict): output = {"loss": output} # Store the loss for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} - outputs.append(output) - - return outputs + # TODO ver este output que es y tambien ver en outputs como se guarda. Donde se have la media? En el training step lol + # Aqui deberiamos segregar por languagues porque es el unico punto en el que tenemos la languague!! O al menos "etiquetarla" o acumularla por language + # 1. Hacemos dict con key para cada idioma 2. cada key tiene una lista donde append los tensors 3. en valid step hacemos lo del stack y allreduces + # Finalmente: Aqui metemos solo el lang ids, en trainer.py acumularemos los resultados y tal. + outputs.extend(list(output["sample_loss"])) # TODO flatten?????? o extend?????? + lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) # TODO esto deberia se un extend???? + + return outputs, lang_ids class AllForwardAllBackwardPipelineEngine(PipelineEngine): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 3f4c5189..583068cd 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -300,7 +300,9 @@ def _print_training_plan(self): ) log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) - def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): + def _update_dataloader_based_on_training_stages( + self, dataloaders: Union[List[DataLoader], DataLoader], is_validation: bool = False + ): from collections.abc import Generator if not hasattr(self.config, "data_stages") or self.config.data_stages is None: @@ -309,9 +311,16 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da dataloader = dataloaders[0] else: dataloader = dataloaders - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + + if is_validation: + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + else: + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) return elif isinstance(dataloaders, Generator): # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader @@ -328,7 +337,7 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): import gc log_rank( - f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", + f"[{'Validation' if is_validation else 'Training'} Stage: {stage_name}] Clearing the previous {'validation' if is_validation else 'training'} stage's dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -369,7 +378,7 @@ def find_stage_idx_to_resume(): self.metadata.last_stage_idx = stage_idx - if is_resume_from_training: + if is_resume_from_training and not is_validation: remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( stage, self.config, self.metadata ) @@ -387,9 +396,15 @@ def find_stage_idx_to_resume(): break if dataloader is not None: - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + if is_validation: + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + else: + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) def train( self, @@ -428,9 +443,23 @@ def train( self.iteration_start_time = time.time() self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) + self._update_dataloader_based_on_training_stages(valid_dataloader_or_dls, is_validation=True) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) + self.training_step_time = time.time() + + # Validation step + # TODO A ver, en este loop solo se lleva a cabo una training iteration pero claro hay un porron de validation iteration... mmmmm + # Tal vez deberiamos mover esto a otro lugar? Es decir, aqui se have un training step pero hacemos varios validation steps + # Lo podemos dejar aqui solamente que las metricas de throughput y tokens consumidos se tendrian que revisar + # Porque actualmente utilizan la global batch size, que es correcta ya que es la que tiene cada training step pero claro, + # Cada validation es mucho mas largo que un training step + # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas + if self.iteration_step % self.config.tokens.val_check_interval == 0: + global_loss, lang_losses = self.validation_step(dataloader=self.current_validation_dataloader) + self.validation_step_time = time.time() + self.validation_step_logs(global_loss, lang_losses) # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -546,12 +575,36 @@ def training_step( return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: - outputs = self.pipeline_engine.validate_batch_iter( + outputs, lang_ids = self.pipeline_engine.validate_batch_iter( model=self.model, - batch=(next(dataloader) for _ in range(self.limit_val_batches)), - nb_microbatches=self.limit_val_batches, + batch=(next(dataloader) for _ in range(self.current_validation_dataloader_lenght)), + nb_microbatches=self.current_validation_dataloader_lenght, ) - return outputs + + lang_losses = { + lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.lang_to_ids.keys() + } + # Compute losses + if isinstance(outputs[0], torch.Tensor): + # Multilingual losses + for loss, lang_id in zip(outputs, lang_ids): + lang_losses[ + self.config.data_stages[self.metadata.last_stage_idx].data.dataset.ids_to_lang[lang_id] + ].append(loss) + # Global loss + global_loss_avg = torch.stack(outputs).sum() + # Sync losses across DP + for lang in lang_losses.keys(): + lang_losses[lang] = torch.stack(lang_losses[lang]).sum() + dist.all_reduce( + lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG + ) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos... + dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) + else: + global_loss_avg = None + lang_losses = None + + return global_loss_avg, lang_losses def train_step_logs( self, @@ -561,7 +614,7 @@ def train_step_logs( # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() torch.cuda.synchronize() - elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 + elapsed_time_per_iteration_ms = (self.training_step_time - self.iteration_start_time) * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) ) # tokens_per_sec is calculated using sequence_length @@ -641,6 +694,68 @@ def train_step_logs( else: exit(0) + def validation_step_logs( + self, + global_loss: torch.Tensor, + lang_losses: torch.Tensor, + ) -> None: + # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 + dist.barrier() + torch.cuda.synchronize() + total_validation_samples = self.current_validation_dataloader_lenght * self.micro_batch_size + elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 + tokens_per_sec = ( + total_validation_samples * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) + ) # tokens_per_sec is calculated using sequence_length + # TODO para el valid ojo con cambiar global_batch_size = len dataloader * mbs + model_tflops, hardware_tflops = self.unwrapped_model.get_flops_per_sec( + iteration_time_in_sec=elapsed_time_per_iteration_ms / 1000, + sequence_length=self.sequence_length, + global_batch_size=total_validation_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches + ) + + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: + assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" + + log_entries = [ + # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), + LogItem( + "validation_consumed_tokens", + self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + "human_format", + ), # , "12d"), + LogItem( + "validation_elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format" + ), # , ".1f"), + LogItem("validation_tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), + LogItem( + "validation_tokens_per_sec_per_gpu", + tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", + ), # , "1.6E"), + LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), + LogItem("validation_model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), + LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), + ] + + # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. + # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 + # GitHub complains: https://github.com/wandb/wandb/issues/3035 + log_entries.extend( + [LogItem(f"{lang}_validation_loss", loss.item(), "human_format") for lang, loss in lang_losses.items()] + ) + + # NOTE: only one rank writes to wandb + if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: + wandb.log( + { + **{log_item.tag: log_item.scalar_value for log_item in log_entries}, + "iteration_step": self.iteration_step, + } + ) + + self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step) + def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: """Initialize the model and load weights from checkpoint if needed.""" # TODO: add max_position_embeddings From 5c09e11a1a1df814b5edccb1ba6ac0a026897d04 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 23 Jul 2024 16:01:09 +0000 Subject: [PATCH 2/8] just in case --- run_train.py | 3 +- src/nanotron/data/dataloader_builder.py | 3 +- .../parallel/pipeline_parallel/engine.py | 4 +- .../parallel/pipeline_parallel/state.py | 4 + src/nanotron/serialize/metadata.py | 2 + src/nanotron/trainer.py | 171 ++++++++++++++---- 6 files changed, 148 insertions(+), 39 deletions(-) diff --git a/run_train.py b/run_train.py index ed9b5607..80c7a426 100644 --- a/run_train.py +++ b/run_train.py @@ -256,6 +256,7 @@ def get_valid_dataloader_from_data_stage( micro_batch_size=trainer.micro_batch_size, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + shuffle=True, ) return valid_dataloader @@ -315,7 +316,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: stage = cast(DatasetStageArgs, stage) log_rank( - f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples for the validation set", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 9d3285f6..b8bfb303 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -23,6 +23,7 @@ def build_nanoset_dataloader( consumed_train_samples: int = 0, dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, + shuffle: bool = False, ) -> DataLoader: # Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job @@ -49,7 +50,7 @@ def build_nanoset_dataloader( dl_rank=dp_rank, drop_last=dataloader_drop_last, consumed_train_samples=consumed_train_samples, - shuffle=False, + shuffle=shuffle, ) return DataLoader( diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index bf690bd0..bc6dc5b5 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -12,7 +12,7 @@ from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model -from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState +from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState, PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers @@ -136,7 +136,7 @@ def validate_batch_iter( nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: # Assign a new state for the current batch - state = PipelineTrainBatchState() # TODO: do i need state? + state = PipelineEvalBatchState() # PipelineTrainBatchState() # TODO: do i need state? self.nb_microbatches = nb_microbatches outputs = [] diff --git a/src/nanotron/parallel/pipeline_parallel/state.py b/src/nanotron/parallel/pipeline_parallel/state.py index e07cc89a..f22d6571 100644 --- a/src/nanotron/parallel/pipeline_parallel/state.py +++ b/src/nanotron/parallel/pipeline_parallel/state.py @@ -4,6 +4,7 @@ from typing import List import torch + from nanotron import distributed as dist from nanotron import logging from nanotron.logging import log_rank @@ -203,6 +204,9 @@ class PipelineEvalBatchState(PipelineBatchState): microbatches_activations_to_recv = collections.deque() activations_buffer = collections.deque() + # Reinitialise counter + nb_forwards = 0 + def register_activation_requiring_backward(self, activation: torch.Tensor): pass diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 0d8708f9..4bd36c19 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -46,6 +46,8 @@ class TrainingMetadata: last_stage_idx: Optional[int] = None data_stages: Optional[List[DataStageMetadata]] = None + last_validation_stage_idx: Optional[int] = None + def __post_init__(self): # NOTE: this is a sanity check after loading a trained checkpoint total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 583068cd..b1cc36ad 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -231,7 +231,11 @@ def __init__( for stage in self.config.data_stages ] self.metadata: TrainingMetadata = TrainingMetadata( - consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages + consumed_train_samples=0, + last_train_step=0, + last_stage_idx=0, + data_stages=data_stages, + last_validation_stage_idx=0, ) # Setup tensorboard write and log writers on output rank @@ -253,6 +257,8 @@ def __init__( self.limit_val_batches = self.config.tokens.limit_val_batches # NOTE: the dataloader currently in use for the current training stage self.current_dataloader: Optional[DataLoader] = None + # NOTE: the dataloader currently in use for the current validation stage + self.current_validation_dataloader: Optional[DataLoader] = None self.post_init() @@ -300,9 +306,108 @@ def _print_training_plan(self): ) log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) - def _update_dataloader_based_on_training_stages( - self, dataloaders: Union[List[DataLoader], DataLoader], is_validation: bool = False - ): + def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataLoader], DataLoader]): + # NOTE(tj.solergibert) Similar to _update_dataloader_based_on_training_stages BUT: + # 1. We call this function EVERY TIME we run the validation loop + # 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset + # in the first iteration and subsequent validations will fail + # TODO(tj.solergibert) Delete previous DataLoaders from memory like we do with training DataLoaders + # TODO(tj.solergibert) Check the tuple case below + from collections.abc import Generator + + if not hasattr(self.config, "data_stages") or self.config.data_stages is None: + + if isinstance(dataloaders, tuple): # TODO(tj.solergibert) Check this tuple case + dataloader = dataloaders[0] + else: + dataloader = dataloaders + + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + + return + elif isinstance(dataloaders, Generator): + # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader + # remove this in the next PR + self.current_validation_dataloader = dataloaders + return + + assert len(dataloaders) > 0, "No dataloaders provided" + assert len(dataloaders) == len( + self.config.data_stages + ), "Number of dataloaders should match the number of dataset stages" + + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + import gc + + log_rank( + f"[Validation Stage: {stage_name}] Clearing the previous validation stage's dataloader and dataset from memory", + logger=logger, + level=logging.INFO, + ) + + # NOTE: Clear dataloader from memory + del dataloader.dataset + del dataloader.sampler + del dataloader.batch_sampler + + gc.collect() + + dataloader = None + + for stage_idx, stage in enumerate(self.config.data_stages): + if stage_idx < self.metadata.last_stage_idx: + continue + + if ( + stage_idx is not self.metadata.last_validation_stage_idx + and self.metadata.last_validation_stage_idx is not None + ): + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + # Si cambiamos de stage borramo el antiguo + # En ambos casos recrear el que toca !!! + # TODO Aqui nos quedamos!!! Tenemos que borrar el anterior dataloader cuando sea necesario y hacer el sanity del current dataloader SIEMPRE + stage = cast(DatasetStageArgs, stage) + print( + stage.name + ) # TODO como actualizamos el last stage index en el training aqui estamos mirando el dataloader de la siguiente iteracion que mal por dios!!!!! + + log_rank( + f"Ese print bueno {stage.name}", + logger=logger, + level=logging.INFO, + rank=0, + ) + # self.metadata.last_stage_idx = stage_idx + """ + if self.current_validation_dataloader is not None: # TODO Si hay algun dataloader ya lo eliminamos. Igualmente creamos de nuevo. Bueno el dataloader como tal ya esta creado, solo hay que devolver el sanity check raro + prev_stage_name = self.config.data_stages[stage_idx - 1].name + prev_dataloader = dataloaders[prev_stage_name] + + if isinstance(prev_dataloader, DataLoader): + # NOTE: we don't need to clear dummy data generator from memory + clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + """ + log_rank( + f"Preparing validation DataLoader from stage {stage.name}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = dataloaders[stage.name] + # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it + dataloader = dataloader() if callable(dataloader) else dataloader + break + + self.current_validation_dataloader_lenght = 200 # TODO len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + + def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): from collections.abc import Generator if not hasattr(self.config, "data_stages") or self.config.data_stages is None: @@ -311,16 +416,9 @@ def _update_dataloader_based_on_training_stages( dataloader = dataloaders[0] else: dataloader = dataloaders - - if is_validation: - self.current_validation_dataloader_lenght = len(dataloader) - self.current_validation_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) - else: - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) return elif isinstance(dataloaders, Generator): # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader @@ -337,7 +435,7 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): import gc log_rank( - f"[{'Validation' if is_validation else 'Training'} Stage: {stage_name}] Clearing the previous {'validation' if is_validation else 'training'} stage's dataloader and datasets from memory", + f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -360,7 +458,7 @@ def find_stage_idx_to_resume(): stage_idx_to_resume = find_stage_idx_to_resume() - for stage_idx, stage in enumerate(self.config.data_stages): + for stage_idx, stage in enumerate(self.config.data_stages): # TODO check metadatalaststageindex init if stage_idx < self.metadata.last_stage_idx: continue @@ -378,7 +476,7 @@ def find_stage_idx_to_resume(): self.metadata.last_stage_idx = stage_idx - if is_resume_from_training and not is_validation: + if is_resume_from_training: remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( stage, self.config, self.metadata ) @@ -396,15 +494,9 @@ def find_stage_idx_to_resume(): break if dataloader is not None: - if is_validation: - self.current_validation_dataloader_lenght = len(dataloader) - self.current_validation_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) - else: - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) def train( self, @@ -443,7 +535,6 @@ def train( self.iteration_start_time = time.time() self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) - self._update_dataloader_based_on_training_stages(valid_dataloader_or_dls, is_validation=True) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) @@ -457,9 +548,18 @@ def train( # Cada validation es mucho mas largo que un training step # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas if self.iteration_step % self.config.tokens.val_check_interval == 0: - global_loss, lang_losses = self.validation_step(dataloader=self.current_validation_dataloader) + log_rank( + f"KOMO???? {self.iteration_step}", + logger=logger, + level=logging.INFO, + rank=0, + ) + self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) + val_global_loss, val_lang_losses = self.validation_step( + dataloader=self.current_validation_dataloader + ) self.validation_step_time = time.time() - self.validation_step_logs(global_loss, lang_losses) + self.validation_step_logs(val_global_loss, val_lang_losses) # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -592,10 +692,10 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten self.config.data_stages[self.metadata.last_stage_idx].data.dataset.ids_to_lang[lang_id] ].append(loss) # Global loss - global_loss_avg = torch.stack(outputs).sum() + global_loss_avg = torch.mean(torch.stack(outputs)) # Sync losses across DP for lang in lang_losses.keys(): - lang_losses[lang] = torch.stack(lang_losses[lang]).sum() + lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) dist.all_reduce( lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG ) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos... @@ -630,7 +730,6 @@ def train_step_logs( lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "consumed_tokens", self.metadata.consumed_train_samples * self.config.tokens.sequence_length, @@ -718,7 +817,6 @@ def validation_step_logs( assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "validation_consumed_tokens", self.metadata.consumed_train_samples * self.config.tokens.sequence_length, @@ -734,7 +832,7 @@ def validation_step_logs( "human_format", ), # , "1.6E"), LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), - LogItem("validation_model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), + LogItem("validation_model_tflops_per_gpu", model_tflops / 3, "human_format"), # , ".2f"), LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), ] @@ -746,12 +844,15 @@ def validation_step_logs( ) # NOTE: only one rank writes to wandb + # NOTE(tj.solergibert) By default wandb.log performs a step in the x-axis every time. + # Set commit=False to log values with the next wandb.log with the training logs if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: wandb.log( { **{log_item.tag: log_item.scalar_value for log_item in log_entries}, "iteration_step": self.iteration_step, - } + }, + commit=False, ) self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step) From 94d6c2a9931cf735366be9b356a01270e657a9a1 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 16:18:36 +0000 Subject: [PATCH 3/8] This looks good --- examples/config_multilingual_nanoset.yaml | 60 +++++----- run_train.py | 9 +- src/nanotron/models/llama.py | 7 +- .../parallel/pipeline_parallel/engine.py | 16 +-- src/nanotron/trainer.py | 103 +++++++++--------- 5 files changed, 99 insertions(+), 96 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 33f9db41..5573a224 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 1000 + checkpoint_interval: 1000000 checkpoints_path: checkpoints/ checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null @@ -7,38 +7,40 @@ checkpoints: data_stages: - data: dataset: - training_folder: datasets/c4-es/train - validation_folder: datasets/c4-es/validation + training_folder: + datasets/c4-es/train: 0.85 + datasets/c4-en/train: 0.05 + datasets/c4-fr/train: 0.1 + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation lang_to_ids: es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 - name: General purpose training (Single dataset) + name: General purpose training (Blended dataset) start_training_step: 1 - data: dataset: training_folder: - datasets/c4-es/train - - datasets/c4-en/train - - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - - datasets/c4-en/validation - - datasets/c4-fr/validation lang_to_ids: es: 128002 - en: 128003 - fr: 128004 num_loading_workers: 1 seed: 42 - name: Second purpose training (> 1 dataset) - start_training_step: 15 + name: Second purpose training (Single dataset) + start_training_step: 100 - data: dataset: training_folder: - datasets/c4-es/train: 0.6 - datasets/c4-en/train: 0.3 - datasets/c4-fr/train: 0.1 + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - datasets/c4-en/validation @@ -50,13 +52,13 @@ data_stages: num_loading_workers: 1 seed: 42 - name: Third purpose training (Blended dataset) - start_training_step: 25 + name: Third purpose training (>1 dataset) + start_training_step: 200 general: benchmark_csv_path: null consumed_train_samples: null ignore_sanity_checks: true - project: Nanoset + project: Multilingual run: llama seed: 42 step: null @@ -75,12 +77,12 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 512 + hidden_size: 4096 initializer_range: 0.02 - intermediate_size: 512 + intermediate_size: 14336 is_llama_config: true - max_position_embeddings: 1024 - num_hidden_layers: 2 + max_position_embeddings: 4096 + num_hidden_layers: 32 num_attention_heads: 32 num_key_value_heads: 8 pad_token_id: null @@ -89,7 +91,7 @@ model: rope_theta: 500000.0 rms_norm_eps: 1.0e-06 rope_scaling: null - tie_word_embeddings: true + tie_word_embeddings: false use_cache: true vocab_size: 128256 optimizer: @@ -116,19 +118,19 @@ parallelism: expert_parallel_size: 1 pp: 1 pp_engine: 1f1b - tp: 1 + tp: 4 tp_linear_async_communication: false tp_mode: REDUCE_SCATTER profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_name_or_path: /mloscratch/homes/solergib/models/Meta-Llama-3-8B tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 10 - micro_batch_size: 4 - sequence_length: 1024 - train_steps: 200 - val_check_interval: 3 + micro_batch_size: 3 + sequence_length: 4096 + train_steps: 800 + val_check_interval: 50 diff --git a/run_train.py b/run_train.py index 80c7a426..2ddff5ad 100644 --- a/run_train.py +++ b/run_train.py @@ -325,8 +325,15 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloader = ( get_valid_dataloader_from_data_stage(trainer, stage.data) if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) + else lambda stage=stage: get_valid_dataloader_from_data_stage(trainer, stage.data) ) + # TODO(tj.solergibert) As we are creating again the valid dataloader in every validation stage, we print multiple times + # the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda + # funcs and directly create all dataloaders. + # + # This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead of creating multiple DataLoaders + # 2. Consume less memory as the lambda func is lighter that the DataLoader object with the Dataset, collator, etc. + # BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling from the Nanoset dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 7ae34dd5..133442af 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -803,12 +803,7 @@ def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( dim=1 - ) # TODO esto de entrada da float/float = float - - -# TODO la loss de cada uno !!!! ((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1)) -# Y pasa el assert close!! -# assert_close(((loss * label_mask).sum(dtype=dtype) / label_mask.sum()), torch.mean((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1))) + ) # NOTE(tj.solergibert) Added dim=1 to return a tensor with shape [Batch size, 1] instead of [1] class Loss(nn.Module): diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index bc6dc5b5..549ef5eb 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -136,7 +136,7 @@ def validate_batch_iter( nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: # Assign a new state for the current batch - state = PipelineEvalBatchState() # PipelineTrainBatchState() # TODO: do i need state? + state = PipelineEvalBatchState() self.nb_microbatches = nb_microbatches outputs = [] @@ -156,21 +156,17 @@ def validate_batch_iter( send_activation() # We make `output` a dict - # TODO convert to dict other items returned by the model (MoE aux loss for example) - # But in next if statement be careful if we return other items in all of the pp processes - # This conversion to dicts is kind of useless as the model already returns a dict with loss key. Maybe the PP ranks return TensorPointer Objects? if not isinstance(output, dict): output = {"loss": output} # Store the loss for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} - # TODO ver este output que es y tambien ver en outputs como se guarda. Donde se have la media? En el training step lol - # Aqui deberiamos segregar por languagues porque es el unico punto en el que tenemos la languague!! O al menos "etiquetarla" o acumularla por language - # 1. Hacemos dict con key para cada idioma 2. cada key tiene una lista donde append los tensors 3. en valid step hacemos lo del stack y allreduces - # Finalmente: Aqui metemos solo el lang ids, en trainer.py acumularemos los resultados y tal. - outputs.extend(list(output["sample_loss"])) # TODO flatten?????? o extend?????? - lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) # TODO esto deberia se un extend???? + + outputs.extend( + list(output["sample_loss"]) + ) # NOTE(tj.solergibert) Yes, it might look useless to do list + extend but it's necessary to split the output["sample_loss"] tensor into multiple tensors + lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) return outputs, lang_ids diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index b1cc36ad..f720446a 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -311,7 +311,25 @@ def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataL # 1. We call this function EVERY TIME we run the validation loop # 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset # in the first iteration and subsequent validations will fail - # TODO(tj.solergibert) Delete previous DataLoaders from memory like we do with training DataLoaders + # `dataloaders` are either torch DataLoaders (the very first stage) OR functions that we call later that provide torch DataLoaders (subsequent stages) + # From this torch DataLoaders objects we then call `sanity_check_dataloader` that will return a iterator. + # In short, `sanity_check_dataloader` just places the input tensors in the GPU when necessary (TensorPointers stay in the CPU) + # + # TBH, the for loop below it's just for deleting the DataLoaders of previous stages, which is not so problematic. The important part is returning the + # DataLoader iterator every time we call this function from the current training stage, which is tracked during training + # + # Also, keep in mind that if val_check_interval = 5 & data.start_training_step = 10 we will already perform the evaluation with the SECOND data stage + # after just training for the current iteration, so it might not be a good idea to set evals during the stage in which we change of data stage + # + # NOTE(tj.solergibert) Further investigation should be done, but there is a extrange behaiviour when deleting the DataLoaders////lambda functs. As they + # are converted into Iterators with `sanity_check_dataloader` we can't access anymore the DataLoader object to del the dataset (After first stage, + # in this function we locally create the DataLoder from the lambda func --> Return Iterator) + # + # Also when the gc deletes the first stage dataloader, all the `DatatroveFileDataset._f` are already None AND the `del` thing are deleting a copy of the + # object, not the object itself + # + # FINAL NOTE(tj.solergibert) I will open a Issue in nanotron to check with them if they are aware of this useless deletitions + # # TODO(tj.solergibert) Check the tuple case below from collections.abc import Generator @@ -339,11 +357,11 @@ def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataL self.config.data_stages ), "Number of dataloaders should match the number of dataset stages" - def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): import gc log_rank( - f"[Validation Stage: {stage_name}] Clearing the previous validation stage's dataloader and dataset from memory", + f"[Validation Stage: {stage_name}] Clearing the previous validation stage's ({prev_stage_name}) dataloader and dataset from memory", logger=logger, level=logging.INFO, ) @@ -355,57 +373,38 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): gc.collect() - dataloader = None - for stage_idx, stage in enumerate(self.config.data_stages): if stage_idx < self.metadata.last_stage_idx: continue + # NOTE(tj.solergibert) From this point stage_idx = self.metadata.last_stage_idx. We update self.metadata.last_stage_idx (which keeps track of the training stage) + # in each and every training step. if ( stage_idx is not self.metadata.last_validation_stage_idx - and self.metadata.last_validation_stage_idx is not None - ): + ): # When stage_idx (= self.metadata.last_stage_idx, the training stage index) is different than the last validation stage index self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index - # Si cambiamos de stage borramo el antiguo - # En ambos casos recrear el que toca !!! - # TODO Aqui nos quedamos!!! Tenemos que borrar el anterior dataloader cuando sea necesario y hacer el sanity del current dataloader SIEMPRE - stage = cast(DatasetStageArgs, stage) - print( - stage.name - ) # TODO como actualizamos el last stage index en el training aqui estamos mirando el dataloader de la siguiente iteracion que mal por dios!!!!! - - log_rank( - f"Ese print bueno {stage.name}", - logger=logger, - level=logging.INFO, - rank=0, - ) - # self.metadata.last_stage_idx = stage_idx - """ - if self.current_validation_dataloader is not None: # TODO Si hay algun dataloader ya lo eliminamos. Igualmente creamos de nuevo. Bueno el dataloader como tal ya esta creado, solo hay que devolver el sanity check raro + # Delete previous stage DataLoader prev_stage_name = self.config.data_stages[stage_idx - 1].name prev_dataloader = dataloaders[prev_stage_name] - if isinstance(prev_dataloader, DataLoader): - # NOTE: we don't need to clear dummy data generator from memory - clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) - """ - log_rank( - f"Preparing validation DataLoader from stage {stage.name}", - logger=logger, - level=logging.INFO, - rank=0, - ) + if isinstance(prev_dataloader, DataLoader): + # NOTE: we don't need to clear dummy data generator from memory + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) + + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + # NOTE(tj.solergibert) Create AGAIN the DataLoader dataloader = dataloaders[stage.name] # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it dataloader = dataloader() if callable(dataloader) else dataloader break - self.current_validation_dataloader_lenght = 200 # TODO len(dataloader) + self.current_validation_dataloader_lenght = len(dataloader) self.current_validation_dataloader = sanity_check_dataloader( dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + ) # NOTE(tj.solergibert) Create a Iterator from the DataLoader def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): from collections.abc import Generator @@ -431,11 +430,11 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da self.config.data_stages ), "Number of dataloaders should match the number of dataset stages" - def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): import gc log_rank( - f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", + f"[Training Stage: {stage_name}] Clearing the previous training stage's ({prev_stage_name}) dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -472,7 +471,9 @@ def find_stage_idx_to_resume(): if isinstance(prev_dataloader, DataLoader): # NOTE: we don't need to clear dummy data generator from memory - clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) self.metadata.last_stage_idx = stage_idx @@ -548,18 +549,14 @@ def train( # Cada validation es mucho mas largo que un training step # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas if self.iteration_step % self.config.tokens.val_check_interval == 0: - log_rank( - f"KOMO???? {self.iteration_step}", - logger=logger, - level=logging.INFO, - rank=0, - ) self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) val_global_loss, val_lang_losses = self.validation_step( dataloader=self.current_validation_dataloader ) self.validation_step_time = time.time() - self.validation_step_logs(val_global_loss, val_lang_losses) + self.validation_step_logs( + val_global_loss, val_lang_losses + ) # TODO(tj.solergibert) Check what happens when val_check_interval % iteration_step_info_interval != 0 # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -684,6 +681,14 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten lang_losses = { lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.lang_to_ids.keys() } + # WARNING(tj.solergibert) This mechanism will fail in the following [corner] case: + # If the lang_losses dict for a given lang IS EMPTY aka in the validation step in a Data Parallel Group + # we have 0 SAMPLES of a given lang, lang_losses[lang] will be a empty python list so the toch.stack call + # will fail with "stack expects a non-empty TensorList". I've tried setting this lang_losses[lang] to torch.empty + # but of course it doesn't works as we then do the average across the DP group. + # We will fix this issue in the future if we encounter this problem again. + # A bit of inspo https://blog.speechmatics.com/Sparse-All-Reduce-Part-1 + # Compute losses if isinstance(outputs[0], torch.Tensor): # Multilingual losses @@ -696,9 +701,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten # Sync losses across DP for lang in lang_losses.keys(): lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) - dist.all_reduce( - lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG - ) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos... + dist.all_reduce(lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) else: global_loss_avg = None @@ -833,7 +836,7 @@ def validation_step_logs( ), # , "1.6E"), LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), LogItem("validation_model_tflops_per_gpu", model_tflops / 3, "human_format"), # , ".2f"), - LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), + LogItem("validation_hardware_tflops_per_gpu", hardware_tflops / 3, "human_format"), # , ".2f"), ] # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. From 5cccf16711e296caa517fb6619ae0dbd5d7ede75 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 16:47:54 +0000 Subject: [PATCH 4/8] This looks better --- run_train.py | 13 ++-- src/nanotron/models/llama.py | 4 +- src/nanotron/trainer.py | 142 ++++++++++++++++------------------- 3 files changed, 76 insertions(+), 83 deletions(-) diff --git a/run_train.py b/run_train.py index 2ddff5ad..a51caf59 100644 --- a/run_train.py +++ b/run_train.py @@ -238,10 +238,10 @@ def get_valid_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): valid_dataset = MultilingualNanoset( - dataset_folders=data.dataset.validation_folder, # TODO Just 1 folder + dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - dataset_tokens=data.dataset.dataset_tokens, # TODO Just 1 lang + dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -331,9 +331,12 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: # the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda # funcs and directly create all dataloaders. # - # This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead of creating multiple DataLoaders - # 2. Consume less memory as the lambda func is lighter that the DataLoader object with the Dataset, collator, etc. - # BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling from the Nanoset + # This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead + # of creating multiple DataLoaders 2. Consume less memory as the lambda func is lighter that the DataLoader object with + # the Dataset, collator, etc. + # BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling + # from the Nanoset. Also they later transform the DataLoader into a Iterator object so it's impossible to retrieve + # the DataLoader object again to delete it (More comments in trainer.py) dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 133442af..8c4125b7 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -825,8 +825,10 @@ def forward( ).transpose(0, 1) # TODO @thomasw21: It's unclear what kind of normalization we want to do. sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float) - # NOTE @tj.solergibert: masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. + # NOTE(tj.solergibert) masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. # We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort. + # WARN(tj.solergibert) Don't panic, the batch loss used to update the parameters is computed in `LlamaForTraining` + # TODO @thomasw21: I think indexing causes a sync we don't actually want # TODO @thomasw21: loss = loss[label_mask].sum() return {"sample_loss": sample_loss} diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index f720446a..c327f508 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -457,7 +457,7 @@ def find_stage_idx_to_resume(): stage_idx_to_resume = find_stage_idx_to_resume() - for stage_idx, stage in enumerate(self.config.data_stages): # TODO check metadatalaststageindex init + for stage_idx, stage in enumerate(self.config.data_stages): if stage_idx < self.metadata.last_stage_idx: continue @@ -541,22 +541,17 @@ def train( outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) self.training_step_time = time.time() - # Validation step - # TODO A ver, en este loop solo se lleva a cabo una training iteration pero claro hay un porron de validation iteration... mmmmm - # Tal vez deberiamos mover esto a otro lugar? Es decir, aqui se have un training step pero hacemos varios validation steps - # Lo podemos dejar aqui solamente que las metricas de throughput y tokens consumidos se tendrian que revisar - # Porque actualmente utilizan la global batch size, que es correcta ya que es la que tiene cada training step pero claro, - # Cada validation es mucho mas largo que un training step - # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas + # Validation stage if self.iteration_step % self.config.tokens.val_check_interval == 0: self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) val_global_loss, val_lang_losses = self.validation_step( dataloader=self.current_validation_dataloader ) self.validation_step_time = time.time() - self.validation_step_logs( - val_global_loss, val_lang_losses - ) # TODO(tj.solergibert) Check what happens when val_check_interval % iteration_step_info_interval != 0 + else: + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + val_global_loss, val_lang_losses = None, None # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -567,7 +562,9 @@ def train( ].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs(outputs=outputs, loss_avg=loss_avg) + self.train_step_logs( + outputs=outputs, loss_avg=loss_avg, global_loss=val_global_loss, lang_losses=val_lang_losses + ) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -711,12 +708,14 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten def train_step_logs( self, - outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], loss_avg: Optional[torch.Tensor], + global_loss: torch.Tensor, + lang_losses: torch.Tensor, ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() torch.cuda.synchronize() + # Training metrics elapsed_time_per_iteration_ms = (self.training_step_time - self.iteration_start_time) * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) @@ -727,9 +726,24 @@ def train_step_logs( global_batch_size=self.global_batch_size, ) + # Validation metrics + if global_loss is not None: + validation_total_samples = self.current_validation_dataloader_lenght * self.micro_batch_size + validation_elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 + validation_tokens_per_sec = ( + validation_total_samples * self.sequence_length / (validation_elapsed_time_per_iteration_ms / 1000) + ) # tokens_per_sec is calculated using sequence_length + + validation_model_tflops, validation_hardware_tflops = self.unwrapped_model.get_flops_per_sec( + iteration_time_in_sec=validation_elapsed_time_per_iteration_ms / 1000, + sequence_length=self.sequence_length, + global_batch_size=validation_total_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches + ) + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" + # Training metrics lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ @@ -753,6 +767,44 @@ def train_step_logs( if self.config.optimizer.clip_grad is not None: log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f")) + # Validation metrics + if global_loss is not None: + log_entries = [ + LogItem( + "validation_consumed_tokens", + self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + "human_format", + ), # , "12d"), + LogItem( + "validation_elapsed_time_per_iteration_ms", + validation_elapsed_time_per_iteration_ms, + "human_format", + ), # , ".1f"), + LogItem("validation_tokens_per_sec", validation_tokens_per_sec, "human_format"), # , "1.6E"), + LogItem( + "validation_tokens_per_sec_per_gpu", + validation_tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", + ), # , "1.6E"), + LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), + LogItem( + "validation_model_tflops_per_gpu", validation_model_tflops / 3, "human_format" + ), # , ".2f"), + LogItem( + "validation_hardware_tflops_per_gpu", validation_hardware_tflops / 3, "human_format" + ), # , ".2f"), + ] + + # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. + # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 + # GitHub complains: https://github.com/wandb/wandb/issues/3035 + log_entries.extend( + [ + LogItem(f"{lang}_validation_loss", loss.item(), "human_format") + for lang, loss in lang_losses.items() + ] + ) + # Log not too often the memory if self.iteration_step < 5 or (self.iteration_step - 1) % self.config.checkpoints.checkpoint_interval == 0: total, used, free = shutil.disk_usage("/") @@ -796,70 +848,6 @@ def train_step_logs( else: exit(0) - def validation_step_logs( - self, - global_loss: torch.Tensor, - lang_losses: torch.Tensor, - ) -> None: - # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 - dist.barrier() - torch.cuda.synchronize() - total_validation_samples = self.current_validation_dataloader_lenght * self.micro_batch_size - elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 - tokens_per_sec = ( - total_validation_samples * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) - ) # tokens_per_sec is calculated using sequence_length - # TODO para el valid ojo con cambiar global_batch_size = len dataloader * mbs - model_tflops, hardware_tflops = self.unwrapped_model.get_flops_per_sec( - iteration_time_in_sec=elapsed_time_per_iteration_ms / 1000, - sequence_length=self.sequence_length, - global_batch_size=total_validation_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches - ) - - if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: - assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" - - log_entries = [ - LogItem( - "validation_consumed_tokens", - self.metadata.consumed_train_samples * self.config.tokens.sequence_length, - "human_format", - ), # , "12d"), - LogItem( - "validation_elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format" - ), # , ".1f"), - LogItem("validation_tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), - LogItem( - "validation_tokens_per_sec_per_gpu", - tokens_per_sec / self.parallel_context.world_pg.size(), - "human_format", - ), # , "1.6E"), - LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), - LogItem("validation_model_tflops_per_gpu", model_tflops / 3, "human_format"), # , ".2f"), - LogItem("validation_hardware_tflops_per_gpu", hardware_tflops / 3, "human_format"), # , ".2f"), - ] - - # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. - # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 - # GitHub complains: https://github.com/wandb/wandb/issues/3035 - log_entries.extend( - [LogItem(f"{lang}_validation_loss", loss.item(), "human_format") for lang, loss in lang_losses.items()] - ) - - # NOTE: only one rank writes to wandb - # NOTE(tj.solergibert) By default wandb.log performs a step in the x-axis every time. - # Set commit=False to log values with the next wandb.log with the training logs - if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: - wandb.log( - { - **{log_item.tag: log_item.scalar_value for log_item in log_entries}, - "iteration_step": self.iteration_step, - }, - commit=False, - ) - - self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step) - def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: """Initialize the model and load weights from checkpoint if needed.""" # TODO: add max_position_embeddings From d75038dad4ba8786344f06725b24b213059f9b97 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 19:21:22 +0000 Subject: [PATCH 5/8] last fixes --- src/nanotron/config/config.py | 7 +++++ src/nanotron/trainer.py | 56 +++++++++++++++++------------------ 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index e5ea3ec1..80229ca2 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -406,6 +406,13 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0: + raise ValueError( + f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}" + ) + # # if lighteval, we need tokenizer to be defined # if self.checkpoints.lighteval is not None: # assert self.tokenizer.tokenizer_name_or_path is not None diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index c327f508..c6ca734c 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -562,9 +562,7 @@ def train( ].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs( - outputs=outputs, loss_avg=loss_avg, global_loss=val_global_loss, lang_losses=val_lang_losses - ) + self.train_step_logs(loss_avg=loss_avg, global_loss=val_global_loss, lang_losses=val_lang_losses) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -769,31 +767,33 @@ def train_step_logs( # Validation metrics if global_loss is not None: - log_entries = [ - LogItem( - "validation_consumed_tokens", - self.metadata.consumed_train_samples * self.config.tokens.sequence_length, - "human_format", - ), # , "12d"), - LogItem( - "validation_elapsed_time_per_iteration_ms", - validation_elapsed_time_per_iteration_ms, - "human_format", - ), # , ".1f"), - LogItem("validation_tokens_per_sec", validation_tokens_per_sec, "human_format"), # , "1.6E"), - LogItem( - "validation_tokens_per_sec_per_gpu", - validation_tokens_per_sec / self.parallel_context.world_pg.size(), - "human_format", - ), # , "1.6E"), - LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), - LogItem( - "validation_model_tflops_per_gpu", validation_model_tflops / 3, "human_format" - ), # , ".2f"), - LogItem( - "validation_hardware_tflops_per_gpu", validation_hardware_tflops / 3, "human_format" - ), # , ".2f"), - ] + log_entries.extend( + [ + LogItem( + "validation_consumed_tokens", + validation_total_samples * self.sequence_length, + "human_format", + ), # , "12d"), + LogItem( + "validation_elapsed_time_per_iteration_ms", + validation_elapsed_time_per_iteration_ms, + "human_format", + ), # , ".1f"), + LogItem("validation_tokens_per_sec", validation_tokens_per_sec, "human_format"), # , "1.6E"), + LogItem( + "validation_tokens_per_sec_per_gpu", + validation_tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", + ), # , "1.6E"), + LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), + LogItem( + "validation_model_tflops_per_gpu", validation_model_tflops / 3, "human_format" + ), # , ".2f"), # NOTE(tj.solergibert) Check llama.py --> def get_flops() --> model_flops for explanation of the / 3 factor + LogItem( + "validation_hardware_tflops_per_gpu", validation_hardware_tflops / 3, "human_format" + ), # , ".2f"), # NOTE(tj.solergibert) Check llama.py --> def get_flops() --> model_flops for explanation of the / 3 factor + ] + ) # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 From ab1dd835ba34d2bc5651ff78ebc60bdf050164aa Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 19:26:42 +0000 Subject: [PATCH 6/8] Fixed tokenizer config --- examples/config_multilingual_nanoset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 5573a224..596e5e32 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -124,7 +124,7 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: /mloscratch/homes/solergib/models/Meta-Llama-3-8B + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 From 2d911544cb915c2cccc264971e3f4ec285ebd27e Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 19:29:51 +0000 Subject: [PATCH 7/8] deleted comments --- src/nanotron/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index c6ca734c..62fe6bcc 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -730,12 +730,12 @@ def train_step_logs( validation_elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 validation_tokens_per_sec = ( validation_total_samples * self.sequence_length / (validation_elapsed_time_per_iteration_ms / 1000) - ) # tokens_per_sec is calculated using sequence_length + ) validation_model_tflops, validation_hardware_tflops = self.unwrapped_model.get_flops_per_sec( iteration_time_in_sec=validation_elapsed_time_per_iteration_ms / 1000, sequence_length=self.sequence_length, - global_batch_size=validation_total_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches + global_batch_size=validation_total_samples, ) if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: From ce068fd5a9d1d29805dedd9e3493fafd883ab847 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 7 Aug 2024 19:44:23 +0000 Subject: [PATCH 8/8] Last fixes --- examples/config_multilingual_nanoset.yaml | 39 +++++----- run_train.py | 4 +- src/nanotron/config/config.py | 11 ++- src/nanotron/data/collator.py | 73 +++++++++++++++++++ src/nanotron/data/dataloader_builder.py | 11 ++- src/nanotron/data/multilingual_nanoset.py | 4 +- src/nanotron/distributed.py | 4 - src/nanotron/models/llama.py | 10 ++- .../parallel/pipeline_parallel/engine.py | 6 +- src/nanotron/trainer.py | 50 +++++++++---- 10 files changed, 156 insertions(+), 56 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 596e5e32..cc66cd70 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -8,17 +8,17 @@ data_stages: - data: dataset: training_folder: - datasets/c4-es/train: 0.85 - datasets/c4-en/train: 0.05 - datasets/c4-fr/train: 0.1 + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - lang_to_ids: - es: 128002 - en: 128003 - fr: 128004 + languages: + - es + - en + - fr num_loading_workers: 1 seed: 42 name: General purpose training (Blended dataset) @@ -29,12 +29,12 @@ data_stages: - datasets/c4-es/train validation_folder: - datasets/c4-es/validation - lang_to_ids: - es: 128002 + languages: + - es num_loading_workers: 1 seed: 42 name: Second purpose training (Single dataset) - start_training_step: 100 + start_training_step: 1000 - data: dataset: training_folder: @@ -45,20 +45,19 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - lang_to_ids: - es: 128002 - en: 128003 - fr: 128004 - + languages: + - es + - en + - fr num_loading_workers: 1 seed: 42 name: Third purpose training (>1 dataset) - start_training_step: 200 + start_training_step: 2000 general: benchmark_csv_path: null consumed_train_samples: null ignore_sanity_checks: true - project: Multilingual + project: MultilingualV2 run: llama seed: 42 step: null @@ -114,7 +113,7 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 1 + dp: 2 expert_parallel_size: 1 pp: 1 pp_engine: 1f1b @@ -132,5 +131,5 @@ tokens: limit_val_batches: 10 micro_batch_size: 3 sequence_length: 4096 - train_steps: 800 - val_check_interval: 50 + train_steps: 500 + val_check_interval: 100 diff --git a/run_train.py b/run_train.py index a51caf59..809d8d41 100644 --- a/run_train.py +++ b/run_train.py @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -209,6 +208,7 @@ def get_dataloader_from_data_stage( consumed_train_samples=consumed_train_samples, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + is_multilingual=True, ) return train_dataloader @@ -241,7 +241,6 @@ def get_valid_dataloader_from_data_stage( dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -257,6 +256,7 @@ def get_valid_dataloader_from_data_stage( dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, shuffle=True, + is_multilingual=True, ) return valid_dataloader diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 80229ca2..b3c755a5 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order + languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -125,14 +125,13 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - self.ids_to_lang = {v: k for k, v in self.lang_to_ids.items()} - self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.languages + ), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})" + assert len(self.training_folder) == len( self.validation_folder ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" - assert len(self.training_folder) == len( - self.dataset_tokens - ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" @dataclass diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py index 199527e1..fd217b1a 100644 --- a/src/nanotron/data/collator.py +++ b/src/nanotron/data/collator.py @@ -78,3 +78,76 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni ) return result + + +@dataclasses.dataclass +class MultilingualNanosetDataCollatorForCLM: + """ + Data collator used for causal language modeling with Nanosets dataset. + + - input_pp_rank: Discards last input id token + - output_pp_rank: Discards first label id token + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + + sequence_length: int + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + + def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "lang_code": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + } + + # TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor? + input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) + lang_code = torch.vstack([examples[i]["lang_code"] for i in range(len(examples))]) # (b, 1) + batch_size, expanded_input_length = input_ids.shape + + result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["lang_code"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + + assert ( + expanded_input_length == self.sequence_length + 1 + ), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}" + + # Process inputs: last token is the label + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = input_ids[:, :-1] + result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + result["lang_code"] = lang_code + + # Process labels: shift them to the left + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = input_ids[:, 1:] + result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + + if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + + return result diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index b8bfb303..f9480029 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,6 +1,6 @@ import nanotron.distributed as dist from nanotron import logging -from nanotron.data.collator import NanosetDataCollatorForCLM +from nanotron.data.collator import MultilingualNanosetDataCollatorForCLM, NanosetDataCollatorForCLM from nanotron.dataloader import ( EmptyInfiniteDataset, get_dataloader_worker_init, @@ -20,6 +20,7 @@ def build_nanoset_dataloader( output_pp_rank: int, micro_batch_size: int, dataloader_num_workers: int, + is_multilingual: bool = False, consumed_train_samples: int = 0, dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, @@ -40,6 +41,14 @@ def build_nanoset_dataloader( parallel_context=parallel_context, ) + if is_multilingual: + data_collator = MultilingualNanosetDataCollatorForCLM( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) + # Compute size and rank of dataloader workers dp_ranks_size = parallel_context.dp_pg.size() dp_rank = parallel_context.dp_pg.rank() diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 7af57448..8eec5549 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -30,7 +30,6 @@ def __init__( dataset_folders: List[str], sequence_length: int, token_size: int, - dataset_tokens: List[int], train_split_num_samples: int = None, is_valid: bool = False, dataset_weights: Union[List[float], None] = None, @@ -47,7 +46,6 @@ def __init__( self.sequence_length = sequence_length self.token_size = token_size self.train_split_num_samples = train_split_num_samples - self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed self.datatrove_datasets = [] @@ -107,7 +105,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset_sample = self.dataset_sample_index[idx] tokens = self.datatrove_datasets[dataset][dataset_sample] - tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token + tokens["lang_code"] = torch.tensor(dataset, dtype=torch.long) return tokens diff --git a/src/nanotron/distributed.py b/src/nanotron/distributed.py index 0156b1bb..0bc54f3e 100644 --- a/src/nanotron/distributed.py +++ b/src/nanotron/distributed.py @@ -52,10 +52,6 @@ def all_gather_into_tensor( # pylint: disable=function-redefined if group is None: group = dist.torch_dist.distributed_c10d._get_default_group() - assert ( - group.size() > 1 - ), "You should probably not call `all_gather_into_tensor` with a single rank, as it copies data over" - if torch_version_above_1_13: return dist.all_gather_into_tensor( output_tensor=output_tensor, input_tensor=input_tensor, group=group, async_op=async_op diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 8c4125b7..ec1b38c0 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -733,14 +733,20 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): - return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask, lang_code=lang_code)[0] def forward_with_hidden_states( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): + # NOTE(tj.solergibert) I bring `lang_code` till the forward of LlamaModel. Remember that + # to use it in the different pipeline blocks you need to also set the module_input_keys & module_output_keys + # of the necessary `PipelineBlock`'s defined in the LlamaModel init! + # all tensors are optional as most ranks don't need anything from the dataloader. output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) @@ -863,12 +869,14 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], input_mask: Union[torch.Tensor, TensorPointer], + lang_code: Union[torch.Tensor, TensorPointer], label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: sharded_logits = self.model( input_ids=input_ids, input_mask=input_mask, + lang_code=lang_code, ) outputs = self.loss( sharded_logits=sharded_logits, diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 549ef5eb..9b548e35 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -140,7 +140,7 @@ def validate_batch_iter( self.nb_microbatches = nb_microbatches outputs = [] - lang_ids = [] + lang_codes = [] with attach_pipeline_state_to_model(model=model, pipeline_state=state): # All forward @@ -166,9 +166,9 @@ def validate_batch_iter( outputs.extend( list(output["sample_loss"]) ) # NOTE(tj.solergibert) Yes, it might look useless to do list + extend but it's necessary to split the output["sample_loss"] tensor into multiple tensors - lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) + lang_codes.extend(micro_batch["lang_code"].flatten().tolist()) - return outputs, lang_ids + return outputs, lang_codes class AllForwardAllBackwardPipelineEngine(PipelineEngine): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 62fe6bcc..a17f9849 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -80,6 +80,7 @@ from nanotron.sanity_checks import ( after_optim_step_sanity_checks, after_tbi_sanity_checks, + assert_tensor_synced_across_pg, before_optim_step_sanity_checks, before_tbi_sanity_checks, ) @@ -667,37 +668,54 @@ def training_step( return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: - outputs, lang_ids = self.pipeline_engine.validate_batch_iter( + outputs, lang_codes = self.pipeline_engine.validate_batch_iter( model=self.model, batch=(next(dataloader) for _ in range(self.current_validation_dataloader_lenght)), nb_microbatches=self.current_validation_dataloader_lenght, ) lang_losses = { - lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.lang_to_ids.keys() + lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.languages } - # WARNING(tj.solergibert) This mechanism will fail in the following [corner] case: - # If the lang_losses dict for a given lang IS EMPTY aka in the validation step in a Data Parallel Group - # we have 0 SAMPLES of a given lang, lang_losses[lang] will be a empty python list so the toch.stack call - # will fail with "stack expects a non-empty TensorList". I've tried setting this lang_losses[lang] to torch.empty - # but of course it doesn't works as we then do the average across the DP group. - # We will fix this issue in the future if we encounter this problem again. - # A bit of inspo https://blog.speechmatics.com/Sparse-All-Reduce-Part-1 + lang_losses_list = list(lang_losses.keys()) # Compute losses if isinstance(outputs[0], torch.Tensor): # Multilingual losses - for loss, lang_id in zip(outputs, lang_ids): - lang_losses[ - self.config.data_stages[self.metadata.last_stage_idx].data.dataset.ids_to_lang[lang_id] - ].append(loss) + for loss, lang_code in zip(outputs, lang_codes): + lang_losses[lang_losses_list[lang_code]].append(loss) # Global loss global_loss_avg = torch.mean(torch.stack(outputs)) - # Sync losses across DP + # Sync multilingual losses across DP for lang in lang_losses.keys(): - lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) - dist.all_reduce(lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) + if not lang_losses[ + lang + ]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation + lang_losses[lang] = torch.tensor(-1, dtype=torch.float32) + else: # If we have at least 1 loss from a given language --> compute local language loss mean + lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) + + # NOTE(tj.solergibert) We create a (DP SIZE, LANGS) tensor to aggregate ALL local losses across DP groups. + # Then we compute the mean of each lang in each and every rank and finally copy back the result to the + # `lang_losses` dict for logging + lang_losses_tensor_out = torch.zeros( + (self.parallel_context.dp_pg.size(), len(lang_losses.keys())), dtype=torch.float, device="cuda" + ) # (DP SIZE, LANGS) + lang_losses_tensor_local = torch.stack(list(lang_losses.values())).unsqueeze(0) # (1, LANGS) + dist.all_gather_into_tensor(lang_losses_tensor_out, lang_losses_tensor_local, self.parallel_context.dp_pg) + mask = lang_losses_tensor_out != -1 + lang_losses_tensor_local = (lang_losses_tensor_out * mask).sum(dim=0) / mask.sum(dim=0) # (1, LANGS) + for idx, lang in enumerate(lang_losses.keys()): + lang_losses[lang] = lang_losses_tensor_local[idx] + + # Sync global losses across DP dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) + + # TODO(tj.solergibert) Delete this testing assertions + for lang in lang_losses.keys(): + assert_tensor_synced_across_pg(tensor=lang_losses[lang], pg=self.parallel_context.dp_pg) + assert_tensor_synced_across_pg(tensor=global_loss_avg, pg=self.parallel_context.dp_pg) + else: global_loss_avg = None lang_losses = None