From 3bce1f4d93e042327fcb29bfd93ee8dd121f0e05 Mon Sep 17 00:00:00 2001 From: AleHC <36459138+AleHD@users.noreply.github.com> Date: Sat, 19 Oct 2024 16:15:58 +0200 Subject: [PATCH] Moe converters (#5) * Converters ready * Added xglm transformers implementation --------- Co-authored-by: Negar Foroutan --- examples/xglm/README.md | 5 + examples/xglm/convert_ntmoe2hf.py | 140 +++ examples/xglm/tests/test_moe.py | 182 +++ examples/xglm/transformers_impl/gating.py | 149 +++ examples/xglm/transformers_impl/xglm_model.py | 1119 +++++++++++++++++ src/nanotron/models/moe.py | 2 +- 6 files changed, 1596 insertions(+), 1 deletion(-) create mode 100644 examples/xglm/convert_ntmoe2hf.py create mode 100644 examples/xglm/tests/test_moe.py create mode 100644 examples/xglm/transformers_impl/gating.py create mode 100644 examples/xglm/transformers_impl/xglm_model.py diff --git a/examples/xglm/README.md b/examples/xglm/README.md index 8f62fc57..48447ac2 100644 --- a/examples/xglm/README.md +++ b/examples/xglm/README.md @@ -25,3 +25,8 @@ cd examples/xglm torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=checkpoints/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-8x564M --num-experts=8 ``` Note that this upcycling _drops_ the bias parameters of the MLP because the MegaBlocks implementation does not support bias parameters. While this is a limitation of the current implementation, the performance is quickly recovered after a few training steps. + +To save back to huggingface format use +```bash +torchrun examples/xglm/convert_ntmoe2hf.py --checkpoint-path=$SCRATCH/checkpoints/xglm-8x564M --save-path=$SCRATCH/checkpoints/huggingface/xglm-8x56fM +``` diff --git a/examples/xglm/convert_ntmoe2hf.py b/examples/xglm/convert_ntmoe2hf.py new file mode 100644 index 00000000..d971c96f --- /dev/null +++ b/examples/xglm/convert_ntmoe2hf.py @@ -0,0 +1,140 @@ +""" +Converts a nanotron moe model to HF format +Command: + torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights +""" + +import warnings +from argparse import ArgumentParser +from pathlib import Path +from typing import Optional + +import torch +from transformers import AutoTokenizer +from tqdm import tqdm + +from nanotron.config.models_config import GPT3MoEConfig +from nanotron.models.gpt3_moe import GPT3MoEForTraining, GPT3MoEBlock +from nanotron.models.moe import dMoE, SparseMLP, LearnedRouter + +from examples.xglm.convert_dense2moe import create_nt_moe_model +from examples.xglm.convert_nt2hf import convert_attention +from examples.xglm.convert_utils import convert_generic +from examples.xglm.transformers_impl.xglm_model import XGLMForCausalLM, XGLMDecoderLayer, XGLMmoeConfig, XGLMSparseMoeBlock, XGLMMLP +from examples.xglm.transformers_impl.gating import BasicGate + + +def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig: + if config.embd_pdrop != config.resid_pdrop: + warnings.warn( + f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion." + ) + if config.layer_norm_epsilon != 1e-5: + warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") + if config.moe_z_loss_weight != 0: + warnings.warn(f"transformer implementation does not support z loss") + assert not config.moe_glu, "Transformer implementation does not support glu MLP layers" + + return XGLMmoeConfig( + # Regular xglm config. + activation_function=config.activation_function, + attention_dropout=config.attn_pdrop, + dropout=config.embd_pdrop, + eos_token_id=config.eos_token_id, + d_model=config.hidden_size, + ffn_dim=config.intermediate_size, + max_position_embeddings=config.max_position_embeddings, + attention_heads=config.num_attention_heads, + num_layers=config.num_hidden_layers, + vocab_size=config.vocab_size, + decoder_start_token_id=config.position_embedding_offset, + activation_dropout=config.act_pdrop, + scale_embedding=config.scale_embedding, + # Moe specifics. + num_local_experts=config.moe_num_experts, + num_experts_per_tok=config.num_experts_per_tok, + gate_type="linear", + gate_depth=1, + router_aux_loss_coef=config.moe_loss_weight, + ) + + +def convert_mlp(mlp_hf: XGLMMLP, mlp_nt: SparseMLP): + convert_generic(mlp_hf.fc1, mlp_nt.w1.module) + convert_generic(mlp_hf.fc2, mlp_nt.w2.module) + + +def convert_gate(gate_hf: BasicGate, gate_nt: LearnedRouter): + convert_generic(gate_hf.gate, gate_nt.layer) + + +def convert_ff(ff_hf: XGLMSparseMoeBlock, ff_nt: dMoE): + convert_gate(ff_hf.gate, ff_nt.gate) + int_size = ff_nt.config.intermediate_size + if len(ff_hf.experts) == 1: + assert ff_nt.experts.mlp.w1.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + assert ff_nt.experts.mlp.w2.module.weight.shape == (ff_nt.config.hidden_size, int_size*len(ff_hf.experts)) + else: + assert ff_nt.experts.mlp.w1.module.weight.T.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + assert ff_nt.experts.mlp.w2.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + + for i, expert_hf in enumerate(ff_hf.experts): + i0 = i*int_size + i1 = (i + 1)*int_size + with torch.no_grad(): + if len(ff_hf.experts) == 1: + expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight[i0:i1, :].clone()) + expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[:, i0:i1].clone()) + else: + expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight.T[i0:i1, :].clone()) + expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[i0:i1, :].T.clone()) + +def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPT3MoEBlock): + convert_generic(block_hf.self_attn_layer_norm, block_nt.ln_1) + convert_attention(block_hf.self_attn, block_nt.attn) + convert_generic(block_hf.final_layer_norm, block_nt.ln_2) + convert_ff(block_hf.block_sparse_moe, block_nt.ff) + + +def convert(model_hf: XGLMForCausalLM, model_nt: GPT3MoEForTraining): + convert_generic(model_hf.model.embed_tokens, model_nt.model.token_embeddings.pp_block.token_embedding) + for layer_hf, layer_nt in tqdm(zip(model_hf.model.layers, model_nt.model.decoder), desc="Converting layers", + total=model_nt.config.num_hidden_layers): + convert_decoder(layer_hf, layer_nt.pp_block) + convert_generic(model_hf.model.layer_norm, model_nt.model.final_layer_norm.pp_block) + convert_generic(model_hf.lm_head, model_nt.model.lm_head.pp_block) + + +def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): + # Load nanotron model. + model_nt = create_nt_moe_model(checkpoint_path=checkpoint_path) + + # Init huggingface model. + model_config_hf = convert_config(model_nt.config) + model_hf = XGLMForCausalLM._from_config(model_config_hf) + + # Copy weights, initialize tokenizer and save model. + if tokenizer_name is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.save_pretrained(save_path) + states = torch.randn(4, 1, 1024) + convert(model_hf, model_nt), states.cuda().bfloat16() + print("Saving...") + model_hf.save_pretrained(save_path) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument( + "--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model" + ) + parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") + args = parser.parse_args() + ret = main(args.checkpoint_path, args.save_path, args.tokenizer_name) diff --git a/examples/xglm/tests/test_moe.py b/examples/xglm/tests/test_moe.py new file mode 100644 index 00000000..7536b84f --- /dev/null +++ b/examples/xglm/tests/test_moe.py @@ -0,0 +1,182 @@ +import torch +import pytest + +import nanotron +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.config.models_config import GPT3MoEConfig +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.trainer import mark_tied_parameters +from nanotron.models.gpt3_moe import GPT3MoEBlock, GPT3MoEForTraining +from nanotron.models.moe import LearnedRouter, dMoE + +from tests.helpers.utils import init_distributed + +from examples.xglm.convert_ntmoe2hf import convert_config, convert_gate, convert_ff, convert +from examples.xglm.tests.test_implementation import almost_close +from examples.xglm.transformers_impl.xglm_model import XGLMSparseMoeBlock, XGLMForCausalLM +from examples.xglm.transformers_impl.gating import BasicGate + + +MAX_SEQUENCE_LENGTH = 2048 +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. +#TEST_SEQUENCE_LENGTH = MAX_SEQUENCE_LENGTH +BATCH_SIZE = 4 +HIDDEN_SIZE = 1024 +#DTYPE = torch.bfloat16 +DTYPE = torch.float32 +TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" + +CONFIG = GPT3MoEConfig( + attn_pdrop=0.0, + embd_pdrop=0.0, + resid_pdrop=0.0, + act_pdrop=0.0, + eos_token_id=2, + hidden_size=HIDDEN_SIZE, + intermediate_size=4096, + layer_norm_epsilon=1e-05, + max_position_embeddings=MAX_SEQUENCE_LENGTH, + num_attention_heads=16, + num_hidden_layers=24, + scale_attn_weights=True, + vocab_size=256008, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=DTYPE is not torch.bfloat16, + # vvv moe vvv + is_moe=True, + moe_num_experts=8, + num_experts_per_tok=2, + moe_loss_weight=0.01, + moe_z_loss_weight=0.0, + moe_glu=False, +) +PARALLEL_CONFIG = ParallelismArgs(dp=1, pp=1, tp=1, expert_parallel_size=1) #CONFIG.moe_num_experts) + + +@pytest.fixture +def hidden_states() -> torch.Tensor: + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) + + +@pytest.fixture +def input_mask() -> torch.Tensor: + return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) + + +@pytest.fixture +def input_ids() -> torch.Tensor: + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) + + +def _test_nt2hf_gate(parallel_context: ParallelContext, hidden_states: torch.Tensor): + hidden_states = hidden_states.cuda() + + config_hf = convert_config(CONFIG) + gate_nt = LearnedRouter(CONFIG).cuda().to(DTYPE) + gate_hf = BasicGate(config_hf).cuda().to(DTYPE) + convert_gate(gate_hf, gate_nt) + + router_logits_nt, _, _ = gate_nt(hidden_states.view(-1, HIDDEN_SIZE)) + router_logits_hf = gate_hf(hidden_states.permute(1, 0, 2).reshape(-1, HIDDEN_SIZE), "") + + router_logits_nt = router_logits_nt.view(TEST_SEQUENCE_LENGTH, BATCH_SIZE, -1) + router_logits_hf = router_logits_hf.view(BATCH_SIZE, TEST_SEQUENCE_LENGTH, -1).permute(1, 0, 2) + + assert router_logits_nt.size() == router_logits_hf.size() + torch.testing.assert_close(router_logits_nt, router_logits_hf) + + +def test_nt2hf_gate(hidden_states: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_gate)(hidden_states=hidden_states) + + +def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tensor, + num_experts: int, num_experts_per_tok: int): + hidden_states = hidden_states.cuda() + + config = {**vars(CONFIG)} + config.update({"moe_num_experts": num_experts, "num_experts_per_tok": num_experts_per_tok}) + config = GPT3MoEConfig(**config) + config_hf = convert_config(config) + ff_nt = dMoE(config, parallel_context, PARALLEL_CONFIG).cuda().to(DTYPE) + ff_hf = XGLMSparseMoeBlock(config_hf).cuda().to(DTYPE) + convert_ff(ff_hf, ff_nt) + + out_nt = ff_nt(hidden_states)["hidden_states"] + out_hf, _ = ff_hf(hidden_states.permute(1, 0, 2).contiguous(), "") + out_hf = out_hf.permute(1, 0, 2) + + assert out_nt.size() == out_hf.size() + almost_close(out_nt, out_hf, max_far=0.05, far_atol=0.003) + + +@pytest.mark.parametrize("num_experts,num_experts_per_tok", [(1, 1), (2, 1), (4, 1), (4, 2), (8, 1), (8, 2), (8, 4)]) +def test_nt2hf_ff(hidden_states: torch.Tensor, num_experts: int, num_experts_per_tok: int): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_ff)(hidden_states=hidden_states, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok) + + +def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + + # Get nanotron model. + config_nt = GPT3MoEConfig(**vars(CONFIG)) + if new_dtype not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3MoEForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=new_dtype, + device="cuda", + ).eval() + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + # Create empty model_hf and make conversion. + model_hf = XGLMForCausalLM(convert_config(config_nt)).cuda().to(new_dtype).eval() + convert(model_hf, model_nt) + + # Needed :/ + aux_losses = { + "load_balancing_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + "z_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + } + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask, aux_losses)["sharded_logits"].to(new_dtype) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask, output_router_logits=False).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + return out_nt.cpu(), out_hf.cpu() + + +def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.01, far_atol=2.0) # We allow for less than 1% errors, but some of these are very large! + #torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) + + +def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) diff --git a/examples/xglm/transformers_impl/gating.py b/examples/xglm/transformers_impl/gating.py new file mode 100644 index 00000000..efa0e357 --- /dev/null +++ b/examples/xglm/transformers_impl/gating.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import math + +from abc import ABC, abstractmethod + + +class Gate(ABC): + def __init__(self, device): + super(Gate, self).__init__() + self.device = device + + @abstractmethod + def compute(self, x): + """ + Compute the output of the gate. + This method should be implemented by all subclasses. + """ + pass + + +def init_x_embeddings(Xs, x_embedding_dim): + x2embeddings = nn.ParameterDict(dict()) + for x in Xs: + x_embedding = torch.empty(x_embedding_dim) + nn.init.normal_(x_embedding) + x2embeddings[str(x)] = nn.Parameter(x_embedding) + return x2embeddings + + +class BasicGate(nn.Module): + """One or two layer feedforward network as the Gate.""" + + def __init__(self, config) -> None: + super().__init__() + + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.ffn_dim = config.ffn_dim + self.activation = nn.ReLU(self.ffn_dim) + + if config.gate_depth == 1: + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + elif config.gate_depth == 2: + self.gate = nn.Sequential( + nn.Linear(self.hidden_dim, self.ffn_dim), + self.activation, + nn.Linear(self.ffn_dim, self.num_experts, bias=False), + ) + else: + raise ValueError("Invalid gate_depth!") + + def forward(self, x, lang_name): + return self.gate(x) + + +class LanguageAwareGate(nn.Module): + """One or two layer feedforward network as the Gate.""" + + def __init__(self, config) -> None: + super().__init__() + + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.ffn_dim = config.ffn_dim + self.activation = nn.ReLU(self.ffn_dim) + self.language_embedding_dim = ( + config.language_embedding_dim + if config.language_embedding_dim is not None + else config.hidden_size + ) + self.lang_embeddings = init_x_embeddings( + config.languages, self.language_embedding_dim + ) + + if config.gate_depth == 1: + self.gate = nn.Linear( + self.hidden_dim + self.language_embedding_dim, + self.num_experts, + bias=False, + ) + elif config.gate_depth == 2: + self.gate = nn.Sequential( + nn.Linear(self.hidden_dim, self.ffn_dim), + self.activation, + nn.Linear(self.ffn_dim, self.num_experts, bias=False), + ) + else: + raise ValueError("Invalid gate_depth!") + + def forward(self, x, lang_name): + # TODO x needs to be added to the language embedding (we need to pass the language as well) + lang_embedding = self.lang_embeddings[str(lang_name)] + lang_embedding.squeeze(0) + lang_embedding = lang_embedding.expand(x.shape[0], -1) + x = torch.cat((x, lang_embedding), dim=-1) + return self.gate(x) + + +class TopKGate(Gate): + def __init__(self, device, straight_through, k=1): + super(TopKGate, self).__init__(device) + self.k = k + self.device = device + self.straight_through = straight_through + + def compute(self, x): + if self.k > 1: + topk_gate_scores, indices = torch.topk(x, self.k) + topk_gate_scores = F.softmax( + topk_gate_scores, + dim=1, + dtype=torch.float, + ).type_as(x) + mask = F.one_hot(indices, x.shape[-1]).float() + mask_flat = mask.sum(dim=-1) + combine_tensor = ( + topk_gate_scores[..., None, None, None] + * mask_flat[..., None, None, None] + * F.one_hot(indices, x.shape[-1])[..., None, None] + ) + combine_tensor = combine_tensor.sum(1) + return combine_tensor, indices, topk_gate_scores + elif self.k == 1: + x = F.softmax(x, dim=-1) + topk_gate_scores, index = x.topk( + k=self.k, dim=-1 + ) # torch.nn.functional.softmax(x , dim=-1).topk(k=self.k, dim=-1) + if self.straight_through: + index_soft = F.softmax(x, dim=-1) + index = (index - index_soft).detach() + index_soft + index = index[:, 0] + topk_gate_scores, index = map( + lambda x: x.squeeze(dim=-1), (topk_gate_scores, index) + ) + else: + topk_gate_scores, index = map( + lambda x: x.squeeze(dim=-1), (topk_gate_scores, index) + ) + + mask = F.one_hot(index, x.shape[-1]).float() + mask_flat = mask.sum(dim=-1) + combine_tensor = ( + topk_gate_scores[..., None, None, None] + * mask_flat[..., None, None, None] + * F.one_hot(index, x.shape[-1])[..., None, None] + ) + return combine_tensor, index, topk_gate_scores diff --git a/examples/xglm/transformers_impl/xglm_model.py b/examples/xglm/transformers_impl/xglm_model.py new file mode 100644 index 00000000..aa80a7fe --- /dev/null +++ b/examples/xglm/transformers_impl/xglm_model.py @@ -0,0 +1,1119 @@ +from typing import List, Optional, Tuple, Union, Literal +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.nn import CrossEntropyLoss + +from transformers.models.xglm.modeling_xglm import ( + XGLMAttention, + XGLMPreTrainedModel, + XGLMSinusoidalPositionalEmbedding, +) +from transformers.models.xglm.configuration_xglm import XGLMConfig +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) + +# from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + add_code_sample_docstrings, + ModelOutput, +) +from examples.xglm.transformers_impl.gating import BasicGate, LanguageAwareGate + +from accelerate.logging import get_logger + +logger = get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/xglm-564M" +_CONFIG_FOR_DOC = "XGLMConfig" + + +class XGLMmoeConfig(XGLMConfig): + def __init__(self, + num_local_experts: int = 1, + num_experts_per_tok: int = 1, + gate_type: Literal["linear", "lang_aware_linear"] = "linear", + gate_depth: Literal[1, 2] = 1, + router_aux_loss_coef: float = 1.0, + **kwargs): + + super().__init__(**kwargs) + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.gate_type = gate_type + self.gate_depth = gate_depth + self.router_aux_loss_coef = router_aux_loss_coef + + +@dataclass +class MoEModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MoECausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden + states terms, to train a MoE model. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + z_loss for the sparse modules. + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + z_loss: torch.FloatTensor = None + aux_loss: torch.FloatTensor = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +XGLM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`XGLMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" +XGLM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def load_balancing_loss_func( + gate_logits: torch.Tensor, + num_experts: torch.Tensor = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return torch.zeros(1) + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class XGLMMLP(nn.Module): + + def __init__(self, config: XGLMConfig): + super().__init__() + self.ffn_dim = config.ffn_dim + self.hidden_dim = config.d_model + + self.fc1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.fc2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + + self.activation_fn = ACT2FN[config.activation_function] + self.dropout = config.dropout + self.activation_dropout = config.activation_dropout + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) + hidden_states = self.fc2(hidden_states) + # TODO this dropout can be removed if we use the dropout in the XGLMDecoderLayer + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + return hidden_states + + def set_mlp_weights(self, fc1, fc2, std=0.02): + """Set the weights of the MLP to the given weights""" + + self.fc1.weight.data = fc1.weight.data.clone() + # norm1_pre = self.fc1.weight.data.norm(2) + self.fc1.weight.data.add_( + torch.normal(mean=0, std=std, size=self.fc1.weight.shape) + ) + # norm1_after = self.fc1.weight.data.norm(2) + + self.fc1.bias.data = fc1.bias.data.clone() + self.fc1.bias.data.add_(torch.normal(mean=0, std=std, size=self.fc1.bias.shape)) + + self.fc2.weight.data = fc2.weight.data.clone() + # norm2_pre = self.fc2.weight.data.norm(2) + self.fc2.weight.data.add_( + torch.normal(mean=0, std=std, size=self.fc2.weight.shape) + ) + # norm2_after = self.fc2.weight.data.norm(2) + self.fc2.bias.data = fc2.bias.data.clone() + self.fc2.bias.data.add_(torch.normal(mean=0, std=std, size=self.fc2.bias.shape)) + + # print("********") + # print(f"Norm1: {norm1_pre} -> {norm1_after}") + # print(f"Norm2: {norm2_pre} -> {norm2_after}") + # print("********") + + +class XGLMSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + # self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating network + if config.gate_type == "linear": + self.gate = BasicGate(config) + elif config.gate_type == "lang_aware_linear": + self.gate = LanguageAwareGate(config) + # self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([XGLMMLP(config) for _ in range(self.num_experts)]) + + def forward(self, hidden_states: torch.Tensor, lang_name: str) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states, lang_name) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + if self.top_k == 1: + routing_weights, selected_experts = routing_weights.max(dim=-1, keepdim=True) + else: + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) + * routing_weights[top_x_list, idx_list, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states, router_logits + + def set_mlp_weights(self, fc1, fc2, std=0.02): + """Set the expert (MLP) weights by given weights""" + for i in range(self.num_experts): + self.experts[i].set_mlp_weights(fc1, fc2, std) + + +class XGLMDecoderLayer(nn.Module): + def __init__(self, config: XGLMConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = XGLMAttention( + embed_dim=self.embed_dim, + num_heads=config.attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + # self.activation_fn = ACT2FN[config.activation_function] + # self.activation_dropout = config.activation_dropout + + if config.add_cross_attention: + self.encoder_attn = XGLMAttention( + embed_dim=self.embed_dim, + num_heads=config.attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.block_sparse_moe = XGLMSparseMoeBlock(config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = True, + lang_name: Optional[str] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = ( + past_key_value[-2:] if past_key_value is not None else None + ) + hidden_states, cross_attn_weights, cross_attn_present_key_value = ( + self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states, router_logits = self.block_sparse_moe(hidden_states, lang_name) + # hidden_states = self.activation_fn(self.fc1(hidden_states)) + # hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + # hidden_states = self.fc2(hidden_states) + # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +@add_start_docstrings( + "The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.", + XGLM_START_DOCSTRING, +) +class XGLMModel(XGLMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_layers* layers. Each layer is a [`XGLMDecoderLayer`] + + Args: + config: XGLMConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding( + config.vocab_size, config.d_model, self.padding_idx + ) + + self.embed_positions = XGLMSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + config.pad_token_id, + ) + self.layers = nn.ModuleList( + [XGLMDecoderLayer(config) for _ in range(config.num_layers)] + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MoEModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + lang_name: str = None, + ) -> Union[Tuple[torch.Tensor], MoEModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + input_shape[-1] + past_key_values_length, + dtype=torch.long, + device=( + input_ids.device if input_ids is not None else inputs_embeds.device + ), + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + hidden_states = inputs_embeds + self.embed_positions( + position_ids, past_key_values_length + ) + hidden_states = nn.functional.dropout( + hidden_states, p=float(self.dropout), training=self.training + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache =" + " False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + all_cross_attentions = ( + () if (output_attentions and encoder_hidden_states is not None) else None + ) + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip( + [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"] + ): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + ( + cross_attn_head_mask[idx] + if cross_attn_head_mask is not None + else None + ), + None, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] + if cross_attn_head_mask is not None + else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + output_router_logits=output_router_logits, + lang_name=lang_name, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return MoEModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + router_logits=all_router_logits, + ) + + +@add_start_docstrings( + """ + The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + XGLM_START_DOCSTRING, +) +class XGLMForCausalLM(XGLMPreTrainedModel): + base_model_prefix = "model" + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = XGLMModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_local_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MoECausalLMOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + lang_name=None, + ) -> Union[Tuple[torch.Tensor], MoECausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + lang_name=lang_name, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + # shift labels and add a pad token to the end + shift_labels = labels.new_zeros(labels.shape) + shift_labels[:, :-1] = labels[:, 1:].clone() + shift_labels[:, -1] = self.config.pad_token_id + + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.config.vocab_size), shift_labels.view(-1) + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_local_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithCrossAttentions( + return MoECausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + use_cache=None, + output_router_logits=False, + **kwargs, + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "output_router_logits": output_router_logits, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +def set_moe_layers(model: nn.Module, config: XGLMConfig): + """Replaces the model's original layer with an MoE one, and sest the weights of the experts (MLPs) to the original MLP weights""" + new_layers = [] + for i in range(config.num_layers): + layer = model.model.layers[i] + moe_layer = XGLMDecoderLayer(config) + moe_layer.block_sparse_moe.set_mlp_weights( + layer.fc1, layer.fc2, config.expert_init_std + ) + new_layers.append(moe_layer) + model.model.layers = nn.ModuleList(new_layers) + + +def freeze_parameters(model, args): + """Freezes paramaters if it's needed.""" + if args.freeze_non_moe_params: + for name, param in model.named_parameters(): + name = name.lower() + if "expert" not in name and "gate" not in name and "lm_head" not in name: + param.requires_grad = False + else: + param.requires_grad = True + + logger.info("non-MoE + lm_head parameters are frozen!") + + model.lm_head.weight.requires_grad = True + # model.lm_head.bias.requires_grad = True + + # for name, param in model.named_parameters(): + # if "block_sparse_moe" not in name: + # print(f"{name}: {param.requires_grad}") + + +def copy_model_weights(pretrained_model: nn.Module, moe_model: XGLMForCausalLM, config: XGLMConfig): + """Replaces the model's original layer with an MoE one, and sest the weights of the experts (MLPs) to the original MLP weights""" + state_dicts = pretrained_model.state_dict() + moe_model.load_state_dict(state_dicts, strict=False) + + # new_layers = [] + for i in range(config.num_layers): + layer = pretrained_model.model.layers[i] + moe_model.model.layers[i].block_sparse_moe.set_mlp_weights( + layer.fc1, layer.fc2, config.init_std + ) + # moe_layer = XGLMDecoderLayer(config) + # moe_layer.block_sparse_moe.set_mlp_weights(layer.fc1, layer.fc2, config.init_std) + # new_layers.append(moe_layer) + # moe_model.model.layers = nn.ModuleList(new_layers) diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index f7bb07bd..98add57d 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -718,4 +718,4 @@ def forward(self, x, topo): x1 = self.sdd(x, self.w1.module.weight, topo) x2 = self.sdd(x, self.w3.module.weight, topo) x = stk.ops.mul(act_fn(x1, self.act), x2) - return self.dsd(x, self.w2.module.weight) \ No newline at end of file + return self.dsd(x, self.w2.module.weight)