diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index cc66cd70..48ae960c 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -5,54 +5,54 @@ checkpoints: resume_checkpoint_path: null save_initial_state: false data_stages: -- data: - dataset: - training_folder: - - datasets/c4-es/train - - datasets/c4-en/train - - datasets/c4-fr/train - validation_folder: - - datasets/c4-es/validation - - datasets/c4-en/validation - - datasets/c4-fr/validation - languages: - - es - - en - - fr - num_loading_workers: 1 - seed: 42 - name: General purpose training (Blended dataset) - start_training_step: 1 -- data: - dataset: - training_folder: - - datasets/c4-es/train - validation_folder: - - datasets/c4-es/validation - languages: - - es - num_loading_workers: 1 - seed: 42 - name: Second purpose training (Single dataset) - start_training_step: 1000 -- data: - dataset: - training_folder: - - datasets/c4-es/train - - datasets/c4-en/train - - datasets/c4-fr/train - validation_folder: - - datasets/c4-es/validation - - datasets/c4-en/validation - - datasets/c4-fr/validation - languages: - - es - - en - - fr - num_loading_workers: 1 - seed: 42 - name: Third purpose training (>1 dataset) - start_training_step: 2000 + - data: + dataset: + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation + languages: + - es + - en + - fr + num_loading_workers: 1 + seed: 42 + name: General purpose training (Blended dataset) + start_training_step: 1 + - data: + dataset: + training_folder: + - datasets/c4-es/train + validation_folder: + - datasets/c4-es/validation + languages: + - es + num_loading_workers: 1 + seed: 42 + name: Second purpose training (Single dataset) + start_training_step: 1000 + - data: + dataset: + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation + languages: + - es + - en + - fr + num_loading_workers: 1 + seed: 42 + name: Third purpose training (>1 dataset) + start_training_step: 2000 general: benchmark_csv_path: null consumed_train_samples: null diff --git a/examples/xglm/README.md b/examples/xglm/README.md index 22765f52..48447ac2 100644 --- a/examples/xglm/README.md +++ b/examples/xglm/README.md @@ -1,18 +1,32 @@ # How to use XGLM? 1. First, make sure to convert the weights from huggingface, for instance: - ``` - torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M - ``` +```bash +torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M +``` -1. Now you are ready to use XGLM. +2. Now you are ready to use XGLM. Make sure you use a .yaml configuration with proper GPT3 config and then run for instance: - ``` - torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml - ``` +```bash +torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml +``` If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. -1. If you want to convert your finetuned checkpoint back to huggingface use: - ``` - torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M - ``` +3. If you want to convert your finetuned checkpoint back to huggingface use: +```bash +torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checkpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M +``` + +## Sparse Upcycling + +To create a sparse model from a dense model, you can use the `convert_dense2moe.py` script that goes from a GPT3 Nanotron model to a GPT3 MoE Nanotron model. For instance: +```bash +cd examples/xglm +torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=checkpoints/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-8x564M --num-experts=8 +``` +Note that this upcycling _drops_ the bias parameters of the MLP because the MegaBlocks implementation does not support bias parameters. While this is a limitation of the current implementation, the performance is quickly recovered after a few training steps. + +To save back to huggingface format use +```bash +torchrun examples/xglm/convert_ntmoe2hf.py --checkpoint-path=$SCRATCH/checkpoints/xglm-8x564M --save-path=$SCRATCH/checkpoints/huggingface/xglm-8x56fM +``` diff --git a/examples/xglm/convert_dense2moe.py b/examples/xglm/convert_dense2moe.py new file mode 100644 index 00000000..fa4d9af7 --- /dev/null +++ b/examples/xglm/convert_dense2moe.py @@ -0,0 +1,179 @@ +""" +Converts a nanotron model to HF format +Command: + torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=nanotron_weights --save-path=nanotron_moe_weights +""" + +import dataclasses +import json +import warnings +from argparse import ArgumentParser +from pathlib import Path +from typing import Optional + +from torch import nn +import torch +import nanotron +from nanotron.config.models_config import GPT3Config, GPT3MoEConfig +from nanotron.models.gpt3 import GPT3ForTraining, GPTBlock +from nanotron.models.gpt3_moe import GPT3MoEForTraining, GPT3MoEBlock +from nanotron.trainer import mark_tied_parameters + +from convert_utils import convert_generic, create_nt_model + + +def convert_config(config: GPT3Config, num_experts=8) -> GPT3MoEConfig: + return GPT3MoEConfig( + **config.__dict__, + is_moe=True, + moe_num_experts=num_experts, + num_experts_per_tok=min(2, num_experts), # arbitrarily chosen + moe_loss_weight=0.01, # arbitrarily chosen + moe_z_loss_weight=0.001, # arbitrarily chosen + moe_glu=False, + ) + + +def convert_dense_to_moe(ff_moe: nn.Module, dense_ff: nn.Module, num_experts: int): + with torch.no_grad(): + # only copy the weight matrix and repeat it n_expert times + weight_1 = dense_ff.c_fc.weight.clone() + if num_experts == 1: + ff_moe.experts.mlp.w1.module.weight.data = weight_1.contiguous() + else: + # [intermediate_size, hidden_size] -> [hidden_size, intermediate_size * n_experts] + weight_1 = weight_1.T + ff_moe.experts.mlp.w1.module.weight.data = weight_1.repeat(1, num_experts) + + weight_2 = dense_ff.c_proj.weight.clone() + if num_experts == 1: # just a specific case for 1 expert + ff_moe.experts.mlp.w2.module.weight.data = weight_2.contiguous() + else: + # [hidden_size, intermediate_size] -> [intermediate_size * n_experts, hidden_size] + weight_2 = weight_2.T + ff_moe.experts.mlp.w2.module.weight.data = weight_2.repeat(num_experts, 1) + + # # -- could add bias only for 2nd layer, because that works with the MegaBlocks MoE implementation + # # -- but won't make a big difference? + # ff_moe.experts.bias.copy_(dense_ff.c_proj.bias) + + # init gating randomly + nn.init.normal_(ff_moe.gate.layer.weight, mean=0.0, std=0.02) + + +def convert_decoder(block_moe: GPT3MoEBlock, block_nt: GPTBlock, num_experts: int): + convert_generic(block_moe.ln_1, block_nt.ln_1) + convert_generic(block_moe.attn, block_nt.attn) + convert_generic(block_moe.ln_2, block_nt.ln_2) + convert_dense_to_moe(block_moe.ff, block_nt.ff, num_experts) + + +def convert( + model_moe: GPT3MoEForTraining, model_dense: GPT3ForTraining, num_experts: int +): + convert_generic( + model_moe.model.token_embeddings.pp_block.token_embedding, + model_dense.model.token_embeddings.pp_block.token_embedding, + ) + for layer_moe, layer_nt in zip(model_moe.model.decoder, model_dense.model.decoder): + convert_decoder(layer_moe.pp_block, layer_nt.pp_block, num_experts) + convert_generic( + model_moe.model.final_layer_norm.pp_block, + model_dense.model.final_layer_norm.pp_block, + ) + convert_generic( + model_moe.model.lm_head.pp_block, model_dense.model.lm_head.pp_block + ) + + +def create_nt_moe_model( + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None, +): + + if model_config is None: + assert checkpoint_path is not None + with open(checkpoint_path / "model_config.json") as f: + model_config = GPT3MoEConfig(**json.load(f)) + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3MoEForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + if checkpoint_path is not None: + nanotron.serialize.load_weights( + model=model_nt, + parallel_context=parallel_context, + root_folder=checkpoint_path, + ) + + return model_nt + + +def main( + checkpoint_path: Path, + save_path: Path, + num_experts: int, +): + # Load nanotron model. + model_dense = create_nt_model(checkpoint_path=checkpoint_path) + + # Init moe model. + model_config_moe = convert_config(model_dense.config, num_experts) + model_moe = create_nt_moe_model(model_config=model_config_moe) + + convert(model_moe, model_dense, num_experts) + nanotron.serialize.save_weights( + model=model_moe, + parallel_context=model_moe.parallel_context, + root_folder=save_path, + ) + with open(save_path / "model_config.json", "w+") as f: + json.dump(dataclasses.asdict(model_config_moe), f) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + # fix all random seeds + torch.manual_seed(0) + torch.cuda.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.deterministic = True + parser = ArgumentParser(description="Convert dense weights to moe format") + parser.add_argument( + "--checkpoint-path", + type=Path, + default="checkpoints/xglm-7.5B", + help="Path to the nanotron dense checkpoint", + ) + parser.add_argument( + "--save-path", + type=Path, + default="checkpoints/xglm-moe-7.5B", + help="Path to save the nanotron moe model", + ) + parser.add_argument( + "--num-experts", + type=int, + default=8, + help="Number of experts in the MoE model (duplicates of MLP layer)", + ) + args = parser.parse_args() + main(args.checkpoint_path, args.save_path, args.num_experts) diff --git a/examples/xglm/convert_ntmoe2hf.py b/examples/xglm/convert_ntmoe2hf.py new file mode 100644 index 00000000..d971c96f --- /dev/null +++ b/examples/xglm/convert_ntmoe2hf.py @@ -0,0 +1,140 @@ +""" +Converts a nanotron moe model to HF format +Command: + torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights +""" + +import warnings +from argparse import ArgumentParser +from pathlib import Path +from typing import Optional + +import torch +from transformers import AutoTokenizer +from tqdm import tqdm + +from nanotron.config.models_config import GPT3MoEConfig +from nanotron.models.gpt3_moe import GPT3MoEForTraining, GPT3MoEBlock +from nanotron.models.moe import dMoE, SparseMLP, LearnedRouter + +from examples.xglm.convert_dense2moe import create_nt_moe_model +from examples.xglm.convert_nt2hf import convert_attention +from examples.xglm.convert_utils import convert_generic +from examples.xglm.transformers_impl.xglm_model import XGLMForCausalLM, XGLMDecoderLayer, XGLMmoeConfig, XGLMSparseMoeBlock, XGLMMLP +from examples.xglm.transformers_impl.gating import BasicGate + + +def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig: + if config.embd_pdrop != config.resid_pdrop: + warnings.warn( + f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion." + ) + if config.layer_norm_epsilon != 1e-5: + warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") + if config.moe_z_loss_weight != 0: + warnings.warn(f"transformer implementation does not support z loss") + assert not config.moe_glu, "Transformer implementation does not support glu MLP layers" + + return XGLMmoeConfig( + # Regular xglm config. + activation_function=config.activation_function, + attention_dropout=config.attn_pdrop, + dropout=config.embd_pdrop, + eos_token_id=config.eos_token_id, + d_model=config.hidden_size, + ffn_dim=config.intermediate_size, + max_position_embeddings=config.max_position_embeddings, + attention_heads=config.num_attention_heads, + num_layers=config.num_hidden_layers, + vocab_size=config.vocab_size, + decoder_start_token_id=config.position_embedding_offset, + activation_dropout=config.act_pdrop, + scale_embedding=config.scale_embedding, + # Moe specifics. + num_local_experts=config.moe_num_experts, + num_experts_per_tok=config.num_experts_per_tok, + gate_type="linear", + gate_depth=1, + router_aux_loss_coef=config.moe_loss_weight, + ) + + +def convert_mlp(mlp_hf: XGLMMLP, mlp_nt: SparseMLP): + convert_generic(mlp_hf.fc1, mlp_nt.w1.module) + convert_generic(mlp_hf.fc2, mlp_nt.w2.module) + + +def convert_gate(gate_hf: BasicGate, gate_nt: LearnedRouter): + convert_generic(gate_hf.gate, gate_nt.layer) + + +def convert_ff(ff_hf: XGLMSparseMoeBlock, ff_nt: dMoE): + convert_gate(ff_hf.gate, ff_nt.gate) + int_size = ff_nt.config.intermediate_size + if len(ff_hf.experts) == 1: + assert ff_nt.experts.mlp.w1.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + assert ff_nt.experts.mlp.w2.module.weight.shape == (ff_nt.config.hidden_size, int_size*len(ff_hf.experts)) + else: + assert ff_nt.experts.mlp.w1.module.weight.T.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + assert ff_nt.experts.mlp.w2.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + + for i, expert_hf in enumerate(ff_hf.experts): + i0 = i*int_size + i1 = (i + 1)*int_size + with torch.no_grad(): + if len(ff_hf.experts) == 1: + expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight[i0:i1, :].clone()) + expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[:, i0:i1].clone()) + else: + expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight.T[i0:i1, :].clone()) + expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[i0:i1, :].T.clone()) + +def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPT3MoEBlock): + convert_generic(block_hf.self_attn_layer_norm, block_nt.ln_1) + convert_attention(block_hf.self_attn, block_nt.attn) + convert_generic(block_hf.final_layer_norm, block_nt.ln_2) + convert_ff(block_hf.block_sparse_moe, block_nt.ff) + + +def convert(model_hf: XGLMForCausalLM, model_nt: GPT3MoEForTraining): + convert_generic(model_hf.model.embed_tokens, model_nt.model.token_embeddings.pp_block.token_embedding) + for layer_hf, layer_nt in tqdm(zip(model_hf.model.layers, model_nt.model.decoder), desc="Converting layers", + total=model_nt.config.num_hidden_layers): + convert_decoder(layer_hf, layer_nt.pp_block) + convert_generic(model_hf.model.layer_norm, model_nt.model.final_layer_norm.pp_block) + convert_generic(model_hf.lm_head, model_nt.model.lm_head.pp_block) + + +def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): + # Load nanotron model. + model_nt = create_nt_moe_model(checkpoint_path=checkpoint_path) + + # Init huggingface model. + model_config_hf = convert_config(model_nt.config) + model_hf = XGLMForCausalLM._from_config(model_config_hf) + + # Copy weights, initialize tokenizer and save model. + if tokenizer_name is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.save_pretrained(save_path) + states = torch.randn(4, 1, 1024) + convert(model_hf, model_nt), states.cuda().bfloat16() + print("Saving...") + model_hf.save_pretrained(save_path) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument( + "--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model" + ) + parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") + args = parser.parse_args() + ret = main(args.checkpoint_path, args.save_path, args.tokenizer_name) diff --git a/examples/xglm/example_config_moe.yaml b/examples/xglm/example_config_moe.yaml new file mode 100644 index 00000000..aa0d739c --- /dev/null +++ b/examples/xglm/example_config_moe.yaml @@ -0,0 +1,113 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints-xglm-moe-8x564M + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: checkpoints-xglm-moe-8x564M + save_initial_state: false +data_stages: + - data: + dataset: + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation + languages: + - es + - en + - fr + num_loading_workers: 1 + seed: 42 + name: General purpose training (Blended dataset) + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: multilingual-moe-init + run: xglm-topk2-8x564M + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + # for random init: + # std: 0.02 + # or upcycling after using the converter scripts: + path: /mloscratch/homes/haegele/swissai/nanotron-multilingual/xglm/xglm-moe + make_vocab_size_divisible_by: 1 + model_config: + activation_function: gelu + attn_pdrop: 0.0 + embd_pdrop: 0.0 + scale_embedding: true + eos_token_id: 2 + hidden_size: 1024 + intermediate_size: 4096 + layer_norm_epsilon: 0.00001 + max_position_embeddings: 2048 + num_attention_heads: 16 + num_hidden_layers: 24 + resid_pdrop: 0.0 + scale_attention_softmax_in_fp32: true + scale_attn_weights: true + vocab_size: 256008 + sinusoidal_position_embedding: true + position_embedding_offset: 2 + use_spda: false + act_pdrop: 0.0 + is_moe: true + moe_num_experts: 8 + num_experts_per_tok: 2 + moe_loss_weight: 0.01 + moe_z_loss_weight: 0.001 + moe_glu: false +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.001 + lr_decay_starting_step: 2000 + lr_decay_steps: 500 + lr_decay_style: 1-sqrt + lr_warmup_steps: 100 + lr_warmup_style: linear + min_decay_lr: 0 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.1 + zero_stage: 0 +parallelism: + dp: 3 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: facebook/xglm-564M + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 5 + limit_test_batches: 0 + limit_val_batches: 10 + micro_batch_size: 4 # fits on one 80GB A100 for 8x564M + sequence_length: 2048 + train_steps: 2500 + val_check_interval: -1 diff --git a/examples/xglm/tests/test_moe.py b/examples/xglm/tests/test_moe.py new file mode 100644 index 00000000..7536b84f --- /dev/null +++ b/examples/xglm/tests/test_moe.py @@ -0,0 +1,182 @@ +import torch +import pytest + +import nanotron +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.config.models_config import GPT3MoEConfig +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.trainer import mark_tied_parameters +from nanotron.models.gpt3_moe import GPT3MoEBlock, GPT3MoEForTraining +from nanotron.models.moe import LearnedRouter, dMoE + +from tests.helpers.utils import init_distributed + +from examples.xglm.convert_ntmoe2hf import convert_config, convert_gate, convert_ff, convert +from examples.xglm.tests.test_implementation import almost_close +from examples.xglm.transformers_impl.xglm_model import XGLMSparseMoeBlock, XGLMForCausalLM +from examples.xglm.transformers_impl.gating import BasicGate + + +MAX_SEQUENCE_LENGTH = 2048 +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. +#TEST_SEQUENCE_LENGTH = MAX_SEQUENCE_LENGTH +BATCH_SIZE = 4 +HIDDEN_SIZE = 1024 +#DTYPE = torch.bfloat16 +DTYPE = torch.float32 +TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" + +CONFIG = GPT3MoEConfig( + attn_pdrop=0.0, + embd_pdrop=0.0, + resid_pdrop=0.0, + act_pdrop=0.0, + eos_token_id=2, + hidden_size=HIDDEN_SIZE, + intermediate_size=4096, + layer_norm_epsilon=1e-05, + max_position_embeddings=MAX_SEQUENCE_LENGTH, + num_attention_heads=16, + num_hidden_layers=24, + scale_attn_weights=True, + vocab_size=256008, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=DTYPE is not torch.bfloat16, + # vvv moe vvv + is_moe=True, + moe_num_experts=8, + num_experts_per_tok=2, + moe_loss_weight=0.01, + moe_z_loss_weight=0.0, + moe_glu=False, +) +PARALLEL_CONFIG = ParallelismArgs(dp=1, pp=1, tp=1, expert_parallel_size=1) #CONFIG.moe_num_experts) + + +@pytest.fixture +def hidden_states() -> torch.Tensor: + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) + + +@pytest.fixture +def input_mask() -> torch.Tensor: + return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) + + +@pytest.fixture +def input_ids() -> torch.Tensor: + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) + + +def _test_nt2hf_gate(parallel_context: ParallelContext, hidden_states: torch.Tensor): + hidden_states = hidden_states.cuda() + + config_hf = convert_config(CONFIG) + gate_nt = LearnedRouter(CONFIG).cuda().to(DTYPE) + gate_hf = BasicGate(config_hf).cuda().to(DTYPE) + convert_gate(gate_hf, gate_nt) + + router_logits_nt, _, _ = gate_nt(hidden_states.view(-1, HIDDEN_SIZE)) + router_logits_hf = gate_hf(hidden_states.permute(1, 0, 2).reshape(-1, HIDDEN_SIZE), "") + + router_logits_nt = router_logits_nt.view(TEST_SEQUENCE_LENGTH, BATCH_SIZE, -1) + router_logits_hf = router_logits_hf.view(BATCH_SIZE, TEST_SEQUENCE_LENGTH, -1).permute(1, 0, 2) + + assert router_logits_nt.size() == router_logits_hf.size() + torch.testing.assert_close(router_logits_nt, router_logits_hf) + + +def test_nt2hf_gate(hidden_states: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_gate)(hidden_states=hidden_states) + + +def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tensor, + num_experts: int, num_experts_per_tok: int): + hidden_states = hidden_states.cuda() + + config = {**vars(CONFIG)} + config.update({"moe_num_experts": num_experts, "num_experts_per_tok": num_experts_per_tok}) + config = GPT3MoEConfig(**config) + config_hf = convert_config(config) + ff_nt = dMoE(config, parallel_context, PARALLEL_CONFIG).cuda().to(DTYPE) + ff_hf = XGLMSparseMoeBlock(config_hf).cuda().to(DTYPE) + convert_ff(ff_hf, ff_nt) + + out_nt = ff_nt(hidden_states)["hidden_states"] + out_hf, _ = ff_hf(hidden_states.permute(1, 0, 2).contiguous(), "") + out_hf = out_hf.permute(1, 0, 2) + + assert out_nt.size() == out_hf.size() + almost_close(out_nt, out_hf, max_far=0.05, far_atol=0.003) + + +@pytest.mark.parametrize("num_experts,num_experts_per_tok", [(1, 1), (2, 1), (4, 1), (4, 2), (8, 1), (8, 2), (8, 4)]) +def test_nt2hf_ff(hidden_states: torch.Tensor, num_experts: int, num_experts_per_tok: int): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_ff)(hidden_states=hidden_states, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok) + + +def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + + # Get nanotron model. + config_nt = GPT3MoEConfig(**vars(CONFIG)) + if new_dtype not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3MoEForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=new_dtype, + device="cuda", + ).eval() + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + # Create empty model_hf and make conversion. + model_hf = XGLMForCausalLM(convert_config(config_nt)).cuda().to(new_dtype).eval() + convert(model_hf, model_nt) + + # Needed :/ + aux_losses = { + "load_balancing_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + "z_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + } + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask, aux_losses)["sharded_logits"].to(new_dtype) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask, output_router_logits=False).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + return out_nt.cpu(), out_hf.cpu() + + +def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.01, far_atol=2.0) # We allow for less than 1% errors, but some of these are very large! + #torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) + + +def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) diff --git a/examples/xglm/transformers_impl/gating.py b/examples/xglm/transformers_impl/gating.py new file mode 100644 index 00000000..efa0e357 --- /dev/null +++ b/examples/xglm/transformers_impl/gating.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import math + +from abc import ABC, abstractmethod + + +class Gate(ABC): + def __init__(self, device): + super(Gate, self).__init__() + self.device = device + + @abstractmethod + def compute(self, x): + """ + Compute the output of the gate. + This method should be implemented by all subclasses. + """ + pass + + +def init_x_embeddings(Xs, x_embedding_dim): + x2embeddings = nn.ParameterDict(dict()) + for x in Xs: + x_embedding = torch.empty(x_embedding_dim) + nn.init.normal_(x_embedding) + x2embeddings[str(x)] = nn.Parameter(x_embedding) + return x2embeddings + + +class BasicGate(nn.Module): + """One or two layer feedforward network as the Gate.""" + + def __init__(self, config) -> None: + super().__init__() + + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.ffn_dim = config.ffn_dim + self.activation = nn.ReLU(self.ffn_dim) + + if config.gate_depth == 1: + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + elif config.gate_depth == 2: + self.gate = nn.Sequential( + nn.Linear(self.hidden_dim, self.ffn_dim), + self.activation, + nn.Linear(self.ffn_dim, self.num_experts, bias=False), + ) + else: + raise ValueError("Invalid gate_depth!") + + def forward(self, x, lang_name): + return self.gate(x) + + +class LanguageAwareGate(nn.Module): + """One or two layer feedforward network as the Gate.""" + + def __init__(self, config) -> None: + super().__init__() + + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.ffn_dim = config.ffn_dim + self.activation = nn.ReLU(self.ffn_dim) + self.language_embedding_dim = ( + config.language_embedding_dim + if config.language_embedding_dim is not None + else config.hidden_size + ) + self.lang_embeddings = init_x_embeddings( + config.languages, self.language_embedding_dim + ) + + if config.gate_depth == 1: + self.gate = nn.Linear( + self.hidden_dim + self.language_embedding_dim, + self.num_experts, + bias=False, + ) + elif config.gate_depth == 2: + self.gate = nn.Sequential( + nn.Linear(self.hidden_dim, self.ffn_dim), + self.activation, + nn.Linear(self.ffn_dim, self.num_experts, bias=False), + ) + else: + raise ValueError("Invalid gate_depth!") + + def forward(self, x, lang_name): + # TODO x needs to be added to the language embedding (we need to pass the language as well) + lang_embedding = self.lang_embeddings[str(lang_name)] + lang_embedding.squeeze(0) + lang_embedding = lang_embedding.expand(x.shape[0], -1) + x = torch.cat((x, lang_embedding), dim=-1) + return self.gate(x) + + +class TopKGate(Gate): + def __init__(self, device, straight_through, k=1): + super(TopKGate, self).__init__(device) + self.k = k + self.device = device + self.straight_through = straight_through + + def compute(self, x): + if self.k > 1: + topk_gate_scores, indices = torch.topk(x, self.k) + topk_gate_scores = F.softmax( + topk_gate_scores, + dim=1, + dtype=torch.float, + ).type_as(x) + mask = F.one_hot(indices, x.shape[-1]).float() + mask_flat = mask.sum(dim=-1) + combine_tensor = ( + topk_gate_scores[..., None, None, None] + * mask_flat[..., None, None, None] + * F.one_hot(indices, x.shape[-1])[..., None, None] + ) + combine_tensor = combine_tensor.sum(1) + return combine_tensor, indices, topk_gate_scores + elif self.k == 1: + x = F.softmax(x, dim=-1) + topk_gate_scores, index = x.topk( + k=self.k, dim=-1 + ) # torch.nn.functional.softmax(x , dim=-1).topk(k=self.k, dim=-1) + if self.straight_through: + index_soft = F.softmax(x, dim=-1) + index = (index - index_soft).detach() + index_soft + index = index[:, 0] + topk_gate_scores, index = map( + lambda x: x.squeeze(dim=-1), (topk_gate_scores, index) + ) + else: + topk_gate_scores, index = map( + lambda x: x.squeeze(dim=-1), (topk_gate_scores, index) + ) + + mask = F.one_hot(index, x.shape[-1]).float() + mask_flat = mask.sum(dim=-1) + combine_tensor = ( + topk_gate_scores[..., None, None, None] + * mask_flat[..., None, None, None] + * F.one_hot(index, x.shape[-1])[..., None, None] + ) + return combine_tensor, index, topk_gate_scores diff --git a/examples/xglm/transformers_impl/xglm_model.py b/examples/xglm/transformers_impl/xglm_model.py new file mode 100644 index 00000000..aa80a7fe --- /dev/null +++ b/examples/xglm/transformers_impl/xglm_model.py @@ -0,0 +1,1119 @@ +from typing import List, Optional, Tuple, Union, Literal +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.nn import CrossEntropyLoss + +from transformers.models.xglm.modeling_xglm import ( + XGLMAttention, + XGLMPreTrainedModel, + XGLMSinusoidalPositionalEmbedding, +) +from transformers.models.xglm.configuration_xglm import XGLMConfig +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) + +# from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + add_code_sample_docstrings, + ModelOutput, +) +from examples.xglm.transformers_impl.gating import BasicGate, LanguageAwareGate + +from accelerate.logging import get_logger + +logger = get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/xglm-564M" +_CONFIG_FOR_DOC = "XGLMConfig" + + +class XGLMmoeConfig(XGLMConfig): + def __init__(self, + num_local_experts: int = 1, + num_experts_per_tok: int = 1, + gate_type: Literal["linear", "lang_aware_linear"] = "linear", + gate_depth: Literal[1, 2] = 1, + router_aux_loss_coef: float = 1.0, + **kwargs): + + super().__init__(**kwargs) + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.gate_type = gate_type + self.gate_depth = gate_depth + self.router_aux_loss_coef = router_aux_loss_coef + + +@dataclass +class MoEModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MoECausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden + states terms, to train a MoE model. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + z_loss for the sparse modules. + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + z_loss: torch.FloatTensor = None + aux_loss: torch.FloatTensor = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +XGLM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`XGLMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" +XGLM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def load_balancing_loss_func( + gate_logits: torch.Tensor, + num_experts: torch.Tensor = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return torch.zeros(1) + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class XGLMMLP(nn.Module): + + def __init__(self, config: XGLMConfig): + super().__init__() + self.ffn_dim = config.ffn_dim + self.hidden_dim = config.d_model + + self.fc1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.fc2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + + self.activation_fn = ACT2FN[config.activation_function] + self.dropout = config.dropout + self.activation_dropout = config.activation_dropout + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) + hidden_states = self.fc2(hidden_states) + # TODO this dropout can be removed if we use the dropout in the XGLMDecoderLayer + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + return hidden_states + + def set_mlp_weights(self, fc1, fc2, std=0.02): + """Set the weights of the MLP to the given weights""" + + self.fc1.weight.data = fc1.weight.data.clone() + # norm1_pre = self.fc1.weight.data.norm(2) + self.fc1.weight.data.add_( + torch.normal(mean=0, std=std, size=self.fc1.weight.shape) + ) + # norm1_after = self.fc1.weight.data.norm(2) + + self.fc1.bias.data = fc1.bias.data.clone() + self.fc1.bias.data.add_(torch.normal(mean=0, std=std, size=self.fc1.bias.shape)) + + self.fc2.weight.data = fc2.weight.data.clone() + # norm2_pre = self.fc2.weight.data.norm(2) + self.fc2.weight.data.add_( + torch.normal(mean=0, std=std, size=self.fc2.weight.shape) + ) + # norm2_after = self.fc2.weight.data.norm(2) + self.fc2.bias.data = fc2.bias.data.clone() + self.fc2.bias.data.add_(torch.normal(mean=0, std=std, size=self.fc2.bias.shape)) + + # print("********") + # print(f"Norm1: {norm1_pre} -> {norm1_after}") + # print(f"Norm2: {norm2_pre} -> {norm2_after}") + # print("********") + + +class XGLMSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + # self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating network + if config.gate_type == "linear": + self.gate = BasicGate(config) + elif config.gate_type == "lang_aware_linear": + self.gate = LanguageAwareGate(config) + # self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([XGLMMLP(config) for _ in range(self.num_experts)]) + + def forward(self, hidden_states: torch.Tensor, lang_name: str) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states, lang_name) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + if self.top_k == 1: + routing_weights, selected_experts = routing_weights.max(dim=-1, keepdim=True) + else: + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) + * routing_weights[top_x_list, idx_list, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states, router_logits + + def set_mlp_weights(self, fc1, fc2, std=0.02): + """Set the expert (MLP) weights by given weights""" + for i in range(self.num_experts): + self.experts[i].set_mlp_weights(fc1, fc2, std) + + +class XGLMDecoderLayer(nn.Module): + def __init__(self, config: XGLMConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = XGLMAttention( + embed_dim=self.embed_dim, + num_heads=config.attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + # self.activation_fn = ACT2FN[config.activation_function] + # self.activation_dropout = config.activation_dropout + + if config.add_cross_attention: + self.encoder_attn = XGLMAttention( + embed_dim=self.embed_dim, + num_heads=config.attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.block_sparse_moe = XGLMSparseMoeBlock(config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = True, + lang_name: Optional[str] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = ( + past_key_value[-2:] if past_key_value is not None else None + ) + hidden_states, cross_attn_weights, cross_attn_present_key_value = ( + self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states, router_logits = self.block_sparse_moe(hidden_states, lang_name) + # hidden_states = self.activation_fn(self.fc1(hidden_states)) + # hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + # hidden_states = self.fc2(hidden_states) + # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +@add_start_docstrings( + "The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.", + XGLM_START_DOCSTRING, +) +class XGLMModel(XGLMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_layers* layers. Each layer is a [`XGLMDecoderLayer`] + + Args: + config: XGLMConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding( + config.vocab_size, config.d_model, self.padding_idx + ) + + self.embed_positions = XGLMSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + config.pad_token_id, + ) + self.layers = nn.ModuleList( + [XGLMDecoderLayer(config) for _ in range(config.num_layers)] + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MoEModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + lang_name: str = None, + ) -> Union[Tuple[torch.Tensor], MoEModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + input_shape[-1] + past_key_values_length, + dtype=torch.long, + device=( + input_ids.device if input_ids is not None else inputs_embeds.device + ), + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + hidden_states = inputs_embeds + self.embed_positions( + position_ids, past_key_values_length + ) + hidden_states = nn.functional.dropout( + hidden_states, p=float(self.dropout), training=self.training + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache =" + " False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + all_cross_attentions = ( + () if (output_attentions and encoder_hidden_states is not None) else None + ) + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip( + [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"] + ): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + ( + cross_attn_head_mask[idx] + if cross_attn_head_mask is not None + else None + ), + None, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] + if cross_attn_head_mask is not None + else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + output_router_logits=output_router_logits, + lang_name=lang_name, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return MoEModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + router_logits=all_router_logits, + ) + + +@add_start_docstrings( + """ + The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + XGLM_START_DOCSTRING, +) +class XGLMForCausalLM(XGLMPreTrainedModel): + base_model_prefix = "model" + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = XGLMModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_local_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MoECausalLMOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + lang_name=None, + ) -> Union[Tuple[torch.Tensor], MoECausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + lang_name=lang_name, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + # shift labels and add a pad token to the end + shift_labels = labels.new_zeros(labels.shape) + shift_labels[:, :-1] = labels[:, 1:].clone() + shift_labels[:, -1] = self.config.pad_token_id + + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.config.vocab_size), shift_labels.view(-1) + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_local_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithCrossAttentions( + return MoECausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + use_cache=None, + output_router_logits=False, + **kwargs, + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "output_router_logits": output_router_logits, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +def set_moe_layers(model: nn.Module, config: XGLMConfig): + """Replaces the model's original layer with an MoE one, and sest the weights of the experts (MLPs) to the original MLP weights""" + new_layers = [] + for i in range(config.num_layers): + layer = model.model.layers[i] + moe_layer = XGLMDecoderLayer(config) + moe_layer.block_sparse_moe.set_mlp_weights( + layer.fc1, layer.fc2, config.expert_init_std + ) + new_layers.append(moe_layer) + model.model.layers = nn.ModuleList(new_layers) + + +def freeze_parameters(model, args): + """Freezes paramaters if it's needed.""" + if args.freeze_non_moe_params: + for name, param in model.named_parameters(): + name = name.lower() + if "expert" not in name and "gate" not in name and "lm_head" not in name: + param.requires_grad = False + else: + param.requires_grad = True + + logger.info("non-MoE + lm_head parameters are frozen!") + + model.lm_head.weight.requires_grad = True + # model.lm_head.bias.requires_grad = True + + # for name, param in model.named_parameters(): + # if "block_sparse_moe" not in name: + # print(f"{name}: {param.requires_grad}") + + +def copy_model_weights(pretrained_model: nn.Module, moe_model: XGLMForCausalLM, config: XGLMConfig): + """Replaces the model's original layer with an MoE one, and sest the weights of the experts (MLPs) to the original MLP weights""" + state_dicts = pretrained_model.state_dict() + moe_model.load_state_dict(state_dicts, strict=False) + + # new_layers = [] + for i in range(config.num_layers): + layer = pretrained_model.model.layers[i] + moe_model.model.layers[i].block_sparse_moe.set_mlp_weights( + layer.fc1, layer.fc2, config.init_std + ) + # moe_layer = XGLMDecoderLayer(config) + # moe_layer.block_sparse_moe.set_mlp_weights(layer.fc1, layer.fc2, config.init_std) + # new_layers.append(moe_layer) + # moe_model.model.layers = nn.ModuleList(new_layers) diff --git a/moe.md b/moe.md new file mode 100644 index 00000000..bb67501e --- /dev/null +++ b/moe.md @@ -0,0 +1,58 @@ +# MoE Env Setup + +TL;DR: need to install megablocks for MoEs. just use the environment `/store/swissai/a06/containers/nanotron_moe/nanotron_moe.toml` :) + +The setup is documented in that folder on the cluster. The setup is: + +```Dockerfile +FROM nvcr.io/nvidia/pytorch:24.05-py3 + +# setup +RUN apt-get update && apt-get install -y \ + python3-pip \ + python3-venv \ + git tmux htop nvtop \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +RUN pip install --upgrade pip setuptools==69.5.1 + +# Update flash-attn. +RUN pip install --upgrade --no-build-isolation flash-attn==2.5.8 +# Install the rest of dependencies. +RUN pip install \ + datasets \ + transformers \ + wandb \ + dacite \ + pyyaml \ + numpy \ + packaging \ + safetensors \ + sentencepiece \ + tqdm + +WORKDIR /workspace +RUN git clone https://github.com/swiss-ai/nanotron.git +WORKDIR /workspace/nanotron +RUN pip install -e .[nanosets] + +RUN pip install megablocks==0.5.1 stanford-stk==0.7.1 --no-deps +``` + +The env `nanotron-moe.toml` with content: +``` +image = "/store/swissai/a06/containers/nanotron_moe/nanotron_moe.sqsh" + +mounts = ["/capstor", "/users", "/store"] +workdir = "/workspace/nanotron" +writable = true + +[annotations] +com.hooks.aws_ofi_nccl.enabled = "true" +com.hooks.aws_ofi_nccl.variant = "cuda12" + +[env] +FI_CXI_DISABLE_HOST_REGISTER = "1" +FI_MR_CACHE_MONITOR = "userfaultfd" +NCCL_DEBUG = "INFO" +``` diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index d2b39441..ad80b82a 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -11,7 +11,12 @@ from yaml.loader import SafeLoader from nanotron.config.lighteval_config import LightEvalConfig -from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit +from nanotron.config.models_config import ( + ExistingCheckpointInit, + NanotronConfigs, + RandomInit, + SpectralMupInit, +) from nanotron.config.parallelism_config import ParallelismArgs from nanotron.config.utils_config import ( RecomputeGranularity, @@ -100,8 +105,12 @@ def __post_init__(self): self.dataset_folder = [self.dataset_folder] self.dataset_weights = [1] elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder - self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights + self.dataset_weights = ( + None # Set to None so we consume all the samples randomly + ) + elif isinstance( + self.dataset_folder, dict + ): # Case 3: dict with > 1 dataset_folder and weights tmp_dataset_folder = self.dataset_folder.copy() self.dataset_folder = list(tmp_dataset_folder.keys()) self.dataset_weights = list(tmp_dataset_folder.values()) @@ -111,7 +120,9 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB + languages: List[ + str + ] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -119,8 +130,45 @@ def __post_init__(self): self.validation_folder = [self.validation_folder] self.dataset_weights = [1] elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder - self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.training_folder, dict): # Case 3: dict with > 1 training_folder and weights + self.dataset_weights = ( + None # Set to None so we consume all the samples randomly + ) + elif isinstance( + self.training_folder, dict + ): # Case 3: dict with > 1 training_folder and weights + tmp_training_folder = self.training_folder.copy() + self.training_folder = list(tmp_training_folder.keys()) + self.dataset_weights = list(tmp_training_folder.values()) + + assert len(self.training_folder) == len( + self.languages + ), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})" + + assert len(self.training_folder) == len( + self.validation_folder + ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" + + +@dataclass +class MultilingualNanosetDatasetsArgs: + training_folder: Union[str, dict, List[str]] + validation_folder: Union[str, List[str]] + languages: List[ + str + ] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB + + def __post_init__(self): + if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder + self.training_folder = [self.training_folder] + self.validation_folder = [self.validation_folder] + self.dataset_weights = [1] + elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder + self.dataset_weights = ( + None # Set to None so we consume all the samples randomly + ) + elif isinstance( + self.training_folder, dict + ): # Case 3: dict with > 1 training_folder and weights tmp_training_folder = self.training_folder.copy() self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) @@ -138,7 +186,12 @@ def __post_init__(self): class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, MultilingualNanosetDatasetsArgs] + dataset: Union[ + PretrainDatasetsArgs, + NanosetDatasetsArgs, + MultilingualNanosetDatasetsArgs, + MultilingualNanosetDatasetsArgs, + ] seed: Optional[int] num_loading_workers: Optional[int] = 1 @@ -157,7 +210,9 @@ class DatasetStageArgs: def __post_init__(self): if self.start_training_step < 0: - raise ValueError(f"training_steps should be a positive integer and not {self.start_training_step}") + raise ValueError( + f"training_steps should be a positive integer and not {self.start_training_step}" + ) @dataclass @@ -378,13 +433,19 @@ def __post_init__(self): if self.profiler is not None and self.profiler.profiler_export_path is not None: assert self.tokens.train_steps < 10 - if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None: + if ( + self.optimizer is not None + and self.optimizer.learning_rate_scheduler.lr_decay_steps is None + ): self.optimizer.learning_rate_scheduler.lr_decay_steps = ( - self.tokens.train_steps - self.optimizer.learning_rate_scheduler.lr_warmup_steps + self.tokens.train_steps + - self.optimizer.learning_rate_scheduler.lr_warmup_steps ) if self.data_stages is not None: - self.data_stages = sorted(self.data_stages, key=lambda stage: stage.start_training_step) + self.data_stages = sorted( + self.data_stages, key=lambda stage: stage.start_training_step + ) names = [stage.name for stage in self.data_stages] training_steps = [stage.start_training_step for stage in self.data_stages] assert any( @@ -393,7 +454,9 @@ def __post_init__(self): for stage in self.data_stages: if names.count(stage.name) > 1: - raise ValueError(f"Each stage should have unique names and not {names}") + raise ValueError( + f"Each stage should have unique names and not {names}" + ) if training_steps.count(stage.start_training_step) > 1: raise ValueError( @@ -402,13 +465,29 @@ def __post_init__(self): # NOTE: must order the stages by start_training_step from lowest to highest assert all( - self.data_stages[i].start_training_step < self.data_stages[i + 1].start_training_step + self.data_stages[i].start_training_step + < self.data_stages[i + 1].start_training_step for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we # must comply with val_check_interval % iteration_step_info_interval = 0 - if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0: + if ( + not self.tokens.val_check_interval + % self.logging.iteration_step_info_interval + == 0 + ): + raise ValueError( + f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}" + ) + + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + if ( + not self.tokens.val_check_interval + % self.logging.iteration_step_info_interval + == 0 + ): raise ValueError( f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}" ) @@ -419,7 +498,11 @@ def __post_init__(self): @property def global_batch_size(self): - return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp + return ( + self.tokens.micro_batch_size + * self.tokens.batch_accumulation_per_replica + * self.parallelism.dp + ) def save_as_yaml(self, file_path: str): config_dict = serialize(self) @@ -435,7 +518,10 @@ def as_dict(self) -> dict: def get_config_from_dict( - config_dict: dict, config_class: Type = Config, skip_unused_config_keys: bool = False, skip_null_keys: bool = False + config_dict: dict, + config_class: Type = Config, + skip_unused_config_keys: bool = False, + skip_null_keys: bool = False, ): """Get a config object from a dictionary @@ -448,12 +534,18 @@ def get_config_from_dict( if skip_unused_config_keys: logger.warning("skip_unused_config_keys set") config_dict = { - field.name: config_dict[field.name] for field in fields(config_class) if field.name in config_dict + field.name: config_dict[field.name] + for field in fields(config_class) + if field.name in config_dict } if skip_null_keys: logger.warning("Skip_null_keys set") config_dict = { - k: {kk: vv for kk, vv in v.items() if vv is not None} if isinstance(v, dict) else v + k: ( + {kk: vv for kk, vv in v.items() if vv is not None} + if isinstance(v, dict) + else v + ) for k, v in config_dict.items() if v is not None } diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 208091c9..3fbcac49 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -61,6 +61,8 @@ class LightEvalTasksArgs: multichoice_continuations_start_space: Optional[bool] = None no_multichoice_continuations_start_space: Optional[bool] = None + langs: Optional[str] = "en" + @dataclass class LightEvalWandbLoggerConfig: @@ -92,3 +94,34 @@ class LightEvalConfig: tasks: Optional[LightEvalTasksArgs] = None logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None + + +# batch_size: 16 +# checkpoints_path: null +# generation: null +# logging: +# hub_repo_details: null +# hub_repo_results: null +# hub_repo_tensorboard: null +# local_output_path: /capstor/scratch/cscs/$USER/multilingual_data_mixture/eval_results +# push_details_to_hub: false +# push_results_to_hub: false +# push_results_to_tensorboard: true +# tensorboard_metric_prefix: eval_ +# wandb: null +# parallelism: +# dp: 1 +# pp: 1 +# pp_engine: 1f1b +# tp: 1 +# tp_linear_async_communication: false +# tp_mode: ALL_REDUCE +# tasks: +# custom_tasks: /capstor/scratch/cscs/$USER/lighteval-multilingual/src/lighteval/community_tasks/multilingual/configs/multilingual.py +# dataset_loading_processes: 1 +# max_samples: 10 +# multichoice_continuations_start_space: null +# no_multichoice_continuations_start_space: null +# num_fewshot_seeds: 1 +# tasks: x_nli +# langs: en,fr \ No newline at end of file diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index af7db5cc..72158651 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -170,7 +170,10 @@ def as_starcoder2(self) -> Starcoder2Config: if "_is_using_mup" in config: del config["_is_using_mup"] return Starcoder2Config( - grouped_query=True, num_kv_heads=self.num_attention_heads, use_rotary_embeddings=False, **config + grouped_query=True, + num_kv_heads=self.num_attention_heads, + use_rotary_embeddings=False, + **config, ) @property @@ -178,4 +181,85 @@ def n_inner(self): return self.intermediate_size -NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config +@dataclass +class GPT3MoEConfig: + """Configuration for a GPT3 __MoE__ model""" + + activation_function: str = "gelu" + attn_pdrop: float = 0.1 + embd_pdrop: float = 0.1 + eos_token_id: int = 49152 + hidden_size: int = 2048 + intermediate_size: Optional[int] = None + layer_norm_epsilon: float = 1e-05 + max_position_embeddings: int = 4096 + num_attention_heads: int = 16 + num_hidden_layers: int = 24 + resid_pdrop: float = 0.1 + scale_attention_softmax_in_fp32: bool = True + scale_attn_weights: bool = True + vocab_size: int = 49280 + sinusoidal_position_embedding: bool = True + position_embedding_offset: int = 2 + use_spda: bool = False + act_pdrop: float = 0.0 + scale_embedding: bool = True + # MoE specific + is_moe: bool = True + moe_num_experts: int = 1 + num_experts_per_tok: int = 1 + moe_loss_weight: float = 0.01 + moe_z_loss_weight: float = 0.001 + moe_glu: bool = False + + def as_gpt3(self) -> GPT3Config: + config = dict(**vars(self)) + + # Moe + del config["is_moe"] + del config["moe_num_experts"] + del config["num_experts_per_tok"] + del config["moe_loss_weight"] + del config["moe_z_loss_weight"] + del config["moe_glu"] + + if "_is_using_mup" in config: + del config["_is_using_mup"] + return GPT3Config(**config) + + def as_starcoder2(self) -> Starcoder2Config: + # same as gpt3 conversion above + config = dict(**vars(self)) + del config["sinusoidal_position_embedding"] + del config["use_spda"] + del config["position_embedding_offset"] + del config["act_pdrop"] + del config["scale_embedding"] + + # Moe + del config["is_moe"] + del config["moe_num_experts"] + del config["num_experts_per_tok"] + del config["moe_loss_weight"] + del config["moe_z_loss_weight"] + del config["moe_glu"] + + if "_is_using_mup" in config: + del config["_is_using_mup"] + return Starcoder2Config( + grouped_query=True, + num_kv_heads=self.num_attention_heads, + use_rotary_embeddings=False, + **config, + ) + + @property + def n_inner(self): + return self.intermediate_size + + @property + def hidden_act(self): + return self.activation_function + + +NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config | GPT3MoEConfig diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 8eec5549..9a8a6c68 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -38,7 +38,9 @@ def __init__( # Checks if isinstance(dataset_folders, str): - warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]") + warnings.warn( + "dataset_folders should be of type List[str] but str was provided. Converting to List[str]" + ) dataset_folders = [dataset_folders] # Init @@ -63,7 +65,9 @@ def __init__( # Build Nanoset Index ## To build the index we need the length of each dataset - self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets] + self.dataset_lengths = [ + len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets + ] ## Set dataset weights if ( dataset_weights is None @@ -76,10 +80,14 @@ def __init__( ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index if is_valid: # Valid MultilingualNanoset - self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(self.dataset_lengths) + self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index( + self.dataset_lengths + ) else: # Train MultilingualNanoset - self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() + self.dataset_index, self.dataset_sample_index = ( + self.build_train_nanoset_index() + ) self.print_nanoset_info() @@ -118,7 +126,9 @@ def build_train_nanoset_index(self) -> np.ndarray: num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1 # Build the dataset indexes for 1 epoch dataset_index, dataset_sample_index = build_train_nanoset_index_helper( - n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths + n_samples=samples_per_epoch, + weights=self.dataset_weights, + dataset_sizes=self.dataset_lengths, ) # Shuffle the indexes the same way numpy_random_state = np.random.RandomState(self.random_seed) @@ -127,7 +137,9 @@ def build_train_nanoset_index(self) -> np.ndarray: numpy_random_state.shuffle(dataset_sample_index) # Concatenate num_epochs the shuffled indexes dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)]) - dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)]) + dataset_sample_index = np.concatenate( + [dataset_sample_index for _ in range(num_epochs)] + ) # Just keep the necessary samples dataset_index = dataset_index[: self.train_split_num_samples] dataset_sample_index = dataset_sample_index[: self.train_split_num_samples] @@ -150,7 +162,9 @@ def print_nanoset_info(self): ) # Print samples from each dataset + weight - dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders)) + dataset_sample_count = count_dataset_indexes( + self.dataset_index, len(self.dataset_folders) + ) for index, sample_count in enumerate(dataset_sample_count): log_rank( f"> Total number of {'validation' if self.is_valid else 'training'} samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", @@ -172,7 +186,9 @@ def build_train_nanoset_index_helper( """ # Create empty arrays for dataset indices and dataset sample indices dataset_index = np.empty((n_samples,), dtype="uint") - dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples + dataset_sample_index = np.empty( + (n_samples,), dtype="long" + ) # Supports dataset with up to 2**64 samples # Initialize buffer for number of samples used for each dataset current_samples = np.zeros((len(weights),), dtype="long") @@ -189,7 +205,9 @@ def build_train_nanoset_index_helper( # Assign the dataset index and update the sample index dataset_index[sample_idx] = max_error_index - dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index] + dataset_sample_index[sample_idx] = ( + current_samples[max_error_index] % dataset_sizes[max_error_index] + ) # Update the total samples for the selected dataset current_samples[max_error_index] += 1 @@ -209,4 +227,6 @@ def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: dataset_index.extend([i] * length) dataset_sample_index.extend(range(length)) - return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") + return np.array(dataset_index, dtype="uint"), np.array( + dataset_sample_index, dtype="long" + ) diff --git a/src/nanotron/models/gpt3_moe.py b/src/nanotron/models/gpt3_moe.py new file mode 100644 index 00000000..06a624ac --- /dev/null +++ b/src/nanotron/models/gpt3_moe.py @@ -0,0 +1,384 @@ +"""PyTorch GPT-3 MoE model.""" + +from contextlib import contextmanager +from typing import Dict, Optional, Union + +import torch +from torch import nn + +from nanotron import distributed as dist +from nanotron.config import GPT3Config, GPT3MoEConfig, ParallelismArgs +from nanotron.models import gpt3 +from nanotron.models.gpt3 import CausalSelfAttention, GPT3ForTraining, GPT3Model, dropout_add_fused_train +from nanotron.models.gpt3 import GPTBlock as GPT3Block +from nanotron.models.moe import ( + dMoE, +) +from nanotron.nn.layer_norm import TritonLayerNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.block import PipelineBlock +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear +from nanotron.random import RandomStates, branch_random_state + + +@contextmanager +def replace_moe_decoder(gpt3config: GPT3MoEConfig): + orig = gpt3.PipelineBlock + try: + + def create_pp_block(module_builder, module_kwargs, **kwargs): + if module_builder is GPT3Block: + # GPT3's GPT module is trying to instantiate a GPT3 GPTBlock. + # Let's return a PipelineBlock with a GPT3Block instead. + # This also requires to replace starcoders2's config with gpt3's config. + module_kwargs["config"] = gpt3config + return orig(module_builder=GPT3MoEBlock, module_kwargs=module_kwargs, **kwargs) + # Else, they are setting up other modules, which we also want unchanged. + return orig(module_builder=module_builder, module_kwargs=module_kwargs, **kwargs) + + gpt3.PipelineBlock = create_pp_block + yield + finally: + gpt3.PipelineBlock = orig + + +@contextmanager +def replace_gpt3_moe_model(gpt3moeconfig: GPT3MoEConfig): + orig = gpt3.GPT3Model + try: + + def create_moe_model( + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + return GPT3MoEModel(gpt3moeconfig, parallel_context, parallel_config, random_states) + + gpt3.GPT3Model = create_moe_model + yield + finally: + gpt3.GPT3Model = orig + + +class GPT3MoEBlock(nn.Module): + def __init__( + self, + config: GPT3MoEConfig, + parallel_config: Optional[ParallelismArgs], + parallel_context: ParallelContext, + tp_pg: dist.ProcessGroup, + random_states: RandomStates, + layer_idx: int, + ): + super(GPT3MoEBlock, self).__init__() + self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.attn = CausalSelfAttention( + config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx + ) + self.attn_dropout = config.attn_pdrop + + self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.ff = dMoE( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + self.ff_dropout = config.resid_pdrop + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + def forward( + self, + hidden_states: torch.Tensor | TensorPointer, + sequence_mask: torch.Tensor | TensorPointer, + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]] = None, + ) -> dict[str, torch.Tensor | TensorPointer]: + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + # hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + # return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) + else: + # No need for random state context manager + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + mlp_output = self.ff(hidden_states=hidden_states) + hidden_states = mlp_output["hidden_states"] + + if aux_losses is not None: + for key, value in mlp_output.items(): + if key != "hidden_states": + aux_losses[key] = aux_losses[key] + value + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) + else: + # No need for random state context manager + hidden_states = hidden_states + residual + + return {"hidden_states": hidden_states, "sequence_mask": output["sequence_mask"], "aux_losses": aux_losses} + + +class GPT3MoEModel(GPT3Model): + def __init__( + self, + config: GPT3MoEConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + with replace_moe_decoder(config): + super().__init__(config.as_gpt3(), parallel_context, parallel_config, random_states) + + # need to adapt the decoder list because we pass the aux_losses around + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=GPT3MoEBlock, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "random_states": random_states, + "parallel_context": parallel_context, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "sequence_mask", "aux_losses"}, + module_output_keys={"hidden_states", "sequence_mask", "aux_losses"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length] + input_mask: torch.Tensor | TensorPointer, # [batch_size, seq_length] + aux_losses: Optional[Dict[str, Union[torch.Tensor, TensorPointer]]] = None, + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + input_embeds = ( + self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"] * self.embed_scale + ) + # TODO: position_ids could be cached. + position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) + position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] + hidden_states = input_embeds + position_embeds + + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode == TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = self.embeds_dropout(input=hidden_states)["hidden_states"] + + hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask, "aux_losses": aux_losses} + for encoder_block in self.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states) + # return hidden_encoder_states["hidden_states"] + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + if aux_losses is not None: + return {"sharded_logits": fp32_sharded_logits, "aux_losses": hidden_encoder_states["aux_losses"]} + else: + return fp32_sharded_logits + + +class GPT3MoEForTraining(GPT3ForTraining): + def __init__( + self, + config: GPT3MoEConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + with replace_gpt3_moe_model(config): + super().__init__(config.as_gpt3(), parallel_context, parallel_config, random_states) + self.config = config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + input_mask: Union[torch.Tensor, TensorPointer], + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] TODO + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # aux_losses are used for load balancing in case of MoEs + aux_losses = { + "load_balancing_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + "z_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + } + model_output = self.model( + input_ids=input_ids, + input_mask=input_mask, + aux_losses=aux_losses, + ) + outputs = self.loss( + sharded_logits=model_output["sharded_logits"], + label_ids=label_ids, + label_mask=label_mask, + ) + + outputs["loss"] = torch.mean(outputs["sample_loss"]) + if isinstance(model_output["aux_losses"], dict): + for key, value in model_output["aux_losses"].items(): + outputs[key] = value + return outputs + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + model_config = self.config + d_ff = model_config.n_inner if model_config.intermediate_size is not None else 4 * model_config.hidden_size + d_qkv = model_config.hidden_size // model_config.num_attention_heads + # active experts + routing + mlp_cost = ( + 2 * d_ff * model_config.hidden_size * model_config.num_experts_per_tok + + model_config.hidden_size * model_config.moe_num_experts + ) + att_cost = 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + block_compute_costs = { + # CausalSelfAttention (qkv proj + attn out) + MLP + GPT3MoEBlock: att_cost + mlp_cost, + # This is the last lm_head + TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + } + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + world_size = self.parallel_context.world_pg.size() + model_flops, hardware_flops = get_flops( + num_layers=self.config.num_hidden_layers, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + vocab_size=self.config.vocab_size, + ffn_hidden_size=self.config.n_inner if self.config.n_inner is not None else 4 * self.config.hidden_size, + seq_len=sequence_length, + batch_size=global_batch_size, + kv_channels=None, + glu_activation=False, + num_experts=self.config.moe_num_experts, + num_experts_per_tok=self.config.num_experts_per_tok, + ) + model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + return model_flops_per_s, hardware_flops_per_s + + +def get_flops( + num_layers, + hidden_size, + num_heads, + vocab_size, + seq_len, + kv_channels=None, + ffn_hidden_size=None, + batch_size=1, + glu_activation=False, + num_experts=1, + num_experts_per_tok=1, +): + """Counts flops in an decoder-only model + Args: + num_layers: number of decoder layers + hidden_size: hidden size of the model + num_heads: number of heads in the model + kv_channels: hidden size of the key and value heads + ffn_hidden_size: hidden size of the FFN + vocab_size: size of the vocabulary + seq_len: sequence length of the decoder + batch_size: batch size + glu_activation: Whether to use GLU activation in FFN. Check T5 v1.1 for more info. + num_experts_per_tok: number of experts per token in the MoE layer + Returns: + model_flops: flops in the model (should be independent of the hardware and model implementation) + hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf + """ + + if kv_channels is None: + assert hidden_size % num_heads == 0 + kv_channels = hidden_size // num_heads + if ffn_hidden_size is None: + ffn_hidden_size = 4 * hidden_size + + # In the following we mark the reduced dimension with parentheses + # decoder + # self attention (MQA) + ## q projection + decoder_q_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * kv_channels + ## kv projection, shared across heads + decoder_kv_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * kv_channels + ## qk logits + decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * seq_len + ### SWA (sliding window attention / local attention) + # window_size = 4096 + # decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * window_size + ## v logits + decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * kv_channels + # decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (window_size) * kv_channels + ## attn out + decoder_attn_out_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * hidden_size + # FF + ## 1st layer + decoder_ffn_1_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + if glu_activation: + # 3 matmuls instead of 2 in FFN + # ref. https://arxiv.org/pdf/2002.05202.pdf + # Used for example in T5 v1.1 + decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + ## 2nd layer + decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + # MoE router + decoder_ffn_router_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_experts + + decoder_flops_fwd = ( + decoder_q_proj_flops_fwd + + decoder_kv_proj_flops_fwd + + decoder_qk_logits_flops_fwd + + decoder_v_logits_flops_fwd + + decoder_attn_out_flops_fwd + + decoder_ffn_1_flops_fwd * num_experts_per_tok + + decoder_ffn_2_flops_fwd * num_experts_per_tok + + decoder_ffn_router_flops_fwd + ) + + # lm head + lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size + + # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to + # both input and weight tensors + model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd + + hardware_flops = model_flops # TODO @nouamanetazi: This is a placeholder for now + return model_flops, hardware_flops diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 2c6ddc01..f7e57d2a 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, List, Optional, Union import torch from torch import nn @@ -593,9 +593,9 @@ def __init__( self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) - + self.recompute_layer = parallel_config.recompute_layer - + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -614,12 +614,12 @@ def _core_forward( hidden_states = hidden_states + residual return hidden_states, output["sequence_mask"] - + def _checkpointed_forward( self, hidden_states: torch.Tensor, sequence_mask: torch.Tensor, - ) -> List[torch.Tensor]: + ) -> List[torch.Tensor]: return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) def forward( @@ -627,7 +627,7 @@ def forward( hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) else: @@ -638,6 +638,7 @@ def forward( "sequence_mask": sequence_mask, } + class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() @@ -758,19 +759,26 @@ def forward( input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] lang_code: Union[torch.Tensor, TensorPointer]=None, # [batch_size, 1] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask, lang_code=lang_code)[0] + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask, lang_code=lang_code)[0] def forward_with_hidden_states( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): # NOTE(tj.solergibert) I bring `lang_code` till the forward of LlamaModel. Remember that # to use it in the different pipeline blocks you need to also set the module_input_keys & module_output_keys # of the necessary `PipelineBlock`'s defined in the LlamaModel init! + # NOTE(tj.solergibert) I bring `lang_code` till the forward of LlamaModel. Remember that + # to use it in the different pipeline blocks you need to also set the module_input_keys & module_output_keys + # of the necessary `PipelineBlock`'s defined in the LlamaModel init! + # all tensors are optional as most ranks don't need anything from the dataloader. output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) @@ -850,6 +858,7 @@ def forward( # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 + sample_loss = sharded_cross_entropy( sample_loss = sharded_cross_entropy( sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float ).transpose(0, 1) @@ -862,6 +871,14 @@ def forward( # TODO @thomasw21: I think indexing causes a sync we don't actually want # TODO @thomasw21: loss = loss[label_mask].sum() return {"sample_loss": sample_loss} + sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float) + # NOTE(tj.solergibert) masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. + # We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort. + # WARN(tj.solergibert) Don't panic, the batch loss used to update the parameters is computed in `LlamaForTraining` + + # TODO @thomasw21: I think indexing causes a sync we don't actually want + # TODO @thomasw21: loss = loss[label_mask].sum() + return {"sample_loss": sample_loss} class LlamaForTraining(NanotronModel): diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py new file mode 100644 index 00000000..98add57d --- /dev/null +++ b/src/nanotron/models/moe.py @@ -0,0 +1,721 @@ +""" MoEs Blocks to replace MLPs in Transformers. """ + +# TODO: implement gpt3 style MLP (currently it's only SwiGLU with 3 weight matrices) + +import warnings +from functools import partial +from typing import Optional, Tuple + +import numpy as np +import stk +import torch +import torch.nn.functional as F +from megablocks.layers import weight_parallel as wp +from megablocks.layers.activation_fn import act_fn +from torch import nn + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import LlamaConfig as Config +from nanotron.config import ParallelismArgs +from nanotron.nn.activations import ACT2FN +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.sharded_parameters import ( + SplitConfig, + mark_all_parameters_in_module_as_sharded, +) +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + +try: + import megablocks.ops as ops + from megablocks.layers.all_to_all import all_to_all +except ImportError: + warnings.warn("Please install megablocks to use MoEs: `pip install megablocks`") + + +logger = logging.get_logger(__name__) + + +def log_mean(x, dim): + return torch.logsumexp(x, dim=dim) - torch.log(torch.tensor(x.shape[dim], dtype=torch.float32)) + + +def load_balancing_loss(router_logits, tokens_per_expert, config: Config) -> torch.Tensor: + """Computes auxiliary load balancing loss as in Switch Transformer. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function + implements the loss function presented in equations (4) - (6). It aims to + penalize those cases where the routing between experts is unbalanced. + + Args: + logits: logits assigned to each expert per token. Shape: + [batch_size * sequence_length, num_experts]. + tokens_per_expert: [num_selected_experts] + + config: Config + + Returns: + The auxiliary loss. + """ + # tokens = batch_size * sequence_length + num_hidden_layers = config.num_hidden_layers + moe_num_experts = config.moe_num_experts + moe_loss_weight = config.moe_loss_weight + num_experts_per_token = config.num_experts_per_tok + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert tokens_per_expert.ndim == 1 and tokens_per_expert.numel() == moe_num_experts + + tokens = router_logits.shape[0] + assert router_logits.ndim == 2 and router_logits.shape[1] == moe_num_experts + + # compute router probability per expert in log space for numerical stability + logprobs = F.log_softmax(router_logits, dim=-1) + # take mean probability over batch + # shape [num_experts] + logprobs = log_mean(logprobs, dim=0) + expert_scores = torch.exp(logprobs) + + tokens_per_expert = tokens_per_expert.to(expert_scores.dtype) + + # Calculate the total scale across all factors. + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = moe_num_experts * moe_loss_weight + scale_denominator = num_hidden_layers * tokens * num_experts_per_token + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +def router_z_loss(router_logits, config: Config) -> torch.Tensor: + """ + The router z-loss was introduced in ST-MoE + (https://arxiv.org/abs/2202.08906). It encourages router logits to remain + small in an effort to improve stability. + + Args: + router_logits: [batch_size * sequence_length, num_experts] + router logits + config: Config + + Returns: + Scalar router z-loss. + """ + num_hidden_layers = config.num_hidden_layers + moe_num_experts = config.moe_num_experts + + tokens = router_logits.shape[0] + assert router_logits.ndim == 2 and router_logits.shape[1] == moe_num_experts + + z_loss_weight = config.moe_z_loss_weight + + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = log_z**2 + + scale_numerator = z_loss_weight + scale_denominator = num_hidden_layers * tokens * moe_num_experts + scale = scale_numerator / scale_denominator + + return scale * z_loss.sum(dim=0) + + +class dMoE(torch.nn.Module): + def __init__( + self, + config: Config, + parallel_context: "ParallelContext", + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.config = config + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + if self.tp_mode == TensorParallelLinearMode.REDUCE_SCATTER: + logging.warn_once( + logger=logger, + msg="TensorParallelLinearMode.REDUCE_SCATTER is still experimental for MoEs. Use at your own risk.", + rank=0, + ) + + # Token router. + self.gate = LearnedRouter(config) + + # Expert computation helper. + self.experts = ParallelDroplessMLP( + config, + use_bias=False, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) + + def forward(self, hidden_states: torch.Tensor): + """ + Args: + x: input tensor of shape [sequence_length, batch_size, hidden_size] + """ + # Compute the expert scores and assignments. + # TODO: support sequence parallelism + batch_size, sequence_length, _ = hidden_states.size() + x = hidden_states.view(-1, self.config.hidden_size) + router_logits, expert_weights, top_experts = self.gate(x) + + # Compute the experts. + x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts) + return { + "hidden_states": x.reshape(batch_size, sequence_length, -1), + "load_balancing_loss": lbl_loss, + "z_loss": z_loss, + } + + +# Adapted from megablocks.layers.router.LearnedRouter +class LearnedRouter(torch.nn.Module): + def __init__(self, config: Config): + super().__init__() + self.layer = torch.nn.Linear(config.hidden_size, config.moe_num_experts, bias=False) + self.config = config + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + router_logits = self.layer(x) # (batch * sequence_length, n_experts) + scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) # TODO: fuse? + + if self.config.num_experts_per_tok == 1: + expert_weights, expert_indices = scores.max(dim=-1, keepdim=True) + else: + expert_weights, expert_indices = torch.topk(scores, self.config.num_experts_per_tok, dim=-1) + # IMPORTANT step to normalize, otherwise weights are very low + expert_weights = expert_weights / torch.norm( + expert_weights, + p=1, + dim=-1, + keepdim=True, + ) + return router_logits, expert_weights, expert_indices.int() + + +# Adapted from megablocks.layers.mlp.ParallelDroplessMLP +class ParallelDroplessMLP(torch.nn.Module): + def __init__( + self, + config: Config, + use_bias: bool, + parallel_context: "ParallelContext", + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.config = config + self.use_bias = use_bias + + self.expert_pg_size = parallel_context.expert_pg.size() + self.expert_parallel_group = parallel_context.expert_pg + + self.hidden_sharding_degree = self.expert_pg_size // min(self.expert_pg_size, self.config.moe_num_experts) + self.experts_per_rank = self.config.moe_num_experts // min(self.expert_pg_size, self.config.moe_num_experts) + + self.num_experts = config.moe_num_experts + self.num_experts_per_tok = self.config.num_experts_per_tok + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + if use_bias: + self.bias = torch.nn.Parameter(torch.empty(config.hidden_size)) + + # Select the forward function for the operating mode. + self.forward_fn = self.parallel_forward_once if self.expert_pg_size > 1 else self.forward_once + + self.blocking = 128 + + if self.experts_per_rank == 1: + if config.moe_glu: + self.mlp = GLU( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + ) + else: + self.mlp = MLP( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + ) + else: + if config.moe_glu: + self.mlp = SparseGLU( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + else: + self.mlp = SparseMLP( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + + max_column_index = (self.config.intermediate_size * self.num_experts) // self.blocking + self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1) + + def indices_and_bins(self, top_expert): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_expert = top_expert.int() + bin_ids, indices = ops.sort(top_expert, self.sort_end_bit) + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = inclusive_cumsum(tokens_per_expert, 0) + return indices, bin_ids, bins, tokens_per_expert + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Calculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking) + padded_bins = inclusive_cumsum(padded_tokens_per_expert, 0) + + # Calculate the bin bounds for the sorted tokens. + bins = inclusive_cumsum(tokens_per_expert, 0) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def forward_once(self, x, expert_weights, top_experts): # TODO: sparse + with torch.no_grad(): + ( + indices, + bin_ids, + bins, + padded_bins, + tokens_per_expert, + ) = self.indices_and_padded_bins(top_experts) + + # Route the tokens for MoE computation. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.num_experts_per_tok) + + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.num_experts_per_tok, + -1, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x, expert_weights, top_experts): + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(top_experts) + repeated_tokens_per_expert = ops.repeat(tokens_per_expert, (self.hidden_sharding_degree,)) + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) + tpe_handle = torch.distributed.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.expert_parallel_group, + async_op=True, + ) + + x = ops.gather(x, indices, bin_ids, bins, self.num_experts_per_tok) + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + + # Reshape to [expert_pg_size, num_experts_per_rank]. + repeated_tokens_per_expert = repeated_tokens_per_expert.view(self.expert_pg_size, self.experts_per_rank) + parallel_tokens_per_expert = parallel_tokens_per_expert.view(self.expert_pg_size, self.experts_per_rank) + + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + x = ops.repeat(x, (self.hidden_sharding_degree, 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, recv_counts, send_counts, self.expert_parallel_group, async_op=True + ) + + with torch.no_grad(): + replicate_bins = inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * self.hidden_sharding_degree, + dtype=torch.int32, + device=indices.device, + ), + self.experts_per_rank, + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received + ).flatten() + + parallel_bin_ids, parallel_indices = ops.sort(parallel_top_expert, self.sort_end_bit) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum(dim=0, dtype=torch.int) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + num_experts_per_tok=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all(parallel_x, send_counts, recv_counts, self.expert_parallel_group) + + # Reduce along the hidden sharding to get the final outputs. + shape = (self.hidden_sharding_degree, -1, self.config.hidden_size) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + self.num_experts_per_tok, + ) + return x, tokens_per_expert.flatten() + + def forward(self, x, router_logits, expert_weights, top_experts): + """ + Args: + x: input tensor of shape [sequence_length, batch_size, hidden_size] + router_logits: tensor of shape [sequence_length * batch_size, n_experts] + expert_weights: tensor of shape [sequence_length * batch_size, num_experts_per_tok] + top_experts: tensor of shape [sequence_length * batch_size, num_experts_per_tok] + """ + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) + if self.training: + lbl_loss = load_balancing_loss(router_logits, tokens_per_expert, self.config) + z_loss = router_z_loss(router_logits, self.config) + else: + lbl_loss = torch.zeros(1, device=x.device) + z_loss = torch.zeros(1, device=x.device) + + if self.use_bias: + x = x + self.bias + return x, lbl_loss, z_loss + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + num_experts_per_tok, + ): + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking) + padded_bins = inclusive_cumsum(padded_tokens_per_expert, 0) + + # Route the tokens for MoE computation. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, num_experts_per_tok) + + # Perform the expert computation. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter(x, indices, bin_ids, expert_weights, bins, padded_bins, num_experts_per_tok) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + _, gather_indices = ops.sort(column_indices.int(), self.transpose_sort_end_bit) + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + assert self.config.intermediate_size % self.blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.config.intermediate_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology(padded_bins, self.blocking, block_rows, blocks_per_row) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=x.dtype, + device="meta", + ) + shape = (padded_tokens, self.config.intermediate_size * self.experts_per_rank) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, row_indices, column_indices, offsets + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + +class ScaleGradient(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, x, scale): + ctx.scale = scale + return x + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, grad): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +class ExpertParallel(nn.Module): + """ + ExpertParallel serves to scale the gradients of the expert weights because unlike DP the gradients are not averaged across the expert parallel group. + """ + + def __init__(self, module, expert_parallel_size: int): + super().__init__() + self.module = module + self.expert_parallel_size = expert_parallel_size + + def forward(self, *args, **kwargs): + self.scale_gradients() + return self.module(*args, **kwargs) + + def scale_gradients(self): + scale_gradient(self.module, 1 / self.expert_parallel_size) + + +class SparseMLP(nn.Module): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: "ParallelContext", + ): + super().__init__() + + self.expert_pg_size = parallel_config.expert_parallel_size if parallel_config is not None else 1 + self.experts_per_rank = config.moe_num_experts // min(self.expert_pg_size, config.moe_num_experts) + self.tp_pg = parallel_context.tp_pg + + self.w1 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + self.w2 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + + if self.tp_pg.size() == 1: + self.w1.module.weight.data = self.w1.module.weight.data.T.contiguous() + + # TODO @nouamane: jit + self.act = ACT2FN[config.hidden_act] + self.sdd = partial(wp.sdd_nt, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.sdd + self.dsd = partial(wp.dsd_nn, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.dsd + + def forward(self, x, topo): + self.w1.scale_gradients(), self.w2.scale_gradients() + x = self.sdd(x.contiguous(), self.w1.module.weight, topo) + activation_fn_out = act_fn(x, self.act) + return self.dsd(activation_fn_out, self.w2.module.weight) + + +class MLP(nn.Module): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.expert_pg_size = parallel_config.expert_parallel_size + self.experts_per_rank = config.moe_num_experts // min(self.expert_pg_size, config.moe_num_experts) + + assert self.experts_per_rank == 1, "moe.MLP only supports 1 expert per rank, otherwise use moe.SparseMLP" + + self.w1 = ExpertParallel( + TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ), + expert_parallel_size=self.expert_pg_size, + ) + + self.w2 = ExpertParallel( + TensorParallelRowLinear( + config.intermediate_size * self.experts_per_rank, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication + and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + ), + expert_parallel_size=self.expert_pg_size, + ) + + # TODO @nouamane: jit + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] + merged_states = self.w1(hidden_states) + hidden_states = self.w2(self.act(merged_states)) + return hidden_states + +class GLU(MLP): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__(config, parallel_config, tp_pg) + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + self.w3 = ExpertParallel( + TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ), + expert_parallel_size=self.expert_pg_size, + ) + + def forward(self, x, topo): + merged_states = self.w1(hidden_states) + hidden_states = self.w2(self.act(merged_states) * self.w3(hidden_states)) + return hidden_states + +def inclusive_cumsum(x, dim): + scalar = ops.inclusive_cumsum(x, dim) + return scalar.view(1) if not len(scalar.size()) else scalar + + +class SparseGLU(SparseMLP): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: "ParallelContext", + ): + super().__init__(config, parallel_config, parallel_context) + self.w3 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + if self.tp_pg.size() == 1: + self.w3.module.weight.data = self.w3.module.weight.data.T.contiguous() + + mark_all_parameters_in_module_as_sharded( + self, + pg=parallel_context.tp_and_expert_pg, + split_config=SplitConfig(split_dim=0), + ) + + def forward(self, x, topo): + # We need to scale gradients manually since we don't call the linears forward + self.w1.scale_gradients(), self.w2.scale_gradients(), self.w3.scale_gradients() + x = x.contiguous() + x1 = self.sdd(x, self.w1.module.weight, topo) + x2 = self.sdd(x, self.w3.module.weight, topo) + x = stk.ops.mul(act_fn(x1, self.act), x2) + return self.dsd(x, self.w2.module.weight) diff --git a/src/nanotron/models/starcoder2.py b/src/nanotron/models/starcoder2.py index 7100351d..6636ffb5 100644 --- a/src/nanotron/models/starcoder2.py +++ b/src/nanotron/models/starcoder2.py @@ -32,6 +32,7 @@ from nanotron.config import ParallelismArgs, Starcoder2Config from nanotron.generation.generate_store import AttachableStore from nanotron.models import NanotronModel +from nanotron.models.moe import ParallelDroplessMLP from nanotron.nn.activations import ACT2FN from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.parallel import ParallelContext @@ -1375,7 +1376,9 @@ def forward( @torch.jit.script def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor - return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( + dim=1 + ) # NOTE(tj.solergibert) Added dim=1 to return a tensor with shape [Batch size, 1] instead of [1] class Loss(nn.Module): @@ -1400,7 +1403,7 @@ def forward( loss = masked_mean(loss, label_mask, dtype=torch.float) # I think indexing causes a sync we don't actually want # loss = loss[label_mask].sum() - return {"loss": loss} + return {"sample_loss": loss} class Starcoder2ForTraining(NanotronModel): @@ -1427,7 +1430,7 @@ def __init__( "label_ids", "label_mask", }, - module_output_keys={"loss"}, + module_output_keys={"sample_loss"}, ) self.config: Starcoder2Config = config self.parallel_config = parallel_config @@ -1437,20 +1440,21 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], input_mask: Union[torch.Tensor, TensorPointer], + lang_code: Union[torch.Tensor, TensorPointer], # TODO label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], - ) -> Union[torch.Tensor, TensorPointer]: + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: sharded_logits = self.model( input_ids=input_ids, input_mask=input_mask, ) - return { - "loss": self.loss( - sharded_logits=sharded_logits, - label_ids=label_ids, - label_mask=label_mask, - )["loss"] - } + outputs = self.loss( + sharded_logits=sharded_logits, + label_ids=label_ids, + label_mask=label_mask, + ) + outputs["loss"] = torch.mean(outputs["sample_loss"]) + return outputs def tie_custom_params(self) -> None: # find all params with names qkv.kv.weight and qkv.kv.bias in them @@ -1517,6 +1521,16 @@ def init_model_randomly(self, config): module.bias.zero_() else: raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, nn.Linear): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, ParallelDroplessMLP): + if hasattr(module, "bias"): + module.bias.zero_() elif isinstance(module, TensorParallelRowLinear): if "weight" == param_name: nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) diff --git a/src/nanotron/parallel/pipeline_parallel/block.py b/src/nanotron/parallel/pipeline_parallel/block.py index 150172f5..4e8cfeb5 100644 --- a/src/nanotron/parallel/pipeline_parallel/block.py +++ b/src/nanotron/parallel/pipeline_parallel/block.py @@ -81,6 +81,24 @@ def forward(self, **kwargs): if isinstance(tensor, TensorPointer): # Current rank is neither the rank holding the data nor the rank responsible for computing block continue + elif isinstance(tensor, dict): + for k, v in tensor.items(): + if isinstance(v, torch.Tensor): + # We need to send the tensor to the rank that actually runs the compute + if self.pipeline_state is not None: + send_to_pipeline_state_buffer( + v, + to_rank=self.rank, + p2p=self.p2p, + pipeline_state=self.pipeline_state, + ) + continue + if v.requires_grad is True: + raise ValueError( + f"Pipeline engine is None and tensor requires grad. Tried sending a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine." + ) + + batch_send_recv.add_send(tensor=v, to_rank=self.rank) else: assert isinstance(tensor, torch.Tensor) # We need to send the tensor to the rank that actually runs the compute @@ -133,6 +151,29 @@ def forward(self, **kwargs): # We don't store result in a buffer recv_id = batch_send_recv.add_recv(from_rank=tensor.group_rank) name_to_recv_id[name] = recv_id + elif isinstance(tensor, dict): + new_kwargs[name] = {} + for k, v in tensor.items(): + # the same as above just looped over the dict + if isinstance(v, TensorPointer): + if isinstance(self.pipeline_state, PipelineTrainBatchState): + for _ in range(len(self.pipeline_state.microbatches_activations_to_send)): + send_activation = self.pipeline_state.microbatches_activations_to_send.popleft() + # Execute + send_activation() + + if self.pipeline_state is not None: + new_kwargs[name][k] = recv_from_pipeline_state_buffer( + from_rank=tensor.group_rank, + p2p=self.p2p, + pipeline_state=self.pipeline_state, + ) + continue + # We don't store result in a buffer + recv_id = batch_send_recv.add_recv(from_rank=tensor.group_rank) + name_to_recv_id[name] = recv_id + else: + new_kwargs[name][k] = v else: new_kwargs[name] = tensor diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 9b548e35..b4f5a3c4 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -11,8 +11,13 @@ from nanotron.logging import log_rank from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd -from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model -from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState, PipelineTrainBatchState +from nanotron.parallel.pipeline_parallel.context_manager import ( + attach_pipeline_state_to_model, +) +from nanotron.parallel.pipeline_parallel.state import ( + PipelineEvalBatchState, + PipelineTrainBatchState, +) from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers @@ -49,14 +54,21 @@ def forward( if not isinstance(output, dict): output = {"loss": output} - # We normalize our loss - if not isinstance(output["loss"], TensorPointer): - output["loss"] = output["loss"] / self.nb_microbatches - - # Add output as activations that require backward pass - if not isinstance(output["loss"], TensorPointer) and not is_validation: - assert output["loss"].requires_grad - state.register_activation_requiring_backward(output["loss"]) + for k, v in output.items(): + if not isinstance(v, TensorPointer) and k != "sample_loss": + output[k] = v / self.nb_microbatches + + # the outputs are either + # - token prediction loss ["loss"] + # - loss per sample (for validation), ["sample_loss"] -- does not require backpropagation + # - auxiliary losses ["load_balancing_loss", "z_loss"] + # that we need to backpropagate through, so register activations + for loss_key, output_tensor in output.items(): + if loss_key == "sample_loss": + continue + if not isinstance(output_tensor, TensorPointer) and not is_validation: + assert output_tensor.requires_grad, loss_key + state.register_activation_requiring_backward(output_tensor) return output @staticmethod @@ -67,7 +79,10 @@ def _get_fwd_context(model: torch_nn.Module): return context def backward( - self, context: ContextManagers, state: PipelineTrainBatchState, grad_accumulator: Optional[GradientAccumulator] + self, + context: ContextManagers, + state: PipelineTrainBatchState, + grad_accumulator: Optional[GradientAccumulator], ): # Increment the number of backwards state.nb_backwards += 1 @@ -147,7 +162,11 @@ def validate_batch_iter( for micro_batch in batch: context = self._get_fwd_context(model=model) output = self.forward( - context=context, state=state, micro_batch=micro_batch, model=model, is_validation=True + context=context, + state=state, + micro_batch=micro_batch, + model=model, + is_validation=True, ) # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage" for _ in range(len(state.microbatches_activations_to_send)): @@ -159,7 +178,7 @@ def validate_batch_iter( if not isinstance(output, dict): output = {"loss": output} - # Store the loss for each microbatch + # Store the loss(es) for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} @@ -278,8 +297,9 @@ def train_batch_iter( send_activation() # Store the loss for each microbatch - if not isinstance(output["loss"], TensorPointer): - output = {k: v.detach() for k, v in output.items()} + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v.detach() outputs.append(output) for micro_batch in batch: @@ -291,8 +311,9 @@ def train_batch_iter( output = {"loss": output} # Store the loss for each microbatch - if not isinstance(output["loss"], TensorPointer): - output = {k: v.detach() for k, v in output.items()} + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v.detach() outputs.append(output) # One backward diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..380ad460 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -34,6 +34,7 @@ class StandardParametrizator(Parametrizator): def __init__(self, config: ModelArgs): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { + nn.Linear: self._parametrize_column_linear, TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 47d83fbe..683744bd 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -57,6 +57,7 @@ from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad from nanotron.models.gpt3 import GPT3ForTraining +from nanotron.models.gpt3_moe import GPT3MoEForTraining from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm @@ -106,6 +107,7 @@ "LlamaConfig": LlamaForTraining, "Starcoder2Config": Starcoder2ForTraining, "GPT3Config": GPT3ForTraining, + "GPT3MoEConfig": GPT3MoEForTraining, } try: @@ -342,7 +344,7 @@ def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataL else: dataloader = dataloaders - self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader_length = min(len(dataloader), self.limit_val_batches) self.current_validation_dataloader = sanity_check_dataloader( dataloader=dataloader, parallel_context=self.parallel_context, config=self.config ) @@ -403,7 +405,7 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_s dataloader = dataloader() if callable(dataloader) else dataloader break - self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader_length = min(len(dataloader), self.limit_val_batches) self.current_validation_dataloader = sanity_check_dataloader( dataloader=dataloader, parallel_context=self.parallel_context, config=self.config ) # NOTE(tj.solergibert) Create a Iterator from the DataLoader @@ -579,7 +581,7 @@ def train( def training_step( self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]] - ) -> Tuple[Iterable[Dict], Optional[torch.Tensor]]: + ) -> Tuple[Iterable[Dict], Optional[Dict[str, torch.Tensor]]]: before_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) if self.iteration_step < 5: @@ -646,11 +648,21 @@ def training_step( # Compute DP average loss and overlap with optimizer step if isinstance(outputs[0]["loss"], torch.Tensor): # This is an average on only one data rank. - loss_avg = torch.stack( - [output["loss"] for output in outputs] - ).sum() # already divided by n_micro_batches_per_batch - # sync loss across DP - handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) + loss_avg = {} + for k in outputs[0].keys(): + if k == "sample_loss": + continue # sample loss is the individual losses, is already averaged as 'lm_loss' + if k == "loss": + loss_avg["lm_loss"] = torch.stack([output[k] for output in outputs]).sum() + k = "lm_loss" + else: + loss_avg[k] = torch.stack([output[k] for output in outputs]).mean() + handle = dist.all_reduce( + loss_avg[k], + group=self.parallel_context.dp_pg, + async_op=True, + op=dist.ReduceOp.AVG, + ) else: loss_avg = None handle = None @@ -674,8 +686,8 @@ def training_step( def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: outputs, lang_codes = self.pipeline_engine.validate_batch_iter( model=self.model, - batch=(next(dataloader) for _ in range(self.current_validation_dataloader_lenght)), - nb_microbatches=self.current_validation_dataloader_lenght, + batch=(next(dataloader) for _ in range(self.current_validation_dataloader_length)), + nb_microbatches=self.current_validation_dataloader_length, ) lang_losses = { @@ -684,7 +696,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten lang_losses_list = list(lang_losses.keys()) # Compute losses - if isinstance(outputs[0], torch.Tensor): + if len(outputs) > 0 and isinstance(outputs[0], torch.Tensor): # Multilingual losses for loss, lang_code in zip(outputs, lang_codes): lang_losses[lang_losses_list[lang_code]].append(loss) @@ -728,9 +740,9 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten def train_step_logs( self, - loss_avg: Optional[torch.Tensor], - global_loss: torch.Tensor, - lang_losses: torch.Tensor, + loss_avg: Optional[Dict[str, torch.Tensor]], + global_loss: Optional[torch.Tensor], + lang_losses: Optional[Dict[str, torch.Tensor]], ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() @@ -748,7 +760,7 @@ def train_step_logs( # Validation metrics if global_loss is not None: - validation_total_samples = self.current_validation_dataloader_lenght * self.micro_batch_size + validation_total_samples = self.current_validation_dataloader_length * self.micro_batch_size validation_elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 validation_tokens_per_sec = ( validation_total_samples * self.sequence_length / (validation_elapsed_time_per_iteration_ms / 1000) @@ -778,12 +790,15 @@ def train_step_logs( "tokens_per_sec_per_gpu", tokens_per_sec / self.parallel_context.world_pg.size(), "human_format" ), # , "1.6E"), LogItem("global_batch_size", self.global_batch_size, "human_format"), # , "5d"), - LogItem("lm_loss", loss_avg.item(), "human_format"), # , "1.6E"), LogItem("lr", lr, "human_format"), # , ".3E"), LogItem("model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), ] + if loss_avg is not None: + for k, v in loss_avg.items(): + log_entries.append(LogItem(k, v.item(), "human_format")) + if self.config.optimizer.clip_grad is not None: log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f"))