diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 599bff6c..48ae960c 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -1,62 +1,63 @@ checkpoints: - checkpoint_interval: 1000 + checkpoint_interval: 1000000 checkpoints_path: checkpoints/ checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null save_initial_state: false data_stages: -- data: - dataset: - training_folder: datasets/c4-es/train - validation_folder: datasets/c4-es/validation - lang_to_ids: - es: 128002 - num_loading_workers: 1 - seed: 42 - name: General purpose training (Single dataset) - start_training_step: 1 -- data: - dataset: - training_folder: - - datasets/c4-es/train - - datasets/c4-en/train - - datasets/c4-fr/train - validation_folder: - - datasets/c4-es/validation - - datasets/c4-en/validation - - datasets/c4-fr/validation - lang_to_ids: - es: 128002 - en: 128003 - fr: 128004 - num_loading_workers: 1 - seed: 42 - name: Second purpose training (> 1 dataset) - start_training_step: 15 -- data: - dataset: - training_folder: - datasets/c4-es/train: 0.6 - datasets/c4-en/train: 0.3 - datasets/c4-fr/train: 0.1 - validation_folder: - - datasets/c4-es/validation - - datasets/c4-en/validation - - datasets/c4-fr/validation - lang_to_ids: - es: 128002 - en: 128003 - fr: 128004 - - num_loading_workers: 1 - seed: 42 - name: Third purpose training (Blended dataset) - start_training_step: 25 + - data: + dataset: + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation + languages: + - es + - en + - fr + num_loading_workers: 1 + seed: 42 + name: General purpose training (Blended dataset) + start_training_step: 1 + - data: + dataset: + training_folder: + - datasets/c4-es/train + validation_folder: + - datasets/c4-es/validation + languages: + - es + num_loading_workers: 1 + seed: 42 + name: Second purpose training (Single dataset) + start_training_step: 1000 + - data: + dataset: + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation + languages: + - es + - en + - fr + num_loading_workers: 1 + seed: 42 + name: Third purpose training (>1 dataset) + start_training_step: 2000 general: benchmark_csv_path: null consumed_train_samples: null ignore_sanity_checks: true - project: Nanoset + project: MultilingualV2 run: llama seed: 42 step: null @@ -75,12 +76,12 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 512 + hidden_size: 4096 initializer_range: 0.02 - intermediate_size: 512 + intermediate_size: 14336 is_llama_config: true - max_position_embeddings: 1024 - num_hidden_layers: 2 + max_position_embeddings: 4096 + num_hidden_layers: 32 num_attention_heads: 32 num_key_value_heads: 8 pad_token_id: null @@ -89,7 +90,7 @@ model: rope_theta: 500000.0 rms_norm_eps: 1.0e-06 rope_scaling: null - tie_word_embeddings: true + tie_word_embeddings: false use_cache: true vocab_size: 128256 optimizer: @@ -112,11 +113,11 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 1 + dp: 2 expert_parallel_size: 1 pp: 1 pp_engine: 1f1b - tp: 1 + tp: 4 tp_linear_async_communication: false tp_mode: REDUCE_SCATTER profiler: null @@ -128,7 +129,7 @@ tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 10 - micro_batch_size: 4 - sequence_length: 1024 - train_steps: 200 - val_check_interval: -1 + micro_batch_size: 3 + sequence_length: 4096 + train_steps: 500 + val_check_interval: 100 diff --git a/examples/doremi/README.md b/examples/doremi/README.md index 5a726bd1..dfc9ea40 100644 --- a/examples/doremi/README.md +++ b/examples/doremi/README.md @@ -87,3 +87,7 @@ For evaluation, we do uniform sampling on the test set to evaluate a 2.5B model - 2.5B llama trained using the optimized weights: https://huggingface.co/nanotron/doremi-llama-2.5b-optimized-weights and the dataset: https://huggingface.co/datasets/nanotron/the-pile-for-doremi + +#### Thoughts + +For DoReMi, it's useful if you don't initially have an idea of what would be a good distribution for your training data, or want a quick way to find a better baseline than the uniform distribution if you want to tune the data distribution by hand. In my previous experiments, DoReMi matched the pretraining performance of the distribution of mamba training but couldn't outperform it. I suspect it doesn't work well when there are nuances, meaning the difference between your known best distribution and a better distribution isn't significant. diff --git a/examples/mamba/README.md b/examples/mamba/README.md index 5c31d07f..8eefa9c2 100644 --- a/examples/mamba/README.md +++ b/examples/mamba/README.md @@ -18,6 +18,18 @@ pip install -r requirements.txt > https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5 +## Bug related to nanotron +Encountered the following issue when ran train_mamba.sh: +``` +causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv +``` +Solved this by doing: +pip uninstall mamba-ssm +pip install causal_conv1d==1.1.1 +pip install mamba-ssm --no-cache-dir +https://github.com/state-spaces/mamba/issues/169 + + ## Credits Credits to the following repositories from which the code was adapted: - https://github.com/state-spaces/mamba diff --git a/examples/mup/README.md b/examples/mup/README.md index c86850ca..ed94c1fb 100644 --- a/examples/mup/README.md +++ b/examples/mup/README.md @@ -32,3 +32,8 @@ We trained a 350m model with spectral µTransfer and standard parametrization us Please check the directory [[./examples/mup/configs]](/examples/mup/configs) for the configurations we used to reproduce the experiments. ![LLaMA](./assets/llama.png) + + +#### Thoughts + +For Spectral MuP, the experiments we used it on MLP only [link] and 300m LLaMA [link] (there are links to the experiment config in the mup readme). However, when we tested it on 1B/8B models iirc, the loss blew up for some reasons. So, we'd recommend they try μTransfer, not spectral μTransfer. diff --git a/run_train.py b/run_train.py index 39cda23b..809d8d41 100644 --- a/run_train.py +++ b/run_train.py @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -209,6 +208,7 @@ def get_dataloader_from_data_stage( consumed_train_samples=consumed_train_samples, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + is_multilingual=True, ) return train_dataloader @@ -241,7 +241,6 @@ def get_valid_dataloader_from_data_stage( dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -256,6 +255,8 @@ def get_valid_dataloader_from_data_stage( micro_batch_size=trainer.micro_batch_size, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + shuffle=True, + is_multilingual=True, ) return valid_dataloader @@ -315,7 +316,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: stage = cast(DatasetStageArgs, stage) log_rank( - f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples for the validation set", logger=logger, level=logging.INFO, rank=0, @@ -324,8 +325,18 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloader = ( get_valid_dataloader_from_data_stage(trainer, stage.data) if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) + else lambda stage=stage: get_valid_dataloader_from_data_stage(trainer, stage.data) ) + # TODO(tj.solergibert) As we are creating again the valid dataloader in every validation stage, we print multiple times + # the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda + # funcs and directly create all dataloaders. + # + # This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead + # of creating multiple DataLoaders 2. Consume less memory as the lambda func is lighter that the DataLoader object with + # the Dataset, collator, etc. + # BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling + # from the Nanoset. Also they later transform the DataLoader into a Iterator object so it's impossible to retrieve + # the DataLoader object again to delete it (More comments in trainer.py) dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index dd2c157d..822be7a4 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -11,7 +11,12 @@ from yaml.loader import SafeLoader from nanotron.config.lighteval_config import LightEvalConfig -from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit +from nanotron.config.models_config import ( + ExistingCheckpointInit, + NanotronConfigs, + RandomInit, + SpectralMupInit, +) from nanotron.config.parallelism_config import ParallelismArgs from nanotron.config.utils_config import ( RecomputeGranularity, @@ -111,7 +116,7 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order + languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -125,20 +130,25 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.languages + ), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})" + assert len(self.training_folder) == len( self.validation_folder ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" - assert len(self.training_folder) == len( - self.dataset_tokens - ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" @dataclass class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, MultilingualNanosetDatasetsArgs] + dataset: Union[ + PretrainDatasetsArgs, + NanosetDatasetsArgs, + MultilingualNanosetDatasetsArgs, + MultilingualNanosetDatasetsArgs, + ] seed: Optional[int] num_loading_workers: Optional[int] = 1 @@ -405,6 +415,13 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0: + raise ValueError( + f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}" + ) + # # if lighteval, we need tokenizer to be defined # if self.checkpoints.lighteval is not None: # assert self.tokenizer.tokenizer_name_or_path is not None @@ -427,7 +444,10 @@ def as_dict(self) -> dict: def get_config_from_dict( - config_dict: dict, config_class: Type = Config, skip_unused_config_keys: bool = False, skip_null_keys: bool = False + config_dict: dict, + config_class: Type = Config, + skip_unused_config_keys: bool = False, + skip_null_keys: bool = False, ): """Get a config object from a dictionary @@ -445,7 +465,7 @@ def get_config_from_dict( if skip_null_keys: logger.warning("Skip_null_keys set") config_dict = { - k: {kk: vv for kk, vv in v.items() if vv is not None} if isinstance(v, dict) else v + k: ({kk: vv for kk, vv in v.items() if vv is not None} if isinstance(v, dict) else v) for k, v in config_dict.items() if v is not None } diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 257a2f72..72158651 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -170,7 +170,10 @@ def as_starcoder2(self) -> Starcoder2Config: if "_is_using_mup" in config: del config["_is_using_mup"] return Starcoder2Config( - grouped_query=True, num_kv_heads=self.num_attention_heads, use_rotary_embeddings=False, **config + grouped_query=True, + num_kv_heads=self.num_attention_heads, + use_rotary_embeddings=False, + **config, ) @property @@ -244,7 +247,10 @@ def as_starcoder2(self) -> Starcoder2Config: if "_is_using_mup" in config: del config["_is_using_mup"] return Starcoder2Config( - grouped_query=True, num_kv_heads=self.num_attention_heads, use_rotary_embeddings=False, **config + grouped_query=True, + num_kv_heads=self.num_attention_heads, + use_rotary_embeddings=False, + **config, ) @property diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 5912425b..321ee045 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -23,6 +23,7 @@ class ParallelismArgs: pp_engine: Pipeline engine to use between "1f1b" and "afab" tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism tp_linear_async_communication: Whether to use async communication in TP linear layers + recompute_layer: Whether to recompute each Transformer layer to save memory. """ dp: int @@ -31,6 +32,7 @@ class ParallelismArgs: pp_engine: Optional[PipelineEngine] = None tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None + recompute_layer: bool = False expert_parallel_size: int = 1 diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py index 199527e1..fd217b1a 100644 --- a/src/nanotron/data/collator.py +++ b/src/nanotron/data/collator.py @@ -78,3 +78,76 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni ) return result + + +@dataclasses.dataclass +class MultilingualNanosetDataCollatorForCLM: + """ + Data collator used for causal language modeling with Nanosets dataset. + + - input_pp_rank: Discards last input id token + - output_pp_rank: Discards first label id token + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + + sequence_length: int + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + + def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "lang_code": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + } + + # TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor? + input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) + lang_code = torch.vstack([examples[i]["lang_code"] for i in range(len(examples))]) # (b, 1) + batch_size, expanded_input_length = input_ids.shape + + result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["lang_code"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + + assert ( + expanded_input_length == self.sequence_length + 1 + ), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}" + + # Process inputs: last token is the label + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = input_ids[:, :-1] + result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + result["lang_code"] = lang_code + + # Process labels: shift them to the left + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = input_ids[:, 1:] + result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + + if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + + return result diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 9d3285f6..f9480029 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,6 +1,6 @@ import nanotron.distributed as dist from nanotron import logging -from nanotron.data.collator import NanosetDataCollatorForCLM +from nanotron.data.collator import MultilingualNanosetDataCollatorForCLM, NanosetDataCollatorForCLM from nanotron.dataloader import ( EmptyInfiniteDataset, get_dataloader_worker_init, @@ -20,9 +20,11 @@ def build_nanoset_dataloader( output_pp_rank: int, micro_batch_size: int, dataloader_num_workers: int, + is_multilingual: bool = False, consumed_train_samples: int = 0, dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, + shuffle: bool = False, ) -> DataLoader: # Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job @@ -39,6 +41,14 @@ def build_nanoset_dataloader( parallel_context=parallel_context, ) + if is_multilingual: + data_collator = MultilingualNanosetDataCollatorForCLM( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) + # Compute size and rank of dataloader workers dp_ranks_size = parallel_context.dp_pg.size() dp_rank = parallel_context.dp_pg.rank() @@ -49,7 +59,7 @@ def build_nanoset_dataloader( dl_rank=dp_rank, drop_last=dataloader_drop_last, consumed_train_samples=consumed_train_samples, - shuffle=False, + shuffle=shuffle, ) return DataLoader( diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 7af57448..84ef1946 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -30,7 +30,6 @@ def __init__( dataset_folders: List[str], sequence_length: int, token_size: int, - dataset_tokens: List[int], train_split_num_samples: int = None, is_valid: bool = False, dataset_weights: Union[List[float], None] = None, @@ -47,7 +46,6 @@ def __init__( self.sequence_length = sequence_length self.token_size = token_size self.train_split_num_samples = train_split_num_samples - self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed self.datatrove_datasets = [] @@ -107,7 +105,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset_sample = self.dataset_sample_index[idx] tokens = self.datatrove_datasets[dataset][dataset_sample] - tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token + tokens["lang_code"] = torch.tensor(dataset, dtype=torch.long) return tokens @@ -120,7 +118,9 @@ def build_train_nanoset_index(self) -> np.ndarray: num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1 # Build the dataset indexes for 1 epoch dataset_index, dataset_sample_index = build_train_nanoset_index_helper( - n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths + n_samples=samples_per_epoch, + weights=self.dataset_weights, + dataset_sizes=self.dataset_lengths, ) # Shuffle the indexes the same way numpy_random_state = np.random.RandomState(self.random_seed) diff --git a/src/nanotron/distributed.py b/src/nanotron/distributed.py index 0156b1bb..0bc54f3e 100644 --- a/src/nanotron/distributed.py +++ b/src/nanotron/distributed.py @@ -52,10 +52,6 @@ def all_gather_into_tensor( # pylint: disable=function-redefined if group is None: group = dist.torch_dist.distributed_c10d._get_default_group() - assert ( - group.size() > 1 - ), "You should probably not call `all_gather_into_tensor` with a single rank, as it copies data over" - if torch_version_above_1_13: return dist.all_gather_into_tensor( output_tensor=output_tensor, input_tensor=input_tensor, group=group, async_op=async_op diff --git a/src/nanotron/models/gpt3_moe.py b/src/nanotron/models/gpt3_moe.py index 0e2add58..1915136c 100644 --- a/src/nanotron/models/gpt3_moe.py +++ b/src/nanotron/models/gpt3_moe.py @@ -218,6 +218,7 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], input_mask: Union[torch.Tensor, TensorPointer], + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] TODO label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: @@ -234,21 +235,22 @@ def forward( else TensorPointer(self.input_pp_rank) ), } - output = self.model( + model_output = self.model( input_ids=input_ids, input_mask=input_mask, aux_losses=aux_losses, ) - loss = self.loss( - sharded_logits=output["sharded_logits"], + outputs = self.loss( + sharded_logits=model_output["sharded_logits"], label_ids=label_ids, label_mask=label_mask, ) - if isinstance(output["aux_losses"], dict): - for key, value in output["aux_losses"].items(): - loss[key] = value - return loss + outputs["loss"] = torch.mean(outputs["sample_loss"]) + if isinstance(model_output["aux_losses"], dict): + for key, value in model_output["aux_losses"].items(): + outputs[key] = value + return outputs def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 2411e5fa..382951fe 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,10 +14,11 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union +from typing import Dict, List, Optional, Union import torch from torch import nn +from torch.utils.checkpoint import CheckpointFunction from nanotron import distributed as dist from nanotron import logging @@ -592,11 +593,13 @@ def __init__( self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) - def forward( + self.recompute_layer = parallel_config.recompute_layer + + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], - ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + ) -> List[Union[torch.Tensor, TensorPointer]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -609,9 +612,29 @@ def forward( hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] hidden_states = hidden_states + residual + return hidden_states, output["sequence_mask"] + + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + sequence_mask: torch.Tensor, + ) -> List[torch.Tensor]: + return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): + hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) + else: + hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask) + return { "hidden_states": hidden_states, - "sequence_mask": output["sequence_mask"], + "sequence_mask": sequence_mask, } @@ -733,14 +756,20 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): - return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask, lang_code=lang_code)[0] def forward_with_hidden_states( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): + # NOTE(tj.solergibert) I bring `lang_code` till the forward of LlamaModel. Remember that + # to use it in the different pipeline blocks you need to also set the module_input_keys & module_output_keys + # of the necessary `PipelineBlock`'s defined in the LlamaModel init! + # all tensors are optional as most ranks don't need anything from the dataloader. output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) @@ -801,7 +830,9 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch @torch.jit.script def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor - return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( + dim=1 + ) # NOTE(tj.solergibert) Added dim=1 to return a tensor with shape [Batch size, 1] instead of [1] class Loss(nn.Module): @@ -818,14 +849,18 @@ def forward( # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 - loss = sharded_cross_entropy( + sample_loss = sharded_cross_entropy( sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float ).transpose(0, 1) # TODO @thomasw21: It's unclear what kind of normalization we want to do. - loss = masked_mean(loss, label_mask, dtype=torch.float) - # I think indexing causes a sync we don't actually want - # loss = loss[label_mask].sum() - return {"loss": loss} + sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float) + # NOTE(tj.solergibert) masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. + # We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort. + # WARN(tj.solergibert) Don't panic, the batch loss used to update the parameters is computed in `LlamaForTraining` + + # TODO @thomasw21: I think indexing causes a sync we don't actually want + # TODO @thomasw21: loss = loss[label_mask].sum() + return {"sample_loss": sample_loss} class LlamaForTraining(NanotronModel): @@ -847,7 +882,7 @@ def __init__( "label_ids", "label_mask", }, - module_output_keys={"loss"}, + module_output_keys={"sample_loss"}, ) self.parallel_context = parallel_context self.config = config @@ -857,19 +892,22 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], input_mask: Union[torch.Tensor, TensorPointer], + lang_code: Union[torch.Tensor, TensorPointer], label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: sharded_logits = self.model( input_ids=input_ids, input_mask=input_mask, + lang_code=lang_code, ) - loss = self.loss( + outputs = self.loss( sharded_logits=sharded_logits, label_ids=label_ids, label_mask=label_mask, - )["loss"] - return {"loss": loss} + ) + outputs["loss"] = torch.mean(outputs["sample_loss"]) + return outputs @torch.no_grad() def init_model_randomly(self, config: Config): diff --git a/src/nanotron/models/starcoder2.py b/src/nanotron/models/starcoder2.py index 1f2eab7d..6636ffb5 100644 --- a/src/nanotron/models/starcoder2.py +++ b/src/nanotron/models/starcoder2.py @@ -32,6 +32,7 @@ from nanotron.config import ParallelismArgs, Starcoder2Config from nanotron.generation.generate_store import AttachableStore from nanotron.models import NanotronModel +from nanotron.models.moe import ParallelDroplessMLP from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.parallel import ParallelContext @@ -56,7 +57,6 @@ from nanotron.parallel.tied_parameters import tie_parameters from nanotron.random import RandomStates, branch_random_state from nanotron.utils import checkpoint_method -from nanotron.models.moe import ParallelDroplessMLP, SparseMLP def pad_to_right(tensor, mask, new_tensor=None): @@ -1376,7 +1376,9 @@ def forward( @torch.jit.script def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor - return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( + dim=1 + ) # NOTE(tj.solergibert) Added dim=1 to return a tensor with shape [Batch size, 1] instead of [1] class Loss(nn.Module): @@ -1401,7 +1403,7 @@ def forward( loss = masked_mean(loss, label_mask, dtype=torch.float) # I think indexing causes a sync we don't actually want # loss = loss[label_mask].sum() - return {"loss": loss} + return {"sample_loss": loss} class Starcoder2ForTraining(NanotronModel): @@ -1428,7 +1430,7 @@ def __init__( "label_ids", "label_mask", }, - module_output_keys={"loss"}, + module_output_keys={"sample_loss"}, ) self.config: Starcoder2Config = config self.parallel_config = parallel_config @@ -1438,20 +1440,21 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], input_mask: Union[torch.Tensor, TensorPointer], + lang_code: Union[torch.Tensor, TensorPointer], # TODO label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], - ) -> Union[torch.Tensor, TensorPointer]: + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: sharded_logits = self.model( input_ids=input_ids, input_mask=input_mask, ) - return { - "loss": self.loss( - sharded_logits=sharded_logits, - label_ids=label_ids, - label_mask=label_mask, - )["loss"] - } + outputs = self.loss( + sharded_logits=sharded_logits, + label_ids=label_ids, + label_mask=label_mask, + ) + outputs["loss"] = torch.mean(outputs["sample_loss"]) + return outputs def tie_custom_params(self) -> None: # find all params with names qkv.kv.weight and qkv.kv.bias in them @@ -1526,7 +1529,7 @@ def init_model_randomly(self, config): else: raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, ParallelDroplessMLP): - if hasattr(module, 'bias'): + if hasattr(module, "bias"): module.bias.zero_() elif isinstance(module, TensorParallelRowLinear): if "weight" == param_name: diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 91840a5e..b4f5a3c4 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,18 +2,24 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup from nanotron.logging import log_rank from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd -from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model -from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState +from nanotron.parallel.pipeline_parallel.context_manager import ( + attach_pipeline_state_to_model, +) +from nanotron.parallel.pipeline_parallel.state import ( + PipelineEvalBatchState, + PipelineTrainBatchState, +) from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) @@ -29,6 +35,7 @@ def forward( state: PipelineTrainBatchState, micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]], model: torch_nn.Module, + is_validation: bool = False, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Increment the number of backwards state.nb_forwards += 1 @@ -48,16 +55,19 @@ def forward( output = {"loss": output} for k, v in output.items(): - if not isinstance(v, TensorPointer): + if not isinstance(v, TensorPointer) and k != "sample_loss": output[k] = v / self.nb_microbatches # the outputs are either # - token prediction loss ["loss"] + # - loss per sample (for validation), ["sample_loss"] -- does not require backpropagation # - auxiliary losses ["load_balancing_loss", "z_loss"] # that we need to backpropagate through, so register activations for loss_key, output_tensor in output.items(): - if not isinstance(output_tensor, TensorPointer): - assert output_tensor.requires_grad + if loss_key == "sample_loss": + continue + if not isinstance(output_tensor, TensorPointer) and not is_validation: + assert output_tensor.requires_grad, loss_key state.register_activation_requiring_backward(output_tensor) return output @@ -69,7 +79,10 @@ def _get_fwd_context(model: torch_nn.Module): return context def backward( - self, context: ContextManagers, state: PipelineTrainBatchState, grad_accumulator: Optional[GradientAccumulator] + self, + context: ContextManagers, + state: PipelineTrainBatchState, + grad_accumulator: Optional[GradientAccumulator], ): # Increment the number of backwards state.nb_backwards += 1 @@ -138,16 +151,23 @@ def validate_batch_iter( nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: # Assign a new state for the current batch - state = PipelineTrainBatchState() # TODO: do i need state? + state = PipelineEvalBatchState() self.nb_microbatches = nb_microbatches outputs = [] + lang_codes = [] with attach_pipeline_state_to_model(model=model, pipeline_state=state): # All forward for micro_batch in batch: context = self._get_fwd_context(model=model) - output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model) + output = self.forward( + context=context, + state=state, + micro_batch=micro_batch, + model=model, + is_validation=True, + ) # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage" for _ in range(len(state.microbatches_activations_to_send)): send_activation = state.microbatches_activations_to_send.popleft() @@ -161,9 +181,13 @@ def validate_batch_iter( # Store the loss(es) for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} - outputs.append(output) - return outputs + outputs.extend( + list(output["sample_loss"]) + ) # NOTE(tj.solergibert) Yes, it might look useless to do list + extend but it's necessary to split the output["sample_loss"] tensor into multiple tensors + lang_codes.extend(micro_batch["lang_code"].flatten().tolist()) + + return outputs, lang_codes class AllForwardAllBackwardPipelineEngine(PipelineEngine): diff --git a/src/nanotron/parallel/pipeline_parallel/state.py b/src/nanotron/parallel/pipeline_parallel/state.py index e07cc89a..f22d6571 100644 --- a/src/nanotron/parallel/pipeline_parallel/state.py +++ b/src/nanotron/parallel/pipeline_parallel/state.py @@ -4,6 +4,7 @@ from typing import List import torch + from nanotron import distributed as dist from nanotron import logging from nanotron.logging import log_rank @@ -203,6 +204,9 @@ class PipelineEvalBatchState(PipelineBatchState): microbatches_activations_to_recv = collections.deque() activations_buffer = collections.deque() + # Reinitialise counter + nb_forwards = 0 + def register_activation_requiring_backward(self, activation: torch.Tensor): pass diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index fdef48ac..ecccd33a 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -89,10 +89,10 @@ def forward( @staticmethod def backward(ctx, grad_output): - # Retreive tensors from the forward path. + # Retrieve tensors from the forward path. softmax, target_mask, masked_target_1d = ctx.saved_tensors - # All the inputs have softmax as thier gradient. + # All the inputs have softmax as their gradient. grad_input = softmax # For simplicity, work with the 2D gradient. sharded_hidden_size = softmax.size()[-1] @@ -387,8 +387,7 @@ def backward(ctx, grad_output): group = ctx.group use_bias = ctx.use_bias - handle_0: Optional[dist.Work] = None - handle_1: Optional[dist.Work] = None + handle: Optional[dist.Work] = None # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = grad_output.shape @@ -412,31 +411,69 @@ def backward(ctx, grad_output): # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 grad_output = grad_output.contiguous() - handle_0 = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - - grad_tensor = grad_output.matmul(weight) + handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - # wait for the first all_gather to finish before starting the second all_gather - if handle_0 is not None: - handle_0.wait() - - # TODO @thomasw21: gather along another dimension - sharded_batch_size, *rest_size = grad_tensor.shape + # total_grad_output: [b, s, h_out] + # weight: [h_out, h_in/n] + # total_grad_tensor: [b, s, h_in/n] + # grad_output: [b/n, s, h_out] + sharded_batch_size, *rest_size_grad_output = grad_output.shape + rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]] if group.size() == 1: - total_grad_tensor = grad_tensor + total_grad_tensor = grad_output.matmul(weight) else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_tensor = torch.empty( unsharded_batch_size, - *rest_size, - device=grad_tensor.device, - dtype=grad_tensor.dtype, + *rest_size_grad_tensor, + device=grad_output.device, + dtype=grad_output.dtype, requires_grad=False, ) + before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split( + total_grad_tensor, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + # compute local shard + torch.mm( + input=grad_output.view(-1, grad_output.shape[-1]), + mat2=weight, + out=same_device_shard_grad_tensor.view(-1, weight.shape[1]), + ) - handle_1 = dist.all_gather_into_tensor(total_grad_tensor, grad_tensor, group=group, async_op=True) + if handle is not None: + handle.wait() + + before_shard_grad_output, _, after_shard_grad_output = torch.split( + total_grad_output, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + + # before shard compute + if before_shard_grad_tensor.numel() > 0: + torch.mm( + input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]), + mat2=weight, + out=before_shard_grad_tensor.view(-1, weight.shape[1]), + ) + # after shard compute + if after_shard_grad_tensor.numel() > 0: + torch.mm( + input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]), + mat2=weight, + out=after_shard_grad_tensor.view(-1, weight.shape[1]), + ) # Convert the tensor shapes to 2D for execution compatibility tensor = tensor.contiguous() @@ -454,9 +491,6 @@ def backward(ctx, grad_output): grad_weight = total_grad_output.t().matmul(tensor) grad_bias = total_grad_output.sum(dim=0) if use_bias else None - if handle_1 is not None: - handle_1.wait() - return total_grad_tensor, grad_weight, grad_bias, None, None diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 0d8708f9..4bd36c19 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -46,6 +46,8 @@ class TrainingMetadata: last_stage_idx: Optional[int] = None data_stages: Optional[List[DataStageMetadata]] = None + last_validation_stage_idx: Optional[int] = None + def __post_init__(self): # NOTE: this is a sanity check after loading a trained checkpoint total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 0b3306c5..91b9a29b 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -81,6 +81,7 @@ from nanotron.sanity_checks import ( after_optim_step_sanity_checks, after_tbi_sanity_checks, + assert_tensor_synced_across_pg, before_optim_step_sanity_checks, before_tbi_sanity_checks, ) @@ -233,7 +234,11 @@ def __init__( for stage in self.config.data_stages ] self.metadata: TrainingMetadata = TrainingMetadata( - consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages + consumed_train_samples=0, + last_train_step=0, + last_stage_idx=0, + data_stages=data_stages, + last_validation_stage_idx=0, ) # Setup tensorboard write and log writers on output rank @@ -255,6 +260,8 @@ def __init__( self.limit_val_batches = self.config.tokens.limit_val_batches # NOTE: the dataloader currently in use for the current training stage self.current_dataloader: Optional[DataLoader] = None + # NOTE: the dataloader currently in use for the current validation stage + self.current_validation_dataloader: Optional[DataLoader] = None self.post_init() @@ -302,6 +309,106 @@ def _print_training_plan(self): ) log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) + def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataLoader], DataLoader]): + # NOTE(tj.solergibert) Similar to _update_dataloader_based_on_training_stages BUT: + # 1. We call this function EVERY TIME we run the validation loop + # 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset + # in the first iteration and subsequent validations will fail + # `dataloaders` are either torch DataLoaders (the very first stage) OR functions that we call later that provide torch DataLoaders (subsequent stages) + # From this torch DataLoaders objects we then call `sanity_check_dataloader` that will return a iterator. + # In short, `sanity_check_dataloader` just places the input tensors in the GPU when necessary (TensorPointers stay in the CPU) + # + # TBH, the for loop below it's just for deleting the DataLoaders of previous stages, which is not so problematic. The important part is returning the + # DataLoader iterator every time we call this function from the current training stage, which is tracked during training + # + # Also, keep in mind that if val_check_interval = 5 & data.start_training_step = 10 we will already perform the evaluation with the SECOND data stage + # after just training for the current iteration, so it might not be a good idea to set evals during the stage in which we change of data stage + # + # NOTE(tj.solergibert) Further investigation should be done, but there is a extrange behaiviour when deleting the DataLoaders////lambda functs. As they + # are converted into Iterators with `sanity_check_dataloader` we can't access anymore the DataLoader object to del the dataset (After first stage, + # in this function we locally create the DataLoder from the lambda func --> Return Iterator) + # + # Also when the gc deletes the first stage dataloader, all the `DatatroveFileDataset._f` are already None AND the `del` thing are deleting a copy of the + # object, not the object itself + # + # FINAL NOTE(tj.solergibert) I will open a Issue in nanotron to check with them if they are aware of this useless deletitions + # + # TODO(tj.solergibert) Check the tuple case below + from collections.abc import Generator + + if not hasattr(self.config, "data_stages") or self.config.data_stages is None: + + if isinstance(dataloaders, tuple): # TODO(tj.solergibert) Check this tuple case + dataloader = dataloaders[0] + else: + dataloader = dataloaders + + self.current_validation_dataloader_length = min(len(dataloader), self.limit_val_batches) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + + return + elif isinstance(dataloaders, Generator): + # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader + # remove this in the next PR + self.current_validation_dataloader = dataloaders + return + + assert len(dataloaders) > 0, "No dataloaders provided" + assert len(dataloaders) == len( + self.config.data_stages + ), "Number of dataloaders should match the number of dataset stages" + + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): + import gc + + log_rank( + f"[Validation Stage: {stage_name}] Clearing the previous validation stage's ({prev_stage_name}) dataloader and dataset from memory", + logger=logger, + level=logging.INFO, + ) + + # NOTE: Clear dataloader from memory + del dataloader.dataset + del dataloader.sampler + del dataloader.batch_sampler + + gc.collect() + + for stage_idx, stage in enumerate(self.config.data_stages): + if stage_idx < self.metadata.last_stage_idx: + continue + # NOTE(tj.solergibert) From this point stage_idx = self.metadata.last_stage_idx. We update self.metadata.last_stage_idx (which keeps track of the training stage) + # in each and every training step. + + if ( + stage_idx is not self.metadata.last_validation_stage_idx + ): # When stage_idx (= self.metadata.last_stage_idx, the training stage index) is different than the last validation stage index + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + # Delete previous stage DataLoader + prev_stage_name = self.config.data_stages[stage_idx - 1].name + prev_dataloader = dataloaders[prev_stage_name] + + if isinstance(prev_dataloader, DataLoader): + # NOTE: we don't need to clear dummy data generator from memory + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) + + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + + # NOTE(tj.solergibert) Create AGAIN the DataLoader + dataloader = dataloaders[stage.name] + # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it + dataloader = dataloader() if callable(dataloader) else dataloader + break + + self.current_validation_dataloader_length = min(len(dataloader), self.limit_val_batches) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) # NOTE(tj.solergibert) Create a Iterator from the DataLoader + def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): from collections.abc import Generator @@ -326,11 +433,11 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da self.config.data_stages ), "Number of dataloaders should match the number of dataset stages" - def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): import gc log_rank( - f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", + f"[Training Stage: {stage_name}] Clearing the previous training stage's ({prev_stage_name}) dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -367,7 +474,9 @@ def find_stage_idx_to_resume(): if isinstance(prev_dataloader, DataLoader): # NOTE: we don't need to clear dummy data generator from memory - clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) self.metadata.last_stage_idx = stage_idx @@ -433,6 +542,19 @@ def train( # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) + self.training_step_time = time.time() + + # Validation stage + if self.iteration_step % self.config.tokens.val_check_interval == 0: + self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) + val_global_loss, val_lang_losses = self.validation_step( + dataloader=self.current_validation_dataloader + ) + self.validation_step_time = time.time() + else: + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + val_global_loss, val_lang_losses = None, None # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -443,7 +565,7 @@ def train( ].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs(outputs=outputs, loss_avg=loss_avg) + self.train_step_logs(loss_avg=loss_avg, global_loss=val_global_loss, lang_losses=val_lang_losses) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -524,6 +646,8 @@ def training_step( # This is an average on only one data rank. loss_avg = {} for k in outputs[0].keys(): + if k == "sample_loss": + continue # sample loss is the individual losses, is already averaged as 'lm_loss' if k == "loss": loss_avg["lm_loss"] = torch.stack([output[k] for output in outputs]).sum() k = "lm_loss" @@ -556,22 +680,71 @@ def training_step( return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: - outputs = self.pipeline_engine.validate_batch_iter( + outputs, lang_codes = self.pipeline_engine.validate_batch_iter( model=self.model, - batch=(next(dataloader) for _ in range(self.limit_val_batches)), - nb_microbatches=self.limit_val_batches, + batch=(next(dataloader) for _ in range(self.current_validation_dataloader_length)), + nb_microbatches=self.current_validation_dataloader_length, ) - return outputs + + lang_losses = { + lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.languages + } + lang_losses_list = list(lang_losses.keys()) + + # Compute losses + if isinstance(outputs[0], torch.Tensor): + # Multilingual losses + for loss, lang_code in zip(outputs, lang_codes): + lang_losses[lang_losses_list[lang_code]].append(loss) + # Global loss + global_loss_avg = torch.mean(torch.stack(outputs)) + # Sync multilingual losses across DP + for lang in lang_losses.keys(): + if not lang_losses[ + lang + ]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation + lang_losses[lang] = torch.tensor(-1, dtype=torch.float32) + else: # If we have at least 1 loss from a given language --> compute local language loss mean + lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) + + # NOTE(tj.solergibert) We create a (DP SIZE, LANGS) tensor to aggregate ALL local losses across DP groups. + # Then we compute the mean of each lang in each and every rank and finally copy back the result to the + # `lang_losses` dict for logging + lang_losses_tensor_out = torch.zeros( + (self.parallel_context.dp_pg.size(), len(lang_losses.keys())), dtype=torch.float, device="cuda" + ) # (DP SIZE, LANGS) + lang_losses_tensor_local = torch.stack(list(lang_losses.values())).unsqueeze(0) # (1, LANGS) + dist.all_gather_into_tensor(lang_losses_tensor_out, lang_losses_tensor_local, self.parallel_context.dp_pg) + mask = lang_losses_tensor_out != -1 + lang_losses_tensor_local = (lang_losses_tensor_out * mask).sum(dim=0) / mask.sum(dim=0) # (1, LANGS) + for idx, lang in enumerate(lang_losses.keys()): + lang_losses[lang] = lang_losses_tensor_local[idx] + + # Sync global losses across DP + dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) + + # TODO(tj.solergibert) Delete this testing assertions + for lang in lang_losses.keys(): + assert_tensor_synced_across_pg(tensor=lang_losses[lang], pg=self.parallel_context.dp_pg) + assert_tensor_synced_across_pg(tensor=global_loss_avg, pg=self.parallel_context.dp_pg) + + else: + global_loss_avg = None + lang_losses = None + + return global_loss_avg, lang_losses def train_step_logs( self, - outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], loss_avg: Optional[Dict[str, torch.Tensor]], + global_loss: Optional[torch.Tensor], + lang_losses: Optional[Dict[str, torch.Tensor]], ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() torch.cuda.synchronize() - elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 + # Training metrics + elapsed_time_per_iteration_ms = (self.training_step_time - self.iteration_start_time) * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) ) # tokens_per_sec is calculated using sequence_length @@ -581,13 +754,27 @@ def train_step_logs( global_batch_size=self.global_batch_size, ) + # Validation metrics + if global_loss is not None: + validation_total_samples = self.current_validation_dataloader_length * self.micro_batch_size + validation_elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 + validation_tokens_per_sec = ( + validation_total_samples * self.sequence_length / (validation_elapsed_time_per_iteration_ms / 1000) + ) + + validation_model_tflops, validation_hardware_tflops = self.unwrapped_model.get_flops_per_sec( + iteration_time_in_sec=validation_elapsed_time_per_iteration_ms / 1000, + sequence_length=self.sequence_length, + global_batch_size=validation_total_samples, + ) + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" + # Training metrics lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "consumed_tokens", self.metadata.consumed_train_samples * self.config.tokens.sequence_length, @@ -611,6 +798,46 @@ def train_step_logs( if self.config.optimizer.clip_grad is not None: log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f")) + # Validation metrics + if global_loss is not None: + log_entries.extend( + [ + LogItem( + "validation_consumed_tokens", + validation_total_samples * self.sequence_length, + "human_format", + ), # , "12d"), + LogItem( + "validation_elapsed_time_per_iteration_ms", + validation_elapsed_time_per_iteration_ms, + "human_format", + ), # , ".1f"), + LogItem("validation_tokens_per_sec", validation_tokens_per_sec, "human_format"), # , "1.6E"), + LogItem( + "validation_tokens_per_sec_per_gpu", + validation_tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", + ), # , "1.6E"), + LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), + LogItem( + "validation_model_tflops_per_gpu", validation_model_tflops / 3, "human_format" + ), # , ".2f"), # NOTE(tj.solergibert) Check llama.py --> def get_flops() --> model_flops for explanation of the / 3 factor + LogItem( + "validation_hardware_tflops_per_gpu", validation_hardware_tflops / 3, "human_format" + ), # , ".2f"), # NOTE(tj.solergibert) Check llama.py --> def get_flops() --> model_flops for explanation of the / 3 factor + ] + ) + + # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. + # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 + # GitHub complains: https://github.com/wandb/wandb/issues/3035 + log_entries.extend( + [ + LogItem(f"{lang}_validation_loss", loss.item(), "human_format") + for lang, loss in lang_losses.items() + ] + ) + # Log not too often the memory if self.iteration_step < 5 or (self.iteration_step - 1) % self.config.checkpoints.checkpoint_interval == 0: total, used, free = shutil.disk_usage("/") diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 127ba2fa..f5dcaeb0 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -208,14 +208,19 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL random_input = torch.randn(batch_size, in_features, device="cuda") # synchronize random_input across tp dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) - + random_input.requires_grad = True # Row linear receives as input sharded input - random_sharded_input = random_input[ - :, - dist.get_rank(parallel_context.tp_pg) - * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) - * in_features_per_rank, - ] + random_sharded_input = ( + random_input[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ] + .detach() + .clone() + ) + random_sharded_input.requires_grad = True # Test that we get the same output after forward pass # TODO @kunhao: We may want to have our custom error type @@ -261,6 +266,16 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL else: assert row_linear.bias is None + torch.testing.assert_close( + random_sharded_input.grad, + random_input.grad[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ], + ) + parallel_context.destroy()