diff --git a/examples/custom-dataloader/config_custom_dl.yaml b/examples/config_no_validation_tiny_llama.yaml similarity index 85% rename from examples/custom-dataloader/config_custom_dl.yaml rename to examples/config_no_validation_tiny_llama.yaml index 970e7407..75210489 100644 --- a/examples/custom-dataloader/config_custom_dl.yaml +++ b/examples/config_no_validation_tiny_llama.yaml @@ -6,7 +6,13 @@ checkpoints: save_initial_state: false data_stages: - data: - dataset: null # Custom dataloader will be used + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text num_loading_workers: 1 seed: 42 name: Stable Training Stage @@ -33,12 +39,12 @@ general: step: null lighteval: null logging: - iteration_step_info_interval: 1 + iteration_step_info_interval: 10 log_level: info log_level_replica: info model: ddp_bucket_cap_mb: 25 - dtype: bfloat16 + dtype: float32 init_method: std: 0.025 make_vocab_size_divisible_by: 1 @@ -81,7 +87,7 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 2 + dp: 1 expert_parallel_size: 1 pp: 1 pp_engine: 1f1b @@ -96,8 +102,8 @@ tokenizer: tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 - limit_val_batches: 0 + limit_val_batches: 5 micro_batch_size: 2 sequence_length: 256 - train_steps: 15 - val_check_interval: -1 + train_steps: 200 + val_check_interval: 2 diff --git a/examples/config_validation_tiny_llama.yaml b/examples/config_validation_tiny_llama.yaml new file mode 100644 index 00000000..5a205b50 --- /dev/null +++ b/examples/config_validation_tiny_llama.yaml @@ -0,0 +1,158 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Training1 + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Training2 + start_training_step: 12 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Training3 + start_training_step: 50 +valid_data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/oscar-en-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Valid1 + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/oscar-en-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Valid2 + start_training_step: 26 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/oscar-en-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Valid3 + start_training_step: 38 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 10 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: float32 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 16 + initializer_range: 0.02 + intermediate_size: 64 + is_llama_config: true + max_position_embeddings: 256 + num_attention_heads: 4 + num_hidden_layers: 2 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 5 + micro_batch_size: 2 + sequence_length: 256 + train_steps: 100 + val_check_interval: 1 diff --git a/examples/config_validation_tiny_llama_resume30.yaml b/examples/config_validation_tiny_llama_resume30.yaml new file mode 100644 index 00000000..f5096f7c --- /dev/null +++ b/examples/config_validation_tiny_llama_resume30.yaml @@ -0,0 +1,158 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: /home/kylematoba/cscs/nanotron/checkpoints/30 + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Training1 + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Training2 + start_training_step: 12 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Training3 + start_training_step: 50 +valid_data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/oscar-en-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Valid1 + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/oscar-en-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Valid2 + start_training_step: 26 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/oscar-en-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Valid3 + start_training_step: 38 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 10 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: float32 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 16 + initializer_range: 0.02 + intermediate_size: 64 + is_llama_config: true + max_position_embeddings: 256 + num_attention_heads: 4 + num_hidden_layers: 2 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 5 + micro_batch_size: 2 + sequence_length: 256 + train_steps: 100 + val_check_interval: 1 diff --git a/examples/custom-dataloader/README.md b/examples/custom-dataloader/README.md deleted file mode 100644 index 9ded4b3a..00000000 --- a/examples/custom-dataloader/README.md +++ /dev/null @@ -1,39 +0,0 @@ -# Use a custom dataloader with Nanotron - -This example shows how to use a custom dataloader with Nanotron. We will use a simple dataloader that loads a random tokenized dataset and feeds it to a Nanotron model. -https://github.com/huggingface/nanotron/blob/2e21db0db46a40bedbd03714616dd0ae4ea75914/examples/custom-dataloader/run_train.py#L72-L84 - -`DataCollatorForCLM` is a custom data collator that takes a list of input_ids and returns a dictionary with the input_ids and the labels on the ranks which need it. For example `input_ids` are only needed in the first PP rank, while `labels` are needed in the last PP rank. - -And to test it out, you should fix your config to have: (example: [config_custom_dl.yaml](config_custom_dl.yaml)) -```yaml -- data: - dataset: null # Custom dataloader will be used - num_loading_workers: 1 - seed: 42 - name: Stable Training Stage - start_training_step: 1 -``` - -To try it out you can run the following command: - -```bash -export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations -torchrun --nproc_per_node=2 examples/custom-dataloader/run_train.py --config-file examples/custom-dataloader/config_custom_dl.yaml -``` - -## Troubleshooting - -### `return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)` -``` - File "/fsx/nouamane/projects/nanotron/src/nanotron/parallel/tensor_parallel/nn.py", line 284, in forward - out = super().forward(masked_input) - File "/fsx/nouamane/miniconda/envs/2-1-cu121/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward - return F.embedding( - File "/fsx/nouamane/miniconda/envs/2-1-cu121/lib/python3.10/site-packages/torch/nn/functional.py", line 2233, in embedding - return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) -RuntimeError: CUDA error: device-side assert triggered -Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. -``` - -If you encounter an error with `torch.embedding`, it's probable you're feeding a token which is bigger than the model's vocabulary size. Check your model's vocab size and tokenizer diff --git a/examples/custom-dataloader/run_train.py b/examples/custom-dataloader/run_train.py deleted file mode 100644 index e1995381..00000000 --- a/examples/custom-dataloader/run_train.py +++ /dev/null @@ -1,222 +0,0 @@ -""" -Nanotron training script example using a custom dataloader. - -Usage: -``` -export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations -torchrun --nproc_per_node=2 examples/custom-dataloader/run_train.py --config-file examples/custom-dataloader/config_custom_dl.yaml -``` -""" -import argparse -from typing import Dict, cast - -import datasets -import numpy as np -from nanotron import logging -from nanotron.config import ( - DataArgs, - DatasetStageArgs, - PretrainDatasetsArgs, -) -from nanotron.dataloader import ( - DataCollatorForCLM, - clm_process, - get_dataloader_worker_init, - get_datasets, - get_train_dataloader, -) -from nanotron.helpers import ( - compute_remain_train_steps_of_a_data_stage_from_ckp, - get_consumed_train_samples_of_a_data_stage_from_ckp, -) -from nanotron.logging import log_rank -from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks -from nanotron.trainer import DistributedTrainer -from nanotron.utils import main_rank_first -from torch.utils.data import DataLoader - -try: - from huggingface_hub import __version__ as hf_hub_version - from transformers import AutoTokenizer - from transformers import __version__ as tf_version -except ImportError: - hf_hub_version = None - tf_version = None - -logger = logging.get_logger(__name__) - - -def get_dataloader_from_data_stage( - trainer: DistributedTrainer, - data: DataArgs, - consumed_train_samples: int, - num_remaining_train_steps: int, -): - """ - Returns a dataloader for a given data stage. - - data: The data configuration for the current stage. - consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero). - num_remaining_train_steps: The number of remaining training steps for this stage. - """ - assert consumed_train_samples >= 0, "consumed_train_samples should be greater than 0" - assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0" - - # First, we need to know which ranks to feed the dataloader to - input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) - - # Case 1: custom data generator - if data.dataset is None: - log_rank("Using custom data generator", logger=logger, level=logging.INFO, rank=0) - - ########################################################################################################### - # This can be replaced with your own tokenized data generator - ########################################################################################################### - train_dataset = datasets.Dataset.from_dict( - { - "input_ids": np.random.randint( - 0, - trainer.config.model.model_config.vocab_size, - (trainer.global_batch_size * num_remaining_train_steps, trainer.sequence_length + 1), - ), - } - ) - ########################################################################################################### - - data_collator = DataCollatorForCLM( - sequence_length=trainer.sequence_length, - input_pp_rank=input_pp_rank, - output_pp_rank=output_pp_rank, - parallel_context=trainer.parallel_context, - ) - - return DataLoader( - train_dataset, - batch_size=trainer.micro_batch_size, - collate_fn=data_collator, - drop_last=True, - num_workers=0, - pin_memory=True, - worker_init_fn=get_dataloader_worker_init(dp_rank=trainer.parallel_context.dp_pg.rank()), - ) - - # Case 2: HuggingFace datasets - elif isinstance(data.dataset, PretrainDatasetsArgs): - log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0) - tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path - log_rank( - f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}", - logger=logger, - level=logging.INFO, - rank=0, - ) - - # We need to the 1st device to process dataset and cache it, then other devices load from cache - with main_rank_first(trainer.parallel_context.world_pg): - # We load the raw dataset - raw_dataset = get_datasets( - hf_dataset_or_datasets=data.dataset.hf_dataset_or_datasets, - hf_dataset_config_name=data.dataset.hf_dataset_config_name, - splits=data.dataset.hf_dataset_splits, - )["train"] - - tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" - - # We apply the Causal Language Modeling preprocessing - train_dataset = clm_process( - raw_dataset=raw_dataset, - tokenizer=tokenizer, - text_column_name=data.dataset.text_column_name, - dataset_processing_num_proc_per_process=data.dataset.dataset_processing_num_proc_per_process, - dataset_overwrite_cache=data.dataset.dataset_overwrite_cache, - sequence_length=trainer.sequence_length, - ) - - # We load the processed dataset on the ranks requiring it - dataloader = get_train_dataloader( - train_dataset=train_dataset, - sequence_length=trainer.sequence_length, - parallel_context=trainer.parallel_context, - input_pp_rank=input_pp_rank, - output_pp_rank=output_pp_rank, - micro_batch_size=trainer.micro_batch_size, - consumed_train_samples=consumed_train_samples, - dataloader_num_workers=data.num_loading_workers, - seed_worker=data.seed, - dataloader_drop_last=True, - ) - - # Check if we have enough samples for train_steps - total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length - num_tokens_needed_for_training = ( - num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length - ) - assert num_tokens_needed_for_training <= total_tokens_dataset, ( - f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), " - f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}" - ) - else: - raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") - - return dataloader - - -def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: - dataloaders = {} - - for stage_idx, stage in enumerate(trainer.config.data_stages): - # NOTE: we only create the dataloader for the first stage, - # then we lazy initialize the dataloader for the other stages - stage = cast(DatasetStageArgs, stage) - consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata) - assert ( - consumed_train_samples is not None - ), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint" - - num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( - stage, trainer.config, trainer.metadata - ) - log_rank( - f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples", - logger=logger, - level=logging.INFO, - rank=0, - ) - - dataloader = ( - get_dataloader_from_data_stage( - trainer, - stage.data, - consumed_train_samples=consumed_train_samples, - num_remaining_train_steps=num_remaining_train_steps, - ) - if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage( - trainer, - stage.data, - consumed_train_samples=consumed_train_samples, - num_remaining_train_steps=num_remaining_train_steps, - ) - ) - dataloaders[stage.name] = dataloader - return dataloaders - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") - return parser.parse_args() - - -if __name__ == "__main__": - args = get_args() - config_file = args.config_file - - # Load trainer and data - trainer = DistributedTrainer(config_file) - dataloader = get_dataloader(trainer) - - # Train - trainer.train(dataloader) diff --git a/reconcile_rotary_embeddings.py b/reconcile_rotary_embeddings.py new file mode 100644 index 00000000..0bbe195c --- /dev/null +++ b/reconcile_rotary_embeddings.py @@ -0,0 +1,91 @@ +import torch +import torchtune +import flash_attn +import flash_attn.layers.rotary + + +class RotaryEmbeddingKyleLikeFA(torch.nn.Module): + """ + Has the same function signature as FA, for interleaved=True and separate q, kv. + seqlen_offset = 0 + Does not operate inplace, but that's fine for how it's used in Nanotron. + """ + def __init__(self, dim: int, base: float): + super().__init__() + self.dim = dim + self.base = float(base) + + self.max_seq_len = None + self.rpe = None + + def forward(self, q, kv): + bs, q_len, n_heads, _ = q.shape + assert self.dim == _ + + assert (bs, q_len, 2, n_heads, self.dim) == kv.shape + + if (self.rpe is None) or (self.max_seq_len != q_len): + self.max_seq_len = q_len + self.rpe = torchtune.modules.RotaryPositionalEmbeddings(dim=self.dim, + max_seq_len=self.max_seq_len, + base=self.base).to(device) + q_out = self.rpe(q) + kv_out = torch.stack((self.rpe(kv[:, :, 0]), kv[:, :, 1]), 2) + return q_out, kv_out + + + +if __name__ == "__main__": + device = torch.device(0) + theta = 10000 + + batch_size = 3 + dim_qk = 4 + q_len = 256 + kv_len = 256 + n_heads = 4 + + max_seq_len = max(q_len, kv_len) + + print(max_seq_len) + + + query_states = torch.rand(batch_size, q_len, n_heads, dim_qk, device=device) + key_value_states = torch.rand(batch_size, kv_len, 2, n_heads, dim_qk, device=device).contiguous() + + + interleaved = True + # interleaved = False + re1 = flash_attn.layers.rotary.RotaryEmbedding(dim=dim_qk, interleaved=interleaved, base=theta).to(device) + re2 = torchtune.modules.RotaryPositionalEmbeddings(dim=dim_qk, max_seq_len=max_seq_len, base=theta).to(device) + re3 = RotaryEmbeddingKyleLikeFA(dim=dim_qk, base=theta).to(device) + + + + print(key_value_states[:, :, 0].shape) + + out2 = re2(query_states) + out3 = re2(key_value_states[:, :, 0]) + # out4 = re2(key_value_states[:, :, 1]) + + out_eq = re3(query_states, kv=key_value_states) + + # torch.testing.assert_close(out2, query_states) + out1 = re1(query_states, kv=key_value_states) + + torch.testing.assert_close(out_eq[0], out1[0]) + torch.testing.assert_close(out_eq[1], out1[1]) + + + # Do this second, since the computation is inplace + torch.testing.assert_close(out1[0], query_states) + + test = torch.stack((out3, key_value_states[:, :, 1]), 2) + torch.testing.assert_close(out1[1], test) + # torch.testing.assert_close(out1[1][:, :, 0], out3) + + + torch.testing.assert_close(out1[0], out2) + + print("done") + diff --git a/run_train.py b/run_train.py index 021d955d..8447e391 100644 --- a/run_train.py +++ b/run_train.py @@ -10,6 +10,9 @@ import argparse from typing import Dict, cast +import warnings +# warnings.filterwarnings("error") + import numpy as np from nanotron import logging from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs @@ -178,44 +181,51 @@ def get_dataloader_from_data_stage( return dataloader -def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: +def get_dataloader(trainer: DistributedTrainer, + data_stages_fieldname: str, + metadata_fieldname: str) -> Dict[str, DataLoader]: dataloaders = {} + data_stages = getattr(trainer.config, data_stages_fieldname) + metadata = getattr(trainer, metadata_fieldname) + + if data_stages and metadata: + for stage_idx, stage in enumerate(data_stages): + # NOTE: we only create the dataloader for the first stage, + # then we lazy initialize the dataloader for the other stages + stage = cast(DatasetStageArgs, stage) + consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, metadata) + assert ( + consumed_train_samples is not None + ), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint" - for stage_idx, stage in enumerate(trainer.config.data_stages): - # NOTE: we only create the dataloader for the first stage, - # then we lazy initialize the dataloader for the other stages - stage = cast(DatasetStageArgs, stage) - consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata) - assert ( - consumed_train_samples is not None - ), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint" - - num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( - stage, trainer.config, trainer.metadata - ) - log_rank( - f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples", - logger=logger, - level=logging.INFO, - rank=0, - ) - - dataloader = ( - get_dataloader_from_data_stage( - trainer, - stage.data, - consumed_train_samples=consumed_train_samples, - num_remaining_train_steps=num_remaining_train_steps, + num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( + stage, data_stages, trainer.config.tokens, metadata ) - if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage( - trainer, - stage.data, - consumed_train_samples=consumed_train_samples, - num_remaining_train_steps=num_remaining_train_steps, + log_rank( + f"[{data_stages_fieldname} Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples", + logger=logger, + level=logging.INFO, + rank=0, ) - ) - dataloaders[stage.name] = dataloader + + dataloader = ( + get_dataloader_from_data_stage( + trainer, + stage.data, + consumed_train_samples=consumed_train_samples, + num_remaining_train_steps=num_remaining_train_steps, + ) + if stage_idx == 0 + else lambda stage=stage: get_dataloader_from_data_stage( + trainer, + stage.data, + consumed_train_samples=consumed_train_samples, + num_remaining_train_steps=num_remaining_train_steps, + ) + ) + dataloaders[stage.name] = dataloader + else: + dataloaders = None return dataloaders @@ -231,7 +241,9 @@ def get_args(): # Load trainer and data trainer = DistributedTrainer(config_file) - dataloader = get_dataloader(trainer) + dataloader_valid = get_dataloader(trainer, "valid_data_stages", "valid_metadata") + dataloader_train = get_dataloader(trainer, "data_stages", "metadata") # Train - trainer.train(dataloader) + trainer.train(dataloader_train, dataloader_valid) + print("Done") diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index adc1eafd..82cca48f 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -338,6 +338,7 @@ class Config: tokens: Optional[TokensArgs] = None optimizer: Optional[OptimizerArgs] = None data_stages: Optional[List[DatasetStageArgs]] = None + valid_data_stages: Optional[List[DatasetStageArgs]] = None profiler: Optional[ProfilerArgs] = None lighteval: Optional[LightEvalConfig] = None diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 580bd99d..b31142de 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -2,7 +2,7 @@ from packaging.version import Version, parse -CHECKPOINT_VERSION = Version("1.4") +CHECKPOINT_VERSION = Version("1.5") PY_VERSION = parse(platform.python_version()) diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index a82f0294..705d46d1 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -679,19 +679,22 @@ def log_throughput( def compute_remain_train_steps_of_a_data_stage_from_ckp( - stage: DatasetStageArgs, config: Config, metadata: TrainingMetadata + stage: DatasetStageArgs, + data_stages, + tokens, + metadata: TrainingMetadata ) -> int: def is_last_stage(): - sorted_stages = sorted(config.data_stages, key=lambda x: x.start_training_step) + sorted_stages = sorted(data_stages, key=lambda x: x.start_training_step) return sorted_stages[-1].start_training_step == stage.start_training_step def is_resume_from_training(): return metadata.last_train_step > 0 if is_last_stage() is True: - total_train_steps = config.tokens.train_steps + total_train_steps = tokens.train_steps else: - next_stage = next((s for s in config.data_stages if s.start_training_step > stage.start_training_step), None) + next_stage = next((s for s in data_stages if s.start_training_step > stage.start_training_step), None) total_train_steps = next_stage.start_training_step if metadata.last_train_step > stage.start_training_step: @@ -708,7 +711,8 @@ def get_consumed_train_samples_of_a_data_stage_from_ckp( stage: DatasetStageArgs, metadata: TrainingMetadata ) -> Optional[int]: start_training_step = stage.start_training_step - return next( + out = next( (s.consumed_train_samples for s in metadata.data_stages if s.start_training_step == start_training_step), None, ) + return out diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 28a2e30f..3206bce0 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -44,6 +44,9 @@ from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator from nanotron.utils import checkpoint_method +use_flash_attn = False +# use_flash_attn = True + logger = logging.get_logger(__name__) @@ -165,7 +168,12 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) - self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) + # do_compile = True + do_compile = False + # self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) + self.split_silu_mul = GLUActivation(config.hidden_act) + if do_compile: + self.split_silu_mul = torch.compile(self.split_silu_mul) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] merged_states = self.gate_up_proj(hidden_states) @@ -193,23 +201,32 @@ def forward( key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] ): - from flash_attn.flash_attn_interface import flash_attn_func - - # NOTE: this scale is for µTransfer, - # in SP, we use sqrt(1/d_h) - softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None - # For now we are assuming that we use causual mask. No magic here - causal = True - attn_output = flash_attn_func( - q=query_states, - k=key_states, - v=value_states, - dropout_p=0.0, - softmax_scale=softmax_scale, - causal=causal, - return_attn_probs=False, - ) - + if use_flash_attn: + from flash_attn.flash_attn_interface import flash_attn_func + + # NOTE: this scale is for µTransfer, + # in SP, we use sqrt(1/d_h) + softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None + # For now we are assuming that we use causual mask. No magic here + causal = True + attn_output = flash_attn_func( + q=query_states, + k=key_states, + v=value_states, + dropout_p=0.0, + softmax_scale=softmax_scale, + causal=causal, + return_attn_probs=False, + ) + else: + assert not self.is_using_mup, "have not tested this" + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states.permute(0, 2, 1, 3), + key_states.permute(0, 2, 1, 3), + value_states.permute(0, 2, 1, 3), + dropout_p=0.0, + is_causal=True, + ).permute(0, 2, 1, 3) return attn_output @@ -241,6 +258,39 @@ def pad_to_right(tensor, mask, new_tensor=None): return new_tensor, right_padded_mask +class RotaryEmbeddingKyleLikeFA(torch.nn.Module): + """ + Has the same function signature as FA, for interleaved=True and separate q, kv. + seqlen_offset = 0 + Does not operate inplace, but that's fine for how it's used in Nanotron. + """ + def __init__(self, dim: int, base: float): + super().__init__() + self.dim = dim + self.base = float(base) + + self.max_seq_len = None + self.rpe = None + + from torchtune.modules import RotaryPositionalEmbeddings + self.rpe_builder = RotaryPositionalEmbeddings + + def forward(self, q, kv): + bs, q_len, n_heads, _ = q.shape + assert self.dim == _ + + assert (bs, q_len, 2, n_heads, self.dim) == kv.shape + + if (self.rpe is None) or (self.max_seq_len != q_len): + self.max_seq_len = q_len + self.rpe = self.rpe_builder(dim=self.dim, + max_seq_len=self.max_seq_len, + base=self.base).to(q.device) + q_out = self.rpe(q) + kv_out = torch.stack((self.rpe(kv[:, :, 0]), kv[:, :, 1]), 2) + return q_out, kv_out + + class CausalSelfAttention(nn.Module, AttachableStore): def __init__( self, @@ -249,8 +299,6 @@ def __init__( tp_pg: dist.ProcessGroup, layer_idx: int, ): - from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding - super().__init__() # Tensor parallel considerations: We split tensors along head dimension assert ( @@ -310,11 +358,15 @@ def __init__( end=config.max_position_embeddings, theta=config.rope_theta, ) - # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding( - dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta - ) + if use_flash_attn: + from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding + self.flash_rotary_embedding = FlashRotaryEmbedding( + dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta + ) + else: + assert config.rope_interleaved, "this case not yet tested" + self.flash_rotary_embedding = RotaryEmbeddingKyleLikeFA(dim=self.d_qk, base=config.rope_theta) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -340,12 +392,6 @@ def forward( hidden_states, # [seq_length, batch_size, hidden_size] sequence_mask, # [batch_size, seq_length] ): - from flash_attn import bert_padding - from flash_attn.flash_attn_interface import ( - flash_attn_varlen_func, - flash_attn_with_kvcache, - ) - qkv_states = self.qkv_proj( hidden_states ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] @@ -397,6 +443,8 @@ def forward( key_states = self.rotary_embedding(key_states, position_ids=position_ids) if "key" not in store: + from flash_attn import bert_padding + from flash_attn.flash_attn_interface import flash_attn_varlen_func # First inference iteration (Prefill) # TODO @nouamane: support custom masking # assert that [ False, False, False, False, True, True, True, True, True, True] is accepted @@ -457,6 +505,8 @@ def forward( pad_to_right(value_states, sequence_mask, new_tensor=v_cache) else: + from flash_attn.flash_attn_interface import flash_attn_with_kvcache + # Pull pre-computed key/value states # Subsequent inference iterations (q_length=1) k_cache = store["key"] diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 2e940744..4dc85ca0 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -181,7 +181,7 @@ def build_grad_buffers( if not param.requires_grad: continue - assert param.dtype != torch.float, f"Expected {name} not to be float" + # assert param.dtype != torch.float, f"Expected {name} not to be float" assert param.is_contiguous(), f"Expected {name} to be contiguous" next_offset = offset + param.numel() * element_size diff --git a/src/nanotron/optim/zero.py b/src/nanotron/optim/zero.py index cb61c8b7..9f9b1fec 100644 --- a/src/nanotron/optim/zero.py +++ b/src/nanotron/optim/zero.py @@ -6,7 +6,10 @@ import numpy as np import torch.optim -from functorch.dim import tree_map +try: + from functorch.dim import tree_map +except: + from torch.utils._pytree import tree_map from torch import nn from tqdm import tqdm diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..bb05055f 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -9,7 +9,7 @@ from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model -from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState +from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState, PipelineEvalBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers from torch import nn as torch_nn @@ -53,7 +53,7 @@ def forward( # Add output as activations that require backward pass if not isinstance(output["loss"], TensorPointer): - assert output["loss"].requires_grad + # assert output["loss"].requires_grad state.register_activation_requiring_backward(output["loss"]) return output @@ -134,9 +134,9 @@ def validate_batch_iter( nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: # Assign a new state for the current batch - state = PipelineTrainBatchState() # TODO: do i need state? + # state = PipelineTrainBatchState() # TODO: do i need state? + state = PipelineEvalBatchState() self.nb_microbatches = nb_microbatches - outputs = [] with attach_pipeline_state_to_model(model=model, pipeline_state=state): diff --git a/src/nanotron/parallel/pipeline_parallel/state.py b/src/nanotron/parallel/pipeline_parallel/state.py index e07cc89a..87fabc19 100644 --- a/src/nanotron/parallel/pipeline_parallel/state.py +++ b/src/nanotron/parallel/pipeline_parallel/state.py @@ -202,6 +202,8 @@ class PipelineEvalBatchState(PipelineBatchState): microbatches_activations_to_send = collections.deque() microbatches_activations_to_recv = collections.deque() activations_buffer = collections.deque() + # Reinitialise counter + nb_forwards = 0 def register_activation_requiring_backward(self, activation: torch.Tensor): pass diff --git a/src/nanotron/serialize/main.py b/src/nanotron/serialize/main.py index 346ad573..816f044d 100644 --- a/src/nanotron/serialize/main.py +++ b/src/nanotron/serialize/main.py @@ -54,6 +54,7 @@ def save( lr_scheduler: torch.optim.lr_scheduler.LRScheduler, parallel_context: ParallelContext, training_metadata: TrainingMetadata, + valid_metadata: TrainingMetadata, root_folder: Path, should_save_config: bool = True, should_save_model: bool = True, @@ -62,6 +63,7 @@ def save( sanity_checks: bool = True, ) -> None: assert isinstance(training_metadata, TrainingMetadata) + assert (valid_metadata is None) or isinstance(valid_metadata, TrainingMetadata) try: if should_save_config: @@ -118,7 +120,10 @@ def save( ) raise e - save_meta(root_folder=root_folder, parallel_context=parallel_context, training_metadata=training_metadata) + save_meta(root_folder=root_folder, + parallel_context=parallel_context, + training_metadata=training_metadata, + valid_metadata=valid_metadata) # TODO @thomas21: sanity check, not sure whether that needs to happen at testing or now (depends how much it costs) ### diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 0d8708f9..6c55d5e6 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -63,8 +63,8 @@ class CheckpointMetadata: version: Version tp: int dp: int - metas: TrainingMetadata - custom_metas: Optional[Dict[str, Any]] = None + train_meta: TrainingMetadata + valid_meta: Optional[TrainingMetadata] @dataclasses.dataclass @@ -125,8 +125,12 @@ def to_list(list_: Union[List, Tuple], type_hooks: Dict[Type, Callable[[Any], An return list_.__class__((process_type(elt, type_hooks=type_hooks) for elt in list_)) -def save_meta(parallel_context: ParallelContext, root_folder: Path, training_metadata: TrainingMetadata): +def save_meta(parallel_context: ParallelContext, + root_folder: Path, + training_metadata: TrainingMetadata, + valid_metadata: TrainingMetadata): assert isinstance(training_metadata, TrainingMetadata) + assert (valid_metadata is None) or isinstance(valid_metadata, TrainingMetadata) if dist.get_rank(parallel_context.world_pg) != 0: return @@ -136,7 +140,8 @@ def save_meta(parallel_context: ParallelContext, root_folder: Path, training_met version=CHECKPOINT_VERSION, tp=parallel_context.tp_pg.size(), dp=parallel_context.dp_pg.size(), - metas=training_metadata, + train_meta=training_metadata, + valid_meta=valid_metadata ) # There are some types that require manual casting in order to work correctly. diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index f11210da..fe0aa2e3 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -282,7 +282,7 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) - state_dict = torch.load( root_folder / optimizer_filename(parallel_context, is_zero=optimizer.inherit_from(optim.ZeroDistributedOptimizer)), - map_location=map_location, + map_location=map_location, weights_only=True ) if isinstance(optimizer, ZeroDistributedOptimizer): @@ -315,5 +315,5 @@ def load_lr_scheduler( ): root_folder = root_folder / "lr_scheduler" - state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context)) + state_dict = torch.load(root_folder / lr_scheduler_filename(parallel_context), weights_only=True) lr_scheduler.load_state_dict(state_dict) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 21251a32..3bec6c70 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -3,7 +3,11 @@ import os import shutil import time +import gc + from dataclasses import asdict +from collections.abc import Generator + from pathlib import Path from pprint import pformat from typing import ( @@ -93,6 +97,8 @@ ) from nanotron.serialize.metadata import DataStageMetadata, TrainingMetadata from nanotron.serialize.optimizer import load_optimizer +dataloader_arg = Dict[str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]]] + logger = logging.get_logger(__name__) @@ -188,6 +194,12 @@ def __init__( optimizer_args=self.config.optimizer, parallel_context=self.parallel_context, ) + # Init learning rate scheduler + self.lr_scheduler = lr_scheduler_builder( + optimizer=self.optimizer, + lr_scheduler_args=self.config.optimizer.learning_rate_scheduler, + total_training_steps=self.config.tokens.train_steps, + ) if self.init_checkpoint_path is not None: load_optimizer( optimizer=self.optimizer, @@ -197,27 +209,32 @@ def __init__( model=self.model, ) - # Init learning rate scheduler - self.lr_scheduler = lr_scheduler_builder( - optimizer=self.optimizer, - lr_scheduler_args=self.config.optimizer.learning_rate_scheduler, - total_training_steps=self.config.tokens.train_steps, - ) if self.init_checkpoint_path is not None: load_lr_scheduler( lr_scheduler=self.lr_scheduler, parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path, ) + # Update optimizer learning rate because otherwise it is set to zero in the first iteration. + param_groups = self.optimizer.get_base_optimizer().param_groups + last_lrs = self.lr_scheduler.get_last_lr() + assert len(param_groups) == len(last_lrs) + for group, last_lr in zip(param_groups, last_lrs): + assert "lr" in group + group["lr"] = last_lr # Define iteration start state if self.init_checkpoint_path is not None: checkpoint_metadata = load_meta( parallel_context=self.parallel_context, root_folder=self.init_checkpoint_path ) - assert isinstance(checkpoint_metadata.metas, TrainingMetadata) + assert isinstance(checkpoint_metadata.train_meta, TrainingMetadata) + assert (checkpoint_metadata.valid_meta is None) or isinstance(checkpoint_metadata.valid_meta, TrainingMetadata) + log_rank(str(checkpoint_metadata), logger=logger, level=logging.INFO, rank=0) - self.metadata: TrainingMetadata = checkpoint_metadata.metas + self.metadata: TrainingMetadata = checkpoint_metadata.train_meta + self.valid_metadata: TrainingMetadata = checkpoint_metadata.valid_meta + # NOTE: we should not change data stages assert ( self.config.tokens.train_steps > self.metadata.last_train_step @@ -232,7 +249,18 @@ def __init__( self.metadata: TrainingMetadata = TrainingMetadata( consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages ) - + if self.config.valid_data_stages: + valid_data_stages = [ + DataStageMetadata( + name=stage.name, start_training_step=stage.start_training_step, consumed_train_samples=0 + ) + for stage in self.config.valid_data_stages + ] + self.valid_metadata: TrainingMetadata = TrainingMetadata( + consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=valid_data_stages + ) + else: + self.valid_metadata = None # Setup tensorboard write and log writers on output rank self.logger_ranks = self.parallel_context.get_global_rank( ep_rank=0, pp_rank=self.unwrapped_model.output_pp_rank, dp_rank=0, tp_rank=0 @@ -250,8 +278,11 @@ def __init__( self.sequence_length = self.config.tokens.sequence_length self.iteration_step = self.metadata.last_train_step self.limit_val_batches = self.config.tokens.limit_val_batches + self.val_check_interval = self.config.tokens.val_check_interval + # NOTE: the dataloader currently in use for the current training stage self.current_dataloader: Optional[DataLoader] = None + self.current_valid_dataloader: Optional[DataLoader] = None self.post_init() @@ -263,7 +294,6 @@ def post_init(self): def pre_training(self, *args, **kwargs): self._print_training_plan() - metadata: TrainingMetadata = self.metadata log_rank( @@ -299,9 +329,20 @@ def _print_training_plan(self): ) log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) - def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): - from collections.abc import Generator + def _clear_dataloader_from_memory(self, dataloader: DataLoader, stage_name: str): + log_rank( + f"[Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", + logger=logger, + level=logging.INFO, + ) + # NOTE: Clear dataloader from memory + del dataloader.dataset + del dataloader.sampler + del dataloader.batch_sampler + + gc.collect() + def _update_train_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): if not hasattr(self.config, "data_stages") or self.config.data_stages is None: if self.current_dataloader is None: if isinstance(dataloaders, tuple): @@ -319,25 +360,7 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da return assert len(dataloaders) > 0, "No dataloaders provided" - assert len(dataloaders) == len( - self.config.data_stages - ), "Number of dataloaders should match the number of dataset stages" - - def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): - import gc - - log_rank( - f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", - logger=logger, - level=logging.INFO, - ) - - # NOTE: Clear dataloader from memory - del dataloader.dataset - del dataloader.sampler - del dataloader.batch_sampler - - gc.collect() + assert len(dataloaders) == len(self.config.data_stages), "Number of dataloaders should match the number of dataset stages" dataloader = None @@ -364,17 +387,16 @@ def find_stage_idx_to_resume(): if isinstance(prev_dataloader, DataLoader): # NOTE: we don't need to clear dummy data generator from memory - clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + self._clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) self.metadata.last_stage_idx = stage_idx - if is_resume_from_training: remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( - stage, self.config, self.metadata + stage, self.config.data_stages, self.config.tokens, self.metadata ) consumed_train_steps = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, self.metadata) log_rank( - f"Resuming training from stage {stage.name}, it has trained for {consumed_train_steps} samples and has {remaining_train_steps} remaining train steps", + f"[Train] Resuming training from stage {stage.name}, it has trained for {consumed_train_steps} samples and has {remaining_train_steps} remaining train steps", logger=logger, level=logging.INFO, rank=0, @@ -390,20 +412,82 @@ def find_stage_idx_to_resume(): dataloader=dataloader, parallel_context=self.parallel_context, config=self.config ) + def _update_valid_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): + if not hasattr(self.config, "valid_data_stages") or self.config.valid_data_stages is None: + if self.current_valid_dataloader is None: + if isinstance(dataloaders, tuple): + dataloader = dataloaders[0] + else: + dataloader = dataloaders + self.current_valid_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + return + + assert len(dataloaders) > 0, "No dataloaders provided" + assert len(dataloaders) == len(self.config.valid_data_stages), "Number of dataloaders should match the number of dataset stages" + + dataloader = None + + def find_stage_idx_to_resume(): + reversed_data_stages = sorted(self.config.valid_data_stages, key=lambda x: x.start_training_step, reverse=True) + for idx, stage in enumerate(reversed_data_stages): + if self.iteration_step >= stage.start_training_step: + return len(self.config.valid_data_stages) - idx - 1 + return None + + stage_idx_to_resume = find_stage_idx_to_resume() + + for stage_idx, stage in enumerate(self.config.valid_data_stages): + if stage_idx < self.valid_metadata.last_stage_idx: + continue + + stage = cast(DatasetStageArgs, stage) + + is_resume_from_training = self.current_valid_dataloader is None and stage_idx_to_resume == stage_idx + if (stage.start_training_step == self.iteration_step) or is_resume_from_training: + if self.current_valid_dataloader is not None: + prev_stage_name = self.config.valid_data_stages[stage_idx - 1].name + prev_dataloader = dataloaders[prev_stage_name] + if isinstance(prev_dataloader, DataLoader): + # NOTE: we don't need to clear dummy data generator from memory + self._clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + + self.valid_metadata.last_stage_idx = stage_idx + + if is_resume_from_training: + remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( + stage, self.config.valid_data_stages, self.config.tokens, self.valid_metadata + ) + consumed_train_steps = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, self.valid_metadata) + log_rank( + f"[Valid] Resuming training from stage {stage.name}, it has trained for {consumed_train_steps} samples and has {remaining_train_steps} remaining train steps", + logger=logger, + level=logging.INFO, + rank=0, + ) + # print(f"{self.iteration_step } -> changing dataloader to '{stage.name}'") + dataloader = dataloaders[stage.name] + dataloader = dataloader() if callable(dataloader) else dataloader + break + + if dataloader is not None: + self.current_valid_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + def train( self, - dataloader_or_dls: Dict[ - str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] - ], + dataloader_train: dataloader_arg, + dataloader_valid: dataloader_arg, **kwargs, ) -> None: self.pre_training(**kwargs) - + skip_validation = (dataloader_valid == {}) if self.config.checkpoints.save_initial_state and self.init_checkpoint_path is None: self.save_checkpoint() self.pipeline_engine: PipelineEngine = self.config.parallelism.pp_engine - self.pipeline_engine.nb_microbatches = self.n_micro_batches_per_batch # TODO @nouamanetazi: refactor this @@ -423,7 +507,15 @@ def train( prof.step() self.iteration_start_time = time.time() - self._update_dataloader_based_on_training_stages(dataloader_or_dls) + self._update_train_dataloader_based_on_training_stages(dataloader_train) + # + # if 31 == self.iteration_step: + # model = getattr(self.model, "model") + # if False: + # model.lm_head.pp_block.weight[:3, :3] + # out = next(self.current_dataloader) + # print(out["input_ids"][:, :25]) + # print(out["label_ids"][:, :25]) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) @@ -432,13 +524,20 @@ def train( # TODO(xrsrke): refactor using callbacks would be better self.metadata.consumed_train_samples += self.global_batch_size self.metadata.last_train_step = self.iteration_step - self.metadata.data_stages[ - self.metadata.last_stage_idx - ].consumed_train_samples += self.global_batch_size + self.metadata.data_stages[self.metadata.last_stage_idx].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: self.train_step_logs(outputs=outputs, loss_avg=loss_avg) + if (not skip_validation) and (self.iteration_step - 1) % self.val_check_interval == 0: + self._update_valid_dataloader_based_on_training_stages(dataloader_valid) + valid_outputs, valid_loss_avg = self.validation_step(dataloader=self.current_valid_dataloader) + + self.valid_metadata.consumed_train_samples += self.global_batch_size + self.valid_metadata.last_train_step = self.iteration_step + self.valid_metadata.data_stages[self.valid_metadata.last_stage_idx].consumed_train_samples += self.global_batch_size + + self.valid_step_logs(outputs=valid_outputs, loss_avg=valid_loss_avg) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: self.save_checkpoint() @@ -541,7 +640,6 @@ def training_step( handle.wait() self.post_train_step() - return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: @@ -550,7 +648,36 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten batch=(next(dataloader) for _ in range(self.limit_val_batches)), nb_microbatches=self.limit_val_batches, ) - return outputs + + if isinstance(outputs[0]["loss"], torch.Tensor): + loss_avg = torch.stack([_["loss"] for _ in outputs]).sum() + handle = dist.all_reduce(loss_avg, + group=self.parallel_context.dp_pg, + async_op=True, + op=dist.ReduceOp.AVG) + else: + loss_avg = None + handle = None + + if handle is not None: + handle.wait() + return outputs, loss_avg + + def valid_step_logs(self, + outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], + loss_avg: Optional[torch.Tensor], + ) -> None: + log_entries = [LogItem("validation_loss_avg", loss_avg, "human_format")] + self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step) + + # NOTE: only one rank writes to wandb + if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: + wandb.log( + { + **{log_item.tag: log_item.scalar_value for log_item in log_entries}, + "iteration_step": self.iteration_step, + } + ) def train_step_logs( self, @@ -572,7 +699,6 @@ def train_step_logs( if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" - lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ @@ -620,9 +746,9 @@ def train_step_logs( { **{log_item.tag: log_item.scalar_value for log_item in log_entries}, "iteration_step": self.iteration_step, - } + }, + step=self.iteration_step ) - self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step) # Nanotron Benchmark mode: we log the throughput and exit @@ -874,6 +1000,7 @@ def save_checkpoint(self) -> Path: parallel_context=self.parallel_context, root_folder=checkpoint_path, training_metadata=self.metadata, + valid_metadata=self.valid_metadata, config=self.config, ) save_random_states(