From 7aa3940716084dd29f5c1cf3f68397db2c5c20e8 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 26 Jun 2024 12:38:55 +0000 Subject: [PATCH 01/51] XGLM work in progress: Causal Attention and Positional Embeddings work --- examples/xglm/__init__.py | 0 examples/xglm/convert_hf2nt.py | 28 ++ examples/xglm/tests/test_attn.py | 74 +++++ examples/xglm/tests/test_implementation.py | 90 ++++++ src/nanotron/config/models_config.py | 36 +++ src/nanotron/models/gpt3.py | 358 +++++++++++++++++++++ 6 files changed, 586 insertions(+) create mode 100644 examples/xglm/__init__.py create mode 100644 examples/xglm/convert_hf2nt.py create mode 100644 examples/xglm/tests/test_attn.py create mode 100644 examples/xglm/tests/test_implementation.py create mode 100644 src/nanotron/models/gpt3.py diff --git a/examples/xglm/__init__.py b/examples/xglm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py new file mode 100644 index 00000000..e008f859 --- /dev/null +++ b/examples/xglm/convert_hf2nt.py @@ -0,0 +1,28 @@ +import torch + +from transformers.models.xglm.modeling_xglm import XGLMAttention +from nanotron.models.gpt3 import CausalSelfAttention + + +def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): + q_ws = torch.chunk(attn_hf.q_proj.weight, attn_hf.num_heads) + k_ws = torch.chunk(attn_hf.k_proj.weight, attn_hf.num_heads) + v_ws = torch.chunk(attn_hf.v_proj.weight, attn_hf.num_heads) + + q_bs = torch.chunk(attn_hf.q_proj.bias, attn_hf.num_heads) + k_bs = torch.chunk(attn_hf.k_proj.bias, attn_hf.num_heads) + v_bs = torch.chunk(attn_hf.v_proj.bias, attn_hf.num_heads) + + qkv_w = [] + qkv_b = [] + for q_w, k_w, v_w, q_b, k_b, v_b in zip(q_ws, k_ws, v_ws, q_bs, k_bs, v_bs): + qkv_w += [q_w, k_w, v_w] + qkv_b += [q_b, k_b, v_b] + qkv_w = torch.cat(qkv_w) + qkv_b = torch.cat(qkv_b) + + with torch.no_grad(): + attn_nt.query_key_value.weight.data = qkv_w.clone() + attn_nt.query_key_value.bias.data = qkv_b.clone() + attn_nt.dense.weight.data = attn_hf.out_proj.weight.clone() + attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() diff --git a/examples/xglm/tests/test_attn.py b/examples/xglm/tests/test_attn.py new file mode 100644 index 00000000..2fcdb3a8 --- /dev/null +++ b/examples/xglm/tests/test_attn.py @@ -0,0 +1,74 @@ +import torch +from torch.nn import functional as F +#torch.Size([4, 2048, 16, 64]), torch.Size([2048, 4, 1024]) + +# inputs = (batchsize * qlen, heads, head_dim) +# outputs = (batchsize*qlen, heads, head_dim) +def sdpa(query, key, value, batchsize: int): + def reshape(tensor): # output = (batchsize, heads, qlen, head_dim) + return tensor.view(batchsize, qlen, heads, head_dim).permute(0, 2, 1, 3) + + batchsize_x_qlen, heads, head_dim = query.size() + qlen = batchsize_x_qlen//batchsize + out = F.scaled_dot_product_attention(reshape(query), reshape(key), reshape(value), is_causal=True) # (b,h,q,d) + return out.permute(0, 2, 1, 3).reshape(batchsize*qlen, heads, head_dim) + + +# inputs = (batchsize * qlen, heads, head_dim) +# outputs = (batchsize*qlen, heads, head_dim) +def fa(query_states, key_states, value_states, batchsize: int): + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + batchsize_x_qlen, heads, head_dim = query_states.size() + qlen = batchsize_x_qlen//batchsize + + q_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") + kv_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") + + # TODO @thomasw21: Compute once, instead of computing for each layers. + cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) + cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) + torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) + torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) + + # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not + # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. + causal = False if q_sequence_mask.shape[1] == 1 else True + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_sequence_mask.shape[1], + max_seqlen_k=kv_sequence_mask.shape[1], + dropout_p=0.0, + softmax_scale=None, # defaults to 1/sqrt(d_qk) + causal=causal, + window_size=(-1, -1), + return_attn_probs=False, + ) + return attn_output + + +def main(): + batchsize = 5 + qlen = 6 + heads = 2 + head_dim = 16 + + query = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + key = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + value = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + + out_pt = sdpa(query, key, value, batchsize) + out_fa = fa(query, key, value, batchsize) + + assert out_pt.size() == out_fa.size() + + torch.testing.assert_close(out_pt, out_fa) + + + +if __name__ == "__main__": + main() diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py new file mode 100644 index 00000000..10f0302a --- /dev/null +++ b/examples/xglm/tests/test_implementation.py @@ -0,0 +1,90 @@ +import numpy as np +import torch +import pytest + +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMSinusoidalPositionalEmbedding + +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import CausalSelfAttention, PositionEmbedding +from nanotron.parallel import ParallelContext + +from tests.helpers.utils import init_distributed + +from examples.xglm.convert_hf2nt import convert_attention + + +SEQUENCE_LENGTH = 2048 +BATCH_SIZE = 4 +HIDDEN_SIZE = 1024 +DTYPE = torch.float64 + +CONFIG = GPT3Config( + attn_pdrop=0.0, + embd_pdrop=0.0, + resid_pdrop=0.0, + eos_token_id=2, + hidden_size=HIDDEN_SIZE, + intermediate_size=4096, + layer_norm_epsilon=1e-05, + max_position_embeddings=SEQUENCE_LENGTH, + num_attention_heads=16, + num_hidden_layers=24, + scale_attn_weights=True, + vocab_size=256008, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=True +) + + +@pytest.fixture +def hidden_states() -> torch.Tensor: + return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, + dtype=DTYPE) + + +@pytest.fixture +def input_mask() -> torch.Tensor: + return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) + + +def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + attn_nt = CausalSelfAttention(CONFIG, None, parallel_context.tp_pg, 0).cuda().eval().to(DTYPE) + attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) + assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) + + # Build xglm mask. + mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) + mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) + + convert_attention(attn_nt, attn_hf) + out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=mask)[0].permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_attention)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_position_embeddings(parallel_context: ParallelContext): + position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, SEQUENCE_LENGTH) + + emb_nt = PositionEmbedding(parallel_context.tp_pg, CONFIG, None).cuda() + emb_hf = XGLMSinusoidalPositionalEmbedding(SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() + + assert emb_nt.position_embedding.weight.size() == emb_hf.weights.size() + torch.testing.assert_close(emb_nt.position_embedding.weight, emb_hf.weights) + + out_nt = emb_nt(position_ids)["position_embeds"] + out_hf = emb_hf(position_ids).permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + +def test_position_embeddings(): + init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 57225243..5b8ac999 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -133,4 +133,40 @@ def n_inner(self): return self.intermediate_size +@dataclass +class GPT3Config: + """Configuration for a GPT3 model""" + + activation_function: str = "gelu" + attn_pdrop: float = 0.1 + embd_pdrop: float = 0.1 + eos_token_id: int = 49152 + hidden_size: int = 2048 + intermediate_size: Optional[int] = None + layer_norm_epsilon: float = 1e-05 + max_position_embeddings: int = 4096 + num_attention_heads: int = 16 + num_hidden_layers: int = 24 + resid_pdrop: float = 0.1 + scale_attention_softmax_in_fp32: bool = True + scale_attn_weights: bool = True + vocab_size: int = 49280 + sinusoidal_position_embedding: bool = True + position_embedding_offset: int = 2 + use_spda: bool = False + + def as_starcoder2(self) -> Starcoder2Config: + config = dict(**vars(self)) + del config["sinusoidal_position_embedding"] + del config["use_spda"] + del config["position_embedding_offset"] + return Starcoder2Config( + grouped_query=True, + num_kv_heads=self.num_attention_heads, + use_rotary_embeddings=False, + **config + ) + + NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any] + diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py new file mode 100644 index 00000000..8cea58c4 --- /dev/null +++ b/src/nanotron/models/gpt3.py @@ -0,0 +1,358 @@ +"""PyTorch GPT-3 model.""" + +import math +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F + +from nanotron import distributed as dist +from nanotron.parallel import ParallelContext +from nanotron.config import Config, GPT3Config, ParallelismArgs +from nanotron.generation.generate_store import AttachableStore +from nanotron.models.starcoder2 import MLP as Starcoder2MLP +from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention +from nanotron.models.starcoder2 import CausalSelfGQA +from nanotron.random import RandomStates, branch_random_state +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding +from nanotron.parallel.tied_parameters import tie_parameters + +# NOTES: +# - tie head_weight with embeddings I think. + +# TODO: +# - class GPT3Config: config lol +# - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. +# - from starcoder import Embedding +# - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding +# - class GPTBLock: very similar to starcoder2 but make it so it support non-GQA or MQA +# - from starcoder import Loss + + +class CoreAttention(Starcoder2CoreAttention): + def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + super().__init__(config.as_starcoder2(), parallel_config, layer_idx) + self.gpt3config = config + + def forward(self, + query_states: torch.Tensor, # [batch_size * q_length, q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] + q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) + kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) + ): + + if self.gpt3config.use_spda: + assert torch.all(q_sequence_mask) + assert torch.all(kv_sequence_mask) + + batch_size, q_length = q_sequence_mask.size() + kv_length = kv_sequence_mask.size(1) + _, q_heads, head_dim = query_states.size() + kv_heads = key_states.size(1) + + attention_output = F.scaled_dot_product_attention( + query_states.view(batch_size, q_length, q_heads, head_dim).permute(0, 2, 1, 3), + key_states.view(batch_size, kv_length, kv_heads, head_dim).permute(0, 2, 1, 3), + value_states.view(batch_size, kv_length, kv_heads, head_dim).permute(0, 2, 1, 3), + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + ) # [batch, q_length, q_heads, head_dim] + attention_output = attention_output.permute(0, 2, 1, 3) + attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) + return attention_output + + assert query_states.dtype in {torch.bfloat16, torch.float16} + return super().forward(query_states, key_states, value_states, q_sequence_mask, kv_sequence_mask) + + +class CausalSelfAttention(CausalSelfGQA): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int, + ): + super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) + self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. + self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. + + +class MLP(Starcoder2MLP): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + # TODO: GPT3Config -> Starcoder2Config. + super().__init__(config, parallel_config, tp_pg) + self.dropout = nn.Dropout(p=config.dropout) # TODO: correct config.dropout name + + def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = self.dropout(input=hidden_states) + hidden_states = self.c_proj(hidden_states) + return {"hidden_states": hidden_states} + + +class GPTBlock(nn.Module): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + random_states: RandomStates, + layer_idx: int, + ): + super(GPTBlock, self).__init__() + self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.attn = CausalSelfAttention( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx + ) + self.attn_dropout = config.attn_pdrop + + self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.ff_dropout = config.resid_pdrop + + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + def forward( + self, + hidden_states: torch.Tensor | TensorPointer, + sequence_mask: torch.Tensor | TensorPointer, + ) -> dict[str, torch.Tensor | TensorPointer]: + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) + else: + # No need for random state context manager + # TODO: add dropout scaling? + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = self.ff(hidden_states=hidden_states)["hidden_states"] + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) + else: + # No need for random state context manager + # TODO: add dropout scaling? + hidden_states = hidden_states + residual + + return { + "hidden_states": hidden_states, + "sequence_mask": output["sequence_mask"], + } + + +class PositionEmbedding(nn.Module, AttachableStore): + def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config: Optional[ParallelismArgs]): + super().__init__() + + self.config = config + if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: + dummy_pos = 0 + else: + dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % k) + true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos + + if config.sinusoidal_position_embedding: + weight = self._make_weights(tp_pg, true_max_size, config.hidden_size) + else: + weight = None + + position_embedding = TensorParallelEmbedding( + num_embeddings=true_max_size, + embedding_dim=config.hidden_size, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + _weight=weight + ) + self.pg = tp_pg + + # Sinusoidal position embeddings are usually not trainable. + # We adjust that by setting the module self.position_embedding without gradient. + if config.sinusoidal_position_embedding: + with torch.no_grad(): + self.position_embedding = position_embedding.requires_grad_(False) + else: + self.position_embedding = position_embedding + + def forward(self, position_ids: torch.Tensor): # [batch_size, seq_length] + position_ids = position_ids.transpose(0, 1) + position_embeds = self.position_embedding(position_ids + self.config.position_embedding_offset) + return {"position_embeds": position_embeds} + + def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, + embedding_dim: int) -> torch.Tensor: + rank = dist.get_rank(group=tp_pg) + tp_size = tp_pg.size() + + assert 0 <= rank < tp_size + assert num_embeddings % tp_size == 0 + assert embedding_dim % 2 == 0 + block_size = num_embeddings//tp_size + + half_dim = embedding_dim//2 + emb = math.log(10_000)/(half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = (rank*block_size + torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) + return emb + + +class GPT3Model(nn.Module): + def __init__( + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + self.token_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"input_ids"}, + module_output_keys={"input_embeds"}, + ) + self.position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=PositionEmbedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"position_ids"}, + module_output_keys={"position_embeds"}, + ) + + self.embeds_dropout = PipelineBlock( + p2p=self.p2p, + module_builder=nn.Dropout, + module_kwargs={"p": config.embd_pdrop}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=GPTBlock, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "random_states": random_states, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "sequence_mask"}, + module_output_keys={"hidden_states", "sequence_mask"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=TritonLayerNorm, + module_kwargs={"normalized_shape": config.hidden_size, "eps": config.layer_norm_epsilon}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.hidden_size, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": parallel_config.tp_linear_async_communication + if parallel_config is not None + else False, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + self.cast_to_fp32 = PipelineBlock( + p2p=self.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) + + + def forward( + self, + input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length] + input_mask: torch.Tensor | TensorPointer, # [batch_size, seq_length] + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) + input_embeds = self.token_embeddings(input_ids=input_ids)["input_embeds"] + position_embeds = self.position_embeds(position_ids=position_ids)["position_embeds"] + hidden_states = input_embeds + position_embeds + + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode == TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = self.embeds_dropout(input=hidden_states)["hidden_states"] + + hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} + for encoder_block in self.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states) + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits From 78dd53cdfdb467961edd1a56b04d8426fd2819df Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 26 Jun 2024 17:24:53 +0000 Subject: [PATCH 02/51] WIP: GPT arch almost done, hf->nt converters working perfectly for non-distributed inference --- examples/xglm/convert_hf2nt.py | 70 +++++++- examples/xglm/tests/test_attn.py | 74 --------- examples/xglm/tests/test_implementation.py | 135 +++++++++++++-- src/nanotron/config/models_config.py | 4 + src/nanotron/models/gpt3.py | 184 ++++++++++----------- 5 files changed, 287 insertions(+), 180 deletions(-) delete mode 100644 examples/xglm/tests/test_attn.py diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index e008f859..6e6ddff1 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -1,7 +1,44 @@ import torch +from torch import nn -from transformers.models.xglm.modeling_xglm import XGLMAttention -from nanotron.models.gpt3 import CausalSelfAttention +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM +from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining +from nanotron.config.models_config import GPT3Config + + +def convert_config(config: XGLMConfig) -> GPT3Config: + # TODOs: + # dropout=0.1, + # layerdrop=0.0, + # init_std=0.02, + # use_cache=True, + # decoder_start_token_id=2, + # pad_token_id=1, + # bos_token_id=0, + + # TODO: when going gpt3->xglm: + # - assert layernorm is 1e-05 + return GPT3Config( + activation_function=config.activation_function, + attn_pdrop=config.attention_dropout, + embd_pdrop=0.0, # TODO + eos_token_id=config.eos_token_id, + hidden_size=config.d_model, + intermediate_size=config.ffn_dim, + layer_norm_epsilon=1e-05, + max_position_embeddings=config.max_position_embeddings, + num_attention_heads=config.attention_heads, + num_hidden_layers=config.num_layers, + resid_pdrop=0.0, # TODO + scale_attention_softmax_in_fp32=True, + scale_attn_weights=True, + vocab_size=config.vocab_size, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=False, + act_pdrop=config.activation_dropout, + scale_embedding=config.scale_embedding, + ) def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): @@ -26,3 +63,32 @@ def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): attn_nt.query_key_value.bias.data = qkv_b.clone() attn_nt.dense.weight.data = attn_hf.out_proj.weight.clone() attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() + + +def convert_generic(module1: nn.Module, module2: nn.Module): + names1 = {name for name, _ in module1.named_parameters()} + names2 = {name for name, _ in module2.named_parameters()} + assert names1 == names2, f"{names1} != {names2}" + params2 = dict(module2.named_parameters()) + for name, param in module1.named_parameters(): + param.data = params2[name].clone() + + +def convert_mlp(mlp_nt: MLP, block_hf: XGLMDecoderLayer): + convert_generic(mlp_nt.c_fc, block_hf.fc1) + convert_generic(mlp_nt.c_proj, block_hf.fc2) + + +def convert_decoder(block_nt: GPTBlock, block_hf: XGLMDecoderLayer): + convert_generic(block_nt.ln_1, block_hf.self_attn_layer_norm) + convert_attention(block_nt.attn, block_hf.self_attn) + convert_generic(block_nt.ln_2, block_hf.final_layer_norm) + convert_mlp(block_nt.ff, block_hf) + + +def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): + convert_generic(model_nt.model.token_embeddings.pp_block.token_embedding, model_hf.model.embed_tokens) + for layer_nt, layer_hf in zip(model_nt.model.decoder, model_hf.model.layers): + convert_decoder(layer_nt.pp_block, layer_hf) + convert_generic(model_nt.model.final_layer_norm.pp_block, model_hf.model.layer_norm) + convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) diff --git a/examples/xglm/tests/test_attn.py b/examples/xglm/tests/test_attn.py deleted file mode 100644 index 2fcdb3a8..00000000 --- a/examples/xglm/tests/test_attn.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -from torch.nn import functional as F -#torch.Size([4, 2048, 16, 64]), torch.Size([2048, 4, 1024]) - -# inputs = (batchsize * qlen, heads, head_dim) -# outputs = (batchsize*qlen, heads, head_dim) -def sdpa(query, key, value, batchsize: int): - def reshape(tensor): # output = (batchsize, heads, qlen, head_dim) - return tensor.view(batchsize, qlen, heads, head_dim).permute(0, 2, 1, 3) - - batchsize_x_qlen, heads, head_dim = query.size() - qlen = batchsize_x_qlen//batchsize - out = F.scaled_dot_product_attention(reshape(query), reshape(key), reshape(value), is_causal=True) # (b,h,q,d) - return out.permute(0, 2, 1, 3).reshape(batchsize*qlen, heads, head_dim) - - -# inputs = (batchsize * qlen, heads, head_dim) -# outputs = (batchsize*qlen, heads, head_dim) -def fa(query_states, key_states, value_states, batchsize: int): - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - batchsize_x_qlen, heads, head_dim = query_states.size() - qlen = batchsize_x_qlen//batchsize - - q_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") - kv_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") - - # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) - - # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not - # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. - causal = False if q_sequence_mask.shape[1] == 1 else True - attn_output = flash_attn_varlen_func( - q=query_states, - k=key_states, - v=value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_sequence_mask.shape[1], - max_seqlen_k=kv_sequence_mask.shape[1], - dropout_p=0.0, - softmax_scale=None, # defaults to 1/sqrt(d_qk) - causal=causal, - window_size=(-1, -1), - return_attn_probs=False, - ) - return attn_output - - -def main(): - batchsize = 5 - qlen = 6 - heads = 2 - head_dim = 16 - - query = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - key = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - value = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - - out_pt = sdpa(query, key, value, batchsize) - out_fa = fa(query, key, value, batchsize) - - assert out_pt.size() == out_fa.size() - - torch.testing.assert_close(out_pt, out_fa) - - - -if __name__ == "__main__": - main() diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index 10f0302a..3636415b 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -1,27 +1,33 @@ +from typing import Optional + import numpy as np import torch import pytest -from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMSinusoidalPositionalEmbedding +from transformers import XGLMTokenizer +from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM +import nanotron from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import CausalSelfAttention, PositionEmbedding +from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock from nanotron.parallel import ParallelContext from tests.helpers.utils import init_distributed -from examples.xglm.convert_hf2nt import convert_attention +from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert SEQUENCE_LENGTH = 2048 BATCH_SIZE = 4 HIDDEN_SIZE = 1024 -DTYPE = torch.float64 +DTYPE = torch.bfloat16 +TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" CONFIG = GPT3Config( attn_pdrop=0.0, embd_pdrop=0.0, resid_pdrop=0.0, + act_pdrop=0.0, eos_token_id=2, hidden_size=HIDDEN_SIZE, intermediate_size=4096, @@ -42,11 +48,22 @@ def hidden_states() -> torch.Tensor: return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) - @pytest.fixture def input_mask() -> torch.Tensor: return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) +@pytest.fixture +def input_ids() -> torch.Tensor: + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH)) + + +def attention_mask() -> torch.Tensor: + # XGLM causal attention mask. + mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) + mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) + return mask + def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() @@ -56,14 +73,9 @@ def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tens attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) - # Build xglm mask. - mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) - mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) - mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) - convert_attention(attn_nt, attn_hf) out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] - out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=mask)[0].permute(1, 0, 2) + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" torch.testing.assert_close(out_nt, out_hf) @@ -88,3 +100,104 @@ def _test_position_embeddings(parallel_context: ParallelContext): def test_position_embeddings(): init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() + + +def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + config_hf = XGLMConfig() + decoder_hf = XGLMDecoderLayer(config_hf).cuda().to(DTYPE).eval() + config_nt = convert_config(config_hf) + if DTYPE not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + decoder_nt = GPTBlock(config_nt, None, parallel_context.tp_pg, random_states, 0).cuda().to(DTYPE).eval() + + convert_decoder(decoder_nt, decoder_hf) + + out_nt = decoder_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, + input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # Get hf model. + if model_hf is None: + config_hf = XGLMConfig() + model_hf = XGLMForCausalLM(config_hf).cuda().to(DTYPE).eval() + else: + model_hf = model_hf.cuda().to(DTYPE).eval() + config_hf = model_hf.config + + # Get nanotron model and make the conversion. + config_nt = convert_config(config_hf) + if DTYPE not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=DTYPE, + device="cuda", + ).eval() + convert(model_nt, model_hf) + + print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters()))/1000/1000) + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask).to(DTYPE) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt.cpu(), out_hf.cpu()) + +def _test_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + _test_model(None, parallel_context, input_ids, input_mask) + + +def test_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) + + +def _test_xglm7B(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") + _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + + +def test_xglm7B(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm7B)() + + +def _test_xglm500M(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") + _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + + +def test_xglm500M(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 5b8ac999..12bac0fb 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -154,12 +154,16 @@ class GPT3Config: sinusoidal_position_embedding: bool = True position_embedding_offset: int = 2 use_spda: bool = False + act_pdrop: float = 0.0 + scale_embedding: bool = True def as_starcoder2(self) -> Starcoder2Config: config = dict(**vars(self)) del config["sinusoidal_position_embedding"] del config["use_spda"] del config["position_embedding_offset"] + del config["act_pdrop"] + del config["scale_embedding"] return Starcoder2Config( grouped_query=True, num_kv_heads=self.num_attention_heads, diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 8cea58c4..99f6ea85 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -2,6 +2,7 @@ import math from typing import Optional +from contextlib import contextmanager import torch from torch import nn @@ -9,11 +10,15 @@ from nanotron import distributed as dist from nanotron.parallel import ParallelContext -from nanotron.config import Config, GPT3Config, ParallelismArgs +from nanotron.config import Config, GPT3Config, ParallelismArgs, Starcoder2Config from nanotron.generation.generate_store import AttachableStore +from nanotron.models import starcoder2 +from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.models.starcoder2 import MLP as Starcoder2MLP +from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention -from nanotron.models.starcoder2 import CausalSelfGQA +from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock +from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel from nanotron.random import RandomStates, branch_random_state from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode @@ -28,10 +33,55 @@ # - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. # - from starcoder import Embedding # - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding -# - class GPTBLock: very similar to starcoder2 but make it so it support non-GQA or MQA +# - class GPTBlock: very similar to starcoder2 but make it so it support non-GQA or MQA # - from starcoder import Loss +@contextmanager +def replace_coreattention(gpt3config: GPT3Config): + orig = starcoder2.CoreAttention + try: + def create_core_attention(config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + return CoreAttention(gpt3config, parallel_config, layer_idx) + starcoder2.CoreAttention = create_core_attention + yield + finally: + starcoder2.CoreAttention = orig + + +@contextmanager +def replace_decoder(gpt3config: GPT3Config): + orig = starcoder2.PipelineBlock + try: + def create_pp_block(module_builder, module_kwargs, **kwargs): + if module_builder is Starcoder2GPTBlock: + # Starcoder2's GPT module is trying to instantiate a Starcoder2 GPTBlock. + # Let's return a PipelineBlock with a GPT3Block instead. + # This also requires to replace starcoders2's config with gpt3's config. + module_kwargs["config"] = gpt3config + return orig(module_builder=GPTBlock, module_kwargs=module_kwargs, **kwargs) + # Else, they are setting up other modules, which we also want unchanged. + return orig(module_builder=module_builder, module_kwargs=module_kwargs, **kwargs) + + starcoder2.PipelineBlock = create_pp_block + yield + finally: + starcoder2.PipelineBlock = orig + + +@contextmanager +def replace_gpt3model(gpt3config: GPT3Config): + orig = starcoder2.GPTModel + try: + def create_gptmodel(config: Starcoder2Config, parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], random_states: RandomStates): + return GPT3Model(gpt3config, parallel_context, parallel_config, random_states) + starcoder2.GPTModel = create_gptmodel + yield + finally: + starcoder2.GPTModel = orig + + class CoreAttention(Starcoder2CoreAttention): def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): super().__init__(config.as_starcoder2(), parallel_config, layer_idx) @@ -63,7 +113,7 @@ def forward(self, ) # [batch, q_length, q_heads, head_dim] attention_output = attention_output.permute(0, 2, 1, 3) attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) - return attention_output + return attention_output.contiguous() assert query_states.dtype in {torch.bfloat16, torch.float16} return super().forward(query_states, key_states, value_states, q_sequence_mask, kv_sequence_mask) @@ -77,9 +127,10 @@ def __init__( tp_pg: dist.ProcessGroup, layer_idx: int, ): - super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) + with replace_coreattention(config): + super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. - self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. + #self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. class MLP(Starcoder2MLP): @@ -88,10 +139,12 @@ def __init__( config: GPT3Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + random_states: RandomStates ): - # TODO: GPT3Config -> Starcoder2Config. - super().__init__(config, parallel_config, tp_pg) - self.dropout = nn.Dropout(p=config.dropout) # TODO: correct config.dropout name + super().__init__(config.as_starcoder2(), parallel_config, tp_pg) + self.dropout = nn.Dropout(p=config.act_pdrop) + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] hidden_states = self.c_fc(hidden_states) @@ -113,6 +166,7 @@ def __init__( random_states: RandomStates, layer_idx: int, ): + #print("New gpt block created :D") super(GPTBlock, self).__init__() self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention( @@ -124,7 +178,7 @@ def __init__( self.attn_dropout = config.attn_pdrop self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, random_states=random_states) self.ff_dropout = config.resid_pdrop self.random_states = random_states @@ -138,8 +192,10 @@ def forward( residual = hidden_states hidden_states = self.ln_1(hidden_states) + #hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] + #return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} if self.training: with branch_random_state( @@ -227,7 +283,7 @@ def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, return emb -class GPT3Model(nn.Module): +class GPT3Model(GPTModel): def __init__( self, config: GPT3Config, @@ -235,24 +291,9 @@ def __init__( parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): - super().__init__() + with replace_decoder(config): + super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) - # Declare all the nodes - self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) - self.random_states = random_states - self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - - self.token_embeddings = PipelineBlock( - p2p=self.p2p, - module_builder=Embedding, - module_kwargs={ - "tp_pg": parallel_context.tp_pg, - "config": config, - "parallel_config": parallel_config, - }, - module_input_keys={"input_ids"}, - module_output_keys={"input_embeds"}, - ) self.position_embeddings = PipelineBlock( p2p=self.p2p, module_builder=PositionEmbedding, @@ -264,69 +305,7 @@ def __init__( module_input_keys={"position_ids"}, module_output_keys={"position_embeds"}, ) - - self.embeds_dropout = PipelineBlock( - p2p=self.p2p, - module_builder=nn.Dropout, - module_kwargs={"p": config.embd_pdrop}, - module_input_keys={"input"}, - module_output_keys={"hidden_states"}, - ) - - self.decoder = nn.ModuleList( - [ - PipelineBlock( - p2p=self.p2p, - module_builder=GPTBlock, - module_kwargs={ - "config": config, - "parallel_config": parallel_config, - "tp_pg": parallel_context.tp_pg, - "random_states": random_states, - "layer_idx": layer_idx, - }, - module_input_keys={"hidden_states", "sequence_mask"}, - module_output_keys={"hidden_states", "sequence_mask"}, - ) - for layer_idx in range(config.num_hidden_layers) - ] - ) - - self.final_layer_norm = PipelineBlock( - p2p=self.p2p, - module_builder=TritonLayerNorm, - module_kwargs={"normalized_shape": config.hidden_size, "eps": config.layer_norm_epsilon}, - module_input_keys={"input"}, - module_output_keys={"hidden_states"}, - ) - - self.lm_head = PipelineBlock( - p2p=self.p2p, - # Understand that this means that we return sharded logits that are going to need to be gathered - module_builder=TensorParallelColumnLinear, - module_kwargs={ - "in_features": config.hidden_size, - "out_features": config.vocab_size, - "pg": parallel_context.tp_pg, - "bias": False, - # TODO: refactor so that we store that default in a single place. - "mode": self.tp_mode, - "async_communication": parallel_config.tp_linear_async_communication - if parallel_config is not None - else False, - }, - module_input_keys={"x"}, - module_output_keys={"logits"}, - ) - - self.cast_to_fp32 = PipelineBlock( - p2p=self.p2p, - module_builder=lambda: lambda x: x.float(), - module_kwargs={}, - module_input_keys={"x"}, - module_output_keys={"output"}, - ) - + self.embed_scale = config.hidden_size**0.5 if config.scale_embedding else 1.0 def forward( self, @@ -335,9 +314,9 @@ def forward( ): # all tensors are optional as most ranks don't need anything from the dataloader. + input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) - input_embeds = self.token_embeddings(input_ids=input_ids)["input_embeds"] - position_embeds = self.position_embeds(position_ids=position_ids)["position_embeds"] + position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] hidden_states = input_embeds + position_embeds with branch_random_state( @@ -348,6 +327,7 @@ def forward( hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) + #return hidden_encoder_states["hidden_states"] hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] @@ -356,3 +336,21 @@ def forward( fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] return fp32_sharded_logits + + +# TODO: maybe reimplement: +# - tie_custom_params +# - get_embeddings_lm_head_tied_names +# - get_block_compute_costs +# - get_flops_per_sec +class GPT3ForTraining(Starcoder2ForTraining): + def __init__( + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + with replace_gpt3model(config): + super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) + From a74c71ad3a56a33501153c7bd00f4418d4ef1cb6 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 9 Jul 2024 16:46:55 +0200 Subject: [PATCH 03/51] Added hf2nt frontend + tested training --- examples/xglm/README.md | 13 +++ examples/xglm/convert_hf2nt.py | 86 ++++++++++++++-- examples/xglm/example_config.yaml | 98 +++++++++++++++++++ src/nanotron/config/models_config.py | 6 +- src/nanotron/models/gpt3.py | 23 +---- .../optimizer_from_gradient_accumulator.py | 3 +- src/nanotron/trainer.py | 2 + 7 files changed, 199 insertions(+), 32 deletions(-) create mode 100644 examples/xglm/README.md create mode 100644 examples/xglm/example_config.yaml diff --git a/examples/xglm/README.md b/examples/xglm/README.md new file mode 100644 index 00000000..abc50f95 --- /dev/null +++ b/examples/xglm/README.md @@ -0,0 +1,13 @@ +# How to use XGLM? + +1. First, make sure to convert the weights from huggingface, for instance: + ``` + torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M + ``` + +1. Now you are ready to use XGLM. + Make sure you use a .yaml configuration with proper GPT3 config and then run for instance: + ``` + torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml + ``` + If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 6e6ddff1..9db5ed93 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -1,27 +1,42 @@ +""" +Converts a HF model to nanotron format +Command: + torchrun --nproc-per-node=1 convert_hf2nt.py --checkpoint-path=hf_weights --save-path=nanotron_weights +""" + +import json +import warnings +import dataclasses +from argparse import ArgumentParser +from pathlib import Path + import torch from torch import nn - from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM + +import nanotron from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from nanotron.config.models_config import GPT3Config +from nanotron.trainer import mark_tied_parameters + def convert_config(config: XGLMConfig) -> GPT3Config: # TODOs: - # dropout=0.1, # layerdrop=0.0, # init_std=0.02, # use_cache=True, - # decoder_start_token_id=2, # pad_token_id=1, # bos_token_id=0, - - # TODO: when going gpt3->xglm: - # - assert layernorm is 1e-05 + if config.dropout != config.attention_dropout: + warnings.warn(f"huggingface.dropout = {config.dropout} does not match with " + f"huggingface.attention_dropout = {config.attention_dropout}. " + "Nanotron implementation needs these two values to be equal " + "for correct conversion.") return GPT3Config( activation_function=config.activation_function, attn_pdrop=config.attention_dropout, - embd_pdrop=0.0, # TODO + embd_pdrop=config.dropout, eos_token_id=config.eos_token_id, hidden_size=config.d_model, intermediate_size=config.ffn_dim, @@ -29,12 +44,12 @@ def convert_config(config: XGLMConfig) -> GPT3Config: max_position_embeddings=config.max_position_embeddings, num_attention_heads=config.attention_heads, num_hidden_layers=config.num_layers, - resid_pdrop=0.0, # TODO + resid_pdrop=config.dropout, scale_attention_softmax_in_fp32=True, scale_attn_weights=True, vocab_size=config.vocab_size, sinusoidal_position_embedding=True, - position_embedding_offset=2, + position_embedding_offset=config.decoder_start_token_id, use_spda=False, act_pdrop=config.activation_dropout, scale_embedding=config.scale_embedding, @@ -92,3 +107,56 @@ def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): convert_decoder(layer_nt.pp_block, layer_hf) convert_generic(model_nt.model.final_layer_norm.pp_block, model_hf.model.layer_norm) convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) + + +def create_nt_model(model_config: GPT3Config, device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16) -> GPT3ForTraining: + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + #random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + return model_nt + + +def main(hf_path: str, save_path: Path): + # Load hf. + print("Loading hf...") + model_hf = XGLMForCausalLM.from_pretrained(hf_path) + + # Init nanotron. + print("Initializing nt...") + config_nt = convert_config(model_hf.config) + model_nt = create_nt_model(config_nt) + + # Copy weights and save model. + print("Copying weights...") + convert(model_nt, model_hf) + nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, + root_folder=save_path) + with open(save_path/"model_config.json", "w+") as f: + json.dump(dataclasses.asdict(config_nt), f) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument("--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint") + parser.add_argument("--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model") + args = parser.parse_args() + main(args.checkpoint_path, args.save_path) diff --git a/examples/xglm/example_config.yaml b/examples/xglm/example_config.yaml new file mode 100644 index 00000000..2d7e9926 --- /dev/null +++ b/examples/xglm/example_config.yaml @@ -0,0 +1,98 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints/xglm + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 64 + hf_dataset_config_name: null + hf_dataset_or_datasets: DKYoon/SlimPajama-6B + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Finetuning + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: xglm-test + run: xglm-dp4tp1pp1 + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + path: /capstor/scratch/cscs/ahernnde/checkpoints/xglm-564M + make_vocab_size_divisible_by: 1 + model_config: + activation_function: gelu + attn_pdrop: 0.1 + embd_pdrop: 0.1 + scale_embedding: true + eos_token_id: 2 + hidden_size: 1024 + intermediate_size: 4096 + layer_norm_epsilon: 0.00001 + max_position_embeddings: 2048 + num_attention_heads: 16 + num_hidden_layers: 24 + resid_pdrop: 0.1 + scale_attention_softmax_in_fp32: true + scale_attn_weights: true + vocab_size: 256008 + sinusoidal_position_embedding: true + position_embedding_offset: 2 + use_spda: false + act_pdrop: 0.0 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 900 + lr_decay_style: cosine + lr_warmup_steps: 100 + lr_warmup_style: linear + min_decay_lr: 1.0e-04 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 4 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: facebook/xglm-564M + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 4 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 8 + sequence_length: 2048 + train_steps: 1000 + val_check_interval: -1 diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 12bac0fb..6c568e80 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -164,6 +164,8 @@ def as_starcoder2(self) -> Starcoder2Config: del config["position_embedding_offset"] del config["act_pdrop"] del config["scale_embedding"] + if "_is_using_mup" in config: + del config["_is_using_mup"] return Starcoder2Config( grouped_query=True, num_kv_heads=self.num_attention_heads, @@ -171,6 +173,4 @@ def as_starcoder2(self) -> Starcoder2Config: **config ) - -NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any] - +NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 99f6ea85..33661c8b 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -18,24 +18,13 @@ from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock -from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel +from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel, dropout_add_fused_train from nanotron.random import RandomStates, branch_random_state from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding from nanotron.parallel.tied_parameters import tie_parameters -# NOTES: -# - tie head_weight with embeddings I think. - -# TODO: -# - class GPT3Config: config lol -# - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. -# - from starcoder import Embedding -# - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding -# - class GPTBlock: very similar to starcoder2 but make it so it support non-GQA or MQA -# - from starcoder import Loss - @contextmanager def replace_coreattention(gpt3config: GPT3Config): @@ -130,7 +119,6 @@ def __init__( with replace_coreattention(config): super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. - #self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. class MLP(Starcoder2MLP): @@ -204,7 +192,6 @@ def forward( hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) else: # No need for random state context manager - # TODO: add dropout scaling? hidden_states = hidden_states + residual residual = hidden_states @@ -218,7 +205,6 @@ def forward( hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) else: # No need for random state context manager - # TODO: add dropout scaling? hidden_states = hidden_states + residual return { @@ -235,7 +221,7 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: dummy_pos = 0 else: - dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % k) + dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size()) true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos if config.sinusoidal_position_embedding: @@ -278,7 +264,7 @@ def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, half_dim = embedding_dim//2 emb = math.log(10_000)/(half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) - emb = (rank*block_size + torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = (rank*block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) return emb @@ -315,6 +301,7 @@ def forward( # all tensors are optional as most ranks don't need anything from the dataloader. input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale + # TODO: position_ids could be cached. position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] hidden_states = input_embeds + position_embeds @@ -339,8 +326,6 @@ def forward( # TODO: maybe reimplement: -# - tie_custom_params -# - get_embeddings_lm_head_tied_names # - get_block_compute_costs # - get_flops_per_sec class GPT3ForTraining(Starcoder2ForTraining): diff --git a/src/nanotron/optim/optimizer_from_gradient_accumulator.py b/src/nanotron/optim/optimizer_from_gradient_accumulator.py index 01be7cb5..9883c720 100644 --- a/src/nanotron/optim/optimizer_from_gradient_accumulator.py +++ b/src/nanotron/optim/optimizer_from_gradient_accumulator.py @@ -38,7 +38,8 @@ def __init__( **{k: v for k, v in named_param_group.items() if k != "named_params"}, "named_params": [ (name, gradient_accumulator.get_parameter_for_optimizer(name)) - for name, _ in named_param_group["named_params"] + for name, param in named_param_group["named_params"] + if param.requires_grad ], } for named_param_group in named_param_groups diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index b6752f38..f01caa3e 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -58,6 +58,7 @@ from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp @@ -103,6 +104,7 @@ CONFIG_TO_MODEL_CLASS = { "LlamaConfig": LlamaForTraining, "Starcoder2Config": Starcoder2ForTraining, + "GPT3Config": GPT3ForTraining, } try: From 04eaef956a091bbcc40e5c2ef140aad7b577f003 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 11 Jul 2024 13:38:52 +0200 Subject: [PATCH 04/51] Added nt2hf conversion + tests :) --- examples/xglm/README.md | 5 + examples/xglm/convert_hf2nt.py | 38 +---- examples/xglm/convert_nt2hf.py | 126 +++++++++++++++ examples/xglm/convert_utils.py | 59 +++++++ examples/xglm/tests/test_implementation.py | 177 +++++++++++++++++---- src/nanotron/config/models_config.py | 4 + src/nanotron/models/gpt3.py | 2 +- 7 files changed, 347 insertions(+), 64 deletions(-) create mode 100644 examples/xglm/convert_nt2hf.py create mode 100644 examples/xglm/convert_utils.py diff --git a/examples/xglm/README.md b/examples/xglm/README.md index abc50f95..22765f52 100644 --- a/examples/xglm/README.md +++ b/examples/xglm/README.md @@ -11,3 +11,8 @@ torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml ``` If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. + +1. If you want to convert your finetuned checkpoint back to huggingface use: + ``` + torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M + ``` diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 9db5ed93..0efcceca 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -18,11 +18,11 @@ from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from nanotron.config.models_config import GPT3Config from nanotron.trainer import mark_tied_parameters - +from examples.xglm.convert_utils import convert_generic, create_nt_model def convert_config(config: XGLMConfig) -> GPT3Config: - # TODOs: + # These settings seem to be unused: # layerdrop=0.0, # init_std=0.02, # use_cache=True, @@ -80,15 +80,6 @@ def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() -def convert_generic(module1: nn.Module, module2: nn.Module): - names1 = {name for name, _ in module1.named_parameters()} - names2 = {name for name, _ in module2.named_parameters()} - assert names1 == names2, f"{names1} != {names2}" - params2 = dict(module2.named_parameters()) - for name, param in module1.named_parameters(): - param.data = params2[name].clone() - - def convert_mlp(mlp_nt: MLP, block_hf: XGLMDecoderLayer): convert_generic(mlp_nt.c_fc, block_hf.fc1) convert_generic(mlp_nt.c_proj, block_hf.fc2) @@ -109,31 +100,6 @@ def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) -def create_nt_model(model_config: GPT3Config, device: torch.device = torch.device("cuda"), - dtype: torch.dtype = torch.bfloat16) -> GPT3ForTraining: - - parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) - parallel_context = nanotron.parallel.ParallelContext( - data_parallel_size=parallel_config.dp, - pipeline_parallel_size=parallel_config.pp, - tensor_parallel_size=parallel_config.tp, - ) - #random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) - model_nt = nanotron.models.build_model( - model_builder=lambda: GPT3ForTraining( - config=model_config, - parallel_context=parallel_context, - parallel_config=parallel_config, - random_states=None, - ), - parallel_context=parallel_context, - dtype=dtype, - device=device, - ) - mark_tied_parameters(model=model_nt, parallel_context=parallel_context) - return model_nt - - def main(hf_path: str, save_path: Path): # Load hf. print("Loading hf...") diff --git a/examples/xglm/convert_nt2hf.py b/examples/xglm/convert_nt2hf.py new file mode 100644 index 00000000..422695a1 --- /dev/null +++ b/examples/xglm/convert_nt2hf.py @@ -0,0 +1,126 @@ +""" +Converts a nanotron model to HF format +Command: + torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights +""" + +from argparse import ArgumentParser +from typing import Optional +from pathlib import Path + +import torch +from transformers import AutoTokenizer +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM + +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining +from examples.xglm.convert_utils import convert_generic, create_nt_model + + +def convert_config(config: GPT3Config) -> XGLMConfig: + if config.embd_pdrop != config.resid_pdrop: + warnings.warn(f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion.") + if config.layer_norm_epsilon != 1e-5: + warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") + return XGLMConfig( + activation_function=config.activation_function, + attention_dropout=config.attn_pdrop, + dropout=config.embd_pdrop, + eos_token_id=config.eos_token_id, + d_model=config.hidden_size, + ffn_dim=config.intermediate_size, + max_position_embeddings=config.max_position_embeddings, + attention_heads=config.num_attention_heads, + num_layers=config.num_hidden_layers, + vocab_size=config.vocab_size, + decoder_start_token_id=config.position_embedding_offset, + activation_dropout=config.act_pdrop, + scale_embedding=config.scale_embedding, + ) + + +def convert_attention(attn_hf: XGLMAttention, attn_nt: XGLMAttention): + qs_w = [] + ks_w = [] + vs_w = [] + qs_b = [] + ks_b = [] + vs_b = [] + + head_dim = attn_hf.head_dim + qkv_ws = list(attn_nt.query_key_value.weight.split(head_dim)) + qkv_bs = list(attn_nt.query_key_value.bias.split(head_dim)) + for i, (w, b) in enumerate(zip(qkv_ws, qkv_bs)): + if i % 3 == 0: + qs_w.append(w) + qs_b.append(b) + elif i % 3 == 1: + ks_w.append(w) + ks_b.append(b) + else: + vs_w.append(w) + vs_b.append(b) + + q_w = torch.cat(qs_w) + k_w = torch.cat(ks_w) + v_w = torch.cat(vs_w) + q_b = torch.cat(qs_b) + k_b = torch.cat(ks_b) + v_b = torch.cat(vs_b) + + with torch.no_grad(): + attn_hf.q_proj.weight.data = q_w.clone() + attn_hf.k_proj.weight.data = k_w.clone() + attn_hf.v_proj.weight.data = v_w.clone() + attn_hf.q_proj.bias.data = q_b.clone() + attn_hf.k_proj.bias.data = k_b.clone() + attn_hf.v_proj.bias.data = v_b.clone() + + attn_hf.out_proj.weight.data = attn_nt.dense.weight.data.clone() + attn_hf.out_proj.bias.data = attn_nt.dense.bias.data.clone() + + +def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPTBlock): + convert_generic(block_hf.self_attn_layer_norm, block_nt.ln_1) + convert_attention(block_hf.self_attn, block_nt.attn) + convert_generic(block_hf.final_layer_norm, block_nt.ln_2) + convert_generic(block_hf.fc1, block_nt.ff.c_fc) + convert_generic(block_hf.fc2, block_nt.ff.c_proj) + + +def convert(model_hf: XGLMForCausalLM, model_nt: GPT3ForTraining): + convert_generic(model_hf.model.embed_tokens, model_nt.model.token_embeddings.pp_block.token_embedding) + for layer_hf, layer_nt in zip(model_hf.model.layers, model_nt.model.decoder): + convert_decoder(layer_hf, layer_nt.pp_block) + convert_generic(model_hf.model.layer_norm, model_nt.model.final_layer_norm.pp_block) + convert_generic(model_hf.lm_head, model_nt.model.lm_head.pp_block) + + +def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): + # Load nanotron model. + model_nt = create_nt_model(checkpoint_path=checkpoint_path) + + # Init huggingface model. + model_config_hf = convert_config(model_nt.config) + model_hf = XGLMForCausalLM._from_config(model_config_hf) + + # Copy weights, initialize tokenizer and save model. + if tokenizer_name is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.save_pretrained(save_path) + convert(model_hf, model_nt) + model_hf.save_pretrained(save_path) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument("--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint") + parser.add_argument("--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model") + parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") + args = parser.parse_args() + main(args.checkpoint_path, args.save_path, args.tokenizer_name) + diff --git a/examples/xglm/convert_utils.py b/examples/xglm/convert_utils.py new file mode 100644 index 00000000..88a731a1 --- /dev/null +++ b/examples/xglm/convert_utils.py @@ -0,0 +1,59 @@ +import json +from pathlib import Path +from typing import Optional + +import torch +from torch import nn + +import nanotron +from nanotron.models.gpt3 import GPT3ForTraining +from nanotron.config.models_config import GPT3Config +from nanotron.trainer import mark_tied_parameters + + +def convert_generic(module1: nn.Module, module2: nn.Module): + names1 = {name for name, _ in module1.named_parameters()} + names2 = {name for name, _ in module2.named_parameters()} + assert names1 == names2, f"{names1} != {names2}" + params2 = dict(module2.named_parameters()) + for name, param in module1.named_parameters(): + param.data = params2[name].clone() + + +def create_nt_model( + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None + ): + + if model_config is None: + assert checkpoint_path is not None + with open(checkpoint_path / "model_config.json") as f: + model_config = GPT3Config(**json.load(f)) + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + if checkpoint_path is not None: + nanotron.serialize.load_weights( + model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path + ) + + return model_nt diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index 3636415b..d9dc0f85 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -8,6 +8,7 @@ from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM import nanotron +from nanotron.trainer import mark_tied_parameters from nanotron.config.models_config import GPT3Config from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock from nanotron.parallel import ParallelContext @@ -15,12 +16,17 @@ from tests.helpers.utils import init_distributed from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert +from examples.xglm.convert_nt2hf import convert_attention as convert_attention_nt2hf +from examples.xglm.convert_nt2hf import convert_config as convert_config_nt2hf +from examples.xglm.convert_nt2hf import convert_decoder as convert_decoder_nt2hf +from examples.xglm.convert_nt2hf import convert as convert_nt2hf -SEQUENCE_LENGTH = 2048 +MAX_SEQUENCE_LENGTH = 2048 +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. BATCH_SIZE = 4 HIDDEN_SIZE = 1024 -DTYPE = torch.bfloat16 +DTYPE = torch.float64 TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" CONFIG = GPT3Config( @@ -32,7 +38,7 @@ hidden_size=HIDDEN_SIZE, intermediate_size=4096, layer_norm_epsilon=1e-05, - max_position_embeddings=SEQUENCE_LENGTH, + max_position_embeddings=MAX_SEQUENCE_LENGTH, num_attention_heads=16, num_hidden_layers=24, scale_attn_weights=True, @@ -45,25 +51,39 @@ @pytest.fixture def hidden_states() -> torch.Tensor: - return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) @pytest.fixture def input_mask() -> torch.Tensor: - return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) + return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) @pytest.fixture def input_ids() -> torch.Tensor: - return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH)) + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) + + +def almost_close(t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-5, rtol: float = 0.016, + max_far: float = 0.0, far_atol: float = 0.01): + very_close = torch.abs(t1 - t2) <= atol + rtol*torch.abs(t2) + not_very_close = ~very_close + + if torch.all(very_close): + return + assert torch.mean(not_very_close.float()) <= max_far, f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" + assert torch.all(torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" def attention_mask() -> torch.Tensor: # XGLM causal attention mask. - mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.ones(TEST_SEQUENCE_LENGTH, TEST_SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) return mask +## +# FROM HERE DOWN (until next comment), all tests are hf->nt +## def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() @@ -85,10 +105,10 @@ def test_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): def _test_position_embeddings(parallel_context: ParallelContext): - position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, SEQUENCE_LENGTH) + position_ids = torch.arange(TEST_SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, TEST_SEQUENCE_LENGTH) emb_nt = PositionEmbedding(parallel_context.tp_pg, CONFIG, None).cuda() - emb_hf = XGLMSinusoidalPositionalEmbedding(SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() + emb_hf = XGLMSinusoidalPositionalEmbedding(MAX_SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() assert emb_nt.position_embedding.weight.size() == emb_hf.weights.size() torch.testing.assert_close(emb_nt.position_embedding.weight, emb_hf.weights) @@ -120,7 +140,7 @@ def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt, out_hf) + torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): @@ -129,21 +149,25 @@ def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) input_ids = input_ids.cuda() input_mask = input_mask.cuda() + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + # Get hf model. if model_hf is None: config_hf = XGLMConfig() - model_hf = XGLMForCausalLM(config_hf).cuda().to(DTYPE).eval() + model_hf = XGLMForCausalLM(config_hf).cuda().to(new_dtype).eval() else: - model_hf = model_hf.cuda().to(DTYPE).eval() + model_hf = model_hf.cuda().to(new_dtype).eval() config_hf = model_hf.config # Get nanotron model and make the conversion. config_nt = convert_config(config_hf) - if DTYPE not in {torch.bfloat16, torch.float16}: + if new_dtype not in {torch.bfloat16, torch.float16}: config_nt.use_spda = True model_nt = nanotron.models.build_model( model_builder=lambda: GPT3ForTraining( @@ -153,7 +177,7 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC random_states=random_states, ), parallel_context=parallel_context, - dtype=DTYPE, + dtype=new_dtype, device="cuda", ).eval() convert(model_nt, model_hf) @@ -162,42 +186,141 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC # Get outputs and assert. with torch.no_grad(): - out_nt = model_nt.model(input_ids, input_mask).to(DTYPE) + out_nt = model_nt.model(input_ids, input_mask).to(new_dtype) del model_nt torch.cuda.empty_cache() out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) del model_hf torch.cuda.empty_cache() assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.cpu(), out_hf.cpu()) + return out_nt.cpu(), out_hf.cpu() + def _test_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): - _test_model(None, parallel_context, input_ids, input_mask) + out_nt, out_hf = _test_model(None, parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.05) def test_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): init_distributed(tp=1, dp=1, pp=1)(_test_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) +def _test_xglm500M(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") + out_nt, out_hf = _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.05) + + +def test_xglm500M(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() + + def _test_xglm7B(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") - _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + almost_close(out_nt, out_hf, max_far=0.15, far_atol=0.1) def test_xglm7B(): init_distributed(tp=1, dp=1, pp=1)(_test_xglm7B)() -def _test_xglm500M(parallel_context: ParallelContext): - tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") - tokenized = tok(TEXT) - model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") - _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) +## +# From here down we test nt->hf converters +## +def _test_nt2hf_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() -def test_xglm500M(): - init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() + attn_nt = CausalSelfAttention(CONFIG, None, parallel_context.tp_pg, 0).cuda().eval().to(DTYPE) + attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) + assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) + + convert_attention_nt2hf(attn_hf, attn_nt) + out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_nt2hf_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_attention)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_nt2hf_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + config_hf = convert_config_nt2hf(CONFIG) + decoder_nt = GPTBlock(CONFIG, None, parallel_context.tp_pg, random_states, 0).cuda().to(DTYPE).eval() + decoder_hf = XGLMDecoderLayer(config_hf).cuda().to(DTYPE).eval() + + convert_decoder_nt2hf(decoder_hf, decoder_nt) + + out_nt = decoder_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + + +def test_nt2hf_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + + # Get nanotron model. + config_nt = GPT3Config(**vars(CONFIG)) + if new_dtype not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=new_dtype, + device="cuda", + ).eval() + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + # Create empty model_hf and make conversion. + model_hf = XGLMForCausalLM(convert_config_nt2hf(config_nt)).cuda().to(new_dtype).eval() + convert_nt2hf(model_hf, model_nt) + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask).to(new_dtype) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + return out_nt.cpu(), out_hf.cpu() + + +def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.01, far_atol=0.02) + + +def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 6c568e80..80f956d1 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -173,4 +173,8 @@ def as_starcoder2(self) -> Starcoder2Config: **config ) + @property + def n_inner(self): + return self.intermediate_size + NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 33661c8b..7d4e6f82 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -338,4 +338,4 @@ def __init__( ): with replace_gpt3model(config): super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) - + self.config = config From 138da5ff5a5a9c34ac6191149dfdc83603b08e20 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 11 Jul 2024 14:32:44 +0200 Subject: [PATCH 05/51] precommit --- examples/xglm/convert_hf2nt.py | 33 ++++---- examples/xglm/convert_nt2hf.py | 28 ++++--- examples/xglm/convert_utils.py | 21 +++-- examples/xglm/tests/test_implementation.py | 89 ++++++++++++++-------- src/nanotron/config/models_config.py | 8 +- src/nanotron/models/gpt3.py | 85 ++++++++++++--------- src/nanotron/trainer.py | 2 +- 7 files changed, 154 insertions(+), 112 deletions(-) diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 0efcceca..c18a1ab8 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -4,20 +4,18 @@ torchrun --nproc-per-node=1 convert_hf2nt.py --checkpoint-path=hf_weights --save-path=nanotron_weights """ +import dataclasses import json import warnings -import dataclasses from argparse import ArgumentParser from pathlib import Path +import nanotron import torch -from torch import nn +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import MLP, CausalSelfAttention, GPT3ForTraining, GPTBlock from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM -import nanotron -from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining -from nanotron.config.models_config import GPT3Config -from nanotron.trainer import mark_tied_parameters from examples.xglm.convert_utils import convert_generic, create_nt_model @@ -29,10 +27,12 @@ def convert_config(config: XGLMConfig) -> GPT3Config: # pad_token_id=1, # bos_token_id=0, if config.dropout != config.attention_dropout: - warnings.warn(f"huggingface.dropout = {config.dropout} does not match with " - f"huggingface.attention_dropout = {config.attention_dropout}. " - "Nanotron implementation needs these two values to be equal " - "for correct conversion.") + warnings.warn( + f"huggingface.dropout = {config.dropout} does not match with " + f"huggingface.attention_dropout = {config.attention_dropout}. " + "Nanotron implementation needs these two values to be equal " + "for correct conversion." + ) return GPT3Config( activation_function=config.activation_function, attn_pdrop=config.attention_dropout, @@ -113,16 +113,19 @@ def main(hf_path: str, save_path: Path): # Copy weights and save model. print("Copying weights...") convert(model_nt, model_hf) - nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, - root_folder=save_path) - with open(save_path/"model_config.json", "w+") as f: + nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, root_folder=save_path) + with open(save_path / "model_config.json", "w+") as f: json.dump(dataclasses.asdict(config_nt), f) print(f"Model saved to {save_path}") if __name__ == "__main__": parser = ArgumentParser(description="Convert HF weights to nanotron format") - parser.add_argument("--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint") - parser.add_argument("--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model") + parser.add_argument( + "--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model" + ) args = parser.parse_args() main(args.checkpoint_path, args.save_path) diff --git a/examples/xglm/convert_nt2hf.py b/examples/xglm/convert_nt2hf.py index 422695a1..81816aa9 100644 --- a/examples/xglm/convert_nt2hf.py +++ b/examples/xglm/convert_nt2hf.py @@ -4,25 +4,28 @@ torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights """ +import warnings from argparse import ArgumentParser -from typing import Optional from pathlib import Path +from typing import Optional import torch +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import GPT3ForTraining, GPTBlock from transformers import AutoTokenizer from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM -from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from examples.xglm.convert_utils import convert_generic, create_nt_model def convert_config(config: GPT3Config) -> XGLMConfig: if config.embd_pdrop != config.resid_pdrop: - warnings.warn(f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " - f"nanotron.resid_pdrop = {config.resid_pdrop}. " - "XGLM implementation needs these two values to be equal " - "for correct conversion.") + warnings.warn( + f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion." + ) if config.layer_norm_epsilon != 1e-5: warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") return XGLMConfig( @@ -70,7 +73,7 @@ def convert_attention(attn_hf: XGLMAttention, attn_nt: XGLMAttention): q_b = torch.cat(qs_b) k_b = torch.cat(ks_b) v_b = torch.cat(vs_b) - + with torch.no_grad(): attn_hf.q_proj.weight.data = q_w.clone() attn_hf.k_proj.weight.data = k_w.clone() @@ -118,9 +121,12 @@ def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): if __name__ == "__main__": parser = ArgumentParser(description="Convert HF weights to nanotron format") - parser.add_argument("--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint") - parser.add_argument("--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model") + parser.add_argument( + "--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model" + ) parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") args = parser.parse_args() main(args.checkpoint_path, args.save_path, args.tokenizer_name) - diff --git a/examples/xglm/convert_utils.py b/examples/xglm/convert_utils.py index 88a731a1..75d67782 100644 --- a/examples/xglm/convert_utils.py +++ b/examples/xglm/convert_utils.py @@ -2,13 +2,12 @@ from pathlib import Path from typing import Optional -import torch -from torch import nn - import nanotron -from nanotron.models.gpt3 import GPT3ForTraining +import torch from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.trainer import mark_tied_parameters +from torch import nn def convert_generic(module1: nn.Module, module2: nn.Module): @@ -21,11 +20,11 @@ def convert_generic(module1: nn.Module, module2: nn.Module): def create_nt_model( - model_config: Optional[GPT3Config] = None, - device: torch.device = torch.device("cuda"), - dtype: torch.dtype = torch.bfloat16, - checkpoint_path: Optional[Path] = None - ): + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None, +): if model_config is None: assert checkpoint_path is not None @@ -52,8 +51,6 @@ def create_nt_model( mark_tied_parameters(model=model_nt, parallel_context=parallel_context) if checkpoint_path is not None: - nanotron.serialize.load_weights( - model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path - ) + nanotron.serialize.load_weights(model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path) return model_nt diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index d9dc0f85..a25d7881 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -1,29 +1,31 @@ from typing import Optional +import nanotron import numpy as np -import torch import pytest - -from transformers import XGLMTokenizer -from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM - -import nanotron -from nanotron.trainer import mark_tied_parameters +import torch from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock +from nanotron.models.gpt3 import CausalSelfAttention, GPT3ForTraining, GPTBlock, PositionEmbedding from nanotron.parallel import ParallelContext +from nanotron.trainer import mark_tied_parameters +from transformers import XGLMTokenizer +from transformers.models.xglm.modeling_xglm import ( + XGLMAttention, + XGLMConfig, + XGLMDecoderLayer, + XGLMForCausalLM, + XGLMSinusoidalPositionalEmbedding, +) -from tests.helpers.utils import init_distributed - -from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert +from examples.xglm.convert_hf2nt import convert, convert_attention, convert_config, convert_decoder +from examples.xglm.convert_nt2hf import convert as convert_nt2hf from examples.xglm.convert_nt2hf import convert_attention as convert_attention_nt2hf from examples.xglm.convert_nt2hf import convert_config as convert_config_nt2hf from examples.xglm.convert_nt2hf import convert_decoder as convert_decoder_nt2hf -from examples.xglm.convert_nt2hf import convert as convert_nt2hf - +from tests.helpers.utils import init_distributed MAX_SEQUENCE_LENGTH = 2048 -TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. BATCH_SIZE = 4 HIDDEN_SIZE = 1024 DTYPE = torch.float64 @@ -45,33 +47,44 @@ vocab_size=256008, sinusoidal_position_embedding=True, position_embedding_offset=2, - use_spda=True + use_spda=True, ) @pytest.fixture def hidden_states() -> torch.Tensor: - return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, - dtype=DTYPE) + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) + @pytest.fixture def input_mask() -> torch.Tensor: return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) + @pytest.fixture def input_ids() -> torch.Tensor: return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) -def almost_close(t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-5, rtol: float = 0.016, - max_far: float = 0.0, far_atol: float = 0.01): - very_close = torch.abs(t1 - t2) <= atol + rtol*torch.abs(t2) +def almost_close( + t1: torch.Tensor, + t2: torch.Tensor, + atol: float = 1e-5, + rtol: float = 0.016, + max_far: float = 0.0, + far_atol: float = 0.01, +): + very_close = torch.abs(t1 - t2) <= atol + rtol * torch.abs(t2) not_very_close = ~very_close if torch.all(very_close): return - assert torch.mean(not_very_close.float()) <= max_far, f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" - assert torch.all(torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" + assert ( + torch.mean(not_very_close.float()) <= max_far + ), f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" + assert torch.all( + torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol + ), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" def attention_mask() -> torch.Tensor: @@ -81,10 +94,12 @@ def attention_mask() -> torch.Tensor: mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) return mask + ## # FROM HERE DOWN (until next comment), all tests are hf->nt ## + def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() sequence_mask = sequence_mask.cuda() @@ -118,6 +133,7 @@ def _test_position_embeddings(parallel_context: ParallelContext): assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" torch.testing.assert_close(out_nt, out_hf) + def test_position_embeddings(): init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() @@ -140,15 +156,21 @@ def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + torch.testing.assert_close( + out_nt.bfloat16(), out_hf.bfloat16() + ) # We cast to bf16 to get more relaxed constraints. def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): init_distributed(tp=1, dp=1, pp=1)(_test_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) -def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, - input_ids: torch.Tensor, input_mask: torch.Tensor): +def _test_model( + model_hf: Optional[XGLMForCausalLM], + parallel_context: ParallelContext, + input_ids: torch.Tensor, + input_mask: torch.Tensor, +): random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) input_ids = input_ids.cuda() @@ -182,7 +204,7 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC ).eval() convert(model_nt, model_hf) - print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters()))/1000/1000) + print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters())) / 1000 / 1000) # Get outputs and assert. with torch.no_grad(): @@ -209,8 +231,9 @@ def _test_xglm500M(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") - out_nt, out_hf = _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model( + model_hf, parallel_context, torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]]) + ) almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.05) @@ -222,8 +245,9 @@ def _test_xglm7B(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") - out_nt, out_hf = _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model( + model_hf, parallel_context, torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]]) + ) almost_close(out_nt, out_hf, max_far=0.15, far_atol=0.1) @@ -235,6 +259,7 @@ def test_xglm7B(): # From here down we test nt->hf converters ## + def _test_nt2hf_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() sequence_mask = sequence_mask.cuda() @@ -269,7 +294,9 @@ def _test_nt2hf_decoder(parallel_context: ParallelContext, hidden_states: torch. out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + torch.testing.assert_close( + out_nt.bfloat16(), out_hf.bfloat16() + ) # We cast to bf16 to get more relaxed constraints. def test_nt2hf_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 80f956d1..37593a54 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, List, Optional, Union +from typing import List, Optional @dataclass @@ -167,14 +167,12 @@ def as_starcoder2(self) -> Starcoder2Config: if "_is_using_mup" in config: del config["_is_using_mup"] return Starcoder2Config( - grouped_query=True, - num_kv_heads=self.num_attention_heads, - use_rotary_embeddings=False, - **config + grouped_query=True, num_kv_heads=self.num_attention_heads, use_rotary_embeddings=False, **config ) @property def n_inner(self): return self.intermediate_size + NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 7d4e6f82..25e5f78b 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -1,37 +1,40 @@ """PyTorch GPT-3 model.""" import math -from typing import Optional from contextlib import contextmanager +from typing import Optional import torch from torch import nn from torch.nn import functional as F from nanotron import distributed as dist -from nanotron.parallel import ParallelContext -from nanotron.config import Config, GPT3Config, ParallelismArgs, Starcoder2Config +from nanotron.config import GPT3Config, ParallelismArgs, Starcoder2Config from nanotron.generation.generate_store import AttachableStore from nanotron.models import starcoder2 -from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.models.starcoder2 import MLP as Starcoder2MLP -from nanotron.parallel.pipeline_parallel.block import PipelineBlock +from nanotron.models.starcoder2 import CausalSelfGQA, GPTModel, Starcoder2ForTraining, dropout_add_fused_train from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock -from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel, dropout_add_fused_train -from nanotron.random import RandomStates, branch_random_state +from nanotron.nn.layer_norm import TritonLayerNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding -from nanotron.parallel.tied_parameters import tie_parameters +from nanotron.random import RandomStates, branch_random_state @contextmanager def replace_coreattention(gpt3config: GPT3Config): orig = starcoder2.CoreAttention try: - def create_core_attention(config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + + def create_core_attention( + config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int + ): return CoreAttention(gpt3config, parallel_config, layer_idx) + starcoder2.CoreAttention = create_core_attention yield finally: @@ -42,6 +45,7 @@ def create_core_attention(config: Starcoder2Config, parallel_config: Optional[Pa def replace_decoder(gpt3config: GPT3Config): orig = starcoder2.PipelineBlock try: + def create_pp_block(module_builder, module_kwargs, **kwargs): if module_builder is Starcoder2GPTBlock: # Starcoder2's GPT module is trying to instantiate a Starcoder2 GPTBlock. @@ -62,9 +66,15 @@ def create_pp_block(module_builder, module_kwargs, **kwargs): def replace_gpt3model(gpt3config: GPT3Config): orig = starcoder2.GPTModel try: - def create_gptmodel(config: Starcoder2Config, parallel_context: ParallelContext, - parallel_config: Optional[ParallelismArgs], random_states: RandomStates): + + def create_gptmodel( + config: Starcoder2Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): return GPT3Model(gpt3config, parallel_context, parallel_config, random_states) + starcoder2.GPTModel = create_gptmodel yield finally: @@ -76,7 +86,8 @@ def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs super().__init__(config.as_starcoder2(), parallel_config, layer_idx) self.gpt3config = config - def forward(self, + def forward( + self, query_states: torch.Tensor, # [batch_size * q_length, q_heads, inner_dim] key_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] value_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] @@ -101,7 +112,7 @@ def forward(self, is_causal=True, ) # [batch, q_length, q_heads, head_dim] attention_output = attention_output.permute(0, 2, 1, 3) - attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) + attention_output = attention_output.reshape(batch_size * q_length, q_heads, head_dim) return attention_output.contiguous() assert query_states.dtype in {torch.bfloat16, torch.float16} @@ -127,7 +138,7 @@ def __init__( config: GPT3Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, - random_states: RandomStates + random_states: RandomStates, ): super().__init__(config.as_starcoder2(), parallel_config, tp_pg) self.dropout = nn.Dropout(p=config.act_pdrop) @@ -154,14 +165,11 @@ def __init__( random_states: RandomStates, layer_idx: int, ): - #print("New gpt block created :D") + # print("New gpt block created :D") super(GPTBlock, self).__init__() self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - layer_idx=layer_idx + config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx ) self.attn_dropout = config.attn_pdrop @@ -180,10 +188,10 @@ def forward( residual = hidden_states hidden_states = self.ln_1(hidden_states) - #hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) + # hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] - #return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} + # return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} if self.training: with branch_random_state( @@ -221,7 +229,9 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: dummy_pos = 0 else: - dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size()) + dummy_pos = tp_pg.size() - ( + (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() + ) true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos if config.sinusoidal_position_embedding: @@ -234,7 +244,7 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config embedding_dim=config.hidden_size, pg=tp_pg, mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, - _weight=weight + _weight=weight, ) self.pg = tp_pg @@ -251,32 +261,31 @@ def forward(self, position_ids: torch.Tensor): # [batch_size, seq_length] position_embeds = self.position_embedding(position_ids + self.config.position_embedding_offset) return {"position_embeds": position_embeds} - def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, - embedding_dim: int) -> torch.Tensor: + def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, embedding_dim: int) -> torch.Tensor: rank = dist.get_rank(group=tp_pg) tp_size = tp_pg.size() assert 0 <= rank < tp_size assert num_embeddings % tp_size == 0 assert embedding_dim % 2 == 0 - block_size = num_embeddings//tp_size + block_size = num_embeddings // tp_size - half_dim = embedding_dim//2 - emb = math.log(10_000)/(half_dim - 1) + half_dim = embedding_dim // 2 + emb = math.log(10_000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) - emb = (rank*block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = (rank * block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) return emb class GPT3Model(GPTModel): def __init__( - self, - config: GPT3Config, - parallel_context: ParallelContext, - parallel_config: Optional[ParallelismArgs], - random_states: RandomStates, - ): + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): with replace_decoder(config): super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) @@ -300,7 +309,9 @@ def forward( ): # all tensors are optional as most ranks don't need anything from the dataloader. - input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale + input_embeds = ( + self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"] * self.embed_scale + ) # TODO: position_ids could be cached. position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] @@ -314,7 +325,7 @@ def forward( hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) - #return hidden_encoder_states["hidden_states"] + # return hidden_encoder_states["hidden_states"] hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index f01caa3e..bc81e326 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -56,9 +56,9 @@ ) from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining -from nanotron.models.gpt3 import GPT3ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp From 0485fd64dc4cf8b68eaa963bfd586de4e6c4ac67 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 11:45:28 +0000 Subject: [PATCH 06/51] Added MultilingualNanoset Config --- src/nanotron/config/config.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 05b49955..bfd20227 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -107,6 +107,27 @@ def __post_init__(self): self.dataset_weights = list(tmp_dataset_folder.values()) +@dataclass +class MultilingualNanosetDatasetsArgs: + dataset_folder: Union[str, dict, List[str]] + dataset_tokens: List[ + int + ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) + + def __post_init__(self): + if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file + self.dataset_folder = [self.dataset_folder] + self.dataset_weights = [1] + elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + self.dataset_weights = None # Set to None so we consume all the samples randomly + elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights + tmp_dataset_folder = self.dataset_folder.copy() + self.dataset_folder = list(tmp_dataset_folder.keys()) + self.dataset_weights = list(tmp_dataset_folder.values()) + + assert len(self.dataset_folder) == len(self.dataset_tokens) + + @dataclass class DataArgs: """Arguments related to the data and data files processing""" From 539832ade4914ac92bc8c66b55dbf031f3195ec6 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 11:48:51 +0000 Subject: [PATCH 07/51] Added MultilingualNanoset --- run_train.py | 125 +++++++++++- src/nanotron/data/multilingual_nanoset.py | 221 ++++++++++++++++++++++ 2 files changed, 343 insertions(+), 3 deletions(-) create mode 100644 src/nanotron/data/multilingual_nanoset.py diff --git a/run_train.py b/run_train.py index 021d955d..649784ca 100644 --- a/run_train.py +++ b/run_train.py @@ -12,7 +12,13 @@ import numpy as np from nanotron import logging -from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs +from nanotron.config import ( + DataArgs, + DatasetStageArgs, + MultilingualNanosetDatasetsArgs, + NanosetDatasetsArgs, + PretrainDatasetsArgs, +) from nanotron.data.dataloader_builder import build_nanoset_dataloader from nanotron.dataloader import ( clm_process, @@ -171,6 +177,40 @@ def get_dataloader_from_data_stage( dataloader_drop_last=True, ) + return train_dataloader + # Case 4: MultilingualNanosets + elif isinstance(data.dataset, MultilingualNanosetDatasetsArgs): + # Get tokenizer cardinality + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + # Create Nanoset + from nanotron.data.multilingual_nanoset import MultilingualNanoset + + with main_rank_first(trainer.parallel_context.world_pg): + train_dataset = MultilingualNanoset( + dataset_folders=data.dataset.dataset_folder, + dataset_weights=data.dataset.dataset_weights, + sequence_length=trainer.sequence_length, + token_size=token_size, + train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, + random_seed=data.seed, + ) + + # Prepare dataloader + train_dataloader = build_nanoset_dataloader( + train_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=consumed_train_samples, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + return train_dataloader else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") @@ -178,6 +218,57 @@ def get_dataloader_from_data_stage( return dataloader +def get_valid_dataloader_from_data_stage( + trainer: DistributedTrainer, + data: DataArgs, + valid_split_num_samples: int, + # consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples +): + + # First, we need to know which ranks to feed the dataloader to + input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) + + # Only support Validation with MultilingualNanosets + if isinstance(data.dataset, NanosetDatasetsArgs): + # Get tokenizer cardinality + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + # Create Multilingual Nanoset + from nanotron.data.multilingual_nanoset import MultilingualNanoset + + with main_rank_first(trainer.parallel_context.world_pg): + valid_dataset = MultilingualNanoset( + dataset_folders=data.dataset.dataset_folder, + dataset_weights=data.dataset.dataset_weights, + sequence_length=trainer.sequence_length, + token_size=token_size, + train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + valid_split_num_samples=valid_split_num_samples, + is_valid=True, + random_seed=data.seed, + ) + + # Prepare dataloader + valid_dataloader = build_nanoset_dataloader( + valid_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=0, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + + return valid_dataloader + else: + raise ValueError( + f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}. Validation is currently just supported for MultilingualNanoset" + ) + + def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloaders = {} @@ -219,6 +310,33 @@ def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: return dataloaders +def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: + dataloaders = {} + + for stage_idx, stage in enumerate(trainer.config.data_stages): + # NOTE: we only create the dataloader for the first stage, + # then we lazy initialize the dataloader for the other stages + stage = cast(DatasetStageArgs, stage) + valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size + + log_rank( + f"[Training Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = ( + get_valid_dataloader_from_data_stage(trainer, stage.data, valid_split_num_samples=valid_split_num_samples) + if stage_idx == 0 + else lambda stage=stage: get_dataloader_from_data_stage( + trainer, stage.data, valid_split_num_samples=valid_split_num_samples + ) + ) + dataloaders[stage.name] = dataloader + return dataloaders + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") @@ -231,7 +349,8 @@ def get_args(): # Load trainer and data trainer = DistributedTrainer(config_file) - dataloader = get_dataloader(trainer) + train_dataloader = get_dataloader(trainer) + valid_dataloader = get_valid_dataloader(trainer) # Train - trainer.train(dataloader) + trainer.train(train_dataloader, valid_dataloader) diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py new file mode 100644 index 00000000..40e06b87 --- /dev/null +++ b/src/nanotron/data/multilingual_nanoset.py @@ -0,0 +1,221 @@ +import os +import warnings +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +from datatrove.utils.dataset import DatatroveFolderDataset +from nanotron import logging +from nanotron.data.utils import count_dataset_indexes, normalize +from nanotron.logging import log_rank +from numba import jit + +logger = logging.get_logger(__name__) + + +class MultilingualNanoset(torch.utils.data.Dataset): + """ + The Nanoset dataset + + Args: + dataset_folders (List[str]): List of folders with tokenized datasets + dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ + sequence_length (int): Sequence length of the built samples + token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise + train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size + """ + + def __init__( + self, + dataset_folders: List[str], + sequence_length: int, + token_size: int, + train_split_num_samples: int, + valid_split_num_samples: int, + is_valid: bool = False, + dataset_weights: Union[List[float], None] = None, + random_seed: int = 1234, + ) -> None: + + # Checks + if isinstance(dataset_folders, str): + warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]") + dataset_folders = [dataset_folders] + + # Init + self.dataset_folders = dataset_folders + self.sequence_length = sequence_length + self.token_size = token_size + self.train_split_num_samples = train_split_num_samples + self.valid_split_num_samples = valid_split_num_samples + self.is_valid = is_valid + self.random_seed = random_seed + self.datatrove_datasets = [] + for dataset_folder in self.dataset_folders: + self.datatrove_datasets.append( + DatatroveFolderDataset( + folder_path=dataset_folder, + filename_pattern=os.path.join(dataset_folder, "*.ds"), + seq_len=sequence_length, + recursive=False, + token_size=token_size, + shuffle=True, + ) + ) + + # Build Nanoset Index + ## To build the index we need the length of each dataset + self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets] + ## Set dataset weights + if ( + dataset_weights is None + ): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch + self.dataset_weights = normalize(self.dataset_lengths) + else: + self.dataset_weights = normalize(dataset_weights) + assert len(dataset_folders) == len( + self.dataset_weights + ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." + ## Build dataset index and dataset sample index + ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts + self.valid_dataset_lenghts = self.dataset_weights * valid_split_num_samples + # Assert that we have sufficient samples to build the valid split + for ds_index in range(len(self.dataset_lengths)): + assert ( + self.valid_dataset_lenghts[ds_index] > self.dataset_lengths[ds_index] + ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." + self.train_dataset_lenghts = [ + a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) + ] # Subtract the valid samples from the training dataset + + if is_valid: # Valid MultilingualNanoset + self.split_num_samples = valid_split_num_samples + self.split_samples_per_epoch = valid_split_num_samples + self.num_epochs = 1 + self.split_dataset_lenghts = self.valid_dataset_lenghts + self.split_dataset_offsets = self.train_dataset_lenghts + + else: # Train MultilingualNanoset + self.split_num_samples = train_split_num_samples + self.split_samples_per_epoch = sum(self.train_dataset_lenghts) + self.num_epochs = int(self.split_num_samples / self.split_samples_per_epoch) + 1 + self.split_dataset_lenghts = self.train_dataset_lenghts + self.split_dataset_offsets = [ + 0 for _ in range(len(self.dataset_lengths)) + ] # For training there is NO offset + + self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() + + self.print_nanoset_info() + + def __len__(self) -> int: + """ + Returns: + int: The number of samples of the Nanoset + """ + + return len(self.dataset_index) + + def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: + """ + Returns sequence_length + 1 tokens from the memmap dataset + + Args: + idx (int): The index into the dataset + + Returns: + Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary + """ + dataset = self.dataset_index[idx] + dataset_sample = self.dataset_sample_index[idx] + + return self.datatrove_datasets[dataset][dataset_sample] + + def build_nanoset_index(self) -> np.ndarray: + """ + Build dataset index and dataset sample index + """ + # Build the dataset indexes for 1 epoch + dataset_index, dataset_sample_index = build_nanoset_index_helper( + n_samples=self.split_samples_per_epoch, + weights=self.dataset_weights, + dataset_sizes=self.split_dataset_lengths, + offsets=self.split_dataset_offsets, + ) + # Shuffle the indexes the same way + numpy_random_state = np.random.RandomState(self.random_seed) + numpy_random_state.shuffle(dataset_index) + numpy_random_state = np.random.RandomState(self.random_seed) + numpy_random_state.shuffle(dataset_sample_index) + # Concatenate num_epochs the shuffled indexes + dataset_index = np.concatenate([dataset_index for _ in range(self.num_epochs)]) + dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(self.num_epochs)]) + # Just keep the necessary samples + dataset_index = dataset_index[: self.split_num_samples] + dataset_sample_index = dataset_sample_index[: self.split_num_samples] + + return dataset_index, dataset_sample_index + + def print_nanoset_info(self): + + log_rank( + f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of samples: {len(self)}", + logger=logger, + level=logging.INFO, + rank=0, + ) + log_rank( + f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of tokens: {len(self) * self.sequence_length}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + # Print samples from each dataset + weight + dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders)) + for index, sample_count in enumerate(dataset_sample_count): + log_rank( + f"> Total number of {'validation' if self.is_valid else 'training'} samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +@jit(nopython=True, cache=True) +def build_nanoset_index_helper( + n_samples: int, weights: np.ndarray, dataset_sizes: List[int], offsets: List[int] +) -> Tuple[np.ndarray, np.ndarray]: + """ + Given multiple datasets and a weighting array, build samples indexes + such that it follows those weights. + For train and valid splits we split each dataset_folder in train (first part) and valid splits. We set the offsets to the train lengths + for generating the valid split + """ + # Create empty arrays for dataset indices and dataset sample indices + dataset_index = np.empty((n_samples,), dtype="uint") + dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples + + # Initialize buffer for number of samples used for each dataset + current_samples = np.zeros((len(weights),), dtype="long") + + # Iterate over all samples + for sample_idx in range(n_samples): + + # Convert sample index to float for comparison against weights + sample_idx_float = max(sample_idx, 1.0) + + # Find the dataset with the highest error + errors = weights * sample_idx_float - current_samples + max_error_index = np.argmax(errors) + + # Assign the dataset index and update the sample index + dataset_index[sample_idx] = max_error_index + dataset_sample_index[sample_idx] = ( + current_samples[max_error_index] % dataset_sizes[max_error_index] + ) + offsets[max_error_index] + + # Update the total samples for the selected dataset + current_samples[max_error_index] += 1 + + return dataset_index, dataset_sample_index From d9f06703d49762b261075467981a066bc01f9249 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 12:25:17 +0000 Subject: [PATCH 08/51] Added Language token --- examples/config_multilingual_nanoset.yaml | 120 ++++++++++++++++++++++ src/nanotron/data/multilingual_nanoset.py | 7 +- 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 examples/config_multilingual_nanoset.yaml diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml new file mode 100644 index 00000000..00ae6570 --- /dev/null +++ b/examples/config_multilingual_nanoset.yaml @@ -0,0 +1,120 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints/ + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_folder: datasets/c4-es/tokenized + dataset_tokens: + - 15 + num_loading_workers: 1 + seed: 42 + name: General purpose training (Single dataset) + start_training_step: 1 +- data: + dataset: + dataset_folder: + - datasets/SlimPajama-6B/tokenized + - datasets/c4-es/tokenized + dataset_tokens: + - 16 + - 15 + num_loading_workers: 1 + seed: 42 + name: Second purpose training (> 1 dataset) + start_training_step: 15 +- data: + dataset: + dataset_folder: + datasets/SlimPajama-6B/tokenized: 0.8 + datasets/c4-es/tokenized: 0.2 + dataset_tokens: + - 16 + - 15 + num_loading_workers: 1 + seed: 42 + name: Third purpose training (Blended dataset) + start_training_step: 25 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: Nanoset + run: llama + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 11008 + is_llama_config: true + max_position_embeddings: 4096 + num_hidden_layers: 32 + num_attention_heads: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rope_interleaved: false + rope_theta: 500000.0 + rms_norm_eps: 1.0e-06 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 98 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 10 + micro_batch_size: 2 + sequence_length: 1024 + train_steps: 200 + val_check_interval: -1 diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 40e06b87..6526659d 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -32,6 +32,7 @@ def __init__( token_size: int, train_split_num_samples: int, valid_split_num_samples: int, + dataset_tokens: List[int], is_valid: bool = False, dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, @@ -48,6 +49,7 @@ def __init__( self.token_size = token_size self.train_split_num_samples = train_split_num_samples self.valid_split_num_samples = valid_split_num_samples + self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed self.datatrove_datasets = [] @@ -129,7 +131,10 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset = self.dataset_index[idx] dataset_sample = self.dataset_sample_index[idx] - return self.datatrove_datasets[dataset][dataset_sample] + tokens = self.datatrove_datasets[dataset][dataset_sample] + tokens[0] = self.dataset_tokens[dataset] # Prepend language token + + return tokens def build_nanoset_index(self) -> np.ndarray: """ From efe87209103382f004a33ddfd940df75c0deef89 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 12:51:42 +0000 Subject: [PATCH 09/51] Forgot the trainer ups --- src/nanotron/trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index bc81e326..3f4c5189 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -393,7 +393,10 @@ def find_stage_idx_to_resume(): def train( self, - dataloader_or_dls: Dict[ + train_dataloader_or_dls: Dict[ + str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] + ], + valid_dataloader_or_dls: Dict[ str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] ], **kwargs, @@ -424,7 +427,7 @@ def train( prof.step() self.iteration_start_time = time.time() - self._update_dataloader_based_on_training_stages(dataloader_or_dls) + self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) From 25ad39b2b25fe80c380065dba9e211dba31ed11e Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 14:12:57 +0000 Subject: [PATCH 10/51] Fix minor errors. Everything works --- run_train.py | 6 ++++-- src/nanotron/config/config.py | 2 +- src/nanotron/data/multilingual_nanoset.py | 11 +++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/run_train.py b/run_train.py index 649784ca..9b77da77 100644 --- a/run_train.py +++ b/run_train.py @@ -195,6 +195,7 @@ def get_dataloader_from_data_stage( token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, + dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -229,7 +230,7 @@ def get_valid_dataloader_from_data_stage( input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) # Only support Validation with MultilingualNanosets - if isinstance(data.dataset, NanosetDatasetsArgs): + if isinstance(data.dataset, MultilingualNanosetDatasetsArgs): # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 @@ -245,6 +246,7 @@ def get_valid_dataloader_from_data_stage( token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, valid_split_num_samples=valid_split_num_samples, + dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -320,7 +322,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size log_rank( - f"[Training Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index bfd20227..924a2cdf 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -132,7 +132,7 @@ def __post_init__(self): class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] + dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, MultilingualNanosetDatasetsArgs] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 6526659d..cd8be195 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -1,5 +1,6 @@ import os import warnings +from math import ceil from typing import Dict, List, Tuple, Union import numpy as np @@ -80,11 +81,13 @@ def __init__( ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts - self.valid_dataset_lenghts = self.dataset_weights * valid_split_num_samples + self.valid_dataset_lenghts = [ + ceil(weight * valid_split_num_samples) for weight in self.dataset_weights + ] # Better not tu use numpy so we don't get overflow issues # Assert that we have sufficient samples to build the valid split for ds_index in range(len(self.dataset_lengths)): assert ( - self.valid_dataset_lenghts[ds_index] > self.dataset_lengths[ds_index] + self.dataset_lengths[ds_index] > self.valid_dataset_lenghts[ds_index] ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." self.train_dataset_lenghts = [ a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) @@ -132,7 +135,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset_sample = self.dataset_sample_index[idx] tokens = self.datatrove_datasets[dataset][dataset_sample] - tokens[0] = self.dataset_tokens[dataset] # Prepend language token + tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token return tokens @@ -144,7 +147,7 @@ def build_nanoset_index(self) -> np.ndarray: dataset_index, dataset_sample_index = build_nanoset_index_helper( n_samples=self.split_samples_per_epoch, weights=self.dataset_weights, - dataset_sizes=self.split_dataset_lengths, + dataset_sizes=self.split_dataset_lenghts, offsets=self.split_dataset_offsets, ) # Shuffle the indexes the same way From d91f9e1e8b67ffa51a14fff9bb0e408c02920631 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 14:13:33 +0000 Subject: [PATCH 11/51] Updated config file with GPT2 tokenized datasets in RCP --- examples/config_multilingual_nanoset.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 00ae6570..3c4476a0 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -7,7 +7,7 @@ checkpoints: data_stages: - data: dataset: - dataset_folder: datasets/c4-es/tokenized + dataset_folder: /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized dataset_tokens: - 15 num_loading_workers: 1 @@ -17,8 +17,8 @@ data_stages: - data: dataset: dataset_folder: - - datasets/SlimPajama-6B/tokenized - - datasets/c4-es/tokenized + - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized + - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized dataset_tokens: - 16 - 15 @@ -29,8 +29,8 @@ data_stages: - data: dataset: dataset_folder: - datasets/SlimPajama-6B/tokenized: 0.8 - datasets/c4-es/tokenized: 0.2 + /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized: 0.8 + /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized: 0.2 dataset_tokens: - 16 - 15 @@ -65,7 +65,7 @@ model: initializer_range: 0.02 intermediate_size: 11008 is_llama_config: true - max_position_embeddings: 4096 + max_position_embeddings: 1024 num_hidden_layers: 32 num_attention_heads: 32 num_key_value_heads: 8 @@ -108,7 +108,7 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_name_or_path: gpt2 tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 From d0c14e38054cb9bef16d75940e2ee076cde26bea Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 10:13:57 +0000 Subject: [PATCH 12/51] Before lunch --- run_train.py | 13 +--- src/nanotron/config/config.py | 6 +- src/nanotron/data/multilingual_nanoset.py | 76 +++++++++-------------- 3 files changed, 37 insertions(+), 58 deletions(-) diff --git a/run_train.py b/run_train.py index 9b77da77..57e0ec25 100644 --- a/run_train.py +++ b/run_train.py @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -222,7 +221,6 @@ def get_dataloader_from_data_stage( def get_valid_dataloader_from_data_stage( trainer: DistributedTrainer, data: DataArgs, - valid_split_num_samples: int, # consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples ): @@ -245,7 +243,6 @@ def get_valid_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - valid_split_num_samples=valid_split_num_samples, dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, @@ -259,7 +256,6 @@ def get_valid_dataloader_from_data_stage( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, micro_batch_size=trainer.micro_batch_size, - consumed_train_samples=0, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, ) @@ -319,21 +315,18 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: # NOTE: we only create the dataloader for the first stage, # then we lazy initialize the dataloader for the other stages stage = cast(DatasetStageArgs, stage) - valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size log_rank( - f"[Validation Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", logger=logger, level=logging.INFO, rank=0, ) dataloader = ( - get_valid_dataloader_from_data_stage(trainer, stage.data, valid_split_num_samples=valid_split_num_samples) + get_valid_dataloader_from_data_stage(trainer, stage.data) if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage( - trainer, stage.data, valid_split_num_samples=valid_split_num_samples - ) + else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) ) dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 924a2cdf..fb3e49dd 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -109,7 +109,8 @@ def __post_init__(self): @dataclass class MultilingualNanosetDatasetsArgs: - dataset_folder: Union[str, dict, List[str]] + training_folder: Union[str, dict, List[str]] + validation_folder: Union[str, dict, List[str]] dataset_tokens: List[ int ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) @@ -125,7 +126,8 @@ def __post_init__(self): self.dataset_folder = list(tmp_dataset_folder.keys()) self.dataset_weights = list(tmp_dataset_folder.values()) - assert len(self.dataset_folder) == len(self.dataset_tokens) + assert len(self.training_folder) == len(self.validation_folder) + assert len(self.training_folder) == len(self.dataset_tokens) @dataclass diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index cd8be195..f634fd98 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -1,6 +1,5 @@ import os import warnings -from math import ceil from typing import Dict, List, Tuple, Union import numpy as np @@ -32,7 +31,6 @@ def __init__( sequence_length: int, token_size: int, train_split_num_samples: int, - valid_split_num_samples: int, dataset_tokens: List[int], is_valid: bool = False, dataset_weights: Union[List[float], None] = None, @@ -49,7 +47,6 @@ def __init__( self.sequence_length = sequence_length self.token_size = token_size self.train_split_num_samples = train_split_num_samples - self.valid_split_num_samples = valid_split_num_samples self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed @@ -80,36 +77,11 @@ def __init__( self.dataset_weights ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index - ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts - self.valid_dataset_lenghts = [ - ceil(weight * valid_split_num_samples) for weight in self.dataset_weights - ] # Better not tu use numpy so we don't get overflow issues - # Assert that we have sufficient samples to build the valid split - for ds_index in range(len(self.dataset_lengths)): - assert ( - self.dataset_lengths[ds_index] > self.valid_dataset_lenghts[ds_index] - ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." - self.train_dataset_lenghts = [ - a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) - ] # Subtract the valid samples from the training dataset - if is_valid: # Valid MultilingualNanoset - self.split_num_samples = valid_split_num_samples - self.split_samples_per_epoch = valid_split_num_samples - self.num_epochs = 1 - self.split_dataset_lenghts = self.valid_dataset_lenghts - self.split_dataset_offsets = self.train_dataset_lenghts + self.dataset_index, self.dataset_sample_index = self.build_valid_nanoset_index(self.dataset_lengths) else: # Train MultilingualNanoset - self.split_num_samples = train_split_num_samples - self.split_samples_per_epoch = sum(self.train_dataset_lenghts) - self.num_epochs = int(self.split_num_samples / self.split_samples_per_epoch) + 1 - self.split_dataset_lenghts = self.train_dataset_lenghts - self.split_dataset_offsets = [ - 0 for _ in range(len(self.dataset_lengths)) - ] # For training there is NO offset - - self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() + self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() self.print_nanoset_info() @@ -139,16 +111,16 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: return tokens - def build_nanoset_index(self) -> np.ndarray: + def build_train_nanoset_index(self) -> np.ndarray: """ - Build dataset index and dataset sample index + Build train dataset index and dataset sample index """ + # Compute samples per epoch and number of epochs + samples_per_epoch = sum(self.dataset_lengths) + num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1 # Build the dataset indexes for 1 epoch - dataset_index, dataset_sample_index = build_nanoset_index_helper( - n_samples=self.split_samples_per_epoch, - weights=self.dataset_weights, - dataset_sizes=self.split_dataset_lenghts, - offsets=self.split_dataset_offsets, + dataset_index, dataset_sample_index = build_train_nanoset_index_helper( + n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths ) # Shuffle the indexes the same way numpy_random_state = np.random.RandomState(self.random_seed) @@ -156,14 +128,28 @@ def build_nanoset_index(self) -> np.ndarray: numpy_random_state = np.random.RandomState(self.random_seed) numpy_random_state.shuffle(dataset_sample_index) # Concatenate num_epochs the shuffled indexes - dataset_index = np.concatenate([dataset_index for _ in range(self.num_epochs)]) - dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(self.num_epochs)]) + dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)]) + dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)]) # Just keep the necessary samples - dataset_index = dataset_index[: self.split_num_samples] - dataset_sample_index = dataset_sample_index[: self.split_num_samples] + dataset_index = dataset_index[: self.train_split_num_samples] + dataset_sample_index = dataset_sample_index[: self.train_split_num_samples] return dataset_index, dataset_sample_index + @jit(nopython=True, cache=True) + def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: + """ + Build valid dataset index and dataset sample index + """ + dataset_index = [] + dataset_sample_index = [] + + for i, length in enumerate(dataset_lengths): + dataset_index.extend([i] * length) + dataset_sample_index.extend(range(length)) + + return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") + def print_nanoset_info(self): log_rank( @@ -191,8 +177,8 @@ def print_nanoset_info(self): @jit(nopython=True, cache=True) -def build_nanoset_index_helper( - n_samples: int, weights: np.ndarray, dataset_sizes: List[int], offsets: List[int] +def build_train_nanoset_index_helper( + n_samples: int, weights: np.ndarray, dataset_sizes: List[int] ) -> Tuple[np.ndarray, np.ndarray]: """ Given multiple datasets and a weighting array, build samples indexes @@ -219,9 +205,7 @@ def build_nanoset_index_helper( # Assign the dataset index and update the sample index dataset_index[sample_idx] = max_error_index - dataset_sample_index[sample_idx] = ( - current_samples[max_error_index] % dataset_sizes[max_error_index] - ) + offsets[max_error_index] + dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index] # Update the total samples for the selected dataset current_samples[max_error_index] += 1 From 9cfc5ea954505d880ffe19580ef8e60b4c8acd70 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 14:10:03 +0000 Subject: [PATCH 13/51] After lunch --- examples/config_multilingual_nanoset.yaml | 42 +++++++++++++++-------- run_train.py | 6 ++-- src/nanotron/config/config.py | 21 ++++++------ src/nanotron/data/multilingual_nanoset.py | 33 +++++++++--------- tools/preprocess_data.py | 5 ++- 5 files changed, 61 insertions(+), 46 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 3c4476a0..238f8269 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -7,7 +7,8 @@ checkpoints: data_stages: - data: dataset: - dataset_folder: /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized + training_folder: datasets/c4-es/train + validation_folder: datasets/c4-es/validation dataset_tokens: - 15 num_loading_workers: 1 @@ -16,24 +17,37 @@ data_stages: start_training_step: 1 - data: dataset: - dataset_folder: - - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized - - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation dataset_tokens: - - 16 - 15 + - 16 + - 17 num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) start_training_step: 15 - data: dataset: - dataset_folder: - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized: 0.8 - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized: 0.2 + training_folder: + datasets/c4-es/train: 0.6 + datasets/c4-en/train: 0.3 + datasets/c4-fr/train: 0.1 + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation dataset_tokens: - - 16 - 15 + - 16 + - 17 + num_loading_workers: 1 seed: 42 name: Third purpose training (Blended dataset) @@ -61,12 +75,12 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 4096 + hidden_size: 512 initializer_range: 0.02 - intermediate_size: 11008 + intermediate_size: 512 is_llama_config: true max_position_embeddings: 1024 - num_hidden_layers: 32 + num_hidden_layers: 2 num_attention_heads: 32 num_key_value_heads: 8 pad_token_id: null @@ -108,13 +122,13 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: gpt2 + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 10 - micro_batch_size: 2 + micro_batch_size: 4 sequence_length: 1024 train_steps: 200 val_check_interval: -1 diff --git a/run_train.py b/run_train.py index 57e0ec25..39cda23b 100644 --- a/run_train.py +++ b/run_train.py @@ -189,7 +189,7 @@ def get_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): train_dataset = MultilingualNanoset( - dataset_folders=data.dataset.dataset_folder, + dataset_folders=data.dataset.training_folder, dataset_weights=data.dataset.dataset_weights, sequence_length=trainer.sequence_length, token_size=token_size, @@ -238,11 +238,9 @@ def get_valid_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): valid_dataset = MultilingualNanoset( - dataset_folders=data.dataset.dataset_folder, - dataset_weights=data.dataset.dataset_weights, + dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index fb3e49dd..ce61a249 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -110,21 +110,20 @@ def __post_init__(self): @dataclass class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] - validation_folder: Union[str, dict, List[str]] - dataset_tokens: List[ - int - ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) + validation_folder: Union[str, List[str]] + dataset_tokens: List[int] # Set token for each language previously defined def __post_init__(self): - if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file - self.dataset_folder = [self.dataset_folder] + if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder + self.training_folder = [self.training_folder] + self.validation_folder = [self.validation_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights - tmp_dataset_folder = self.dataset_folder.copy() - self.dataset_folder = list(tmp_dataset_folder.keys()) - self.dataset_weights = list(tmp_dataset_folder.values()) + elif isinstance(self.training_folder, dict): # Case 3: dict with > 1 training_folder and weights + tmp_training_folder = self.training_folder.copy() + self.training_folder = list(tmp_training_folder.keys()) + self.dataset_weights = list(tmp_training_folder.values()) assert len(self.training_folder) == len(self.validation_folder) assert len(self.training_folder) == len(self.dataset_tokens) diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index f634fd98..7af57448 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -30,8 +30,8 @@ def __init__( dataset_folders: List[str], sequence_length: int, token_size: int, - train_split_num_samples: int, dataset_tokens: List[int], + train_split_num_samples: int = None, is_valid: bool = False, dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, @@ -78,7 +78,7 @@ def __init__( ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index if is_valid: # Valid MultilingualNanoset - self.dataset_index, self.dataset_sample_index = self.build_valid_nanoset_index(self.dataset_lengths) + self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(self.dataset_lengths) else: # Train MultilingualNanoset self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() @@ -136,20 +136,6 @@ def build_train_nanoset_index(self) -> np.ndarray: return dataset_index, dataset_sample_index - @jit(nopython=True, cache=True) - def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: - """ - Build valid dataset index and dataset sample index - """ - dataset_index = [] - dataset_sample_index = [] - - for i, length in enumerate(dataset_lengths): - dataset_index.extend([i] * length) - dataset_sample_index.extend(range(length)) - - return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") - def print_nanoset_info(self): log_rank( @@ -211,3 +197,18 @@ def build_train_nanoset_index_helper( current_samples[max_error_index] += 1 return dataset_index, dataset_sample_index + + +@jit(nopython=True, cache=True) +def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: + """ + Build valid dataset index and dataset sample index + """ + dataset_index = [] + dataset_sample_index = [] + + for i, length in enumerate(dataset_lengths): + dataset_index.extend([i] * length) + dataset_sample_index.extend(range(length)) + + return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index c668aa58..8383ba38 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -98,7 +98,9 @@ def main(args): dataset_options={"split": args.split}, ) elif args.readers == "parquet": - datatrove_reader = ParquetReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) + datatrove_reader = ParquetReader( + data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern + ) else: datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) @@ -107,6 +109,7 @@ def main(args): datatrove_reader, DocumentTokenizer( output_folder=args.output_folder, + shuffle=False, tokenizer_name_or_path=args.tokenizer_name_or_path, eos_token=args.eos_token, max_tokens_per_file=1e9, From eed7bce10712a9137eec78ac9c3b6c609fcb28d5 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Thu, 18 Jul 2024 10:48:00 +0000 Subject: [PATCH 14/51] Ready --- examples/config_multilingual_nanoset.yaml | 20 ++++++++++---------- src/nanotron/config/config.py | 11 ++++++++--- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 238f8269..599bff6c 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -9,8 +9,8 @@ data_stages: dataset: training_folder: datasets/c4-es/train validation_folder: datasets/c4-es/validation - dataset_tokens: - - 15 + lang_to_ids: + es: 128002 num_loading_workers: 1 seed: 42 name: General purpose training (Single dataset) @@ -25,10 +25,10 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - dataset_tokens: - - 15 - - 16 - - 17 + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) @@ -43,10 +43,10 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - dataset_tokens: - - 15 - - 16 - - 17 + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index ce61a249..dd2c157d 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - dataset_tokens: List[int] # Set token for each language previously defined + lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -125,8 +125,13 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - assert len(self.training_folder) == len(self.validation_folder) - assert len(self.training_folder) == len(self.dataset_tokens) + self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.validation_folder + ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" + assert len(self.training_folder) == len( + self.dataset_tokens + ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" @dataclass From 7a932f89a8844d15b03b8fff050dbec53f552565 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Fri, 2 Aug 2024 10:59:44 +0200 Subject: [PATCH 15/51] start documenting moe setup --- moe.md | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 moe.md diff --git a/moe.md b/moe.md new file mode 100644 index 00000000..46bca04f --- /dev/null +++ b/moe.md @@ -0,0 +1,65 @@ +# MoE Env Setup + +TL;DR: need to install megablocks for MoEs, which depends on triton; cannot install triton inside the docker image because it requires a CUDA-capable GPU, which is not available in the build environment. therefore install triton from source inside a venv in the container, then install megablocks + + +```Dockerfile +FROM nvcr.io/nvidia/pytorch:24.04-py3 + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y python3.10-venv && apt-get clean && rm -rf /var/lib/apt/lists/* + +# Update flash-attn. +RUN pip install --upgrade --no-build-isolation flash-attn==2.5.8 + +# Install the rest of dependencies. +RUN pip install \ + datasets \ + transformers \ + wandb \ + dacite \ + pyyaml \ + numpy \ + packaging \ + safetensors \ + tqdm + +``` + + +after image is built, create env `~/.edf/nanotron-moe.toml` with content (adapt to wherever the image is stored) +``` +image = "/capstor/scratch/cscs/$USER/container-images/nanotron-moe/nanotron-moe-v1.0.sqsh" + +mounts = ["/capstor", "/users", "/store"] +workdir = "/users/$USER/" +writable = true + +[annotations] +com.hooks.aws_ofi_nccl.enabled = "true" +com.hooks.aws_ofi_nccl.variant = "cuda12" + +[env] +FI_CXI_DISABLE_HOST_REGISTER = "1" +FI_MR_CACHE_MONITOR = "userfaultfd" +NCCL_DEBUG = "INFO" +``` + +TODO: make image available on the cluster in /store + + + +in a running container (`srun --reservation=todi --environment=nanotron-moe --container-workdir=$PWD --pty bash`) +```bash +cd $SCRATCH/$USER/nanotron-multilingual # or wherever you want the venv +mkdir multilingual-venv && cd multilingual-venv +python -m venv --system-site-packages ./moe-venv +source ./moe-venv/bin/activate +git clone https://github.com/triton-lang/triton.git; \ + cd triton; \ + pip install ninja cmake wheel; # build-time dependencies \ + pip install -e python; cd .. +pip install megablocks==0.5.1 +``` + From f08a05e3e4546b94106e20b9f0a2b13e66eaf7a4 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Fri, 2 Aug 2024 11:08:46 +0200 Subject: [PATCH 16/51] base moe file --- src/nanotron/models/moe.py | 688 +++++++++++++++++++++++++++++++++++++ 1 file changed, 688 insertions(+) create mode 100644 src/nanotron/models/moe.py diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py new file mode 100644 index 00000000..416ba5e3 --- /dev/null +++ b/src/nanotron/models/moe.py @@ -0,0 +1,688 @@ +""" MoEs Blocks to replace MLPs in Transformers. """ + +import warnings +from functools import partial +from typing import Optional, Tuple + +import numpy as np +import stk +import torch +import torch.nn.functional as F +from megablocks.layers import weight_parallel as wp +from megablocks.layers.activation_fn import act_fn +from torch import nn + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import LlamaConfig as Config +from nanotron.config import ParallelismArgs +from nanotron.nn.activations import ACT2FN +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.sharded_parameters import ( + SplitConfig, + mark_all_parameters_in_module_as_sharded, +) +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelRowLinear, +) + +try: + import megablocks.ops as ops + from megablocks.layers.all_to_all import all_to_all +except ImportError: + warnings.warn("Please install megablocks to use MoEs: `pip install megablocks`") + + +logger = logging.get_logger(__name__) + + +def log_mean(x, dim): + return torch.logsumexp(x, dim=dim) - torch.log(torch.tensor(x.shape[dim], dtype=torch.float32)) + + +def load_balancing_loss(router_logits, tokens_per_expert, config: Config) -> torch.Tensor: + """Computes auxiliary load balancing loss as in Switch Transformer. + + See Switch Transformer (https://arxiv.org/abs/2101.03961). This function + implements the loss function presented in equations (4) - (6). It aims to + penalize those cases where the routing between experts is unbalanced. + + Args: + logits: logits assigned to each expert per token. Shape: + [batch_size * sequence_length, num_experts]. + tokens_per_expert: [num_selected_experts] + + config: Config + + Returns: + The auxiliary loss. + """ + # tokens = batch_size * sequence_length + num_hidden_layers = config.num_hidden_layers + moe_num_experts = config.moe_num_experts + moe_loss_weight = config.moe_loss_weight + num_experts_per_token = config.num_experts_per_tok + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert tokens_per_expert.ndim == 1 and tokens_per_expert.numel() == moe_num_experts + + tokens = router_logits.shape[0] + assert router_logits.ndim == 2 and router_logits.shape[1] == moe_num_experts + + # compute router probability per expert in log space for numerical stability + logprobs = F.log_softmax(router_logits, dim=-1) + # take mean probability over batch + # shape [num_experts] + logprobs = log_mean(logprobs, dim=0) + expert_scores = torch.exp(logprobs) + + tokens_per_expert = tokens_per_expert.to(expert_scores.dtype) + + # Calculate the total scale across all factors. + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = moe_num_experts * moe_loss_weight + scale_denominator = num_hidden_layers * tokens * num_experts_per_token + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +def router_z_loss(router_logits, config: Config) -> torch.Tensor: + """ + The router z-loss was introduced in ST-MoE + (https://arxiv.org/abs/2202.08906). It encourages router logits to remain + small in an effort to improve stability. + + Args: + router_logits: [batch_size * sequence_length, num_experts] + router logits + config: Config + + Returns: + Scalar router z-loss. + """ + num_hidden_layers = config.num_hidden_layers + moe_num_experts = config.moe_num_experts + + tokens = router_logits.shape[0] + assert router_logits.ndim == 2 and router_logits.shape[1] == moe_num_experts + + z_loss_weight = config.moe_z_loss_weight + + log_z = torch.logsumexp(router_logits, dim=-1) + z_loss = log_z**2 + + scale_numerator = z_loss_weight + scale_denominator = num_hidden_layers * tokens * moe_num_experts + scale = scale_numerator / scale_denominator + + return scale * z_loss.sum(dim=0) + + +class dMoE(torch.nn.Module): + def __init__( + self, + config: Config, + parallel_context: "ParallelContext", + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.config = config + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + if self.tp_mode == TensorParallelLinearMode.REDUCE_SCATTER: + logging.warn_once( + logger=logger, + msg="TensorParallelLinearMode.REDUCE_SCATTER is still experimental for MoEs. Use at your own risk.", + rank=0, + ) + + # Token router. + self.gate = LearnedRouter(config) + + # Expert computation helper. + self.experts = ParallelDroplessMLP( + config, + use_bias=False, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) + + def forward(self, hidden_states: torch.Tensor): + """ + Args: + x: input tensor of shape [sequence_length, batch_size, hidden_size] + """ + # Compute the expert scores and assignments. + # TODO: support sequence parallelism + batch_size, sequence_length, _ = hidden_states.size() + x = hidden_states.view(-1, self.config.hidden_size) + router_logits, expert_weights, top_experts = self.gate(x) + + # Compute the experts. + x, lbl_loss, z_loss = self.experts(x, router_logits, expert_weights, top_experts) + return { + "hidden_states": x.reshape(batch_size, sequence_length, -1), + "load_balancing_loss": lbl_loss, + "z_loss": z_loss, + } + + +# Adapted from megablocks.layers.router.LearnedRouter +class LearnedRouter(torch.nn.Module): + def __init__(self, config: Config): + super().__init__() + self.layer = torch.nn.Linear(config.hidden_size, config.moe_num_experts, bias=False) + self.config = config + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + router_logits = self.layer(x) # (batch * sequence_length, n_experts) + scores = F.softmax(router_logits, dim=-1, dtype=torch.float32) # TODO: fuse? + + if self.config.num_experts_per_tok == 1: + expert_weights, expert_indices = scores.max(dim=-1, keepdim=True) + else: + expert_weights, expert_indices = torch.topk(scores, self.config.num_experts_per_tok, dim=-1) + # IMPORTANT step to normalize, otherwise weights are very low + expert_weights = expert_weights / torch.norm( + expert_weights, + p=1, + dim=-1, + keepdim=True, + ) + return router_logits, expert_weights, expert_indices.int() + + +# Adapted from megablocks.layers.mlp.ParallelDroplessMLP +class ParallelDroplessMLP(torch.nn.Module): + def __init__( + self, + config: Config, + use_bias: bool, + parallel_context: "ParallelContext", + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.config = config + self.use_bias = use_bias + + self.expert_pg_size = parallel_context.expert_pg.size() + self.expert_parallel_group = parallel_context.expert_pg + + self.hidden_sharding_degree = self.expert_pg_size // min(self.expert_pg_size, self.config.moe_num_experts) + self.experts_per_rank = self.config.moe_num_experts // min(self.expert_pg_size, self.config.moe_num_experts) + + self.num_experts = config.moe_num_experts + self.num_experts_per_tok = self.config.num_experts_per_tok + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + if use_bias: + self.bias = torch.nn.Parameter(torch.empty(config.hidden_size)) + + # Select the forward function for the operating mode. + self.forward_fn = self.parallel_forward_once if self.expert_pg_size > 1 else self.forward_once + + self.blocking = 128 + + if self.experts_per_rank == 1: + self.mlp = MLP( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + ) + else: + self.mlp = SparseGLU( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + + max_column_index = (self.config.intermediate_size * self.num_experts) // self.blocking + self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1) + + def indices_and_bins(self, top_expert): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_expert = top_expert.int() + bin_ids, indices = ops.sort(top_expert, self.sort_end_bit) + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = inclusive_cumsum(tokens_per_expert, 0) + return indices, bin_ids, bins, tokens_per_expert + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Calculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking) + padded_bins = inclusive_cumsum(padded_tokens_per_expert, 0) + + # Calculate the bin bounds for the sorted tokens. + bins = inclusive_cumsum(tokens_per_expert, 0) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def forward_once(self, x, expert_weights, top_experts): # TODO: sparse + with torch.no_grad(): + ( + indices, + bin_ids, + bins, + padded_bins, + tokens_per_expert, + ) = self.indices_and_padded_bins(top_experts) + + # Route the tokens for MoE computation. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, self.num_experts_per_tok) + + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.num_experts_per_tok, + -1, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x, expert_weights, top_experts): + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(top_experts) + repeated_tokens_per_expert = ops.repeat(tokens_per_expert, (self.hidden_sharding_degree,)) + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) + tpe_handle = torch.distributed.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.expert_parallel_group, + async_op=True, + ) + + x = ops.gather(x, indices, bin_ids, bins, self.num_experts_per_tok) + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + + # Reshape to [expert_pg_size, num_experts_per_rank]. + repeated_tokens_per_expert = repeated_tokens_per_expert.view(self.expert_pg_size, self.experts_per_rank) + parallel_tokens_per_expert = parallel_tokens_per_expert.view(self.expert_pg_size, self.experts_per_rank) + + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + x = ops.repeat(x, (self.hidden_sharding_degree, 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, recv_counts, send_counts, self.expert_parallel_group, async_op=True + ) + + with torch.no_grad(): + replicate_bins = inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * self.hidden_sharding_degree, + dtype=torch.int32, + device=indices.device, + ), + self.experts_per_rank, + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received + ).flatten() + + parallel_bin_ids, parallel_indices = ops.sort(parallel_top_expert, self.sort_end_bit) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum(dim=0, dtype=torch.int) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + num_experts_per_tok=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all(parallel_x, send_counts, recv_counts, self.expert_parallel_group) + + # Reduce along the hidden sharding to get the final outputs. + shape = (self.hidden_sharding_degree, -1, self.config.hidden_size) + x = ops.sum(x.view(shape), dim=0) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + self.num_experts_per_tok, + ) + return x, tokens_per_expert.flatten() + + def forward(self, x, router_logits, expert_weights, top_experts): + """ + Args: + x: input tensor of shape [sequence_length, batch_size, hidden_size] + router_logits: tensor of shape [sequence_length * batch_size, n_experts] + expert_weights: tensor of shape [sequence_length * batch_size, num_experts_per_tok] + top_experts: tensor of shape [sequence_length * batch_size, num_experts_per_tok] + """ + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights.flatten(), top_experts.flatten()) + if self.training: + lbl_loss = load_balancing_loss(router_logits, tokens_per_expert, self.config) + z_loss = router_z_loss(router_logits, self.config) + else: + lbl_loss = torch.zeros(1, device=x.device) + z_loss = torch.zeros(1, device=x.device) + + if self.use_bias: + return x + self.bias + return x, lbl_loss, z_loss + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + num_experts_per_tok, + ): + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up(tokens_per_expert, self.blocking) + padded_bins = inclusive_cumsum(padded_tokens_per_expert, 0) + + # Route the tokens for MoE computation. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, num_experts_per_tok) + + # Perform the expert computation. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter(x, indices, bin_ids, expert_weights, bins, padded_bins, num_experts_per_tok) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + _, gather_indices = ops.sort(column_indices.int(), self.transpose_sort_end_bit) + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + assert self.config.intermediate_size % self.blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.config.intermediate_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology(padded_bins, self.blocking, block_rows, blocks_per_row) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=x.dtype, + device="meta", + ) + shape = (padded_tokens, self.config.intermediate_size * self.experts_per_rank) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, row_indices, column_indices, offsets + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + +class ScaleGradient(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, x, scale): + ctx.scale = scale + return x + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, grad): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +class ExpertParallel(nn.Module): + """ + ExpertParallel serves to scale the gradients of the expert weights because unlike DP the gradients are not averaged across the expert parallel group. + """ + + def __init__(self, module, expert_parallel_size: int): + super().__init__() + self.module = module + self.expert_parallel_size = expert_parallel_size + + def forward(self, *args, **kwargs): + self.scale_gradients() + return self.module(*args, **kwargs) + + def scale_gradients(self): + scale_gradient(self.module, 1 / self.expert_parallel_size) + + +class SparseMLP(nn.Module): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: "ParallelContext", + ): + super().__init__() + + self.expert_pg_size = parallel_config.expert_parallel_size if parallel_config is not None else 1 + self.experts_per_rank = config.moe_num_experts // min(self.expert_pg_size, config.moe_num_experts) + self.tp_pg = parallel_context.tp_pg + + self.w1 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + self.w2 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + + if self.tp_pg.size() == 1: + self.w1.module.weight.data = self.w1.module.weight.data.T.contiguous() + + # TODO @nouamane: jit + self.act = ACT2FN[config.hidden_act] + self.sdd = partial(wp.sdd_nt, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.sdd + self.dsd = partial(wp.dsd_nn, group=self.tp_pg) if self.tp_pg.size() > 1 else stk.ops.dsd + + def forward(self, x, topo): + self.w1.scale_gradients(), self.w2.scale_gradients() + x = self.sdd(x.contiguous(), self.w1.module.weight, topo) + activation_fn_out = act_fn(x, self.act) + return self.dsd(activation_fn_out, self.w2.module.weight) + + +class MLP(nn.Module): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.expert_pg_size = parallel_config.expert_parallel_size + self.experts_per_rank = config.moe_num_experts // min(self.expert_pg_size, config.moe_num_experts) + + assert self.experts_per_rank == 1, "moe.MLP only supports 1 expert per rank, otherwise use moe.SparseMLP" + + self.w1 = ExpertParallel( + TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ), + expert_parallel_size=self.expert_pg_size, + ) + + self.w2 = ExpertParallel( + TensorParallelRowLinear( + config.intermediate_size * self.experts_per_rank, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication + and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + ), + expert_parallel_size=self.expert_pg_size, + ) + + self.w3 = ExpertParallel( + TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ), + expert_parallel_size=self.expert_pg_size, + ) + # TODO @nouamane: jit + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] + merged_states = self.w1(hidden_states) + hidden_states = self.w2(self.act(merged_states) * self.w3(hidden_states)) + return hidden_states + + +def inclusive_cumsum(x, dim): + scalar = ops.inclusive_cumsum(x, dim) + return scalar.view(1) if not len(scalar.size()) else scalar + + +class SparseGLU(SparseMLP): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: "ParallelContext", + ): + super().__init__(config, parallel_config, parallel_context) + self.w3 = ExpertParallel( + nn.Linear( + config.hidden_size, + config.intermediate_size * self.experts_per_rank // self.tp_pg.size(), + bias=False, + ), + expert_parallel_size=self.expert_pg_size, + ) + if self.tp_pg.size() == 1: + self.w3.module.weight.data = self.w3.module.weight.data.T.contiguous() + + mark_all_parameters_in_module_as_sharded( + self, + pg=parallel_context.tp_and_expert_pg, + split_config=SplitConfig(split_dim=0), + ) + + def forward(self, x, topo): + # We need to scale gradients manually since we don't call the linears forward + self.w1.scale_gradients(), self.w2.scale_gradients(), self.w3.scale_gradients() + x = x.contiguous() + x1 = self.sdd(x, self.w1.module.weight, topo) + x2 = self.sdd(x, self.w3.module.weight, topo) + x = stk.ops.mul(act_fn(x1, self.act), x2) + return self.dsd(x, self.w2.module.weight) \ No newline at end of file From fa06c0de5a31581d4a1c0d3d48e8700aa50c89da Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Mon, 5 Aug 2024 11:04:41 +0200 Subject: [PATCH 17/51] add todo --- src/nanotron/models/moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index 416ba5e3..10f00194 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -1,5 +1,7 @@ """ MoEs Blocks to replace MLPs in Transformers. """ +# TODO: implement gpt3 style MLP (currently it's only SwiGLU with 3 weight matrices) + import warnings from functools import partial from typing import Optional, Tuple From a9dba53a14a9b149800e04efc4bd1065d9267791 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Mon, 5 Aug 2024 11:05:26 +0200 Subject: [PATCH 18/51] gpt3_moe basis --- src/nanotron/config/models_config.py | 54 +++++- src/nanotron/models/gpt3_moe.py | 274 +++++++++++++++++++++++++++ src/nanotron/trainer.py | 2 + 3 files changed, 329 insertions(+), 1 deletion(-) create mode 100644 src/nanotron/models/gpt3_moe.py diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index af7db5cc..0dfbbf44 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -177,5 +177,57 @@ def as_starcoder2(self) -> Starcoder2Config: def n_inner(self): return self.intermediate_size +@dataclass +class GPT3MoEConfig: + """Configuration for a GPT3 __MoE__ model""" + + activation_function: str = "gelu" + attn_pdrop: float = 0.1 + embd_pdrop: float = 0.1 + eos_token_id: int = 49152 + hidden_size: int = 2048 + intermediate_size: Optional[int] = None + layer_norm_epsilon: float = 1e-05 + max_position_embeddings: int = 4096 + num_attention_heads: int = 16 + num_hidden_layers: int = 24 + resid_pdrop: float = 0.1 + scale_attention_softmax_in_fp32: bool = True + scale_attn_weights: bool = True + vocab_size: int = 49280 + sinusoidal_position_embedding: bool = True + position_embedding_offset: int = 2 + use_spda: bool = False + act_pdrop: float = 0.0 + scale_embedding: bool = True + # MoE specific + is_moe: bool = True + moe_num_experts: int = 1 + num_experts_per_tok: int = 1 + moe_loss_weight: float = 0.01 + moe_z_loss_weight: float = 0.001 + + + def as_gpt3(self) -> GPT3Config: + config = dict(**vars(self)) + + # Moe + del config["is_moe"] + del config["moe_num_experts"] + del config["num_experts_per_tok"] + del config["moe_loss_weight"] + del config["moe_z_loss_weight"] + + if "_is_using_mup" in config: + del config["_is_using_mup"] + return GPT3Config( + **config + ) + + @property + def n_inner(self): + return self.intermediate_size + + -NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config +NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config | GPT3MoEConfig diff --git a/src/nanotron/models/gpt3_moe.py b/src/nanotron/models/gpt3_moe.py new file mode 100644 index 00000000..6fdb4669 --- /dev/null +++ b/src/nanotron/models/gpt3_moe.py @@ -0,0 +1,274 @@ +"""PyTorch GPT-3 MoE model.""" + +import math +from contextlib import contextmanager +from typing import Dict, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from nanotron import distributed as dist +from nanotron.config import GPT3MoEConfig, ParallelismArgs, GPT3Config +from nanotron.generation.generate_store import AttachableStore +from nanotron.models import gpt3 +from nanotron.models.moe import ( + dMoE, +) +from nanotron.models.gpt3 import CausalSelfAttention, GPTModel, PositionEmbedding, dropout_add_fused_train, GPT3ForTraining +from nanotron.models.gpt3 import GPTBlock as GPT3Block +from nanotron.nn.layer_norm import TritonLayerNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.block import PipelineBlock +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear, TensorParallelEmbedding +from nanotron.random import RandomStates, branch_random_state + + +@contextmanager +def replace_decoder(gpt3config: GPT3MoEConfig): + orig = gpt3.PipelineBlock + try: + + def create_pp_block(module_builder, module_kwargs, **kwargs): + if module_builder is GPT3Block: + # GPT3's GPT module is trying to instantiate a GPT3 GPTBlock. + # Let's return a PipelineBlock with a GPT3Block instead. + # This also requires to replace starcoders2's config with gpt3's config. + module_kwargs["config"] = gpt3config + return orig(module_builder=GPTBlock, module_kwargs=module_kwargs, **kwargs) + # Else, they are setting up other modules, which we also want unchanged. + return orig(module_builder=module_builder, module_kwargs=module_kwargs, **kwargs) + + gpt3.PipelineBlock = create_pp_block + yield + finally: + gpt3.PipelineBlock = orig + + +@contextmanager +def replace_gpt3model(gpt3moeconfig: GPT3MoEConfig): + orig = gpt3.GPTModel + try: + + def create_gptmodel( + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + return GPT3MoEModel(gpt3moeconfig, parallel_context, parallel_config, random_states) + + gpt3.GPTModel = create_gptmodel + yield + finally: + gpt3.GPTModel = orig + +class GPTBlock(nn.Module): + def __init__( + self, + config: GPT3MoEConfig, + parallel_config: Optional[ParallelismArgs], + parallel_context: ParallelContext, + tp_pg: dist.ProcessGroup, + random_states: RandomStates, + layer_idx: int, + ): + super(GPTBlock, self).__init__() + self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.attn = CausalSelfAttention( + config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx + ) + self.attn_dropout = config.attn_pdrop + + self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.ff = dMoE( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + self.ff_dropout = config.resid_pdrop + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + + def forward( + self, + hidden_states: torch.Tensor | TensorPointer, + sequence_mask: torch.Tensor | TensorPointer, + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], + ) -> dict[str, torch.Tensor | TensorPointer]: + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + # hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + # return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) + else: + # No need for random state context manager + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + mlp_output = self.ff(hidden_states=hidden_states) + hidden_states = mlp_output["hidden_states"] + + for key, value in mlp_output.items(): + if key != "hidden_states": + aux_losses[key] = aux_losses[key] + value + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) + else: + # No need for random state context manager + hidden_states = hidden_states + residual + + return { + "hidden_states": hidden_states, + "sequence_mask": output["sequence_mask"], + "aux_losses": aux_losses + } + +class GPT3MoEModel(GPTModel): + def __init__( + self, + config: GPT3MoEConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + with replace_decoder(config): + super().__init__(config.as_gpt3(), parallel_context, parallel_config, random_states) + + # need to adapt the decoder list because we pass the aux_losses around + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=GPTBlock, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "random_states": random_states, + "parallel_context": parallel_context, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "sequence_mask", "aux_losses"}, + module_output_keys={"hidden_states", "sequence_mask", "aux_losses"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + def forward( + self, + input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length] + input_mask: torch.Tensor | TensorPointer, # [batch_size, seq_length] + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + input_embeds = ( + self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"] * self.embed_scale + ) + # TODO: position_ids could be cached. + position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) + position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] + hidden_states = input_embeds + position_embeds + + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode == TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = self.embeds_dropout(input=hidden_states)["hidden_states"] + + hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask, "aux_losses": aux_losses} + for encoder_block in self.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states) + # return hidden_encoder_states["hidden_states"] + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits, hidden_encoder_states["aux_losses"] + + +# TODO: maybe reimplement: +# - get_block_compute_costs +# - get_flops_per_sec +class GPT3MoEForTraining(GPT3ForTraining): + def __init__( + self, + config: GPT3MoEConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + with replace_gpt3model(config): + super().__init__(config.as_gpt3(), parallel_context, parallel_config, random_states) + self.config = config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + input_mask: Union[torch.Tensor, TensorPointer], + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # aux_losses are used for load balancing in case of MoEs + aux_losses = { + "load_balancing_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + "z_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + } + output = self.model( + input_ids=input_ids, + input_mask=input_mask, + aux_losses=aux_losses, + ) + loss = self.loss( + sharded_logits=output["sharded_logits"], + label_ids=label_ids, + label_mask=label_mask, + ) + + if isinstance(output['aux_losses'], dict): + for key, value in output["aux_losses"].items(): + loss[key] = value + return loss + + # TODO: adapt with MoE costs + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + model_config = self.config + d_ff = model_config.n_inner if model_config.intermediate_size is not None else 4 * model_config.hidden_size + d_qkv = model_config.hidden_size // model_config.num_attention_heads + block_compute_costs = { + # CausalSelfAttention (qkv proj + attn out) + MLP + GPTBlock: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + + 2 * d_ff * model_config.hidden_size, + # This is the last lm_head + TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + } + return block_compute_costs \ No newline at end of file diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 3f4c5189..af16e39c 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -57,6 +57,7 @@ from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad from nanotron.models.gpt3 import GPT3ForTraining +from nanotron.models.gpt3_moe import GPT3MoEForTraining from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining from nanotron.optim.clip_grads import clip_grad_norm @@ -105,6 +106,7 @@ "LlamaConfig": LlamaForTraining, "Starcoder2Config": Starcoder2ForTraining, "GPT3Config": GPT3ForTraining, + "GPT3MoEConfig": GPT3MoEForTraining, } try: From 2efffb846488989e7a69fdedf65412083854dd32 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Mon, 5 Aug 2024 11:05:48 +0200 Subject: [PATCH 19/51] add nn.linear to init for moe router --- src/nanotron/scaling/parametrization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..380ad460 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -34,6 +34,7 @@ class StandardParametrizator(Parametrizator): def __init__(self, config: ModelArgs): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { + nn.Linear: self._parametrize_column_linear, TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, From 57c58b7ac60831921911506cbe17ff6cedafd06b Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Mon, 5 Aug 2024 11:06:12 +0200 Subject: [PATCH 20/51] changes to pipeline for backward through aux losses --- .../parallel/pipeline_parallel/engine.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..91840a5e 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -47,14 +47,18 @@ def forward( if not isinstance(output, dict): output = {"loss": output} - # We normalize our loss - if not isinstance(output["loss"], TensorPointer): - output["loss"] = output["loss"] / self.nb_microbatches - - # Add output as activations that require backward pass - if not isinstance(output["loss"], TensorPointer): - assert output["loss"].requires_grad - state.register_activation_requiring_backward(output["loss"]) + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v / self.nb_microbatches + + # the outputs are either + # - token prediction loss ["loss"] + # - auxiliary losses ["load_balancing_loss", "z_loss"] + # that we need to backpropagate through, so register activations + for loss_key, output_tensor in output.items(): + if not isinstance(output_tensor, TensorPointer): + assert output_tensor.requires_grad + state.register_activation_requiring_backward(output_tensor) return output @staticmethod @@ -154,7 +158,7 @@ def validate_batch_iter( if not isinstance(output, dict): output = {"loss": output} - # Store the loss for each microbatch + # Store the loss(es) for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} outputs.append(output) @@ -269,8 +273,9 @@ def train_batch_iter( send_activation() # Store the loss for each microbatch - if not isinstance(output["loss"], TensorPointer): - output = {k: v.detach() for k, v in output.items()} + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v.detach() outputs.append(output) for micro_batch in batch: @@ -282,8 +287,9 @@ def train_batch_iter( output = {"loss": output} # Store the loss for each microbatch - if not isinstance(output["loss"], TensorPointer): - output = {k: v.detach() for k, v in output.items()} + for k, v in output.items(): + if not isinstance(v, TensorPointer): + output[k] = v.detach() outputs.append(output) # One backward From 3967beeb65392a9dd9ae66751ea1ce791bea7a6d Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Mon, 5 Aug 2024 11:50:53 +0200 Subject: [PATCH 21/51] correct block costs and flops --- src/nanotron/models/gpt3_moe.py | 121 ++++++++++++++++++++++++++++++-- 1 file changed, 114 insertions(+), 7 deletions(-) diff --git a/src/nanotron/models/gpt3_moe.py b/src/nanotron/models/gpt3_moe.py index 6fdb4669..2dc997bc 100644 --- a/src/nanotron/models/gpt3_moe.py +++ b/src/nanotron/models/gpt3_moe.py @@ -207,9 +207,6 @@ def forward( return fp32_sharded_logits, hidden_encoder_states["aux_losses"] -# TODO: maybe reimplement: -# - get_block_compute_costs -# - get_flops_per_sec class GPT3MoEForTraining(GPT3ForTraining): def __init__( self, @@ -258,17 +255,127 @@ def forward( loss[key] = value return loss - # TODO: adapt with MoE costs def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" model_config = self.config d_ff = model_config.n_inner if model_config.intermediate_size is not None else 4 * model_config.hidden_size d_qkv = model_config.hidden_size // model_config.num_attention_heads + # active experts + routing + mlp_cost = 2 * d_ff * model_config.hidden_size * model_config.num_experts_per_tok \ + + model_config.hidden_size * model_config.moe_num_experts + att_cost = 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP - GPTBlock: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size - + 2 * d_ff * model_config.hidden_size, + GPTBlock: att_cost + mlp_cost, # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } - return block_compute_costs \ No newline at end of file + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + world_size = self.parallel_context.world_pg.size() + model_flops, hardware_flops = get_flops( + num_layers=self.config.num_hidden_layers, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + vocab_size=self.config.vocab_size, + ffn_hidden_size=self.config.n_inner if self.config.n_inner is not None else 4 * self.config.hidden_size, + seq_len=sequence_length, + batch_size=global_batch_size, + kv_channels=None, + glu_activation=False, + num_experts=self.config.moe_num_experts, + num_experts_per_tok=self.config.num_experts_per_tok, + ) + model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + return model_flops_per_s, hardware_flops_per_s + + +def get_flops( + num_layers, + hidden_size, + num_heads, + vocab_size, + seq_len, + kv_channels=None, + ffn_hidden_size=None, + batch_size=1, + glu_activation=False, + num_experts=1, + num_experts_per_tok=1, +): + """Counts flops in an decoder-only model + Args: + num_layers: number of decoder layers + hidden_size: hidden size of the model + num_heads: number of heads in the model + kv_channels: hidden size of the key and value heads + ffn_hidden_size: hidden size of the FFN + vocab_size: size of the vocabulary + seq_len: sequence length of the decoder + batch_size: batch size + glu_activation: Whether to use GLU activation in FFN. Check T5 v1.1 for more info. + num_experts_per_tok: number of experts per token in the MoE layer + Returns: + model_flops: flops in the model (should be independent of the hardware and model implementation) + hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf + """ + + if kv_channels is None: + assert hidden_size % num_heads == 0 + kv_channels = hidden_size // num_heads + if ffn_hidden_size is None: + ffn_hidden_size = 4 * hidden_size + + # In the following we mark the reduced dimension with parentheses + # decoder + # self attention (MQA) + ## q projection + decoder_q_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * kv_channels + ## kv projection, shared across heads + decoder_kv_proj_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * kv_channels + ## qk logits + decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * seq_len + ### SWA (sliding window attention / local attention) + # window_size = 4096 + # decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * window_size + ## v logits + decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * kv_channels + # decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (window_size) * kv_channels + ## attn out + decoder_attn_out_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (kv_channels) * hidden_size + # FF + ## 1st layer + decoder_ffn_1_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + if glu_activation: + # 3 matmuls instead of 2 in FFN + # ref. https://arxiv.org/pdf/2002.05202.pdf + # Used for example in T5 v1.1 + decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + ## 2nd layer + decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + # MoE router + decoder_ffn_router_flops_fwd = 2 * num_layers * batch_size * seq_len * (hidden_size) * num_experts + + decoder_flops_fwd = ( + decoder_q_proj_flops_fwd + + decoder_kv_proj_flops_fwd + + decoder_qk_logits_flops_fwd + + decoder_v_logits_flops_fwd + + decoder_attn_out_flops_fwd + + decoder_ffn_1_flops_fwd * num_experts_per_tok + + decoder_ffn_2_flops_fwd * num_experts_per_tok + + decoder_ffn_router_flops_fwd + ) + + # lm head + lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size + + # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to + # both input and weight tensors + model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd + + hardware_flops = model_flops # TODO @nouamanetazi: This is a placeholder for now + return model_flops, hardware_flops From bcb94cc57a2fee947c800a5d09dab02aec12bf8b Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Mon, 5 Aug 2024 12:11:39 +0200 Subject: [PATCH 22/51] case of dict in pipelineblock --- .../parallel/pipeline_parallel/block.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/nanotron/parallel/pipeline_parallel/block.py b/src/nanotron/parallel/pipeline_parallel/block.py index 150172f5..57bd8e0a 100644 --- a/src/nanotron/parallel/pipeline_parallel/block.py +++ b/src/nanotron/parallel/pipeline_parallel/block.py @@ -81,6 +81,25 @@ def forward(self, **kwargs): if isinstance(tensor, TensorPointer): # Current rank is neither the rank holding the data nor the rank responsible for computing block continue + elif isinstance(tensor, dict): + for k, v in tensor.items(): + if isinstance(v, torch.Tensor): + # We need to send the tensor to the rank that actually runs the compute + if self.pipeline_state is not None: + send_to_pipeline_state_buffer( + v, + to_rank=self.rank, + p2p=self.p2p, + pipeline_state=self.pipeline_state, + ) + continue + + if v.requires_grad is True: + raise ValueError( + f"Pipeline engine is None and tensor requires grad. Tried sending a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine." + ) + + batch_send_recv.add_send(tensor=v, to_rank=self.rank) else: assert isinstance(tensor, torch.Tensor) # We need to send the tensor to the rank that actually runs the compute From 91acdc03f7b6435caacaddf678931c5220160c9e Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Mon, 5 Aug 2024 12:38:44 +0200 Subject: [PATCH 23/51] option for GLU or normal MLP --- src/nanotron/config/models_config.py | 1 + src/nanotron/models/moe.py | 59 +++++++++++++++++++++------- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 0dfbbf44..075419dd 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -206,6 +206,7 @@ class GPT3MoEConfig: num_experts_per_tok: int = 1 moe_loss_weight: float = 0.01 moe_z_loss_weight: float = 0.001 + moe_glu: bool = False def as_gpt3(self) -> GPT3Config: diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index 10f00194..6ef9f0f9 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -230,17 +230,31 @@ def __init__( self.blocking = 128 if self.experts_per_rank == 1: - self.mlp = MLP( - config=config, - parallel_config=parallel_config, - tp_pg=parallel_context.tp_pg, - ) + if config.moe_glu: + self.mlp = GLU( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + ) + else: + self.mlp = MLP( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + ) else: - self.mlp = SparseGLU( - config=config, - parallel_config=parallel_config, - parallel_context=parallel_context, - ) + if config.moe_glu: + self.mlp = SparseGLU( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) + else: + self.mlp = SparseMLP( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + ) max_column_index = (self.config.intermediate_size * self.num_experts) // self.blocking self.transpose_sort_end_bit = max(int(np.ceil(np.log2(max_column_index))), 1) @@ -630,6 +644,26 @@ def __init__( expert_parallel_size=self.expert_pg_size, ) + # TODO @nouamane: jit + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] + merged_states = self.w1(hidden_states) + hidden_states = self.w2(self.act(merged_states)) + return hidden_states + +class GLU(MLP): + def __init__( + self, + config: Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__(config, parallel_config, tp_pg) + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) self.w3 = ExpertParallel( TensorParallelColumnLinear( config.hidden_size, @@ -641,15 +675,12 @@ def __init__( ), expert_parallel_size=self.expert_pg_size, ) - # TODO @nouamane: jit - self.act = ACT2FN[config.hidden_act] - def forward(self, hidden_states, topo): # [seq_length, batch_size, hidden_dim] + def forward(self, x, topo): merged_states = self.w1(hidden_states) hidden_states = self.w2(self.act(merged_states) * self.w3(hidden_states)) return hidden_states - def inclusive_cumsum(x, dim): scalar = ops.inclusive_cumsum(x, dim) return scalar.view(1) if not len(scalar.size()) else scalar From df3befc58a92ab8a7385105ea4cc278f2f532802 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Tue, 6 Aug 2024 15:59:04 +0000 Subject: [PATCH 24/51] init of linear layer in starcoder --- src/nanotron/config/models_config.py | 33 ++++++++++-- src/nanotron/models/gpt3_moe.py | 75 +++++++++++++--------------- src/nanotron/models/starcoder2.py | 7 +++ 3 files changed, 72 insertions(+), 43 deletions(-) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 075419dd..257a2f72 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -177,6 +177,7 @@ def as_starcoder2(self) -> Starcoder2Config: def n_inner(self): return self.intermediate_size + @dataclass class GPT3MoEConfig: """Configuration for a GPT3 __MoE__ model""" @@ -208,27 +209,51 @@ class GPT3MoEConfig: moe_z_loss_weight: float = 0.001 moe_glu: bool = False - def as_gpt3(self) -> GPT3Config: config = dict(**vars(self)) - # Moe + # Moe + del config["is_moe"] + del config["moe_num_experts"] + del config["num_experts_per_tok"] + del config["moe_loss_weight"] + del config["moe_z_loss_weight"] + del config["moe_glu"] + + if "_is_using_mup" in config: + del config["_is_using_mup"] + return GPT3Config(**config) + + def as_starcoder2(self) -> Starcoder2Config: + # same as gpt3 conversion above + config = dict(**vars(self)) + del config["sinusoidal_position_embedding"] + del config["use_spda"] + del config["position_embedding_offset"] + del config["act_pdrop"] + del config["scale_embedding"] + + # Moe del config["is_moe"] del config["moe_num_experts"] del config["num_experts_per_tok"] del config["moe_loss_weight"] del config["moe_z_loss_weight"] + del config["moe_glu"] if "_is_using_mup" in config: del config["_is_using_mup"] - return GPT3Config( - **config + return Starcoder2Config( + grouped_query=True, num_kv_heads=self.num_attention_heads, use_rotary_embeddings=False, **config ) @property def n_inner(self): return self.intermediate_size + @property + def hidden_act(self): + return self.activation_function NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config | GPT3MoEConfig diff --git a/src/nanotron/models/gpt3_moe.py b/src/nanotron/models/gpt3_moe.py index 2dc997bc..0e2add58 100644 --- a/src/nanotron/models/gpt3_moe.py +++ b/src/nanotron/models/gpt3_moe.py @@ -1,33 +1,30 @@ """PyTorch GPT-3 MoE model.""" -import math from contextlib import contextmanager from typing import Dict, Optional, Union import torch from torch import nn -from torch.nn import functional as F from nanotron import distributed as dist -from nanotron.config import GPT3MoEConfig, ParallelismArgs, GPT3Config -from nanotron.generation.generate_store import AttachableStore +from nanotron.config import GPT3Config, GPT3MoEConfig, ParallelismArgs from nanotron.models import gpt3 +from nanotron.models.gpt3 import CausalSelfAttention, GPT3ForTraining, GPT3Model, dropout_add_fused_train +from nanotron.models.gpt3 import GPTBlock as GPT3Block from nanotron.models.moe import ( dMoE, ) -from nanotron.models.gpt3 import CausalSelfAttention, GPTModel, PositionEmbedding, dropout_add_fused_train, GPT3ForTraining -from nanotron.models.gpt3 import GPTBlock as GPT3Block from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.parallel import ParallelContext from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear, TensorParallelEmbedding +from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear from nanotron.random import RandomStates, branch_random_state @contextmanager -def replace_decoder(gpt3config: GPT3MoEConfig): +def replace_moe_decoder(gpt3config: GPT3MoEConfig): orig = gpt3.PipelineBlock try: @@ -37,7 +34,7 @@ def create_pp_block(module_builder, module_kwargs, **kwargs): # Let's return a PipelineBlock with a GPT3Block instead. # This also requires to replace starcoders2's config with gpt3's config. module_kwargs["config"] = gpt3config - return orig(module_builder=GPTBlock, module_kwargs=module_kwargs, **kwargs) + return orig(module_builder=GPT3MoEBlock, module_kwargs=module_kwargs, **kwargs) # Else, they are setting up other modules, which we also want unchanged. return orig(module_builder=module_builder, module_kwargs=module_kwargs, **kwargs) @@ -48,11 +45,11 @@ def create_pp_block(module_builder, module_kwargs, **kwargs): @contextmanager -def replace_gpt3model(gpt3moeconfig: GPT3MoEConfig): - orig = gpt3.GPTModel +def replace_gpt3_moe_model(gpt3moeconfig: GPT3MoEConfig): + orig = gpt3.GPT3Model try: - def create_gptmodel( + def create_moe_model( config: GPT3Config, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], @@ -60,12 +57,13 @@ def create_gptmodel( ): return GPT3MoEModel(gpt3moeconfig, parallel_context, parallel_config, random_states) - gpt3.GPTModel = create_gptmodel + gpt3.GPT3Model = create_moe_model yield finally: - gpt3.GPTModel = orig + gpt3.GPT3Model = orig -class GPTBlock(nn.Module): + +class GPT3MoEBlock(nn.Module): def __init__( self, config: GPT3MoEConfig, @@ -75,7 +73,7 @@ def __init__( random_states: RandomStates, layer_idx: int, ): - super(GPTBlock, self).__init__() + super(GPT3MoEBlock, self).__init__() self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx @@ -83,17 +81,16 @@ def __init__( self.attn_dropout = config.attn_pdrop self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - + self.ff = dMoE( - config=config, - parallel_config=parallel_config, - parallel_context=parallel_context, + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, ) self.ff_dropout = config.resid_pdrop self.random_states = random_states self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - def forward( self, hidden_states: torch.Tensor | TensorPointer, @@ -135,13 +132,10 @@ def forward( # No need for random state context manager hidden_states = hidden_states + residual - return { - "hidden_states": hidden_states, - "sequence_mask": output["sequence_mask"], - "aux_losses": aux_losses - } + return {"hidden_states": hidden_states, "sequence_mask": output["sequence_mask"], "aux_losses": aux_losses} + -class GPT3MoEModel(GPTModel): +class GPT3MoEModel(GPT3Model): def __init__( self, config: GPT3MoEConfig, @@ -149,15 +143,15 @@ def __init__( parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): - with replace_decoder(config): + with replace_moe_decoder(config): super().__init__(config.as_gpt3(), parallel_context, parallel_config, random_states) - + # need to adapt the decoder list because we pass the aux_losses around self.decoder = nn.ModuleList( [ PipelineBlock( p2p=self.p2p, - module_builder=GPTBlock, + module_builder=GPT3MoEBlock, module_kwargs={ "config": config, "parallel_config": parallel_config, @@ -172,6 +166,7 @@ def __init__( for layer_idx in range(config.num_hidden_layers) ] ) + def forward( self, input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length] @@ -204,7 +199,7 @@ def forward( fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] - return fp32_sharded_logits, hidden_encoder_states["aux_losses"] + return {"sharded_logits": fp32_sharded_logits, "aux_losses": hidden_encoder_states["aux_losses"]} class GPT3MoEForTraining(GPT3ForTraining): @@ -215,7 +210,7 @@ def __init__( parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): - with replace_gpt3model(config): + with replace_gpt3_moe_model(config): super().__init__(config.as_gpt3(), parallel_context, parallel_config, random_states) self.config = config @@ -249,29 +244,31 @@ def forward( label_ids=label_ids, label_mask=label_mask, ) - - if isinstance(output['aux_losses'], dict): + + if isinstance(output["aux_losses"], dict): for key, value in output["aux_losses"].items(): loss[key] = value return loss - + def get_block_compute_costs(self): """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" model_config = self.config d_ff = model_config.n_inner if model_config.intermediate_size is not None else 4 * model_config.hidden_size d_qkv = model_config.hidden_size // model_config.num_attention_heads # active experts + routing - mlp_cost = 2 * d_ff * model_config.hidden_size * model_config.num_experts_per_tok \ + mlp_cost = ( + 2 * d_ff * model_config.hidden_size * model_config.num_experts_per_tok + model_config.hidden_size * model_config.moe_num_experts + ) att_cost = 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size block_compute_costs = { # CausalSelfAttention (qkv proj + attn out) + MLP - GPTBlock: att_cost + mlp_cost, + GPT3MoEBlock: att_cost + mlp_cost, # This is the last lm_head TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } return block_compute_costs - + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): """Get flops per second for a given model""" world_size = self.parallel_context.world_pg.size() @@ -291,7 +288,7 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) return model_flops_per_s, hardware_flops_per_s - + def get_flops( num_layers, diff --git a/src/nanotron/models/starcoder2.py b/src/nanotron/models/starcoder2.py index 7100351d..b05a67bb 100644 --- a/src/nanotron/models/starcoder2.py +++ b/src/nanotron/models/starcoder2.py @@ -1517,6 +1517,13 @@ def init_model_randomly(self, config): module.bias.zero_() else: raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, nn.Linear): + if "weight" == param_name: + nn.init.normal_(module.weight, mean=0.0, std=std) + elif "bias" == param_name: + module.bias.zero_() + else: + raise ValueError(f"Who the fuck is {param_name}?") elif isinstance(module, TensorParallelRowLinear): if "weight" == param_name: nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) From 6edce837e19f378fb1e5fd9887aa039c48ddfa19 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Thu, 8 Aug 2024 14:49:56 +0200 Subject: [PATCH 25/51] potential bug in pipeline block --- .../parallel/pipeline_parallel/block.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/nanotron/parallel/pipeline_parallel/block.py b/src/nanotron/parallel/pipeline_parallel/block.py index 57bd8e0a..4e8cfeb5 100644 --- a/src/nanotron/parallel/pipeline_parallel/block.py +++ b/src/nanotron/parallel/pipeline_parallel/block.py @@ -93,7 +93,6 @@ def forward(self, **kwargs): pipeline_state=self.pipeline_state, ) continue - if v.requires_grad is True: raise ValueError( f"Pipeline engine is None and tensor requires grad. Tried sending a tensor to {self.rank}. Usually that means that your model is pipeline sharded and you haven't chosen a specific pipeline engine." @@ -152,6 +151,29 @@ def forward(self, **kwargs): # We don't store result in a buffer recv_id = batch_send_recv.add_recv(from_rank=tensor.group_rank) name_to_recv_id[name] = recv_id + elif isinstance(tensor, dict): + new_kwargs[name] = {} + for k, v in tensor.items(): + # the same as above just looped over the dict + if isinstance(v, TensorPointer): + if isinstance(self.pipeline_state, PipelineTrainBatchState): + for _ in range(len(self.pipeline_state.microbatches_activations_to_send)): + send_activation = self.pipeline_state.microbatches_activations_to_send.popleft() + # Execute + send_activation() + + if self.pipeline_state is not None: + new_kwargs[name][k] = recv_from_pipeline_state_buffer( + from_rank=tensor.group_rank, + p2p=self.p2p, + pipeline_state=self.pipeline_state, + ) + continue + # We don't store result in a buffer + recv_id = batch_send_recv.add_recv(from_rank=tensor.group_rank) + name_to_recv_id[name] = recv_id + else: + new_kwargs[name][k] = v else: new_kwargs[name] = tensor From 7425167856b816aee20d5087eee3ded9cbe4d27e Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 26 Jun 2024 12:38:55 +0000 Subject: [PATCH 26/51] XGLM work in progress: Causal Attention and Positional Embeddings work --- examples/xglm/__init__.py | 0 examples/xglm/convert_hf2nt.py | 28 ++ examples/xglm/tests/test_attn.py | 74 +++++ examples/xglm/tests/test_implementation.py | 90 ++++++ src/nanotron/config/models_config.py | 36 +++ src/nanotron/models/gpt3.py | 358 +++++++++++++++++++++ 6 files changed, 586 insertions(+) create mode 100644 examples/xglm/__init__.py create mode 100644 examples/xglm/convert_hf2nt.py create mode 100644 examples/xglm/tests/test_attn.py create mode 100644 examples/xglm/tests/test_implementation.py create mode 100644 src/nanotron/models/gpt3.py diff --git a/examples/xglm/__init__.py b/examples/xglm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py new file mode 100644 index 00000000..e008f859 --- /dev/null +++ b/examples/xglm/convert_hf2nt.py @@ -0,0 +1,28 @@ +import torch + +from transformers.models.xglm.modeling_xglm import XGLMAttention +from nanotron.models.gpt3 import CausalSelfAttention + + +def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): + q_ws = torch.chunk(attn_hf.q_proj.weight, attn_hf.num_heads) + k_ws = torch.chunk(attn_hf.k_proj.weight, attn_hf.num_heads) + v_ws = torch.chunk(attn_hf.v_proj.weight, attn_hf.num_heads) + + q_bs = torch.chunk(attn_hf.q_proj.bias, attn_hf.num_heads) + k_bs = torch.chunk(attn_hf.k_proj.bias, attn_hf.num_heads) + v_bs = torch.chunk(attn_hf.v_proj.bias, attn_hf.num_heads) + + qkv_w = [] + qkv_b = [] + for q_w, k_w, v_w, q_b, k_b, v_b in zip(q_ws, k_ws, v_ws, q_bs, k_bs, v_bs): + qkv_w += [q_w, k_w, v_w] + qkv_b += [q_b, k_b, v_b] + qkv_w = torch.cat(qkv_w) + qkv_b = torch.cat(qkv_b) + + with torch.no_grad(): + attn_nt.query_key_value.weight.data = qkv_w.clone() + attn_nt.query_key_value.bias.data = qkv_b.clone() + attn_nt.dense.weight.data = attn_hf.out_proj.weight.clone() + attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() diff --git a/examples/xglm/tests/test_attn.py b/examples/xglm/tests/test_attn.py new file mode 100644 index 00000000..2fcdb3a8 --- /dev/null +++ b/examples/xglm/tests/test_attn.py @@ -0,0 +1,74 @@ +import torch +from torch.nn import functional as F +#torch.Size([4, 2048, 16, 64]), torch.Size([2048, 4, 1024]) + +# inputs = (batchsize * qlen, heads, head_dim) +# outputs = (batchsize*qlen, heads, head_dim) +def sdpa(query, key, value, batchsize: int): + def reshape(tensor): # output = (batchsize, heads, qlen, head_dim) + return tensor.view(batchsize, qlen, heads, head_dim).permute(0, 2, 1, 3) + + batchsize_x_qlen, heads, head_dim = query.size() + qlen = batchsize_x_qlen//batchsize + out = F.scaled_dot_product_attention(reshape(query), reshape(key), reshape(value), is_causal=True) # (b,h,q,d) + return out.permute(0, 2, 1, 3).reshape(batchsize*qlen, heads, head_dim) + + +# inputs = (batchsize * qlen, heads, head_dim) +# outputs = (batchsize*qlen, heads, head_dim) +def fa(query_states, key_states, value_states, batchsize: int): + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + batchsize_x_qlen, heads, head_dim = query_states.size() + qlen = batchsize_x_qlen//batchsize + + q_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") + kv_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") + + # TODO @thomasw21: Compute once, instead of computing for each layers. + cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) + cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) + torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) + torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) + + # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not + # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. + causal = False if q_sequence_mask.shape[1] == 1 else True + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_sequence_mask.shape[1], + max_seqlen_k=kv_sequence_mask.shape[1], + dropout_p=0.0, + softmax_scale=None, # defaults to 1/sqrt(d_qk) + causal=causal, + window_size=(-1, -1), + return_attn_probs=False, + ) + return attn_output + + +def main(): + batchsize = 5 + qlen = 6 + heads = 2 + head_dim = 16 + + query = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + key = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + value = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + + out_pt = sdpa(query, key, value, batchsize) + out_fa = fa(query, key, value, batchsize) + + assert out_pt.size() == out_fa.size() + + torch.testing.assert_close(out_pt, out_fa) + + + +if __name__ == "__main__": + main() diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py new file mode 100644 index 00000000..10f0302a --- /dev/null +++ b/examples/xglm/tests/test_implementation.py @@ -0,0 +1,90 @@ +import numpy as np +import torch +import pytest + +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMSinusoidalPositionalEmbedding + +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import CausalSelfAttention, PositionEmbedding +from nanotron.parallel import ParallelContext + +from tests.helpers.utils import init_distributed + +from examples.xglm.convert_hf2nt import convert_attention + + +SEQUENCE_LENGTH = 2048 +BATCH_SIZE = 4 +HIDDEN_SIZE = 1024 +DTYPE = torch.float64 + +CONFIG = GPT3Config( + attn_pdrop=0.0, + embd_pdrop=0.0, + resid_pdrop=0.0, + eos_token_id=2, + hidden_size=HIDDEN_SIZE, + intermediate_size=4096, + layer_norm_epsilon=1e-05, + max_position_embeddings=SEQUENCE_LENGTH, + num_attention_heads=16, + num_hidden_layers=24, + scale_attn_weights=True, + vocab_size=256008, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=True +) + + +@pytest.fixture +def hidden_states() -> torch.Tensor: + return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, + dtype=DTYPE) + + +@pytest.fixture +def input_mask() -> torch.Tensor: + return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) + + +def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + attn_nt = CausalSelfAttention(CONFIG, None, parallel_context.tp_pg, 0).cuda().eval().to(DTYPE) + attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) + assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) + + # Build xglm mask. + mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) + mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) + + convert_attention(attn_nt, attn_hf) + out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=mask)[0].permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_attention)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_position_embeddings(parallel_context: ParallelContext): + position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, SEQUENCE_LENGTH) + + emb_nt = PositionEmbedding(parallel_context.tp_pg, CONFIG, None).cuda() + emb_hf = XGLMSinusoidalPositionalEmbedding(SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() + + assert emb_nt.position_embedding.weight.size() == emb_hf.weights.size() + torch.testing.assert_close(emb_nt.position_embedding.weight, emb_hf.weights) + + out_nt = emb_nt(position_ids)["position_embeds"] + out_hf = emb_hf(position_ids).permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + +def test_position_embeddings(): + init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 2630e1d6..f214b357 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -136,4 +136,40 @@ def n_inner(self): return self.intermediate_size +@dataclass +class GPT3Config: + """Configuration for a GPT3 model""" + + activation_function: str = "gelu" + attn_pdrop: float = 0.1 + embd_pdrop: float = 0.1 + eos_token_id: int = 49152 + hidden_size: int = 2048 + intermediate_size: Optional[int] = None + layer_norm_epsilon: float = 1e-05 + max_position_embeddings: int = 4096 + num_attention_heads: int = 16 + num_hidden_layers: int = 24 + resid_pdrop: float = 0.1 + scale_attention_softmax_in_fp32: bool = True + scale_attn_weights: bool = True + vocab_size: int = 49280 + sinusoidal_position_embedding: bool = True + position_embedding_offset: int = 2 + use_spda: bool = False + + def as_starcoder2(self) -> Starcoder2Config: + config = dict(**vars(self)) + del config["sinusoidal_position_embedding"] + del config["use_spda"] + del config["position_embedding_offset"] + return Starcoder2Config( + grouped_query=True, + num_kv_heads=self.num_attention_heads, + use_rotary_embeddings=False, + **config + ) + + NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any] + diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py new file mode 100644 index 00000000..8cea58c4 --- /dev/null +++ b/src/nanotron/models/gpt3.py @@ -0,0 +1,358 @@ +"""PyTorch GPT-3 model.""" + +import math +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F + +from nanotron import distributed as dist +from nanotron.parallel import ParallelContext +from nanotron.config import Config, GPT3Config, ParallelismArgs +from nanotron.generation.generate_store import AttachableStore +from nanotron.models.starcoder2 import MLP as Starcoder2MLP +from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention +from nanotron.models.starcoder2 import CausalSelfGQA +from nanotron.random import RandomStates, branch_random_state +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding +from nanotron.parallel.tied_parameters import tie_parameters + +# NOTES: +# - tie head_weight with embeddings I think. + +# TODO: +# - class GPT3Config: config lol +# - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. +# - from starcoder import Embedding +# - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding +# - class GPTBLock: very similar to starcoder2 but make it so it support non-GQA or MQA +# - from starcoder import Loss + + +class CoreAttention(Starcoder2CoreAttention): + def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + super().__init__(config.as_starcoder2(), parallel_config, layer_idx) + self.gpt3config = config + + def forward(self, + query_states: torch.Tensor, # [batch_size * q_length, q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] + q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) + kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) + ): + + if self.gpt3config.use_spda: + assert torch.all(q_sequence_mask) + assert torch.all(kv_sequence_mask) + + batch_size, q_length = q_sequence_mask.size() + kv_length = kv_sequence_mask.size(1) + _, q_heads, head_dim = query_states.size() + kv_heads = key_states.size(1) + + attention_output = F.scaled_dot_product_attention( + query_states.view(batch_size, q_length, q_heads, head_dim).permute(0, 2, 1, 3), + key_states.view(batch_size, kv_length, kv_heads, head_dim).permute(0, 2, 1, 3), + value_states.view(batch_size, kv_length, kv_heads, head_dim).permute(0, 2, 1, 3), + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + ) # [batch, q_length, q_heads, head_dim] + attention_output = attention_output.permute(0, 2, 1, 3) + attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) + return attention_output + + assert query_states.dtype in {torch.bfloat16, torch.float16} + return super().forward(query_states, key_states, value_states, q_sequence_mask, kv_sequence_mask) + + +class CausalSelfAttention(CausalSelfGQA): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int, + ): + super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) + self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. + self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. + + +class MLP(Starcoder2MLP): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + # TODO: GPT3Config -> Starcoder2Config. + super().__init__(config, parallel_config, tp_pg) + self.dropout = nn.Dropout(p=config.dropout) # TODO: correct config.dropout name + + def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = self.dropout(input=hidden_states) + hidden_states = self.c_proj(hidden_states) + return {"hidden_states": hidden_states} + + +class GPTBlock(nn.Module): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + random_states: RandomStates, + layer_idx: int, + ): + super(GPTBlock, self).__init__() + self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.attn = CausalSelfAttention( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx + ) + self.attn_dropout = config.attn_pdrop + + self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.ff_dropout = config.resid_pdrop + + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + def forward( + self, + hidden_states: torch.Tensor | TensorPointer, + sequence_mask: torch.Tensor | TensorPointer, + ) -> dict[str, torch.Tensor | TensorPointer]: + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) + else: + # No need for random state context manager + # TODO: add dropout scaling? + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = self.ff(hidden_states=hidden_states)["hidden_states"] + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) + else: + # No need for random state context manager + # TODO: add dropout scaling? + hidden_states = hidden_states + residual + + return { + "hidden_states": hidden_states, + "sequence_mask": output["sequence_mask"], + } + + +class PositionEmbedding(nn.Module, AttachableStore): + def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config: Optional[ParallelismArgs]): + super().__init__() + + self.config = config + if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: + dummy_pos = 0 + else: + dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % k) + true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos + + if config.sinusoidal_position_embedding: + weight = self._make_weights(tp_pg, true_max_size, config.hidden_size) + else: + weight = None + + position_embedding = TensorParallelEmbedding( + num_embeddings=true_max_size, + embedding_dim=config.hidden_size, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + _weight=weight + ) + self.pg = tp_pg + + # Sinusoidal position embeddings are usually not trainable. + # We adjust that by setting the module self.position_embedding without gradient. + if config.sinusoidal_position_embedding: + with torch.no_grad(): + self.position_embedding = position_embedding.requires_grad_(False) + else: + self.position_embedding = position_embedding + + def forward(self, position_ids: torch.Tensor): # [batch_size, seq_length] + position_ids = position_ids.transpose(0, 1) + position_embeds = self.position_embedding(position_ids + self.config.position_embedding_offset) + return {"position_embeds": position_embeds} + + def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, + embedding_dim: int) -> torch.Tensor: + rank = dist.get_rank(group=tp_pg) + tp_size = tp_pg.size() + + assert 0 <= rank < tp_size + assert num_embeddings % tp_size == 0 + assert embedding_dim % 2 == 0 + block_size = num_embeddings//tp_size + + half_dim = embedding_dim//2 + emb = math.log(10_000)/(half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = (rank*block_size + torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) + return emb + + +class GPT3Model(nn.Module): + def __init__( + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + self.token_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"input_ids"}, + module_output_keys={"input_embeds"}, + ) + self.position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=PositionEmbedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"position_ids"}, + module_output_keys={"position_embeds"}, + ) + + self.embeds_dropout = PipelineBlock( + p2p=self.p2p, + module_builder=nn.Dropout, + module_kwargs={"p": config.embd_pdrop}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=GPTBlock, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "random_states": random_states, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "sequence_mask"}, + module_output_keys={"hidden_states", "sequence_mask"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=TritonLayerNorm, + module_kwargs={"normalized_shape": config.hidden_size, "eps": config.layer_norm_epsilon}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.hidden_size, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": parallel_config.tp_linear_async_communication + if parallel_config is not None + else False, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + self.cast_to_fp32 = PipelineBlock( + p2p=self.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) + + + def forward( + self, + input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length] + input_mask: torch.Tensor | TensorPointer, # [batch_size, seq_length] + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) + input_embeds = self.token_embeddings(input_ids=input_ids)["input_embeds"] + position_embeds = self.position_embeds(position_ids=position_ids)["position_embeds"] + hidden_states = input_embeds + position_embeds + + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode == TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = self.embeds_dropout(input=hidden_states)["hidden_states"] + + hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} + for encoder_block in self.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states) + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits From 42695ba79110f34c16e4bb8bdf20ecdc74312fc9 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 26 Jun 2024 17:24:53 +0000 Subject: [PATCH 27/51] WIP: GPT arch almost done, hf->nt converters working perfectly for non-distributed inference --- examples/xglm/convert_hf2nt.py | 70 +++++++- examples/xglm/tests/test_attn.py | 74 --------- examples/xglm/tests/test_implementation.py | 135 +++++++++++++-- src/nanotron/config/models_config.py | 4 + src/nanotron/models/gpt3.py | 184 ++++++++++----------- 5 files changed, 287 insertions(+), 180 deletions(-) delete mode 100644 examples/xglm/tests/test_attn.py diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index e008f859..6e6ddff1 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -1,7 +1,44 @@ import torch +from torch import nn -from transformers.models.xglm.modeling_xglm import XGLMAttention -from nanotron.models.gpt3 import CausalSelfAttention +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM +from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining +from nanotron.config.models_config import GPT3Config + + +def convert_config(config: XGLMConfig) -> GPT3Config: + # TODOs: + # dropout=0.1, + # layerdrop=0.0, + # init_std=0.02, + # use_cache=True, + # decoder_start_token_id=2, + # pad_token_id=1, + # bos_token_id=0, + + # TODO: when going gpt3->xglm: + # - assert layernorm is 1e-05 + return GPT3Config( + activation_function=config.activation_function, + attn_pdrop=config.attention_dropout, + embd_pdrop=0.0, # TODO + eos_token_id=config.eos_token_id, + hidden_size=config.d_model, + intermediate_size=config.ffn_dim, + layer_norm_epsilon=1e-05, + max_position_embeddings=config.max_position_embeddings, + num_attention_heads=config.attention_heads, + num_hidden_layers=config.num_layers, + resid_pdrop=0.0, # TODO + scale_attention_softmax_in_fp32=True, + scale_attn_weights=True, + vocab_size=config.vocab_size, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=False, + act_pdrop=config.activation_dropout, + scale_embedding=config.scale_embedding, + ) def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): @@ -26,3 +63,32 @@ def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): attn_nt.query_key_value.bias.data = qkv_b.clone() attn_nt.dense.weight.data = attn_hf.out_proj.weight.clone() attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() + + +def convert_generic(module1: nn.Module, module2: nn.Module): + names1 = {name for name, _ in module1.named_parameters()} + names2 = {name for name, _ in module2.named_parameters()} + assert names1 == names2, f"{names1} != {names2}" + params2 = dict(module2.named_parameters()) + for name, param in module1.named_parameters(): + param.data = params2[name].clone() + + +def convert_mlp(mlp_nt: MLP, block_hf: XGLMDecoderLayer): + convert_generic(mlp_nt.c_fc, block_hf.fc1) + convert_generic(mlp_nt.c_proj, block_hf.fc2) + + +def convert_decoder(block_nt: GPTBlock, block_hf: XGLMDecoderLayer): + convert_generic(block_nt.ln_1, block_hf.self_attn_layer_norm) + convert_attention(block_nt.attn, block_hf.self_attn) + convert_generic(block_nt.ln_2, block_hf.final_layer_norm) + convert_mlp(block_nt.ff, block_hf) + + +def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): + convert_generic(model_nt.model.token_embeddings.pp_block.token_embedding, model_hf.model.embed_tokens) + for layer_nt, layer_hf in zip(model_nt.model.decoder, model_hf.model.layers): + convert_decoder(layer_nt.pp_block, layer_hf) + convert_generic(model_nt.model.final_layer_norm.pp_block, model_hf.model.layer_norm) + convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) diff --git a/examples/xglm/tests/test_attn.py b/examples/xglm/tests/test_attn.py deleted file mode 100644 index 2fcdb3a8..00000000 --- a/examples/xglm/tests/test_attn.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -from torch.nn import functional as F -#torch.Size([4, 2048, 16, 64]), torch.Size([2048, 4, 1024]) - -# inputs = (batchsize * qlen, heads, head_dim) -# outputs = (batchsize*qlen, heads, head_dim) -def sdpa(query, key, value, batchsize: int): - def reshape(tensor): # output = (batchsize, heads, qlen, head_dim) - return tensor.view(batchsize, qlen, heads, head_dim).permute(0, 2, 1, 3) - - batchsize_x_qlen, heads, head_dim = query.size() - qlen = batchsize_x_qlen//batchsize - out = F.scaled_dot_product_attention(reshape(query), reshape(key), reshape(value), is_causal=True) # (b,h,q,d) - return out.permute(0, 2, 1, 3).reshape(batchsize*qlen, heads, head_dim) - - -# inputs = (batchsize * qlen, heads, head_dim) -# outputs = (batchsize*qlen, heads, head_dim) -def fa(query_states, key_states, value_states, batchsize: int): - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - batchsize_x_qlen, heads, head_dim = query_states.size() - qlen = batchsize_x_qlen//batchsize - - q_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") - kv_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") - - # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) - - # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not - # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. - causal = False if q_sequence_mask.shape[1] == 1 else True - attn_output = flash_attn_varlen_func( - q=query_states, - k=key_states, - v=value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_sequence_mask.shape[1], - max_seqlen_k=kv_sequence_mask.shape[1], - dropout_p=0.0, - softmax_scale=None, # defaults to 1/sqrt(d_qk) - causal=causal, - window_size=(-1, -1), - return_attn_probs=False, - ) - return attn_output - - -def main(): - batchsize = 5 - qlen = 6 - heads = 2 - head_dim = 16 - - query = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - key = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - value = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - - out_pt = sdpa(query, key, value, batchsize) - out_fa = fa(query, key, value, batchsize) - - assert out_pt.size() == out_fa.size() - - torch.testing.assert_close(out_pt, out_fa) - - - -if __name__ == "__main__": - main() diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index 10f0302a..3636415b 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -1,27 +1,33 @@ +from typing import Optional + import numpy as np import torch import pytest -from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMSinusoidalPositionalEmbedding +from transformers import XGLMTokenizer +from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM +import nanotron from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import CausalSelfAttention, PositionEmbedding +from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock from nanotron.parallel import ParallelContext from tests.helpers.utils import init_distributed -from examples.xglm.convert_hf2nt import convert_attention +from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert SEQUENCE_LENGTH = 2048 BATCH_SIZE = 4 HIDDEN_SIZE = 1024 -DTYPE = torch.float64 +DTYPE = torch.bfloat16 +TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" CONFIG = GPT3Config( attn_pdrop=0.0, embd_pdrop=0.0, resid_pdrop=0.0, + act_pdrop=0.0, eos_token_id=2, hidden_size=HIDDEN_SIZE, intermediate_size=4096, @@ -42,11 +48,22 @@ def hidden_states() -> torch.Tensor: return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) - @pytest.fixture def input_mask() -> torch.Tensor: return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) +@pytest.fixture +def input_ids() -> torch.Tensor: + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH)) + + +def attention_mask() -> torch.Tensor: + # XGLM causal attention mask. + mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) + mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) + return mask + def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() @@ -56,14 +73,9 @@ def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tens attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) - # Build xglm mask. - mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) - mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) - mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) - convert_attention(attn_nt, attn_hf) out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] - out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=mask)[0].permute(1, 0, 2) + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" torch.testing.assert_close(out_nt, out_hf) @@ -88,3 +100,104 @@ def _test_position_embeddings(parallel_context: ParallelContext): def test_position_embeddings(): init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() + + +def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + config_hf = XGLMConfig() + decoder_hf = XGLMDecoderLayer(config_hf).cuda().to(DTYPE).eval() + config_nt = convert_config(config_hf) + if DTYPE not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + decoder_nt = GPTBlock(config_nt, None, parallel_context.tp_pg, random_states, 0).cuda().to(DTYPE).eval() + + convert_decoder(decoder_nt, decoder_hf) + + out_nt = decoder_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, + input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # Get hf model. + if model_hf is None: + config_hf = XGLMConfig() + model_hf = XGLMForCausalLM(config_hf).cuda().to(DTYPE).eval() + else: + model_hf = model_hf.cuda().to(DTYPE).eval() + config_hf = model_hf.config + + # Get nanotron model and make the conversion. + config_nt = convert_config(config_hf) + if DTYPE not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=DTYPE, + device="cuda", + ).eval() + convert(model_nt, model_hf) + + print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters()))/1000/1000) + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask).to(DTYPE) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt.cpu(), out_hf.cpu()) + +def _test_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + _test_model(None, parallel_context, input_ids, input_mask) + + +def test_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) + + +def _test_xglm7B(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") + _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + + +def test_xglm7B(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm7B)() + + +def _test_xglm500M(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") + _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + + +def test_xglm500M(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index f214b357..56d6411f 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -157,12 +157,16 @@ class GPT3Config: sinusoidal_position_embedding: bool = True position_embedding_offset: int = 2 use_spda: bool = False + act_pdrop: float = 0.0 + scale_embedding: bool = True def as_starcoder2(self) -> Starcoder2Config: config = dict(**vars(self)) del config["sinusoidal_position_embedding"] del config["use_spda"] del config["position_embedding_offset"] + del config["act_pdrop"] + del config["scale_embedding"] return Starcoder2Config( grouped_query=True, num_kv_heads=self.num_attention_heads, diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 8cea58c4..99f6ea85 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -2,6 +2,7 @@ import math from typing import Optional +from contextlib import contextmanager import torch from torch import nn @@ -9,11 +10,15 @@ from nanotron import distributed as dist from nanotron.parallel import ParallelContext -from nanotron.config import Config, GPT3Config, ParallelismArgs +from nanotron.config import Config, GPT3Config, ParallelismArgs, Starcoder2Config from nanotron.generation.generate_store import AttachableStore +from nanotron.models import starcoder2 +from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.models.starcoder2 import MLP as Starcoder2MLP +from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention -from nanotron.models.starcoder2 import CausalSelfGQA +from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock +from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel from nanotron.random import RandomStates, branch_random_state from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode @@ -28,10 +33,55 @@ # - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. # - from starcoder import Embedding # - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding -# - class GPTBLock: very similar to starcoder2 but make it so it support non-GQA or MQA +# - class GPTBlock: very similar to starcoder2 but make it so it support non-GQA or MQA # - from starcoder import Loss +@contextmanager +def replace_coreattention(gpt3config: GPT3Config): + orig = starcoder2.CoreAttention + try: + def create_core_attention(config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + return CoreAttention(gpt3config, parallel_config, layer_idx) + starcoder2.CoreAttention = create_core_attention + yield + finally: + starcoder2.CoreAttention = orig + + +@contextmanager +def replace_decoder(gpt3config: GPT3Config): + orig = starcoder2.PipelineBlock + try: + def create_pp_block(module_builder, module_kwargs, **kwargs): + if module_builder is Starcoder2GPTBlock: + # Starcoder2's GPT module is trying to instantiate a Starcoder2 GPTBlock. + # Let's return a PipelineBlock with a GPT3Block instead. + # This also requires to replace starcoders2's config with gpt3's config. + module_kwargs["config"] = gpt3config + return orig(module_builder=GPTBlock, module_kwargs=module_kwargs, **kwargs) + # Else, they are setting up other modules, which we also want unchanged. + return orig(module_builder=module_builder, module_kwargs=module_kwargs, **kwargs) + + starcoder2.PipelineBlock = create_pp_block + yield + finally: + starcoder2.PipelineBlock = orig + + +@contextmanager +def replace_gpt3model(gpt3config: GPT3Config): + orig = starcoder2.GPTModel + try: + def create_gptmodel(config: Starcoder2Config, parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], random_states: RandomStates): + return GPT3Model(gpt3config, parallel_context, parallel_config, random_states) + starcoder2.GPTModel = create_gptmodel + yield + finally: + starcoder2.GPTModel = orig + + class CoreAttention(Starcoder2CoreAttention): def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): super().__init__(config.as_starcoder2(), parallel_config, layer_idx) @@ -63,7 +113,7 @@ def forward(self, ) # [batch, q_length, q_heads, head_dim] attention_output = attention_output.permute(0, 2, 1, 3) attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) - return attention_output + return attention_output.contiguous() assert query_states.dtype in {torch.bfloat16, torch.float16} return super().forward(query_states, key_states, value_states, q_sequence_mask, kv_sequence_mask) @@ -77,9 +127,10 @@ def __init__( tp_pg: dist.ProcessGroup, layer_idx: int, ): - super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) + with replace_coreattention(config): + super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. - self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. + #self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. class MLP(Starcoder2MLP): @@ -88,10 +139,12 @@ def __init__( config: GPT3Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + random_states: RandomStates ): - # TODO: GPT3Config -> Starcoder2Config. - super().__init__(config, parallel_config, tp_pg) - self.dropout = nn.Dropout(p=config.dropout) # TODO: correct config.dropout name + super().__init__(config.as_starcoder2(), parallel_config, tp_pg) + self.dropout = nn.Dropout(p=config.act_pdrop) + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] hidden_states = self.c_fc(hidden_states) @@ -113,6 +166,7 @@ def __init__( random_states: RandomStates, layer_idx: int, ): + #print("New gpt block created :D") super(GPTBlock, self).__init__() self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention( @@ -124,7 +178,7 @@ def __init__( self.attn_dropout = config.attn_pdrop self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, random_states=random_states) self.ff_dropout = config.resid_pdrop self.random_states = random_states @@ -138,8 +192,10 @@ def forward( residual = hidden_states hidden_states = self.ln_1(hidden_states) + #hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] + #return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} if self.training: with branch_random_state( @@ -227,7 +283,7 @@ def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, return emb -class GPT3Model(nn.Module): +class GPT3Model(GPTModel): def __init__( self, config: GPT3Config, @@ -235,24 +291,9 @@ def __init__( parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): - super().__init__() + with replace_decoder(config): + super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) - # Declare all the nodes - self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) - self.random_states = random_states - self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - - self.token_embeddings = PipelineBlock( - p2p=self.p2p, - module_builder=Embedding, - module_kwargs={ - "tp_pg": parallel_context.tp_pg, - "config": config, - "parallel_config": parallel_config, - }, - module_input_keys={"input_ids"}, - module_output_keys={"input_embeds"}, - ) self.position_embeddings = PipelineBlock( p2p=self.p2p, module_builder=PositionEmbedding, @@ -264,69 +305,7 @@ def __init__( module_input_keys={"position_ids"}, module_output_keys={"position_embeds"}, ) - - self.embeds_dropout = PipelineBlock( - p2p=self.p2p, - module_builder=nn.Dropout, - module_kwargs={"p": config.embd_pdrop}, - module_input_keys={"input"}, - module_output_keys={"hidden_states"}, - ) - - self.decoder = nn.ModuleList( - [ - PipelineBlock( - p2p=self.p2p, - module_builder=GPTBlock, - module_kwargs={ - "config": config, - "parallel_config": parallel_config, - "tp_pg": parallel_context.tp_pg, - "random_states": random_states, - "layer_idx": layer_idx, - }, - module_input_keys={"hidden_states", "sequence_mask"}, - module_output_keys={"hidden_states", "sequence_mask"}, - ) - for layer_idx in range(config.num_hidden_layers) - ] - ) - - self.final_layer_norm = PipelineBlock( - p2p=self.p2p, - module_builder=TritonLayerNorm, - module_kwargs={"normalized_shape": config.hidden_size, "eps": config.layer_norm_epsilon}, - module_input_keys={"input"}, - module_output_keys={"hidden_states"}, - ) - - self.lm_head = PipelineBlock( - p2p=self.p2p, - # Understand that this means that we return sharded logits that are going to need to be gathered - module_builder=TensorParallelColumnLinear, - module_kwargs={ - "in_features": config.hidden_size, - "out_features": config.vocab_size, - "pg": parallel_context.tp_pg, - "bias": False, - # TODO: refactor so that we store that default in a single place. - "mode": self.tp_mode, - "async_communication": parallel_config.tp_linear_async_communication - if parallel_config is not None - else False, - }, - module_input_keys={"x"}, - module_output_keys={"logits"}, - ) - - self.cast_to_fp32 = PipelineBlock( - p2p=self.p2p, - module_builder=lambda: lambda x: x.float(), - module_kwargs={}, - module_input_keys={"x"}, - module_output_keys={"output"}, - ) - + self.embed_scale = config.hidden_size**0.5 if config.scale_embedding else 1.0 def forward( self, @@ -335,9 +314,9 @@ def forward( ): # all tensors are optional as most ranks don't need anything from the dataloader. + input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) - input_embeds = self.token_embeddings(input_ids=input_ids)["input_embeds"] - position_embeds = self.position_embeds(position_ids=position_ids)["position_embeds"] + position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] hidden_states = input_embeds + position_embeds with branch_random_state( @@ -348,6 +327,7 @@ def forward( hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) + #return hidden_encoder_states["hidden_states"] hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] @@ -356,3 +336,21 @@ def forward( fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] return fp32_sharded_logits + + +# TODO: maybe reimplement: +# - tie_custom_params +# - get_embeddings_lm_head_tied_names +# - get_block_compute_costs +# - get_flops_per_sec +class GPT3ForTraining(Starcoder2ForTraining): + def __init__( + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + with replace_gpt3model(config): + super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) + From 6294aad23daeff0af49b63303c13aebfe0fc6954 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 9 Jul 2024 16:46:55 +0200 Subject: [PATCH 28/51] Added hf2nt frontend + tested training --- examples/xglm/README.md | 13 +++ examples/xglm/convert_hf2nt.py | 86 ++++++++++++++-- examples/xglm/example_config.yaml | 98 +++++++++++++++++++ src/nanotron/config/models_config.py | 6 +- src/nanotron/models/gpt3.py | 23 +---- .../optimizer_from_gradient_accumulator.py | 3 +- src/nanotron/trainer.py | 2 + 7 files changed, 199 insertions(+), 32 deletions(-) create mode 100644 examples/xglm/README.md create mode 100644 examples/xglm/example_config.yaml diff --git a/examples/xglm/README.md b/examples/xglm/README.md new file mode 100644 index 00000000..abc50f95 --- /dev/null +++ b/examples/xglm/README.md @@ -0,0 +1,13 @@ +# How to use XGLM? + +1. First, make sure to convert the weights from huggingface, for instance: + ``` + torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M + ``` + +1. Now you are ready to use XGLM. + Make sure you use a .yaml configuration with proper GPT3 config and then run for instance: + ``` + torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml + ``` + If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 6e6ddff1..9db5ed93 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -1,27 +1,42 @@ +""" +Converts a HF model to nanotron format +Command: + torchrun --nproc-per-node=1 convert_hf2nt.py --checkpoint-path=hf_weights --save-path=nanotron_weights +""" + +import json +import warnings +import dataclasses +from argparse import ArgumentParser +from pathlib import Path + import torch from torch import nn - from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM + +import nanotron from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from nanotron.config.models_config import GPT3Config +from nanotron.trainer import mark_tied_parameters + def convert_config(config: XGLMConfig) -> GPT3Config: # TODOs: - # dropout=0.1, # layerdrop=0.0, # init_std=0.02, # use_cache=True, - # decoder_start_token_id=2, # pad_token_id=1, # bos_token_id=0, - - # TODO: when going gpt3->xglm: - # - assert layernorm is 1e-05 + if config.dropout != config.attention_dropout: + warnings.warn(f"huggingface.dropout = {config.dropout} does not match with " + f"huggingface.attention_dropout = {config.attention_dropout}. " + "Nanotron implementation needs these two values to be equal " + "for correct conversion.") return GPT3Config( activation_function=config.activation_function, attn_pdrop=config.attention_dropout, - embd_pdrop=0.0, # TODO + embd_pdrop=config.dropout, eos_token_id=config.eos_token_id, hidden_size=config.d_model, intermediate_size=config.ffn_dim, @@ -29,12 +44,12 @@ def convert_config(config: XGLMConfig) -> GPT3Config: max_position_embeddings=config.max_position_embeddings, num_attention_heads=config.attention_heads, num_hidden_layers=config.num_layers, - resid_pdrop=0.0, # TODO + resid_pdrop=config.dropout, scale_attention_softmax_in_fp32=True, scale_attn_weights=True, vocab_size=config.vocab_size, sinusoidal_position_embedding=True, - position_embedding_offset=2, + position_embedding_offset=config.decoder_start_token_id, use_spda=False, act_pdrop=config.activation_dropout, scale_embedding=config.scale_embedding, @@ -92,3 +107,56 @@ def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): convert_decoder(layer_nt.pp_block, layer_hf) convert_generic(model_nt.model.final_layer_norm.pp_block, model_hf.model.layer_norm) convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) + + +def create_nt_model(model_config: GPT3Config, device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16) -> GPT3ForTraining: + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + #random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + return model_nt + + +def main(hf_path: str, save_path: Path): + # Load hf. + print("Loading hf...") + model_hf = XGLMForCausalLM.from_pretrained(hf_path) + + # Init nanotron. + print("Initializing nt...") + config_nt = convert_config(model_hf.config) + model_nt = create_nt_model(config_nt) + + # Copy weights and save model. + print("Copying weights...") + convert(model_nt, model_hf) + nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, + root_folder=save_path) + with open(save_path/"model_config.json", "w+") as f: + json.dump(dataclasses.asdict(config_nt), f) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument("--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint") + parser.add_argument("--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model") + args = parser.parse_args() + main(args.checkpoint_path, args.save_path) diff --git a/examples/xglm/example_config.yaml b/examples/xglm/example_config.yaml new file mode 100644 index 00000000..2d7e9926 --- /dev/null +++ b/examples/xglm/example_config.yaml @@ -0,0 +1,98 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints/xglm + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 64 + hf_dataset_config_name: null + hf_dataset_or_datasets: DKYoon/SlimPajama-6B + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Finetuning + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: xglm-test + run: xglm-dp4tp1pp1 + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + path: /capstor/scratch/cscs/ahernnde/checkpoints/xglm-564M + make_vocab_size_divisible_by: 1 + model_config: + activation_function: gelu + attn_pdrop: 0.1 + embd_pdrop: 0.1 + scale_embedding: true + eos_token_id: 2 + hidden_size: 1024 + intermediate_size: 4096 + layer_norm_epsilon: 0.00001 + max_position_embeddings: 2048 + num_attention_heads: 16 + num_hidden_layers: 24 + resid_pdrop: 0.1 + scale_attention_softmax_in_fp32: true + scale_attn_weights: true + vocab_size: 256008 + sinusoidal_position_embedding: true + position_embedding_offset: 2 + use_spda: false + act_pdrop: 0.0 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 900 + lr_decay_style: cosine + lr_warmup_steps: 100 + lr_warmup_style: linear + min_decay_lr: 1.0e-04 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 4 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: facebook/xglm-564M + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 4 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 8 + sequence_length: 2048 + train_steps: 1000 + val_check_interval: -1 diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 56d6411f..20a92126 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -167,6 +167,8 @@ def as_starcoder2(self) -> Starcoder2Config: del config["position_embedding_offset"] del config["act_pdrop"] del config["scale_embedding"] + if "_is_using_mup" in config: + del config["_is_using_mup"] return Starcoder2Config( grouped_query=True, num_kv_heads=self.num_attention_heads, @@ -174,6 +176,4 @@ def as_starcoder2(self) -> Starcoder2Config: **config ) - -NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any] - +NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 99f6ea85..33661c8b 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -18,24 +18,13 @@ from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock -from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel +from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel, dropout_add_fused_train from nanotron.random import RandomStates, branch_random_state from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding from nanotron.parallel.tied_parameters import tie_parameters -# NOTES: -# - tie head_weight with embeddings I think. - -# TODO: -# - class GPT3Config: config lol -# - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. -# - from starcoder import Embedding -# - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding -# - class GPTBlock: very similar to starcoder2 but make it so it support non-GQA or MQA -# - from starcoder import Loss - @contextmanager def replace_coreattention(gpt3config: GPT3Config): @@ -130,7 +119,6 @@ def __init__( with replace_coreattention(config): super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. - #self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. class MLP(Starcoder2MLP): @@ -204,7 +192,6 @@ def forward( hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) else: # No need for random state context manager - # TODO: add dropout scaling? hidden_states = hidden_states + residual residual = hidden_states @@ -218,7 +205,6 @@ def forward( hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) else: # No need for random state context manager - # TODO: add dropout scaling? hidden_states = hidden_states + residual return { @@ -235,7 +221,7 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: dummy_pos = 0 else: - dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % k) + dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size()) true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos if config.sinusoidal_position_embedding: @@ -278,7 +264,7 @@ def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, half_dim = embedding_dim//2 emb = math.log(10_000)/(half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) - emb = (rank*block_size + torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = (rank*block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) return emb @@ -315,6 +301,7 @@ def forward( # all tensors are optional as most ranks don't need anything from the dataloader. input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale + # TODO: position_ids could be cached. position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] hidden_states = input_embeds + position_embeds @@ -339,8 +326,6 @@ def forward( # TODO: maybe reimplement: -# - tie_custom_params -# - get_embeddings_lm_head_tied_names # - get_block_compute_costs # - get_flops_per_sec class GPT3ForTraining(Starcoder2ForTraining): diff --git a/src/nanotron/optim/optimizer_from_gradient_accumulator.py b/src/nanotron/optim/optimizer_from_gradient_accumulator.py index 01be7cb5..9883c720 100644 --- a/src/nanotron/optim/optimizer_from_gradient_accumulator.py +++ b/src/nanotron/optim/optimizer_from_gradient_accumulator.py @@ -38,7 +38,8 @@ def __init__( **{k: v for k, v in named_param_group.items() if k != "named_params"}, "named_params": [ (name, gradient_accumulator.get_parameter_for_optimizer(name)) - for name, _ in named_param_group["named_params"] + for name, param in named_param_group["named_params"] + if param.requires_grad ], } for named_param_group in named_param_groups diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index b6752f38..f01caa3e 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -58,6 +58,7 @@ from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp @@ -103,6 +104,7 @@ CONFIG_TO_MODEL_CLASS = { "LlamaConfig": LlamaForTraining, "Starcoder2Config": Starcoder2ForTraining, + "GPT3Config": GPT3ForTraining, } try: From b469ee90de19ea17fd92468ed72dcaba63a949fe Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 11 Jul 2024 13:38:52 +0200 Subject: [PATCH 29/51] Added nt2hf conversion + tests :) --- examples/xglm/README.md | 5 + examples/xglm/convert_hf2nt.py | 38 +---- examples/xglm/convert_nt2hf.py | 126 +++++++++++++++ examples/xglm/convert_utils.py | 59 +++++++ examples/xglm/tests/test_implementation.py | 177 +++++++++++++++++---- src/nanotron/config/models_config.py | 4 + src/nanotron/models/gpt3.py | 2 +- 7 files changed, 347 insertions(+), 64 deletions(-) create mode 100644 examples/xglm/convert_nt2hf.py create mode 100644 examples/xglm/convert_utils.py diff --git a/examples/xglm/README.md b/examples/xglm/README.md index abc50f95..22765f52 100644 --- a/examples/xglm/README.md +++ b/examples/xglm/README.md @@ -11,3 +11,8 @@ torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml ``` If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. + +1. If you want to convert your finetuned checkpoint back to huggingface use: + ``` + torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M + ``` diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 9db5ed93..0efcceca 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -18,11 +18,11 @@ from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from nanotron.config.models_config import GPT3Config from nanotron.trainer import mark_tied_parameters - +from examples.xglm.convert_utils import convert_generic, create_nt_model def convert_config(config: XGLMConfig) -> GPT3Config: - # TODOs: + # These settings seem to be unused: # layerdrop=0.0, # init_std=0.02, # use_cache=True, @@ -80,15 +80,6 @@ def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() -def convert_generic(module1: nn.Module, module2: nn.Module): - names1 = {name for name, _ in module1.named_parameters()} - names2 = {name for name, _ in module2.named_parameters()} - assert names1 == names2, f"{names1} != {names2}" - params2 = dict(module2.named_parameters()) - for name, param in module1.named_parameters(): - param.data = params2[name].clone() - - def convert_mlp(mlp_nt: MLP, block_hf: XGLMDecoderLayer): convert_generic(mlp_nt.c_fc, block_hf.fc1) convert_generic(mlp_nt.c_proj, block_hf.fc2) @@ -109,31 +100,6 @@ def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) -def create_nt_model(model_config: GPT3Config, device: torch.device = torch.device("cuda"), - dtype: torch.dtype = torch.bfloat16) -> GPT3ForTraining: - - parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) - parallel_context = nanotron.parallel.ParallelContext( - data_parallel_size=parallel_config.dp, - pipeline_parallel_size=parallel_config.pp, - tensor_parallel_size=parallel_config.tp, - ) - #random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) - model_nt = nanotron.models.build_model( - model_builder=lambda: GPT3ForTraining( - config=model_config, - parallel_context=parallel_context, - parallel_config=parallel_config, - random_states=None, - ), - parallel_context=parallel_context, - dtype=dtype, - device=device, - ) - mark_tied_parameters(model=model_nt, parallel_context=parallel_context) - return model_nt - - def main(hf_path: str, save_path: Path): # Load hf. print("Loading hf...") diff --git a/examples/xglm/convert_nt2hf.py b/examples/xglm/convert_nt2hf.py new file mode 100644 index 00000000..422695a1 --- /dev/null +++ b/examples/xglm/convert_nt2hf.py @@ -0,0 +1,126 @@ +""" +Converts a nanotron model to HF format +Command: + torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights +""" + +from argparse import ArgumentParser +from typing import Optional +from pathlib import Path + +import torch +from transformers import AutoTokenizer +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM + +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining +from examples.xglm.convert_utils import convert_generic, create_nt_model + + +def convert_config(config: GPT3Config) -> XGLMConfig: + if config.embd_pdrop != config.resid_pdrop: + warnings.warn(f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion.") + if config.layer_norm_epsilon != 1e-5: + warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") + return XGLMConfig( + activation_function=config.activation_function, + attention_dropout=config.attn_pdrop, + dropout=config.embd_pdrop, + eos_token_id=config.eos_token_id, + d_model=config.hidden_size, + ffn_dim=config.intermediate_size, + max_position_embeddings=config.max_position_embeddings, + attention_heads=config.num_attention_heads, + num_layers=config.num_hidden_layers, + vocab_size=config.vocab_size, + decoder_start_token_id=config.position_embedding_offset, + activation_dropout=config.act_pdrop, + scale_embedding=config.scale_embedding, + ) + + +def convert_attention(attn_hf: XGLMAttention, attn_nt: XGLMAttention): + qs_w = [] + ks_w = [] + vs_w = [] + qs_b = [] + ks_b = [] + vs_b = [] + + head_dim = attn_hf.head_dim + qkv_ws = list(attn_nt.query_key_value.weight.split(head_dim)) + qkv_bs = list(attn_nt.query_key_value.bias.split(head_dim)) + for i, (w, b) in enumerate(zip(qkv_ws, qkv_bs)): + if i % 3 == 0: + qs_w.append(w) + qs_b.append(b) + elif i % 3 == 1: + ks_w.append(w) + ks_b.append(b) + else: + vs_w.append(w) + vs_b.append(b) + + q_w = torch.cat(qs_w) + k_w = torch.cat(ks_w) + v_w = torch.cat(vs_w) + q_b = torch.cat(qs_b) + k_b = torch.cat(ks_b) + v_b = torch.cat(vs_b) + + with torch.no_grad(): + attn_hf.q_proj.weight.data = q_w.clone() + attn_hf.k_proj.weight.data = k_w.clone() + attn_hf.v_proj.weight.data = v_w.clone() + attn_hf.q_proj.bias.data = q_b.clone() + attn_hf.k_proj.bias.data = k_b.clone() + attn_hf.v_proj.bias.data = v_b.clone() + + attn_hf.out_proj.weight.data = attn_nt.dense.weight.data.clone() + attn_hf.out_proj.bias.data = attn_nt.dense.bias.data.clone() + + +def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPTBlock): + convert_generic(block_hf.self_attn_layer_norm, block_nt.ln_1) + convert_attention(block_hf.self_attn, block_nt.attn) + convert_generic(block_hf.final_layer_norm, block_nt.ln_2) + convert_generic(block_hf.fc1, block_nt.ff.c_fc) + convert_generic(block_hf.fc2, block_nt.ff.c_proj) + + +def convert(model_hf: XGLMForCausalLM, model_nt: GPT3ForTraining): + convert_generic(model_hf.model.embed_tokens, model_nt.model.token_embeddings.pp_block.token_embedding) + for layer_hf, layer_nt in zip(model_hf.model.layers, model_nt.model.decoder): + convert_decoder(layer_hf, layer_nt.pp_block) + convert_generic(model_hf.model.layer_norm, model_nt.model.final_layer_norm.pp_block) + convert_generic(model_hf.lm_head, model_nt.model.lm_head.pp_block) + + +def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): + # Load nanotron model. + model_nt = create_nt_model(checkpoint_path=checkpoint_path) + + # Init huggingface model. + model_config_hf = convert_config(model_nt.config) + model_hf = XGLMForCausalLM._from_config(model_config_hf) + + # Copy weights, initialize tokenizer and save model. + if tokenizer_name is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.save_pretrained(save_path) + convert(model_hf, model_nt) + model_hf.save_pretrained(save_path) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument("--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint") + parser.add_argument("--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model") + parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") + args = parser.parse_args() + main(args.checkpoint_path, args.save_path, args.tokenizer_name) + diff --git a/examples/xglm/convert_utils.py b/examples/xglm/convert_utils.py new file mode 100644 index 00000000..88a731a1 --- /dev/null +++ b/examples/xglm/convert_utils.py @@ -0,0 +1,59 @@ +import json +from pathlib import Path +from typing import Optional + +import torch +from torch import nn + +import nanotron +from nanotron.models.gpt3 import GPT3ForTraining +from nanotron.config.models_config import GPT3Config +from nanotron.trainer import mark_tied_parameters + + +def convert_generic(module1: nn.Module, module2: nn.Module): + names1 = {name for name, _ in module1.named_parameters()} + names2 = {name for name, _ in module2.named_parameters()} + assert names1 == names2, f"{names1} != {names2}" + params2 = dict(module2.named_parameters()) + for name, param in module1.named_parameters(): + param.data = params2[name].clone() + + +def create_nt_model( + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None + ): + + if model_config is None: + assert checkpoint_path is not None + with open(checkpoint_path / "model_config.json") as f: + model_config = GPT3Config(**json.load(f)) + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + if checkpoint_path is not None: + nanotron.serialize.load_weights( + model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path + ) + + return model_nt diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index 3636415b..d9dc0f85 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -8,6 +8,7 @@ from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM import nanotron +from nanotron.trainer import mark_tied_parameters from nanotron.config.models_config import GPT3Config from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock from nanotron.parallel import ParallelContext @@ -15,12 +16,17 @@ from tests.helpers.utils import init_distributed from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert +from examples.xglm.convert_nt2hf import convert_attention as convert_attention_nt2hf +from examples.xglm.convert_nt2hf import convert_config as convert_config_nt2hf +from examples.xglm.convert_nt2hf import convert_decoder as convert_decoder_nt2hf +from examples.xglm.convert_nt2hf import convert as convert_nt2hf -SEQUENCE_LENGTH = 2048 +MAX_SEQUENCE_LENGTH = 2048 +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. BATCH_SIZE = 4 HIDDEN_SIZE = 1024 -DTYPE = torch.bfloat16 +DTYPE = torch.float64 TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" CONFIG = GPT3Config( @@ -32,7 +38,7 @@ hidden_size=HIDDEN_SIZE, intermediate_size=4096, layer_norm_epsilon=1e-05, - max_position_embeddings=SEQUENCE_LENGTH, + max_position_embeddings=MAX_SEQUENCE_LENGTH, num_attention_heads=16, num_hidden_layers=24, scale_attn_weights=True, @@ -45,25 +51,39 @@ @pytest.fixture def hidden_states() -> torch.Tensor: - return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) @pytest.fixture def input_mask() -> torch.Tensor: - return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) + return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) @pytest.fixture def input_ids() -> torch.Tensor: - return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH)) + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) + + +def almost_close(t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-5, rtol: float = 0.016, + max_far: float = 0.0, far_atol: float = 0.01): + very_close = torch.abs(t1 - t2) <= atol + rtol*torch.abs(t2) + not_very_close = ~very_close + + if torch.all(very_close): + return + assert torch.mean(not_very_close.float()) <= max_far, f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" + assert torch.all(torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" def attention_mask() -> torch.Tensor: # XGLM causal attention mask. - mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.ones(TEST_SEQUENCE_LENGTH, TEST_SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) return mask +## +# FROM HERE DOWN (until next comment), all tests are hf->nt +## def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() @@ -85,10 +105,10 @@ def test_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): def _test_position_embeddings(parallel_context: ParallelContext): - position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, SEQUENCE_LENGTH) + position_ids = torch.arange(TEST_SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, TEST_SEQUENCE_LENGTH) emb_nt = PositionEmbedding(parallel_context.tp_pg, CONFIG, None).cuda() - emb_hf = XGLMSinusoidalPositionalEmbedding(SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() + emb_hf = XGLMSinusoidalPositionalEmbedding(MAX_SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() assert emb_nt.position_embedding.weight.size() == emb_hf.weights.size() torch.testing.assert_close(emb_nt.position_embedding.weight, emb_hf.weights) @@ -120,7 +140,7 @@ def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt, out_hf) + torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): @@ -129,21 +149,25 @@ def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) input_ids = input_ids.cuda() input_mask = input_mask.cuda() + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + # Get hf model. if model_hf is None: config_hf = XGLMConfig() - model_hf = XGLMForCausalLM(config_hf).cuda().to(DTYPE).eval() + model_hf = XGLMForCausalLM(config_hf).cuda().to(new_dtype).eval() else: - model_hf = model_hf.cuda().to(DTYPE).eval() + model_hf = model_hf.cuda().to(new_dtype).eval() config_hf = model_hf.config # Get nanotron model and make the conversion. config_nt = convert_config(config_hf) - if DTYPE not in {torch.bfloat16, torch.float16}: + if new_dtype not in {torch.bfloat16, torch.float16}: config_nt.use_spda = True model_nt = nanotron.models.build_model( model_builder=lambda: GPT3ForTraining( @@ -153,7 +177,7 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC random_states=random_states, ), parallel_context=parallel_context, - dtype=DTYPE, + dtype=new_dtype, device="cuda", ).eval() convert(model_nt, model_hf) @@ -162,42 +186,141 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC # Get outputs and assert. with torch.no_grad(): - out_nt = model_nt.model(input_ids, input_mask).to(DTYPE) + out_nt = model_nt.model(input_ids, input_mask).to(new_dtype) del model_nt torch.cuda.empty_cache() out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) del model_hf torch.cuda.empty_cache() assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.cpu(), out_hf.cpu()) + return out_nt.cpu(), out_hf.cpu() + def _test_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): - _test_model(None, parallel_context, input_ids, input_mask) + out_nt, out_hf = _test_model(None, parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.05) def test_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): init_distributed(tp=1, dp=1, pp=1)(_test_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) +def _test_xglm500M(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") + out_nt, out_hf = _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.05) + + +def test_xglm500M(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() + + def _test_xglm7B(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") - _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + almost_close(out_nt, out_hf, max_far=0.15, far_atol=0.1) def test_xglm7B(): init_distributed(tp=1, dp=1, pp=1)(_test_xglm7B)() -def _test_xglm500M(parallel_context: ParallelContext): - tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") - tokenized = tok(TEXT) - model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") - _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) +## +# From here down we test nt->hf converters +## +def _test_nt2hf_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() -def test_xglm500M(): - init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() + attn_nt = CausalSelfAttention(CONFIG, None, parallel_context.tp_pg, 0).cuda().eval().to(DTYPE) + attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) + assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) + + convert_attention_nt2hf(attn_hf, attn_nt) + out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_nt2hf_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_attention)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_nt2hf_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + config_hf = convert_config_nt2hf(CONFIG) + decoder_nt = GPTBlock(CONFIG, None, parallel_context.tp_pg, random_states, 0).cuda().to(DTYPE).eval() + decoder_hf = XGLMDecoderLayer(config_hf).cuda().to(DTYPE).eval() + + convert_decoder_nt2hf(decoder_hf, decoder_nt) + + out_nt = decoder_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + + +def test_nt2hf_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + + # Get nanotron model. + config_nt = GPT3Config(**vars(CONFIG)) + if new_dtype not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=new_dtype, + device="cuda", + ).eval() + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + # Create empty model_hf and make conversion. + model_hf = XGLMForCausalLM(convert_config_nt2hf(config_nt)).cuda().to(new_dtype).eval() + convert_nt2hf(model_hf, model_nt) + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask).to(new_dtype) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + return out_nt.cpu(), out_hf.cpu() + + +def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.01, far_atol=0.02) + + +def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 20a92126..5f59a439 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -176,4 +176,8 @@ def as_starcoder2(self) -> Starcoder2Config: **config ) + @property + def n_inner(self): + return self.intermediate_size + NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 33661c8b..7d4e6f82 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -338,4 +338,4 @@ def __init__( ): with replace_gpt3model(config): super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) - + self.config = config From 1b19ca28aa74531683870f9edf5226e02e884ac1 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 11 Jul 2024 14:32:44 +0200 Subject: [PATCH 30/51] precommit --- examples/xglm/convert_hf2nt.py | 33 ++++---- examples/xglm/convert_nt2hf.py | 28 ++++--- examples/xglm/convert_utils.py | 21 +++-- examples/xglm/tests/test_implementation.py | 89 ++++++++++++++-------- src/nanotron/config/models_config.py | 8 +- src/nanotron/models/gpt3.py | 85 ++++++++++++--------- src/nanotron/trainer.py | 2 +- 7 files changed, 154 insertions(+), 112 deletions(-) diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 0efcceca..c18a1ab8 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -4,20 +4,18 @@ torchrun --nproc-per-node=1 convert_hf2nt.py --checkpoint-path=hf_weights --save-path=nanotron_weights """ +import dataclasses import json import warnings -import dataclasses from argparse import ArgumentParser from pathlib import Path +import nanotron import torch -from torch import nn +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import MLP, CausalSelfAttention, GPT3ForTraining, GPTBlock from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM -import nanotron -from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining -from nanotron.config.models_config import GPT3Config -from nanotron.trainer import mark_tied_parameters from examples.xglm.convert_utils import convert_generic, create_nt_model @@ -29,10 +27,12 @@ def convert_config(config: XGLMConfig) -> GPT3Config: # pad_token_id=1, # bos_token_id=0, if config.dropout != config.attention_dropout: - warnings.warn(f"huggingface.dropout = {config.dropout} does not match with " - f"huggingface.attention_dropout = {config.attention_dropout}. " - "Nanotron implementation needs these two values to be equal " - "for correct conversion.") + warnings.warn( + f"huggingface.dropout = {config.dropout} does not match with " + f"huggingface.attention_dropout = {config.attention_dropout}. " + "Nanotron implementation needs these two values to be equal " + "for correct conversion." + ) return GPT3Config( activation_function=config.activation_function, attn_pdrop=config.attention_dropout, @@ -113,16 +113,19 @@ def main(hf_path: str, save_path: Path): # Copy weights and save model. print("Copying weights...") convert(model_nt, model_hf) - nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, - root_folder=save_path) - with open(save_path/"model_config.json", "w+") as f: + nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, root_folder=save_path) + with open(save_path / "model_config.json", "w+") as f: json.dump(dataclasses.asdict(config_nt), f) print(f"Model saved to {save_path}") if __name__ == "__main__": parser = ArgumentParser(description="Convert HF weights to nanotron format") - parser.add_argument("--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint") - parser.add_argument("--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model") + parser.add_argument( + "--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model" + ) args = parser.parse_args() main(args.checkpoint_path, args.save_path) diff --git a/examples/xglm/convert_nt2hf.py b/examples/xglm/convert_nt2hf.py index 422695a1..81816aa9 100644 --- a/examples/xglm/convert_nt2hf.py +++ b/examples/xglm/convert_nt2hf.py @@ -4,25 +4,28 @@ torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights """ +import warnings from argparse import ArgumentParser -from typing import Optional from pathlib import Path +from typing import Optional import torch +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import GPT3ForTraining, GPTBlock from transformers import AutoTokenizer from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM -from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from examples.xglm.convert_utils import convert_generic, create_nt_model def convert_config(config: GPT3Config) -> XGLMConfig: if config.embd_pdrop != config.resid_pdrop: - warnings.warn(f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " - f"nanotron.resid_pdrop = {config.resid_pdrop}. " - "XGLM implementation needs these two values to be equal " - "for correct conversion.") + warnings.warn( + f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion." + ) if config.layer_norm_epsilon != 1e-5: warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") return XGLMConfig( @@ -70,7 +73,7 @@ def convert_attention(attn_hf: XGLMAttention, attn_nt: XGLMAttention): q_b = torch.cat(qs_b) k_b = torch.cat(ks_b) v_b = torch.cat(vs_b) - + with torch.no_grad(): attn_hf.q_proj.weight.data = q_w.clone() attn_hf.k_proj.weight.data = k_w.clone() @@ -118,9 +121,12 @@ def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): if __name__ == "__main__": parser = ArgumentParser(description="Convert HF weights to nanotron format") - parser.add_argument("--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint") - parser.add_argument("--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model") + parser.add_argument( + "--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model" + ) parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") args = parser.parse_args() main(args.checkpoint_path, args.save_path, args.tokenizer_name) - diff --git a/examples/xglm/convert_utils.py b/examples/xglm/convert_utils.py index 88a731a1..75d67782 100644 --- a/examples/xglm/convert_utils.py +++ b/examples/xglm/convert_utils.py @@ -2,13 +2,12 @@ from pathlib import Path from typing import Optional -import torch -from torch import nn - import nanotron -from nanotron.models.gpt3 import GPT3ForTraining +import torch from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.trainer import mark_tied_parameters +from torch import nn def convert_generic(module1: nn.Module, module2: nn.Module): @@ -21,11 +20,11 @@ def convert_generic(module1: nn.Module, module2: nn.Module): def create_nt_model( - model_config: Optional[GPT3Config] = None, - device: torch.device = torch.device("cuda"), - dtype: torch.dtype = torch.bfloat16, - checkpoint_path: Optional[Path] = None - ): + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None, +): if model_config is None: assert checkpoint_path is not None @@ -52,8 +51,6 @@ def create_nt_model( mark_tied_parameters(model=model_nt, parallel_context=parallel_context) if checkpoint_path is not None: - nanotron.serialize.load_weights( - model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path - ) + nanotron.serialize.load_weights(model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path) return model_nt diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index d9dc0f85..a25d7881 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -1,29 +1,31 @@ from typing import Optional +import nanotron import numpy as np -import torch import pytest - -from transformers import XGLMTokenizer -from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM - -import nanotron -from nanotron.trainer import mark_tied_parameters +import torch from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock +from nanotron.models.gpt3 import CausalSelfAttention, GPT3ForTraining, GPTBlock, PositionEmbedding from nanotron.parallel import ParallelContext +from nanotron.trainer import mark_tied_parameters +from transformers import XGLMTokenizer +from transformers.models.xglm.modeling_xglm import ( + XGLMAttention, + XGLMConfig, + XGLMDecoderLayer, + XGLMForCausalLM, + XGLMSinusoidalPositionalEmbedding, +) -from tests.helpers.utils import init_distributed - -from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert +from examples.xglm.convert_hf2nt import convert, convert_attention, convert_config, convert_decoder +from examples.xglm.convert_nt2hf import convert as convert_nt2hf from examples.xglm.convert_nt2hf import convert_attention as convert_attention_nt2hf from examples.xglm.convert_nt2hf import convert_config as convert_config_nt2hf from examples.xglm.convert_nt2hf import convert_decoder as convert_decoder_nt2hf -from examples.xglm.convert_nt2hf import convert as convert_nt2hf - +from tests.helpers.utils import init_distributed MAX_SEQUENCE_LENGTH = 2048 -TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. BATCH_SIZE = 4 HIDDEN_SIZE = 1024 DTYPE = torch.float64 @@ -45,33 +47,44 @@ vocab_size=256008, sinusoidal_position_embedding=True, position_embedding_offset=2, - use_spda=True + use_spda=True, ) @pytest.fixture def hidden_states() -> torch.Tensor: - return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, - dtype=DTYPE) + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) + @pytest.fixture def input_mask() -> torch.Tensor: return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) + @pytest.fixture def input_ids() -> torch.Tensor: return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) -def almost_close(t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-5, rtol: float = 0.016, - max_far: float = 0.0, far_atol: float = 0.01): - very_close = torch.abs(t1 - t2) <= atol + rtol*torch.abs(t2) +def almost_close( + t1: torch.Tensor, + t2: torch.Tensor, + atol: float = 1e-5, + rtol: float = 0.016, + max_far: float = 0.0, + far_atol: float = 0.01, +): + very_close = torch.abs(t1 - t2) <= atol + rtol * torch.abs(t2) not_very_close = ~very_close if torch.all(very_close): return - assert torch.mean(not_very_close.float()) <= max_far, f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" - assert torch.all(torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" + assert ( + torch.mean(not_very_close.float()) <= max_far + ), f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" + assert torch.all( + torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol + ), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" def attention_mask() -> torch.Tensor: @@ -81,10 +94,12 @@ def attention_mask() -> torch.Tensor: mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) return mask + ## # FROM HERE DOWN (until next comment), all tests are hf->nt ## + def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() sequence_mask = sequence_mask.cuda() @@ -118,6 +133,7 @@ def _test_position_embeddings(parallel_context: ParallelContext): assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" torch.testing.assert_close(out_nt, out_hf) + def test_position_embeddings(): init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() @@ -140,15 +156,21 @@ def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + torch.testing.assert_close( + out_nt.bfloat16(), out_hf.bfloat16() + ) # We cast to bf16 to get more relaxed constraints. def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): init_distributed(tp=1, dp=1, pp=1)(_test_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) -def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, - input_ids: torch.Tensor, input_mask: torch.Tensor): +def _test_model( + model_hf: Optional[XGLMForCausalLM], + parallel_context: ParallelContext, + input_ids: torch.Tensor, + input_mask: torch.Tensor, +): random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) input_ids = input_ids.cuda() @@ -182,7 +204,7 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC ).eval() convert(model_nt, model_hf) - print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters()))/1000/1000) + print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters())) / 1000 / 1000) # Get outputs and assert. with torch.no_grad(): @@ -209,8 +231,9 @@ def _test_xglm500M(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") - out_nt, out_hf = _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model( + model_hf, parallel_context, torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]]) + ) almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.05) @@ -222,8 +245,9 @@ def _test_xglm7B(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") - out_nt, out_hf = _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model( + model_hf, parallel_context, torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]]) + ) almost_close(out_nt, out_hf, max_far=0.15, far_atol=0.1) @@ -235,6 +259,7 @@ def test_xglm7B(): # From here down we test nt->hf converters ## + def _test_nt2hf_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() sequence_mask = sequence_mask.cuda() @@ -269,7 +294,9 @@ def _test_nt2hf_decoder(parallel_context: ParallelContext, hidden_states: torch. out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + torch.testing.assert_close( + out_nt.bfloat16(), out_hf.bfloat16() + ) # We cast to bf16 to get more relaxed constraints. def test_nt2hf_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 5f59a439..af7db5cc 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, List, Optional, Union +from typing import List, Optional @dataclass @@ -170,14 +170,12 @@ def as_starcoder2(self) -> Starcoder2Config: if "_is_using_mup" in config: del config["_is_using_mup"] return Starcoder2Config( - grouped_query=True, - num_kv_heads=self.num_attention_heads, - use_rotary_embeddings=False, - **config + grouped_query=True, num_kv_heads=self.num_attention_heads, use_rotary_embeddings=False, **config ) @property def n_inner(self): return self.intermediate_size + NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 7d4e6f82..25e5f78b 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -1,37 +1,40 @@ """PyTorch GPT-3 model.""" import math -from typing import Optional from contextlib import contextmanager +from typing import Optional import torch from torch import nn from torch.nn import functional as F from nanotron import distributed as dist -from nanotron.parallel import ParallelContext -from nanotron.config import Config, GPT3Config, ParallelismArgs, Starcoder2Config +from nanotron.config import GPT3Config, ParallelismArgs, Starcoder2Config from nanotron.generation.generate_store import AttachableStore from nanotron.models import starcoder2 -from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.models.starcoder2 import MLP as Starcoder2MLP -from nanotron.parallel.pipeline_parallel.block import PipelineBlock +from nanotron.models.starcoder2 import CausalSelfGQA, GPTModel, Starcoder2ForTraining, dropout_add_fused_train from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock -from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel, dropout_add_fused_train -from nanotron.random import RandomStates, branch_random_state +from nanotron.nn.layer_norm import TritonLayerNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding -from nanotron.parallel.tied_parameters import tie_parameters +from nanotron.random import RandomStates, branch_random_state @contextmanager def replace_coreattention(gpt3config: GPT3Config): orig = starcoder2.CoreAttention try: - def create_core_attention(config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + + def create_core_attention( + config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int + ): return CoreAttention(gpt3config, parallel_config, layer_idx) + starcoder2.CoreAttention = create_core_attention yield finally: @@ -42,6 +45,7 @@ def create_core_attention(config: Starcoder2Config, parallel_config: Optional[Pa def replace_decoder(gpt3config: GPT3Config): orig = starcoder2.PipelineBlock try: + def create_pp_block(module_builder, module_kwargs, **kwargs): if module_builder is Starcoder2GPTBlock: # Starcoder2's GPT module is trying to instantiate a Starcoder2 GPTBlock. @@ -62,9 +66,15 @@ def create_pp_block(module_builder, module_kwargs, **kwargs): def replace_gpt3model(gpt3config: GPT3Config): orig = starcoder2.GPTModel try: - def create_gptmodel(config: Starcoder2Config, parallel_context: ParallelContext, - parallel_config: Optional[ParallelismArgs], random_states: RandomStates): + + def create_gptmodel( + config: Starcoder2Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): return GPT3Model(gpt3config, parallel_context, parallel_config, random_states) + starcoder2.GPTModel = create_gptmodel yield finally: @@ -76,7 +86,8 @@ def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs super().__init__(config.as_starcoder2(), parallel_config, layer_idx) self.gpt3config = config - def forward(self, + def forward( + self, query_states: torch.Tensor, # [batch_size * q_length, q_heads, inner_dim] key_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] value_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] @@ -101,7 +112,7 @@ def forward(self, is_causal=True, ) # [batch, q_length, q_heads, head_dim] attention_output = attention_output.permute(0, 2, 1, 3) - attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) + attention_output = attention_output.reshape(batch_size * q_length, q_heads, head_dim) return attention_output.contiguous() assert query_states.dtype in {torch.bfloat16, torch.float16} @@ -127,7 +138,7 @@ def __init__( config: GPT3Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, - random_states: RandomStates + random_states: RandomStates, ): super().__init__(config.as_starcoder2(), parallel_config, tp_pg) self.dropout = nn.Dropout(p=config.act_pdrop) @@ -154,14 +165,11 @@ def __init__( random_states: RandomStates, layer_idx: int, ): - #print("New gpt block created :D") + # print("New gpt block created :D") super(GPTBlock, self).__init__() self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - layer_idx=layer_idx + config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx ) self.attn_dropout = config.attn_pdrop @@ -180,10 +188,10 @@ def forward( residual = hidden_states hidden_states = self.ln_1(hidden_states) - #hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) + # hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] - #return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} + # return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} if self.training: with branch_random_state( @@ -221,7 +229,9 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: dummy_pos = 0 else: - dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size()) + dummy_pos = tp_pg.size() - ( + (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() + ) true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos if config.sinusoidal_position_embedding: @@ -234,7 +244,7 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config embedding_dim=config.hidden_size, pg=tp_pg, mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, - _weight=weight + _weight=weight, ) self.pg = tp_pg @@ -251,32 +261,31 @@ def forward(self, position_ids: torch.Tensor): # [batch_size, seq_length] position_embeds = self.position_embedding(position_ids + self.config.position_embedding_offset) return {"position_embeds": position_embeds} - def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, - embedding_dim: int) -> torch.Tensor: + def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, embedding_dim: int) -> torch.Tensor: rank = dist.get_rank(group=tp_pg) tp_size = tp_pg.size() assert 0 <= rank < tp_size assert num_embeddings % tp_size == 0 assert embedding_dim % 2 == 0 - block_size = num_embeddings//tp_size + block_size = num_embeddings // tp_size - half_dim = embedding_dim//2 - emb = math.log(10_000)/(half_dim - 1) + half_dim = embedding_dim // 2 + emb = math.log(10_000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) - emb = (rank*block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = (rank * block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) return emb class GPT3Model(GPTModel): def __init__( - self, - config: GPT3Config, - parallel_context: ParallelContext, - parallel_config: Optional[ParallelismArgs], - random_states: RandomStates, - ): + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): with replace_decoder(config): super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) @@ -300,7 +309,9 @@ def forward( ): # all tensors are optional as most ranks don't need anything from the dataloader. - input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale + input_embeds = ( + self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"] * self.embed_scale + ) # TODO: position_ids could be cached. position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] @@ -314,7 +325,7 @@ def forward( hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) - #return hidden_encoder_states["hidden_states"] + # return hidden_encoder_states["hidden_states"] hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index f01caa3e..bc81e326 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -56,9 +56,9 @@ ) from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining -from nanotron.models.gpt3 import GPT3ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp From c1fabacfb95a54424ed551945998f1e777d4afb0 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 11:45:28 +0000 Subject: [PATCH 31/51] Added MultilingualNanoset Config --- src/nanotron/config/config.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 05b49955..bfd20227 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -107,6 +107,27 @@ def __post_init__(self): self.dataset_weights = list(tmp_dataset_folder.values()) +@dataclass +class MultilingualNanosetDatasetsArgs: + dataset_folder: Union[str, dict, List[str]] + dataset_tokens: List[ + int + ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) + + def __post_init__(self): + if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file + self.dataset_folder = [self.dataset_folder] + self.dataset_weights = [1] + elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + self.dataset_weights = None # Set to None so we consume all the samples randomly + elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights + tmp_dataset_folder = self.dataset_folder.copy() + self.dataset_folder = list(tmp_dataset_folder.keys()) + self.dataset_weights = list(tmp_dataset_folder.values()) + + assert len(self.dataset_folder) == len(self.dataset_tokens) + + @dataclass class DataArgs: """Arguments related to the data and data files processing""" From 086b50dfc720988eaa80089bb8ffc9385a0d34be Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 11:48:51 +0000 Subject: [PATCH 32/51] Added MultilingualNanoset --- run_train.py | 125 +++++++++++- src/nanotron/data/multilingual_nanoset.py | 221 ++++++++++++++++++++++ 2 files changed, 343 insertions(+), 3 deletions(-) create mode 100644 src/nanotron/data/multilingual_nanoset.py diff --git a/run_train.py b/run_train.py index 021d955d..649784ca 100644 --- a/run_train.py +++ b/run_train.py @@ -12,7 +12,13 @@ import numpy as np from nanotron import logging -from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs +from nanotron.config import ( + DataArgs, + DatasetStageArgs, + MultilingualNanosetDatasetsArgs, + NanosetDatasetsArgs, + PretrainDatasetsArgs, +) from nanotron.data.dataloader_builder import build_nanoset_dataloader from nanotron.dataloader import ( clm_process, @@ -171,6 +177,40 @@ def get_dataloader_from_data_stage( dataloader_drop_last=True, ) + return train_dataloader + # Case 4: MultilingualNanosets + elif isinstance(data.dataset, MultilingualNanosetDatasetsArgs): + # Get tokenizer cardinality + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + # Create Nanoset + from nanotron.data.multilingual_nanoset import MultilingualNanoset + + with main_rank_first(trainer.parallel_context.world_pg): + train_dataset = MultilingualNanoset( + dataset_folders=data.dataset.dataset_folder, + dataset_weights=data.dataset.dataset_weights, + sequence_length=trainer.sequence_length, + token_size=token_size, + train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, + random_seed=data.seed, + ) + + # Prepare dataloader + train_dataloader = build_nanoset_dataloader( + train_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=consumed_train_samples, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + return train_dataloader else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") @@ -178,6 +218,57 @@ def get_dataloader_from_data_stage( return dataloader +def get_valid_dataloader_from_data_stage( + trainer: DistributedTrainer, + data: DataArgs, + valid_split_num_samples: int, + # consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples +): + + # First, we need to know which ranks to feed the dataloader to + input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) + + # Only support Validation with MultilingualNanosets + if isinstance(data.dataset, NanosetDatasetsArgs): + # Get tokenizer cardinality + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + # Create Multilingual Nanoset + from nanotron.data.multilingual_nanoset import MultilingualNanoset + + with main_rank_first(trainer.parallel_context.world_pg): + valid_dataset = MultilingualNanoset( + dataset_folders=data.dataset.dataset_folder, + dataset_weights=data.dataset.dataset_weights, + sequence_length=trainer.sequence_length, + token_size=token_size, + train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + valid_split_num_samples=valid_split_num_samples, + is_valid=True, + random_seed=data.seed, + ) + + # Prepare dataloader + valid_dataloader = build_nanoset_dataloader( + valid_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=0, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + + return valid_dataloader + else: + raise ValueError( + f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}. Validation is currently just supported for MultilingualNanoset" + ) + + def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloaders = {} @@ -219,6 +310,33 @@ def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: return dataloaders +def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: + dataloaders = {} + + for stage_idx, stage in enumerate(trainer.config.data_stages): + # NOTE: we only create the dataloader for the first stage, + # then we lazy initialize the dataloader for the other stages + stage = cast(DatasetStageArgs, stage) + valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size + + log_rank( + f"[Training Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = ( + get_valid_dataloader_from_data_stage(trainer, stage.data, valid_split_num_samples=valid_split_num_samples) + if stage_idx == 0 + else lambda stage=stage: get_dataloader_from_data_stage( + trainer, stage.data, valid_split_num_samples=valid_split_num_samples + ) + ) + dataloaders[stage.name] = dataloader + return dataloaders + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") @@ -231,7 +349,8 @@ def get_args(): # Load trainer and data trainer = DistributedTrainer(config_file) - dataloader = get_dataloader(trainer) + train_dataloader = get_dataloader(trainer) + valid_dataloader = get_valid_dataloader(trainer) # Train - trainer.train(dataloader) + trainer.train(train_dataloader, valid_dataloader) diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py new file mode 100644 index 00000000..40e06b87 --- /dev/null +++ b/src/nanotron/data/multilingual_nanoset.py @@ -0,0 +1,221 @@ +import os +import warnings +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +from datatrove.utils.dataset import DatatroveFolderDataset +from nanotron import logging +from nanotron.data.utils import count_dataset_indexes, normalize +from nanotron.logging import log_rank +from numba import jit + +logger = logging.get_logger(__name__) + + +class MultilingualNanoset(torch.utils.data.Dataset): + """ + The Nanoset dataset + + Args: + dataset_folders (List[str]): List of folders with tokenized datasets + dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ + sequence_length (int): Sequence length of the built samples + token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise + train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size + """ + + def __init__( + self, + dataset_folders: List[str], + sequence_length: int, + token_size: int, + train_split_num_samples: int, + valid_split_num_samples: int, + is_valid: bool = False, + dataset_weights: Union[List[float], None] = None, + random_seed: int = 1234, + ) -> None: + + # Checks + if isinstance(dataset_folders, str): + warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]") + dataset_folders = [dataset_folders] + + # Init + self.dataset_folders = dataset_folders + self.sequence_length = sequence_length + self.token_size = token_size + self.train_split_num_samples = train_split_num_samples + self.valid_split_num_samples = valid_split_num_samples + self.is_valid = is_valid + self.random_seed = random_seed + self.datatrove_datasets = [] + for dataset_folder in self.dataset_folders: + self.datatrove_datasets.append( + DatatroveFolderDataset( + folder_path=dataset_folder, + filename_pattern=os.path.join(dataset_folder, "*.ds"), + seq_len=sequence_length, + recursive=False, + token_size=token_size, + shuffle=True, + ) + ) + + # Build Nanoset Index + ## To build the index we need the length of each dataset + self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets] + ## Set dataset weights + if ( + dataset_weights is None + ): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch + self.dataset_weights = normalize(self.dataset_lengths) + else: + self.dataset_weights = normalize(dataset_weights) + assert len(dataset_folders) == len( + self.dataset_weights + ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." + ## Build dataset index and dataset sample index + ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts + self.valid_dataset_lenghts = self.dataset_weights * valid_split_num_samples + # Assert that we have sufficient samples to build the valid split + for ds_index in range(len(self.dataset_lengths)): + assert ( + self.valid_dataset_lenghts[ds_index] > self.dataset_lengths[ds_index] + ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." + self.train_dataset_lenghts = [ + a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) + ] # Subtract the valid samples from the training dataset + + if is_valid: # Valid MultilingualNanoset + self.split_num_samples = valid_split_num_samples + self.split_samples_per_epoch = valid_split_num_samples + self.num_epochs = 1 + self.split_dataset_lenghts = self.valid_dataset_lenghts + self.split_dataset_offsets = self.train_dataset_lenghts + + else: # Train MultilingualNanoset + self.split_num_samples = train_split_num_samples + self.split_samples_per_epoch = sum(self.train_dataset_lenghts) + self.num_epochs = int(self.split_num_samples / self.split_samples_per_epoch) + 1 + self.split_dataset_lenghts = self.train_dataset_lenghts + self.split_dataset_offsets = [ + 0 for _ in range(len(self.dataset_lengths)) + ] # For training there is NO offset + + self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() + + self.print_nanoset_info() + + def __len__(self) -> int: + """ + Returns: + int: The number of samples of the Nanoset + """ + + return len(self.dataset_index) + + def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: + """ + Returns sequence_length + 1 tokens from the memmap dataset + + Args: + idx (int): The index into the dataset + + Returns: + Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary + """ + dataset = self.dataset_index[idx] + dataset_sample = self.dataset_sample_index[idx] + + return self.datatrove_datasets[dataset][dataset_sample] + + def build_nanoset_index(self) -> np.ndarray: + """ + Build dataset index and dataset sample index + """ + # Build the dataset indexes for 1 epoch + dataset_index, dataset_sample_index = build_nanoset_index_helper( + n_samples=self.split_samples_per_epoch, + weights=self.dataset_weights, + dataset_sizes=self.split_dataset_lengths, + offsets=self.split_dataset_offsets, + ) + # Shuffle the indexes the same way + numpy_random_state = np.random.RandomState(self.random_seed) + numpy_random_state.shuffle(dataset_index) + numpy_random_state = np.random.RandomState(self.random_seed) + numpy_random_state.shuffle(dataset_sample_index) + # Concatenate num_epochs the shuffled indexes + dataset_index = np.concatenate([dataset_index for _ in range(self.num_epochs)]) + dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(self.num_epochs)]) + # Just keep the necessary samples + dataset_index = dataset_index[: self.split_num_samples] + dataset_sample_index = dataset_sample_index[: self.split_num_samples] + + return dataset_index, dataset_sample_index + + def print_nanoset_info(self): + + log_rank( + f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of samples: {len(self)}", + logger=logger, + level=logging.INFO, + rank=0, + ) + log_rank( + f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of tokens: {len(self) * self.sequence_length}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + # Print samples from each dataset + weight + dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders)) + for index, sample_count in enumerate(dataset_sample_count): + log_rank( + f"> Total number of {'validation' if self.is_valid else 'training'} samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +@jit(nopython=True, cache=True) +def build_nanoset_index_helper( + n_samples: int, weights: np.ndarray, dataset_sizes: List[int], offsets: List[int] +) -> Tuple[np.ndarray, np.ndarray]: + """ + Given multiple datasets and a weighting array, build samples indexes + such that it follows those weights. + For train and valid splits we split each dataset_folder in train (first part) and valid splits. We set the offsets to the train lengths + for generating the valid split + """ + # Create empty arrays for dataset indices and dataset sample indices + dataset_index = np.empty((n_samples,), dtype="uint") + dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples + + # Initialize buffer for number of samples used for each dataset + current_samples = np.zeros((len(weights),), dtype="long") + + # Iterate over all samples + for sample_idx in range(n_samples): + + # Convert sample index to float for comparison against weights + sample_idx_float = max(sample_idx, 1.0) + + # Find the dataset with the highest error + errors = weights * sample_idx_float - current_samples + max_error_index = np.argmax(errors) + + # Assign the dataset index and update the sample index + dataset_index[sample_idx] = max_error_index + dataset_sample_index[sample_idx] = ( + current_samples[max_error_index] % dataset_sizes[max_error_index] + ) + offsets[max_error_index] + + # Update the total samples for the selected dataset + current_samples[max_error_index] += 1 + + return dataset_index, dataset_sample_index From 1fe74457503da8fff820ccde7a6c4172d6026e56 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 12:25:17 +0000 Subject: [PATCH 33/51] Added Language token --- examples/config_multilingual_nanoset.yaml | 120 ++++++++++++++++++++++ src/nanotron/data/multilingual_nanoset.py | 7 +- 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 examples/config_multilingual_nanoset.yaml diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml new file mode 100644 index 00000000..00ae6570 --- /dev/null +++ b/examples/config_multilingual_nanoset.yaml @@ -0,0 +1,120 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints/ + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_folder: datasets/c4-es/tokenized + dataset_tokens: + - 15 + num_loading_workers: 1 + seed: 42 + name: General purpose training (Single dataset) + start_training_step: 1 +- data: + dataset: + dataset_folder: + - datasets/SlimPajama-6B/tokenized + - datasets/c4-es/tokenized + dataset_tokens: + - 16 + - 15 + num_loading_workers: 1 + seed: 42 + name: Second purpose training (> 1 dataset) + start_training_step: 15 +- data: + dataset: + dataset_folder: + datasets/SlimPajama-6B/tokenized: 0.8 + datasets/c4-es/tokenized: 0.2 + dataset_tokens: + - 16 + - 15 + num_loading_workers: 1 + seed: 42 + name: Third purpose training (Blended dataset) + start_training_step: 25 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: Nanoset + run: llama + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 11008 + is_llama_config: true + max_position_embeddings: 4096 + num_hidden_layers: 32 + num_attention_heads: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rope_interleaved: false + rope_theta: 500000.0 + rms_norm_eps: 1.0e-06 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 98 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 10 + micro_batch_size: 2 + sequence_length: 1024 + train_steps: 200 + val_check_interval: -1 diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 40e06b87..6526659d 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -32,6 +32,7 @@ def __init__( token_size: int, train_split_num_samples: int, valid_split_num_samples: int, + dataset_tokens: List[int], is_valid: bool = False, dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, @@ -48,6 +49,7 @@ def __init__( self.token_size = token_size self.train_split_num_samples = train_split_num_samples self.valid_split_num_samples = valid_split_num_samples + self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed self.datatrove_datasets = [] @@ -129,7 +131,10 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset = self.dataset_index[idx] dataset_sample = self.dataset_sample_index[idx] - return self.datatrove_datasets[dataset][dataset_sample] + tokens = self.datatrove_datasets[dataset][dataset_sample] + tokens[0] = self.dataset_tokens[dataset] # Prepend language token + + return tokens def build_nanoset_index(self) -> np.ndarray: """ From fb6631a5bc30bbc7a3bc087df95e279f25c80153 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 12:51:42 +0000 Subject: [PATCH 34/51] Forgot the trainer ups --- src/nanotron/trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index bc81e326..3f4c5189 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -393,7 +393,10 @@ def find_stage_idx_to_resume(): def train( self, - dataloader_or_dls: Dict[ + train_dataloader_or_dls: Dict[ + str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] + ], + valid_dataloader_or_dls: Dict[ str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] ], **kwargs, @@ -424,7 +427,7 @@ def train( prof.step() self.iteration_start_time = time.time() - self._update_dataloader_based_on_training_stages(dataloader_or_dls) + self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) From a6eb1bd550ae1596b98996d3cf6553071d98f18b Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 14:12:57 +0000 Subject: [PATCH 35/51] Fix minor errors. Everything works --- run_train.py | 6 ++++-- src/nanotron/config/config.py | 2 +- src/nanotron/data/multilingual_nanoset.py | 11 +++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/run_train.py b/run_train.py index 649784ca..9b77da77 100644 --- a/run_train.py +++ b/run_train.py @@ -195,6 +195,7 @@ def get_dataloader_from_data_stage( token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, + dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -229,7 +230,7 @@ def get_valid_dataloader_from_data_stage( input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) # Only support Validation with MultilingualNanosets - if isinstance(data.dataset, NanosetDatasetsArgs): + if isinstance(data.dataset, MultilingualNanosetDatasetsArgs): # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 @@ -245,6 +246,7 @@ def get_valid_dataloader_from_data_stage( token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, valid_split_num_samples=valid_split_num_samples, + dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -320,7 +322,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size log_rank( - f"[Training Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index bfd20227..924a2cdf 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -132,7 +132,7 @@ def __post_init__(self): class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] + dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, MultilingualNanosetDatasetsArgs] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 6526659d..cd8be195 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -1,5 +1,6 @@ import os import warnings +from math import ceil from typing import Dict, List, Tuple, Union import numpy as np @@ -80,11 +81,13 @@ def __init__( ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts - self.valid_dataset_lenghts = self.dataset_weights * valid_split_num_samples + self.valid_dataset_lenghts = [ + ceil(weight * valid_split_num_samples) for weight in self.dataset_weights + ] # Better not tu use numpy so we don't get overflow issues # Assert that we have sufficient samples to build the valid split for ds_index in range(len(self.dataset_lengths)): assert ( - self.valid_dataset_lenghts[ds_index] > self.dataset_lengths[ds_index] + self.dataset_lengths[ds_index] > self.valid_dataset_lenghts[ds_index] ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." self.train_dataset_lenghts = [ a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) @@ -132,7 +135,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset_sample = self.dataset_sample_index[idx] tokens = self.datatrove_datasets[dataset][dataset_sample] - tokens[0] = self.dataset_tokens[dataset] # Prepend language token + tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token return tokens @@ -144,7 +147,7 @@ def build_nanoset_index(self) -> np.ndarray: dataset_index, dataset_sample_index = build_nanoset_index_helper( n_samples=self.split_samples_per_epoch, weights=self.dataset_weights, - dataset_sizes=self.split_dataset_lengths, + dataset_sizes=self.split_dataset_lenghts, offsets=self.split_dataset_offsets, ) # Shuffle the indexes the same way From ef3fac4627cb04d1611be03b4050add5975d2b80 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 14:13:33 +0000 Subject: [PATCH 36/51] Updated config file with GPT2 tokenized datasets in RCP --- examples/config_multilingual_nanoset.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 00ae6570..3c4476a0 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -7,7 +7,7 @@ checkpoints: data_stages: - data: dataset: - dataset_folder: datasets/c4-es/tokenized + dataset_folder: /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized dataset_tokens: - 15 num_loading_workers: 1 @@ -17,8 +17,8 @@ data_stages: - data: dataset: dataset_folder: - - datasets/SlimPajama-6B/tokenized - - datasets/c4-es/tokenized + - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized + - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized dataset_tokens: - 16 - 15 @@ -29,8 +29,8 @@ data_stages: - data: dataset: dataset_folder: - datasets/SlimPajama-6B/tokenized: 0.8 - datasets/c4-es/tokenized: 0.2 + /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized: 0.8 + /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized: 0.2 dataset_tokens: - 16 - 15 @@ -65,7 +65,7 @@ model: initializer_range: 0.02 intermediate_size: 11008 is_llama_config: true - max_position_embeddings: 4096 + max_position_embeddings: 1024 num_hidden_layers: 32 num_attention_heads: 32 num_key_value_heads: 8 @@ -108,7 +108,7 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_name_or_path: gpt2 tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 From 49294f135cab0124608aae5d41f91e29ac427512 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 10:13:57 +0000 Subject: [PATCH 37/51] Before lunch --- run_train.py | 13 +--- src/nanotron/config/config.py | 6 +- src/nanotron/data/multilingual_nanoset.py | 76 +++++++++-------------- 3 files changed, 37 insertions(+), 58 deletions(-) diff --git a/run_train.py b/run_train.py index 9b77da77..57e0ec25 100644 --- a/run_train.py +++ b/run_train.py @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -222,7 +221,6 @@ def get_dataloader_from_data_stage( def get_valid_dataloader_from_data_stage( trainer: DistributedTrainer, data: DataArgs, - valid_split_num_samples: int, # consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples ): @@ -245,7 +243,6 @@ def get_valid_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - valid_split_num_samples=valid_split_num_samples, dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, @@ -259,7 +256,6 @@ def get_valid_dataloader_from_data_stage( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, micro_batch_size=trainer.micro_batch_size, - consumed_train_samples=0, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, ) @@ -319,21 +315,18 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: # NOTE: we only create the dataloader for the first stage, # then we lazy initialize the dataloader for the other stages stage = cast(DatasetStageArgs, stage) - valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size log_rank( - f"[Validation Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", logger=logger, level=logging.INFO, rank=0, ) dataloader = ( - get_valid_dataloader_from_data_stage(trainer, stage.data, valid_split_num_samples=valid_split_num_samples) + get_valid_dataloader_from_data_stage(trainer, stage.data) if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage( - trainer, stage.data, valid_split_num_samples=valid_split_num_samples - ) + else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) ) dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 924a2cdf..fb3e49dd 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -109,7 +109,8 @@ def __post_init__(self): @dataclass class MultilingualNanosetDatasetsArgs: - dataset_folder: Union[str, dict, List[str]] + training_folder: Union[str, dict, List[str]] + validation_folder: Union[str, dict, List[str]] dataset_tokens: List[ int ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) @@ -125,7 +126,8 @@ def __post_init__(self): self.dataset_folder = list(tmp_dataset_folder.keys()) self.dataset_weights = list(tmp_dataset_folder.values()) - assert len(self.dataset_folder) == len(self.dataset_tokens) + assert len(self.training_folder) == len(self.validation_folder) + assert len(self.training_folder) == len(self.dataset_tokens) @dataclass diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index cd8be195..f634fd98 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -1,6 +1,5 @@ import os import warnings -from math import ceil from typing import Dict, List, Tuple, Union import numpy as np @@ -32,7 +31,6 @@ def __init__( sequence_length: int, token_size: int, train_split_num_samples: int, - valid_split_num_samples: int, dataset_tokens: List[int], is_valid: bool = False, dataset_weights: Union[List[float], None] = None, @@ -49,7 +47,6 @@ def __init__( self.sequence_length = sequence_length self.token_size = token_size self.train_split_num_samples = train_split_num_samples - self.valid_split_num_samples = valid_split_num_samples self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed @@ -80,36 +77,11 @@ def __init__( self.dataset_weights ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index - ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts - self.valid_dataset_lenghts = [ - ceil(weight * valid_split_num_samples) for weight in self.dataset_weights - ] # Better not tu use numpy so we don't get overflow issues - # Assert that we have sufficient samples to build the valid split - for ds_index in range(len(self.dataset_lengths)): - assert ( - self.dataset_lengths[ds_index] > self.valid_dataset_lenghts[ds_index] - ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." - self.train_dataset_lenghts = [ - a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) - ] # Subtract the valid samples from the training dataset - if is_valid: # Valid MultilingualNanoset - self.split_num_samples = valid_split_num_samples - self.split_samples_per_epoch = valid_split_num_samples - self.num_epochs = 1 - self.split_dataset_lenghts = self.valid_dataset_lenghts - self.split_dataset_offsets = self.train_dataset_lenghts + self.dataset_index, self.dataset_sample_index = self.build_valid_nanoset_index(self.dataset_lengths) else: # Train MultilingualNanoset - self.split_num_samples = train_split_num_samples - self.split_samples_per_epoch = sum(self.train_dataset_lenghts) - self.num_epochs = int(self.split_num_samples / self.split_samples_per_epoch) + 1 - self.split_dataset_lenghts = self.train_dataset_lenghts - self.split_dataset_offsets = [ - 0 for _ in range(len(self.dataset_lengths)) - ] # For training there is NO offset - - self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() + self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() self.print_nanoset_info() @@ -139,16 +111,16 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: return tokens - def build_nanoset_index(self) -> np.ndarray: + def build_train_nanoset_index(self) -> np.ndarray: """ - Build dataset index and dataset sample index + Build train dataset index and dataset sample index """ + # Compute samples per epoch and number of epochs + samples_per_epoch = sum(self.dataset_lengths) + num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1 # Build the dataset indexes for 1 epoch - dataset_index, dataset_sample_index = build_nanoset_index_helper( - n_samples=self.split_samples_per_epoch, - weights=self.dataset_weights, - dataset_sizes=self.split_dataset_lenghts, - offsets=self.split_dataset_offsets, + dataset_index, dataset_sample_index = build_train_nanoset_index_helper( + n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths ) # Shuffle the indexes the same way numpy_random_state = np.random.RandomState(self.random_seed) @@ -156,14 +128,28 @@ def build_nanoset_index(self) -> np.ndarray: numpy_random_state = np.random.RandomState(self.random_seed) numpy_random_state.shuffle(dataset_sample_index) # Concatenate num_epochs the shuffled indexes - dataset_index = np.concatenate([dataset_index for _ in range(self.num_epochs)]) - dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(self.num_epochs)]) + dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)]) + dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)]) # Just keep the necessary samples - dataset_index = dataset_index[: self.split_num_samples] - dataset_sample_index = dataset_sample_index[: self.split_num_samples] + dataset_index = dataset_index[: self.train_split_num_samples] + dataset_sample_index = dataset_sample_index[: self.train_split_num_samples] return dataset_index, dataset_sample_index + @jit(nopython=True, cache=True) + def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: + """ + Build valid dataset index and dataset sample index + """ + dataset_index = [] + dataset_sample_index = [] + + for i, length in enumerate(dataset_lengths): + dataset_index.extend([i] * length) + dataset_sample_index.extend(range(length)) + + return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") + def print_nanoset_info(self): log_rank( @@ -191,8 +177,8 @@ def print_nanoset_info(self): @jit(nopython=True, cache=True) -def build_nanoset_index_helper( - n_samples: int, weights: np.ndarray, dataset_sizes: List[int], offsets: List[int] +def build_train_nanoset_index_helper( + n_samples: int, weights: np.ndarray, dataset_sizes: List[int] ) -> Tuple[np.ndarray, np.ndarray]: """ Given multiple datasets and a weighting array, build samples indexes @@ -219,9 +205,7 @@ def build_nanoset_index_helper( # Assign the dataset index and update the sample index dataset_index[sample_idx] = max_error_index - dataset_sample_index[sample_idx] = ( - current_samples[max_error_index] % dataset_sizes[max_error_index] - ) + offsets[max_error_index] + dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index] # Update the total samples for the selected dataset current_samples[max_error_index] += 1 From 8a80e5aee5e3b61bb016af96a1ebe578c836b730 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 14:10:03 +0000 Subject: [PATCH 38/51] After lunch --- examples/config_multilingual_nanoset.yaml | 42 +++++++++++++++-------- run_train.py | 6 ++-- src/nanotron/config/config.py | 21 ++++++------ src/nanotron/data/multilingual_nanoset.py | 33 +++++++++--------- tools/preprocess_data.py | 5 ++- 5 files changed, 61 insertions(+), 46 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 3c4476a0..238f8269 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -7,7 +7,8 @@ checkpoints: data_stages: - data: dataset: - dataset_folder: /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized + training_folder: datasets/c4-es/train + validation_folder: datasets/c4-es/validation dataset_tokens: - 15 num_loading_workers: 1 @@ -16,24 +17,37 @@ data_stages: start_training_step: 1 - data: dataset: - dataset_folder: - - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized - - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation dataset_tokens: - - 16 - 15 + - 16 + - 17 num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) start_training_step: 15 - data: dataset: - dataset_folder: - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized: 0.8 - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized: 0.2 + training_folder: + datasets/c4-es/train: 0.6 + datasets/c4-en/train: 0.3 + datasets/c4-fr/train: 0.1 + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation dataset_tokens: - - 16 - 15 + - 16 + - 17 + num_loading_workers: 1 seed: 42 name: Third purpose training (Blended dataset) @@ -61,12 +75,12 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 4096 + hidden_size: 512 initializer_range: 0.02 - intermediate_size: 11008 + intermediate_size: 512 is_llama_config: true max_position_embeddings: 1024 - num_hidden_layers: 32 + num_hidden_layers: 2 num_attention_heads: 32 num_key_value_heads: 8 pad_token_id: null @@ -108,13 +122,13 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: gpt2 + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 10 - micro_batch_size: 2 + micro_batch_size: 4 sequence_length: 1024 train_steps: 200 val_check_interval: -1 diff --git a/run_train.py b/run_train.py index 57e0ec25..39cda23b 100644 --- a/run_train.py +++ b/run_train.py @@ -189,7 +189,7 @@ def get_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): train_dataset = MultilingualNanoset( - dataset_folders=data.dataset.dataset_folder, + dataset_folders=data.dataset.training_folder, dataset_weights=data.dataset.dataset_weights, sequence_length=trainer.sequence_length, token_size=token_size, @@ -238,11 +238,9 @@ def get_valid_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): valid_dataset = MultilingualNanoset( - dataset_folders=data.dataset.dataset_folder, - dataset_weights=data.dataset.dataset_weights, + dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index fb3e49dd..ce61a249 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -110,21 +110,20 @@ def __post_init__(self): @dataclass class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] - validation_folder: Union[str, dict, List[str]] - dataset_tokens: List[ - int - ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) + validation_folder: Union[str, List[str]] + dataset_tokens: List[int] # Set token for each language previously defined def __post_init__(self): - if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file - self.dataset_folder = [self.dataset_folder] + if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder + self.training_folder = [self.training_folder] + self.validation_folder = [self.validation_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights - tmp_dataset_folder = self.dataset_folder.copy() - self.dataset_folder = list(tmp_dataset_folder.keys()) - self.dataset_weights = list(tmp_dataset_folder.values()) + elif isinstance(self.training_folder, dict): # Case 3: dict with > 1 training_folder and weights + tmp_training_folder = self.training_folder.copy() + self.training_folder = list(tmp_training_folder.keys()) + self.dataset_weights = list(tmp_training_folder.values()) assert len(self.training_folder) == len(self.validation_folder) assert len(self.training_folder) == len(self.dataset_tokens) diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index f634fd98..7af57448 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -30,8 +30,8 @@ def __init__( dataset_folders: List[str], sequence_length: int, token_size: int, - train_split_num_samples: int, dataset_tokens: List[int], + train_split_num_samples: int = None, is_valid: bool = False, dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, @@ -78,7 +78,7 @@ def __init__( ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index if is_valid: # Valid MultilingualNanoset - self.dataset_index, self.dataset_sample_index = self.build_valid_nanoset_index(self.dataset_lengths) + self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(self.dataset_lengths) else: # Train MultilingualNanoset self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() @@ -136,20 +136,6 @@ def build_train_nanoset_index(self) -> np.ndarray: return dataset_index, dataset_sample_index - @jit(nopython=True, cache=True) - def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: - """ - Build valid dataset index and dataset sample index - """ - dataset_index = [] - dataset_sample_index = [] - - for i, length in enumerate(dataset_lengths): - dataset_index.extend([i] * length) - dataset_sample_index.extend(range(length)) - - return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") - def print_nanoset_info(self): log_rank( @@ -211,3 +197,18 @@ def build_train_nanoset_index_helper( current_samples[max_error_index] += 1 return dataset_index, dataset_sample_index + + +@jit(nopython=True, cache=True) +def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: + """ + Build valid dataset index and dataset sample index + """ + dataset_index = [] + dataset_sample_index = [] + + for i, length in enumerate(dataset_lengths): + dataset_index.extend([i] * length) + dataset_sample_index.extend(range(length)) + + return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index c668aa58..8383ba38 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -98,7 +98,9 @@ def main(args): dataset_options={"split": args.split}, ) elif args.readers == "parquet": - datatrove_reader = ParquetReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) + datatrove_reader = ParquetReader( + data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern + ) else: datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) @@ -107,6 +109,7 @@ def main(args): datatrove_reader, DocumentTokenizer( output_folder=args.output_folder, + shuffle=False, tokenizer_name_or_path=args.tokenizer_name_or_path, eos_token=args.eos_token, max_tokens_per_file=1e9, From 0fa19716a2770297ae1dbce860bd5f7a36008799 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Thu, 18 Jul 2024 10:48:00 +0000 Subject: [PATCH 39/51] Ready --- examples/config_multilingual_nanoset.yaml | 20 ++++++++++---------- src/nanotron/config/config.py | 11 ++++++++--- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 238f8269..599bff6c 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -9,8 +9,8 @@ data_stages: dataset: training_folder: datasets/c4-es/train validation_folder: datasets/c4-es/validation - dataset_tokens: - - 15 + lang_to_ids: + es: 128002 num_loading_workers: 1 seed: 42 name: General purpose training (Single dataset) @@ -25,10 +25,10 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - dataset_tokens: - - 15 - - 16 - - 17 + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) @@ -43,10 +43,10 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - dataset_tokens: - - 15 - - 16 - - 17 + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index ce61a249..dd2c157d 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - dataset_tokens: List[int] # Set token for each language previously defined + lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -125,8 +125,13 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - assert len(self.training_folder) == len(self.validation_folder) - assert len(self.training_folder) == len(self.dataset_tokens) + self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.validation_folder + ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" + assert len(self.training_folder) == len( + self.dataset_tokens + ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" @dataclass From 8b68126fc0f8826c55867ca5cfb9d90c44f38524 Mon Sep 17 00:00:00 2001 From: Antoni-Joan Solergibert <74564958+TJ-Solergibert@users.noreply.github.com> Date: Thu, 15 Aug 2024 09:42:01 +0200 Subject: [PATCH 40/51] Add multilingual validation (#3) Add multilingual validation step. --- examples/config_multilingual_nanoset.yaml | 77 +++--- run_train.py | 19 +- src/nanotron/config/config.py | 17 +- src/nanotron/data/collator.py | 73 +++++ src/nanotron/data/dataloader_builder.py | 14 +- src/nanotron/data/multilingual_nanoset.py | 4 +- src/nanotron/distributed.py | 4 - src/nanotron/models/llama.py | 37 ++- .../parallel/pipeline_parallel/engine.py | 25 +- .../parallel/pipeline_parallel/state.py | 4 + src/nanotron/serialize/metadata.py | 2 + src/nanotron/trainer.py | 249 +++++++++++++++++- 12 files changed, 438 insertions(+), 87 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 599bff6c..cc66cd70 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 1000 + checkpoint_interval: 1000000 checkpoints_path: checkpoints/ checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null @@ -7,56 +7,57 @@ checkpoints: data_stages: - data: dataset: - training_folder: datasets/c4-es/train - validation_folder: datasets/c4-es/validation - lang_to_ids: - es: 128002 + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation + languages: + - es + - en + - fr num_loading_workers: 1 seed: 42 - name: General purpose training (Single dataset) + name: General purpose training (Blended dataset) start_training_step: 1 - data: dataset: training_folder: - datasets/c4-es/train - - datasets/c4-en/train - - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - - datasets/c4-en/validation - - datasets/c4-fr/validation - lang_to_ids: - es: 128002 - en: 128003 - fr: 128004 + languages: + - es num_loading_workers: 1 seed: 42 - name: Second purpose training (> 1 dataset) - start_training_step: 15 + name: Second purpose training (Single dataset) + start_training_step: 1000 - data: dataset: training_folder: - datasets/c4-es/train: 0.6 - datasets/c4-en/train: 0.3 - datasets/c4-fr/train: 0.1 + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - lang_to_ids: - es: 128002 - en: 128003 - fr: 128004 - + languages: + - es + - en + - fr num_loading_workers: 1 seed: 42 - name: Third purpose training (Blended dataset) - start_training_step: 25 + name: Third purpose training (>1 dataset) + start_training_step: 2000 general: benchmark_csv_path: null consumed_train_samples: null ignore_sanity_checks: true - project: Nanoset + project: MultilingualV2 run: llama seed: 42 step: null @@ -75,12 +76,12 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 512 + hidden_size: 4096 initializer_range: 0.02 - intermediate_size: 512 + intermediate_size: 14336 is_llama_config: true - max_position_embeddings: 1024 - num_hidden_layers: 2 + max_position_embeddings: 4096 + num_hidden_layers: 32 num_attention_heads: 32 num_key_value_heads: 8 pad_token_id: null @@ -89,7 +90,7 @@ model: rope_theta: 500000.0 rms_norm_eps: 1.0e-06 rope_scaling: null - tie_word_embeddings: true + tie_word_embeddings: false use_cache: true vocab_size: 128256 optimizer: @@ -112,11 +113,11 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 1 + dp: 2 expert_parallel_size: 1 pp: 1 pp_engine: 1f1b - tp: 1 + tp: 4 tp_linear_async_communication: false tp_mode: REDUCE_SCATTER profiler: null @@ -128,7 +129,7 @@ tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 10 - micro_batch_size: 4 - sequence_length: 1024 - train_steps: 200 - val_check_interval: -1 + micro_batch_size: 3 + sequence_length: 4096 + train_steps: 500 + val_check_interval: 100 diff --git a/run_train.py b/run_train.py index 39cda23b..809d8d41 100644 --- a/run_train.py +++ b/run_train.py @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -209,6 +208,7 @@ def get_dataloader_from_data_stage( consumed_train_samples=consumed_train_samples, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + is_multilingual=True, ) return train_dataloader @@ -241,7 +241,6 @@ def get_valid_dataloader_from_data_stage( dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -256,6 +255,8 @@ def get_valid_dataloader_from_data_stage( micro_batch_size=trainer.micro_batch_size, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + shuffle=True, + is_multilingual=True, ) return valid_dataloader @@ -315,7 +316,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: stage = cast(DatasetStageArgs, stage) log_rank( - f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples for the validation set", logger=logger, level=logging.INFO, rank=0, @@ -324,8 +325,18 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloader = ( get_valid_dataloader_from_data_stage(trainer, stage.data) if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) + else lambda stage=stage: get_valid_dataloader_from_data_stage(trainer, stage.data) ) + # TODO(tj.solergibert) As we are creating again the valid dataloader in every validation stage, we print multiple times + # the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda + # funcs and directly create all dataloaders. + # + # This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead + # of creating multiple DataLoaders 2. Consume less memory as the lambda func is lighter that the DataLoader object with + # the Dataset, collator, etc. + # BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling + # from the Nanoset. Also they later transform the DataLoader into a Iterator object so it's impossible to retrieve + # the DataLoader object again to delete it (More comments in trainer.py) dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index dd2c157d..b3c755a5 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order + languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -125,13 +125,13 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.languages + ), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})" + assert len(self.training_folder) == len( self.validation_folder ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" - assert len(self.training_folder) == len( - self.dataset_tokens - ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" @dataclass @@ -405,6 +405,13 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0: + raise ValueError( + f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}" + ) + # # if lighteval, we need tokenizer to be defined # if self.checkpoints.lighteval is not None: # assert self.tokenizer.tokenizer_name_or_path is not None diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py index 199527e1..fd217b1a 100644 --- a/src/nanotron/data/collator.py +++ b/src/nanotron/data/collator.py @@ -78,3 +78,76 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni ) return result + + +@dataclasses.dataclass +class MultilingualNanosetDataCollatorForCLM: + """ + Data collator used for causal language modeling with Nanosets dataset. + + - input_pp_rank: Discards last input id token + - output_pp_rank: Discards first label id token + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + + sequence_length: int + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + + def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "lang_code": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + } + + # TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor? + input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) + lang_code = torch.vstack([examples[i]["lang_code"] for i in range(len(examples))]) # (b, 1) + batch_size, expanded_input_length = input_ids.shape + + result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["lang_code"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + + assert ( + expanded_input_length == self.sequence_length + 1 + ), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}" + + # Process inputs: last token is the label + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = input_ids[:, :-1] + result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + result["lang_code"] = lang_code + + # Process labels: shift them to the left + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = input_ids[:, 1:] + result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + + if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + + return result diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 9d3285f6..f9480029 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,6 +1,6 @@ import nanotron.distributed as dist from nanotron import logging -from nanotron.data.collator import NanosetDataCollatorForCLM +from nanotron.data.collator import MultilingualNanosetDataCollatorForCLM, NanosetDataCollatorForCLM from nanotron.dataloader import ( EmptyInfiniteDataset, get_dataloader_worker_init, @@ -20,9 +20,11 @@ def build_nanoset_dataloader( output_pp_rank: int, micro_batch_size: int, dataloader_num_workers: int, + is_multilingual: bool = False, consumed_train_samples: int = 0, dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, + shuffle: bool = False, ) -> DataLoader: # Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job @@ -39,6 +41,14 @@ def build_nanoset_dataloader( parallel_context=parallel_context, ) + if is_multilingual: + data_collator = MultilingualNanosetDataCollatorForCLM( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) + # Compute size and rank of dataloader workers dp_ranks_size = parallel_context.dp_pg.size() dp_rank = parallel_context.dp_pg.rank() @@ -49,7 +59,7 @@ def build_nanoset_dataloader( dl_rank=dp_rank, drop_last=dataloader_drop_last, consumed_train_samples=consumed_train_samples, - shuffle=False, + shuffle=shuffle, ) return DataLoader( diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 7af57448..8eec5549 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -30,7 +30,6 @@ def __init__( dataset_folders: List[str], sequence_length: int, token_size: int, - dataset_tokens: List[int], train_split_num_samples: int = None, is_valid: bool = False, dataset_weights: Union[List[float], None] = None, @@ -47,7 +46,6 @@ def __init__( self.sequence_length = sequence_length self.token_size = token_size self.train_split_num_samples = train_split_num_samples - self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed self.datatrove_datasets = [] @@ -107,7 +105,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset_sample = self.dataset_sample_index[idx] tokens = self.datatrove_datasets[dataset][dataset_sample] - tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token + tokens["lang_code"] = torch.tensor(dataset, dtype=torch.long) return tokens diff --git a/src/nanotron/distributed.py b/src/nanotron/distributed.py index 0156b1bb..0bc54f3e 100644 --- a/src/nanotron/distributed.py +++ b/src/nanotron/distributed.py @@ -52,10 +52,6 @@ def all_gather_into_tensor( # pylint: disable=function-redefined if group is None: group = dist.torch_dist.distributed_c10d._get_default_group() - assert ( - group.size() > 1 - ), "You should probably not call `all_gather_into_tensor` with a single rank, as it copies data over" - if torch_version_above_1_13: return dist.all_gather_into_tensor( output_tensor=output_tensor, input_tensor=input_tensor, group=group, async_op=async_op diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index a440c8d0..82895025 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -755,14 +755,20 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): - return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask, lang_code=lang_code)[0] def forward_with_hidden_states( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): + # NOTE(tj.solergibert) I bring `lang_code` till the forward of LlamaModel. Remember that + # to use it in the different pipeline blocks you need to also set the module_input_keys & module_output_keys + # of the necessary `PipelineBlock`'s defined in the LlamaModel init! + # all tensors are optional as most ranks don't need anything from the dataloader. output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) @@ -823,7 +829,9 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch @torch.jit.script def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor - return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( + dim=1 + ) # NOTE(tj.solergibert) Added dim=1 to return a tensor with shape [Batch size, 1] instead of [1] class Loss(nn.Module): @@ -840,14 +848,18 @@ def forward( # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 - loss = sharded_cross_entropy( + sample_loss = sharded_cross_entropy( sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float ).transpose(0, 1) # TODO @thomasw21: It's unclear what kind of normalization we want to do. - loss = masked_mean(loss, label_mask, dtype=torch.float) - # I think indexing causes a sync we don't actually want - # loss = loss[label_mask].sum() - return {"loss": loss} + sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float) + # NOTE(tj.solergibert) masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. + # We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort. + # WARN(tj.solergibert) Don't panic, the batch loss used to update the parameters is computed in `LlamaForTraining` + + # TODO @thomasw21: I think indexing causes a sync we don't actually want + # TODO @thomasw21: loss = loss[label_mask].sum() + return {"sample_loss": sample_loss} class LlamaForTraining(NanotronModel): @@ -869,7 +881,7 @@ def __init__( "label_ids", "label_mask", }, - module_output_keys={"loss"}, + module_output_keys={"sample_loss"}, ) self.parallel_context = parallel_context self.config = config @@ -879,19 +891,22 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], input_mask: Union[torch.Tensor, TensorPointer], + lang_code: Union[torch.Tensor, TensorPointer], label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: sharded_logits = self.model( input_ids=input_ids, input_mask=input_mask, + lang_code=lang_code, ) - loss = self.loss( + outputs = self.loss( sharded_logits=sharded_logits, label_ids=label_ids, label_mask=label_mask, - )["loss"] - return {"loss": loss} + ) + outputs["loss"] = torch.mean(outputs["sample_loss"]) + return outputs @torch.no_grad() def init_model_randomly(self, config: Config): diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..9b548e35 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,6 +2,9 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup @@ -9,11 +12,9 @@ from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model -from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState +from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState, PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) @@ -29,6 +30,7 @@ def forward( state: PipelineTrainBatchState, micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]], model: torch_nn.Module, + is_validation: bool = False, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Increment the number of backwards state.nb_forwards += 1 @@ -52,7 +54,7 @@ def forward( output["loss"] = output["loss"] / self.nb_microbatches # Add output as activations that require backward pass - if not isinstance(output["loss"], TensorPointer): + if not isinstance(output["loss"], TensorPointer) and not is_validation: assert output["loss"].requires_grad state.register_activation_requiring_backward(output["loss"]) return output @@ -134,16 +136,19 @@ def validate_batch_iter( nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: # Assign a new state for the current batch - state = PipelineTrainBatchState() # TODO: do i need state? + state = PipelineEvalBatchState() self.nb_microbatches = nb_microbatches outputs = [] + lang_codes = [] with attach_pipeline_state_to_model(model=model, pipeline_state=state): # All forward for micro_batch in batch: context = self._get_fwd_context(model=model) - output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model) + output = self.forward( + context=context, state=state, micro_batch=micro_batch, model=model, is_validation=True + ) # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage" for _ in range(len(state.microbatches_activations_to_send)): send_activation = state.microbatches_activations_to_send.popleft() @@ -157,9 +162,13 @@ def validate_batch_iter( # Store the loss for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} - outputs.append(output) - return outputs + outputs.extend( + list(output["sample_loss"]) + ) # NOTE(tj.solergibert) Yes, it might look useless to do list + extend but it's necessary to split the output["sample_loss"] tensor into multiple tensors + lang_codes.extend(micro_batch["lang_code"].flatten().tolist()) + + return outputs, lang_codes class AllForwardAllBackwardPipelineEngine(PipelineEngine): diff --git a/src/nanotron/parallel/pipeline_parallel/state.py b/src/nanotron/parallel/pipeline_parallel/state.py index e07cc89a..f22d6571 100644 --- a/src/nanotron/parallel/pipeline_parallel/state.py +++ b/src/nanotron/parallel/pipeline_parallel/state.py @@ -4,6 +4,7 @@ from typing import List import torch + from nanotron import distributed as dist from nanotron import logging from nanotron.logging import log_rank @@ -203,6 +204,9 @@ class PipelineEvalBatchState(PipelineBatchState): microbatches_activations_to_recv = collections.deque() activations_buffer = collections.deque() + # Reinitialise counter + nb_forwards = 0 + def register_activation_requiring_backward(self, activation: torch.Tensor): pass diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 0d8708f9..4bd36c19 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -46,6 +46,8 @@ class TrainingMetadata: last_stage_idx: Optional[int] = None data_stages: Optional[List[DataStageMetadata]] = None + last_validation_stage_idx: Optional[int] = None + def __post_init__(self): # NOTE: this is a sanity check after loading a trained checkpoint total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 3f4c5189..a17f9849 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -80,6 +80,7 @@ from nanotron.sanity_checks import ( after_optim_step_sanity_checks, after_tbi_sanity_checks, + assert_tensor_synced_across_pg, before_optim_step_sanity_checks, before_tbi_sanity_checks, ) @@ -231,7 +232,11 @@ def __init__( for stage in self.config.data_stages ] self.metadata: TrainingMetadata = TrainingMetadata( - consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages + consumed_train_samples=0, + last_train_step=0, + last_stage_idx=0, + data_stages=data_stages, + last_validation_stage_idx=0, ) # Setup tensorboard write and log writers on output rank @@ -253,6 +258,8 @@ def __init__( self.limit_val_batches = self.config.tokens.limit_val_batches # NOTE: the dataloader currently in use for the current training stage self.current_dataloader: Optional[DataLoader] = None + # NOTE: the dataloader currently in use for the current validation stage + self.current_validation_dataloader: Optional[DataLoader] = None self.post_init() @@ -300,6 +307,106 @@ def _print_training_plan(self): ) log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) + def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataLoader], DataLoader]): + # NOTE(tj.solergibert) Similar to _update_dataloader_based_on_training_stages BUT: + # 1. We call this function EVERY TIME we run the validation loop + # 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset + # in the first iteration and subsequent validations will fail + # `dataloaders` are either torch DataLoaders (the very first stage) OR functions that we call later that provide torch DataLoaders (subsequent stages) + # From this torch DataLoaders objects we then call `sanity_check_dataloader` that will return a iterator. + # In short, `sanity_check_dataloader` just places the input tensors in the GPU when necessary (TensorPointers stay in the CPU) + # + # TBH, the for loop below it's just for deleting the DataLoaders of previous stages, which is not so problematic. The important part is returning the + # DataLoader iterator every time we call this function from the current training stage, which is tracked during training + # + # Also, keep in mind that if val_check_interval = 5 & data.start_training_step = 10 we will already perform the evaluation with the SECOND data stage + # after just training for the current iteration, so it might not be a good idea to set evals during the stage in which we change of data stage + # + # NOTE(tj.solergibert) Further investigation should be done, but there is a extrange behaiviour when deleting the DataLoaders////lambda functs. As they + # are converted into Iterators with `sanity_check_dataloader` we can't access anymore the DataLoader object to del the dataset (After first stage, + # in this function we locally create the DataLoder from the lambda func --> Return Iterator) + # + # Also when the gc deletes the first stage dataloader, all the `DatatroveFileDataset._f` are already None AND the `del` thing are deleting a copy of the + # object, not the object itself + # + # FINAL NOTE(tj.solergibert) I will open a Issue in nanotron to check with them if they are aware of this useless deletitions + # + # TODO(tj.solergibert) Check the tuple case below + from collections.abc import Generator + + if not hasattr(self.config, "data_stages") or self.config.data_stages is None: + + if isinstance(dataloaders, tuple): # TODO(tj.solergibert) Check this tuple case + dataloader = dataloaders[0] + else: + dataloader = dataloaders + + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + + return + elif isinstance(dataloaders, Generator): + # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader + # remove this in the next PR + self.current_validation_dataloader = dataloaders + return + + assert len(dataloaders) > 0, "No dataloaders provided" + assert len(dataloaders) == len( + self.config.data_stages + ), "Number of dataloaders should match the number of dataset stages" + + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): + import gc + + log_rank( + f"[Validation Stage: {stage_name}] Clearing the previous validation stage's ({prev_stage_name}) dataloader and dataset from memory", + logger=logger, + level=logging.INFO, + ) + + # NOTE: Clear dataloader from memory + del dataloader.dataset + del dataloader.sampler + del dataloader.batch_sampler + + gc.collect() + + for stage_idx, stage in enumerate(self.config.data_stages): + if stage_idx < self.metadata.last_stage_idx: + continue + # NOTE(tj.solergibert) From this point stage_idx = self.metadata.last_stage_idx. We update self.metadata.last_stage_idx (which keeps track of the training stage) + # in each and every training step. + + if ( + stage_idx is not self.metadata.last_validation_stage_idx + ): # When stage_idx (= self.metadata.last_stage_idx, the training stage index) is different than the last validation stage index + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + # Delete previous stage DataLoader + prev_stage_name = self.config.data_stages[stage_idx - 1].name + prev_dataloader = dataloaders[prev_stage_name] + + if isinstance(prev_dataloader, DataLoader): + # NOTE: we don't need to clear dummy data generator from memory + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) + + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + + # NOTE(tj.solergibert) Create AGAIN the DataLoader + dataloader = dataloaders[stage.name] + # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it + dataloader = dataloader() if callable(dataloader) else dataloader + break + + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) # NOTE(tj.solergibert) Create a Iterator from the DataLoader + def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): from collections.abc import Generator @@ -324,11 +431,11 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da self.config.data_stages ), "Number of dataloaders should match the number of dataset stages" - def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): import gc log_rank( - f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", + f"[Training Stage: {stage_name}] Clearing the previous training stage's ({prev_stage_name}) dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -365,7 +472,9 @@ def find_stage_idx_to_resume(): if isinstance(prev_dataloader, DataLoader): # NOTE: we don't need to clear dummy data generator from memory - clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) self.metadata.last_stage_idx = stage_idx @@ -431,6 +540,19 @@ def train( # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) + self.training_step_time = time.time() + + # Validation stage + if self.iteration_step % self.config.tokens.val_check_interval == 0: + self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) + val_global_loss, val_lang_losses = self.validation_step( + dataloader=self.current_validation_dataloader + ) + self.validation_step_time = time.time() + else: + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + val_global_loss, val_lang_losses = None, None # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -441,7 +563,7 @@ def train( ].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs(outputs=outputs, loss_avg=loss_avg) + self.train_step_logs(loss_avg=loss_avg, global_loss=val_global_loss, lang_losses=val_lang_losses) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -546,22 +668,71 @@ def training_step( return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: - outputs = self.pipeline_engine.validate_batch_iter( + outputs, lang_codes = self.pipeline_engine.validate_batch_iter( model=self.model, - batch=(next(dataloader) for _ in range(self.limit_val_batches)), - nb_microbatches=self.limit_val_batches, + batch=(next(dataloader) for _ in range(self.current_validation_dataloader_lenght)), + nb_microbatches=self.current_validation_dataloader_lenght, ) - return outputs + + lang_losses = { + lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.languages + } + lang_losses_list = list(lang_losses.keys()) + + # Compute losses + if isinstance(outputs[0], torch.Tensor): + # Multilingual losses + for loss, lang_code in zip(outputs, lang_codes): + lang_losses[lang_losses_list[lang_code]].append(loss) + # Global loss + global_loss_avg = torch.mean(torch.stack(outputs)) + # Sync multilingual losses across DP + for lang in lang_losses.keys(): + if not lang_losses[ + lang + ]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation + lang_losses[lang] = torch.tensor(-1, dtype=torch.float32) + else: # If we have at least 1 loss from a given language --> compute local language loss mean + lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) + + # NOTE(tj.solergibert) We create a (DP SIZE, LANGS) tensor to aggregate ALL local losses across DP groups. + # Then we compute the mean of each lang in each and every rank and finally copy back the result to the + # `lang_losses` dict for logging + lang_losses_tensor_out = torch.zeros( + (self.parallel_context.dp_pg.size(), len(lang_losses.keys())), dtype=torch.float, device="cuda" + ) # (DP SIZE, LANGS) + lang_losses_tensor_local = torch.stack(list(lang_losses.values())).unsqueeze(0) # (1, LANGS) + dist.all_gather_into_tensor(lang_losses_tensor_out, lang_losses_tensor_local, self.parallel_context.dp_pg) + mask = lang_losses_tensor_out != -1 + lang_losses_tensor_local = (lang_losses_tensor_out * mask).sum(dim=0) / mask.sum(dim=0) # (1, LANGS) + for idx, lang in enumerate(lang_losses.keys()): + lang_losses[lang] = lang_losses_tensor_local[idx] + + # Sync global losses across DP + dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) + + # TODO(tj.solergibert) Delete this testing assertions + for lang in lang_losses.keys(): + assert_tensor_synced_across_pg(tensor=lang_losses[lang], pg=self.parallel_context.dp_pg) + assert_tensor_synced_across_pg(tensor=global_loss_avg, pg=self.parallel_context.dp_pg) + + else: + global_loss_avg = None + lang_losses = None + + return global_loss_avg, lang_losses def train_step_logs( self, - outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], loss_avg: Optional[torch.Tensor], + global_loss: torch.Tensor, + lang_losses: torch.Tensor, ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() torch.cuda.synchronize() - elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 + # Training metrics + elapsed_time_per_iteration_ms = (self.training_step_time - self.iteration_start_time) * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) ) # tokens_per_sec is calculated using sequence_length @@ -571,13 +742,27 @@ def train_step_logs( global_batch_size=self.global_batch_size, ) + # Validation metrics + if global_loss is not None: + validation_total_samples = self.current_validation_dataloader_lenght * self.micro_batch_size + validation_elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 + validation_tokens_per_sec = ( + validation_total_samples * self.sequence_length / (validation_elapsed_time_per_iteration_ms / 1000) + ) + + validation_model_tflops, validation_hardware_tflops = self.unwrapped_model.get_flops_per_sec( + iteration_time_in_sec=validation_elapsed_time_per_iteration_ms / 1000, + sequence_length=self.sequence_length, + global_batch_size=validation_total_samples, + ) + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" + # Training metrics lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "consumed_tokens", self.metadata.consumed_train_samples * self.config.tokens.sequence_length, @@ -598,6 +783,46 @@ def train_step_logs( if self.config.optimizer.clip_grad is not None: log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f")) + # Validation metrics + if global_loss is not None: + log_entries.extend( + [ + LogItem( + "validation_consumed_tokens", + validation_total_samples * self.sequence_length, + "human_format", + ), # , "12d"), + LogItem( + "validation_elapsed_time_per_iteration_ms", + validation_elapsed_time_per_iteration_ms, + "human_format", + ), # , ".1f"), + LogItem("validation_tokens_per_sec", validation_tokens_per_sec, "human_format"), # , "1.6E"), + LogItem( + "validation_tokens_per_sec_per_gpu", + validation_tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", + ), # , "1.6E"), + LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), + LogItem( + "validation_model_tflops_per_gpu", validation_model_tflops / 3, "human_format" + ), # , ".2f"), # NOTE(tj.solergibert) Check llama.py --> def get_flops() --> model_flops for explanation of the / 3 factor + LogItem( + "validation_hardware_tflops_per_gpu", validation_hardware_tflops / 3, "human_format" + ), # , ".2f"), # NOTE(tj.solergibert) Check llama.py --> def get_flops() --> model_flops for explanation of the / 3 factor + ] + ) + + # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. + # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 + # GitHub complains: https://github.com/wandb/wandb/issues/3035 + log_entries.extend( + [ + LogItem(f"{lang}_validation_loss", loss.item(), "human_format") + for lang, loss in lang_losses.items() + ] + ) + # Log not too often the memory if self.iteration_step < 5 or (self.iteration_step - 1) % self.config.checkpoints.checkpoint_interval == 0: total, used, free = shutil.disk_usage("/") From d08c949aac15479592c1c657e717b1027e060e02 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Thu, 15 Aug 2024 14:40:00 +0000 Subject: [PATCH 41/51] correct logging of all losses --- src/nanotron/trainer.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index af16e39c..0b3306c5 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -455,7 +455,7 @@ def train( def training_step( self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]] - ) -> Tuple[Iterable[Dict], Optional[torch.Tensor]]: + ) -> Tuple[Iterable[Dict], Optional[Dict[str, torch.Tensor]]]: before_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) if self.iteration_step < 5: @@ -522,11 +522,19 @@ def training_step( # Compute DP average loss and overlap with optimizer step if isinstance(outputs[0]["loss"], torch.Tensor): # This is an average on only one data rank. - loss_avg = torch.stack( - [output["loss"] for output in outputs] - ).sum() # already divided by n_micro_batches_per_batch - # sync loss across DP - handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) + loss_avg = {} + for k in outputs[0].keys(): + if k == "loss": + loss_avg["lm_loss"] = torch.stack([output[k] for output in outputs]).sum() + k = "lm_loss" + else: + loss_avg[k] = torch.stack([output[k] for output in outputs]).mean() + handle = dist.all_reduce( + loss_avg[k], + group=self.parallel_context.dp_pg, + async_op=True, + op=dist.ReduceOp.AVG, + ) else: loss_avg = None handle = None @@ -558,7 +566,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten def train_step_logs( self, outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], - loss_avg: Optional[torch.Tensor], + loss_avg: Optional[Dict[str, torch.Tensor]], ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() @@ -591,12 +599,15 @@ def train_step_logs( "tokens_per_sec_per_gpu", tokens_per_sec / self.parallel_context.world_pg.size(), "human_format" ), # , "1.6E"), LogItem("global_batch_size", self.global_batch_size, "human_format"), # , "5d"), - LogItem("lm_loss", loss_avg.item(), "human_format"), # , "1.6E"), LogItem("lr", lr, "human_format"), # , ".3E"), LogItem("model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), LogItem("hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), ] + if loss_avg is not None: + for k, v in loss_avg.items(): + log_entries.append(LogItem(k, v.item(), "human_format")) + if self.config.optimizer.clip_grad is not None: log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f")) From d14315f8a6d4663fae1cc916db3ece5cda9ecd6d Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Fri, 16 Aug 2024 10:04:53 +0000 Subject: [PATCH 42/51] minor bug fix when using bias --- src/nanotron/models/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index 6ef9f0f9..f7bb07bd 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -431,7 +431,7 @@ def forward(self, x, router_logits, expert_weights, top_experts): z_loss = torch.zeros(1, device=x.device) if self.use_bias: - return x + self.bias + x = x + self.bias return x, lbl_loss, z_loss def permute_and_compute( From 5dc67fe1acb864f3c771776f216b5c69be83ad9d Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Fri, 16 Aug 2024 10:05:13 +0000 Subject: [PATCH 43/51] bias init in case of use for moe --- src/nanotron/models/starcoder2.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/nanotron/models/starcoder2.py b/src/nanotron/models/starcoder2.py index b05a67bb..1f2eab7d 100644 --- a/src/nanotron/models/starcoder2.py +++ b/src/nanotron/models/starcoder2.py @@ -56,6 +56,7 @@ from nanotron.parallel.tied_parameters import tie_parameters from nanotron.random import RandomStates, branch_random_state from nanotron.utils import checkpoint_method +from nanotron.models.moe import ParallelDroplessMLP, SparseMLP def pad_to_right(tensor, mask, new_tensor=None): @@ -1524,6 +1525,9 @@ def init_model_randomly(self, config): module.bias.zero_() else: raise ValueError(f"Who the fuck is {param_name}?") + elif isinstance(module, ParallelDroplessMLP): + if hasattr(module, 'bias'): + module.bias.zero_() elif isinstance(module, TensorParallelRowLinear): if "weight" == param_name: nn.init.normal_(module.weight, mean=0.0, std=sigma / math.sqrt(2 * num_layers)) From fad34976cd8683560603f8dff9ec27129318a549 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Fri, 16 Aug 2024 10:06:41 +0000 Subject: [PATCH 44/51] sparse upcycling converter --- examples/xglm/README.md | 31 +++-- examples/xglm/convert_dense2moe.py | 179 +++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 11 deletions(-) create mode 100644 examples/xglm/convert_dense2moe.py diff --git a/examples/xglm/README.md b/examples/xglm/README.md index 22765f52..8f62fc57 100644 --- a/examples/xglm/README.md +++ b/examples/xglm/README.md @@ -1,18 +1,27 @@ # How to use XGLM? 1. First, make sure to convert the weights from huggingface, for instance: - ``` - torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M - ``` +```bash +torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M +``` -1. Now you are ready to use XGLM. +2. Now you are ready to use XGLM. Make sure you use a .yaml configuration with proper GPT3 config and then run for instance: - ``` - torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml - ``` +```bash +torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml +``` If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. -1. If you want to convert your finetuned checkpoint back to huggingface use: - ``` - torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M - ``` +3. If you want to convert your finetuned checkpoint back to huggingface use: +```bash +torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checkpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M +``` + +## Sparse Upcycling + +To create a sparse model from a dense model, you can use the `convert_dense2moe.py` script that goes from a GPT3 Nanotron model to a GPT3 MoE Nanotron model. For instance: +```bash +cd examples/xglm +torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=checkpoints/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-8x564M --num-experts=8 +``` +Note that this upcycling _drops_ the bias parameters of the MLP because the MegaBlocks implementation does not support bias parameters. While this is a limitation of the current implementation, the performance is quickly recovered after a few training steps. diff --git a/examples/xglm/convert_dense2moe.py b/examples/xglm/convert_dense2moe.py new file mode 100644 index 00000000..fa4d9af7 --- /dev/null +++ b/examples/xglm/convert_dense2moe.py @@ -0,0 +1,179 @@ +""" +Converts a nanotron model to HF format +Command: + torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=nanotron_weights --save-path=nanotron_moe_weights +""" + +import dataclasses +import json +import warnings +from argparse import ArgumentParser +from pathlib import Path +from typing import Optional + +from torch import nn +import torch +import nanotron +from nanotron.config.models_config import GPT3Config, GPT3MoEConfig +from nanotron.models.gpt3 import GPT3ForTraining, GPTBlock +from nanotron.models.gpt3_moe import GPT3MoEForTraining, GPT3MoEBlock +from nanotron.trainer import mark_tied_parameters + +from convert_utils import convert_generic, create_nt_model + + +def convert_config(config: GPT3Config, num_experts=8) -> GPT3MoEConfig: + return GPT3MoEConfig( + **config.__dict__, + is_moe=True, + moe_num_experts=num_experts, + num_experts_per_tok=min(2, num_experts), # arbitrarily chosen + moe_loss_weight=0.01, # arbitrarily chosen + moe_z_loss_weight=0.001, # arbitrarily chosen + moe_glu=False, + ) + + +def convert_dense_to_moe(ff_moe: nn.Module, dense_ff: nn.Module, num_experts: int): + with torch.no_grad(): + # only copy the weight matrix and repeat it n_expert times + weight_1 = dense_ff.c_fc.weight.clone() + if num_experts == 1: + ff_moe.experts.mlp.w1.module.weight.data = weight_1.contiguous() + else: + # [intermediate_size, hidden_size] -> [hidden_size, intermediate_size * n_experts] + weight_1 = weight_1.T + ff_moe.experts.mlp.w1.module.weight.data = weight_1.repeat(1, num_experts) + + weight_2 = dense_ff.c_proj.weight.clone() + if num_experts == 1: # just a specific case for 1 expert + ff_moe.experts.mlp.w2.module.weight.data = weight_2.contiguous() + else: + # [hidden_size, intermediate_size] -> [intermediate_size * n_experts, hidden_size] + weight_2 = weight_2.T + ff_moe.experts.mlp.w2.module.weight.data = weight_2.repeat(num_experts, 1) + + # # -- could add bias only for 2nd layer, because that works with the MegaBlocks MoE implementation + # # -- but won't make a big difference? + # ff_moe.experts.bias.copy_(dense_ff.c_proj.bias) + + # init gating randomly + nn.init.normal_(ff_moe.gate.layer.weight, mean=0.0, std=0.02) + + +def convert_decoder(block_moe: GPT3MoEBlock, block_nt: GPTBlock, num_experts: int): + convert_generic(block_moe.ln_1, block_nt.ln_1) + convert_generic(block_moe.attn, block_nt.attn) + convert_generic(block_moe.ln_2, block_nt.ln_2) + convert_dense_to_moe(block_moe.ff, block_nt.ff, num_experts) + + +def convert( + model_moe: GPT3MoEForTraining, model_dense: GPT3ForTraining, num_experts: int +): + convert_generic( + model_moe.model.token_embeddings.pp_block.token_embedding, + model_dense.model.token_embeddings.pp_block.token_embedding, + ) + for layer_moe, layer_nt in zip(model_moe.model.decoder, model_dense.model.decoder): + convert_decoder(layer_moe.pp_block, layer_nt.pp_block, num_experts) + convert_generic( + model_moe.model.final_layer_norm.pp_block, + model_dense.model.final_layer_norm.pp_block, + ) + convert_generic( + model_moe.model.lm_head.pp_block, model_dense.model.lm_head.pp_block + ) + + +def create_nt_moe_model( + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None, +): + + if model_config is None: + assert checkpoint_path is not None + with open(checkpoint_path / "model_config.json") as f: + model_config = GPT3MoEConfig(**json.load(f)) + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3MoEForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + if checkpoint_path is not None: + nanotron.serialize.load_weights( + model=model_nt, + parallel_context=parallel_context, + root_folder=checkpoint_path, + ) + + return model_nt + + +def main( + checkpoint_path: Path, + save_path: Path, + num_experts: int, +): + # Load nanotron model. + model_dense = create_nt_model(checkpoint_path=checkpoint_path) + + # Init moe model. + model_config_moe = convert_config(model_dense.config, num_experts) + model_moe = create_nt_moe_model(model_config=model_config_moe) + + convert(model_moe, model_dense, num_experts) + nanotron.serialize.save_weights( + model=model_moe, + parallel_context=model_moe.parallel_context, + root_folder=save_path, + ) + with open(save_path / "model_config.json", "w+") as f: + json.dump(dataclasses.asdict(model_config_moe), f) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + # fix all random seeds + torch.manual_seed(0) + torch.cuda.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cudnn.deterministic = True + parser = ArgumentParser(description="Convert dense weights to moe format") + parser.add_argument( + "--checkpoint-path", + type=Path, + default="checkpoints/xglm-7.5B", + help="Path to the nanotron dense checkpoint", + ) + parser.add_argument( + "--save-path", + type=Path, + default="checkpoints/xglm-moe-7.5B", + help="Path to save the nanotron moe model", + ) + parser.add_argument( + "--num-experts", + type=int, + default=8, + help="Number of experts in the MoE model (duplicates of MLP layer)", + ) + args = parser.parse_args() + main(args.checkpoint_path, args.save_path, args.num_experts) From f31a1a32df5d70b9bc12a242b237746e2ba24964 Mon Sep 17 00:00:00 2001 From: Alex Hagele Date: Fri, 16 Aug 2024 14:37:46 +0000 Subject: [PATCH 45/51] add example config --- examples/xglm/example_config_moe.yaml | 113 ++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 examples/xglm/example_config_moe.yaml diff --git a/examples/xglm/example_config_moe.yaml b/examples/xglm/example_config_moe.yaml new file mode 100644 index 00000000..aa0d739c --- /dev/null +++ b/examples/xglm/example_config_moe.yaml @@ -0,0 +1,113 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints-xglm-moe-8x564M + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: checkpoints-xglm-moe-8x564M + save_initial_state: false +data_stages: + - data: + dataset: + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation + languages: + - es + - en + - fr + num_loading_workers: 1 + seed: 42 + name: General purpose training (Blended dataset) + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: multilingual-moe-init + run: xglm-topk2-8x564M + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + # for random init: + # std: 0.02 + # or upcycling after using the converter scripts: + path: /mloscratch/homes/haegele/swissai/nanotron-multilingual/xglm/xglm-moe + make_vocab_size_divisible_by: 1 + model_config: + activation_function: gelu + attn_pdrop: 0.0 + embd_pdrop: 0.0 + scale_embedding: true + eos_token_id: 2 + hidden_size: 1024 + intermediate_size: 4096 + layer_norm_epsilon: 0.00001 + max_position_embeddings: 2048 + num_attention_heads: 16 + num_hidden_layers: 24 + resid_pdrop: 0.0 + scale_attention_softmax_in_fp32: true + scale_attn_weights: true + vocab_size: 256008 + sinusoidal_position_embedding: true + position_embedding_offset: 2 + use_spda: false + act_pdrop: 0.0 + is_moe: true + moe_num_experts: 8 + num_experts_per_tok: 2 + moe_loss_weight: 0.01 + moe_z_loss_weight: 0.001 + moe_glu: false +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.001 + lr_decay_starting_step: 2000 + lr_decay_steps: 500 + lr_decay_style: 1-sqrt + lr_warmup_steps: 100 + lr_warmup_style: linear + min_decay_lr: 0 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.1 + zero_stage: 0 +parallelism: + dp: 3 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: facebook/xglm-564M + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 5 + limit_test_batches: 0 + limit_val_batches: 10 + micro_batch_size: 4 # fits on one 80GB A100 for 8x564M + sequence_length: 2048 + train_steps: 2500 + val_check_interval: -1 From a45dc35681b86f1dad03d49eaae618173c890e57 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Wed, 4 Sep 2024 09:59:33 +0200 Subject: [PATCH 46/51] small fixes --- src/nanotron/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 91b9a29b..53f3708f 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -692,7 +692,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten lang_losses_list = list(lang_losses.keys()) # Compute losses - if isinstance(outputs[0], torch.Tensor): + if len(outputs) > 0 and isinstance(outputs[0], torch.Tensor): # Multilingual losses for loss, lang_code in zip(outputs, lang_codes): lang_losses[lang_losses_list[lang_code]].append(loss) @@ -703,7 +703,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten if not lang_losses[ lang ]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation - lang_losses[lang] = torch.tensor(-1, dtype=torch.float32) + lang_losses[lang] = torch.tensor(-1, dtype=torch.float32, device="cuda") else: # If we have at least 1 loss from a given language --> compute local language loss mean lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) From bb768bb0ae1cb8f12550719c5c89a3c1f54b321c Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Fri, 6 Sep 2024 09:50:13 +0200 Subject: [PATCH 47/51] fix for eval --- src/nanotron/models/gpt3_moe.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/nanotron/models/gpt3_moe.py b/src/nanotron/models/gpt3_moe.py index 1915136c..06a624ac 100644 --- a/src/nanotron/models/gpt3_moe.py +++ b/src/nanotron/models/gpt3_moe.py @@ -95,7 +95,7 @@ def forward( self, hidden_states: torch.Tensor | TensorPointer, sequence_mask: torch.Tensor | TensorPointer, - aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], + aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]] = None, ) -> dict[str, torch.Tensor | TensorPointer]: residual = hidden_states @@ -119,9 +119,10 @@ def forward( mlp_output = self.ff(hidden_states=hidden_states) hidden_states = mlp_output["hidden_states"] - for key, value in mlp_output.items(): - if key != "hidden_states": - aux_losses[key] = aux_losses[key] + value + if aux_losses is not None: + for key, value in mlp_output.items(): + if key != "hidden_states": + aux_losses[key] = aux_losses[key] + value if self.training: with branch_random_state( @@ -171,7 +172,7 @@ def forward( self, input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length] input_mask: torch.Tensor | TensorPointer, # [batch_size, seq_length] - aux_losses: Dict[str, Union[torch.Tensor, TensorPointer]], + aux_losses: Optional[Dict[str, Union[torch.Tensor, TensorPointer]]] = None, ): # all tensors are optional as most ranks don't need anything from the dataloader. @@ -199,7 +200,10 @@ def forward( fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] - return {"sharded_logits": fp32_sharded_logits, "aux_losses": hidden_encoder_states["aux_losses"]} + if aux_losses is not None: + return {"sharded_logits": fp32_sharded_logits, "aux_losses": hidden_encoder_states["aux_losses"]} + else: + return fp32_sharded_logits class GPT3MoEForTraining(GPT3ForTraining): From f8e30b4cc9e16030bedf52d7154d2a614f7ac185 Mon Sep 17 00:00:00 2001 From: Alexander Hagele Date: Fri, 6 Sep 2024 09:58:37 +0200 Subject: [PATCH 48/51] lighteval fix for multiling --- src/nanotron/config/lighteval_config.py | 33 +++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index b5f12059..a1b71070 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -60,6 +60,8 @@ class LightEvalTasksArgs: multichoice_continuations_start_space: Optional[bool] = None no_multichoice_continuations_start_space: Optional[bool] = None + langs: Optional[str] = "en" + @dataclass class LightEvalWandbLoggerConfig: @@ -91,3 +93,34 @@ class LightEvalConfig: tasks: Optional[LightEvalTasksArgs] = None logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None + + +# batch_size: 16 +# checkpoints_path: null +# generation: null +# logging: +# hub_repo_details: null +# hub_repo_results: null +# hub_repo_tensorboard: null +# local_output_path: /capstor/scratch/cscs/$USER/multilingual_data_mixture/eval_results +# push_details_to_hub: false +# push_results_to_hub: false +# push_results_to_tensorboard: true +# tensorboard_metric_prefix: eval_ +# wandb: null +# parallelism: +# dp: 1 +# pp: 1 +# pp_engine: 1f1b +# tp: 1 +# tp_linear_async_communication: false +# tp_mode: ALL_REDUCE +# tasks: +# custom_tasks: /capstor/scratch/cscs/$USER/lighteval-multilingual/src/lighteval/community_tasks/multilingual/configs/multilingual.py +# dataset_loading_processes: 1 +# max_samples: 10 +# multichoice_continuations_start_space: null +# no_multichoice_continuations_start_space: null +# num_fewshot_seeds: 1 +# tasks: x_nli +# langs: en,fr \ No newline at end of file From ac27adaeb160b9803ac927c1f81f4d0a8880154c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20H=C3=A4gele?= <56194263+haeggee@users.noreply.github.com> Date: Fri, 6 Sep 2024 17:26:56 +0200 Subject: [PATCH 49/51] Update moe.md --- moe.md | 45 +++++++++++++++++++-------------------------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/moe.md b/moe.md index 46bca04f..7c2bb672 100644 --- a/moe.md +++ b/moe.md @@ -1,18 +1,23 @@ # MoE Env Setup -TL;DR: need to install megablocks for MoEs, which depends on triton; cannot install triton inside the docker image because it requires a CUDA-capable GPU, which is not available in the build environment. therefore install triton from source inside a venv in the container, then install megablocks +TL;DR: need to install megablocks for MoEs. just use the environment `/store/swissai/a06/containers/nanotron_moe/nanotron_moe.toml` :) +The setup is documented in that folder on the cluster. The setup is: ```Dockerfile -FROM nvcr.io/nvidia/pytorch:24.04-py3 +FROM nvcr.io/nvidia/pytorch:24.05-py3 -ENV DEBIAN_FRONTEND=noninteractive +# setup +RUN apt-get update && apt-get install -y \ + python3-pip \ + python3-venv \ + git tmux htop nvtop \ + && apt-get clean && rm -rf /var/lib/apt/lists/* -RUN apt-get update && apt-get install -y python3.10-venv && apt-get clean && rm -rf /var/lib/apt/lists/* +RUN pip install --upgrade pip setuptools==69.5.1 # Update flash-attn. RUN pip install --upgrade --no-build-isolation flash-attn==2.5.8 - # Install the rest of dependencies. RUN pip install \ datasets \ @@ -23,14 +28,20 @@ RUN pip install \ numpy \ packaging \ safetensors \ + sentencepiece \ tqdm -``` +WORKDIR /workspace +RUN git clone https://github.com/swiss-ai/nanotron.git +WORKDIR /workspace/nanotron +RUN pip install -e .[nanosets] +RUN pip install megablocks==0.5.1 stanford-stk==0.7.1 --no-deps +``` -after image is built, create env `~/.edf/nanotron-moe.toml` with content (adapt to wherever the image is stored) +The env `nanotron-moe.toml` with content: ``` -image = "/capstor/scratch/cscs/$USER/container-images/nanotron-moe/nanotron-moe-v1.0.sqsh" +image = "/store/swissai/a06/containers/nanotron_moe/nanotron_moe.sqsh" mounts = ["/capstor", "/users", "/store"] workdir = "/users/$USER/" @@ -45,21 +56,3 @@ FI_CXI_DISABLE_HOST_REGISTER = "1" FI_MR_CACHE_MONITOR = "userfaultfd" NCCL_DEBUG = "INFO" ``` - -TODO: make image available on the cluster in /store - - - -in a running container (`srun --reservation=todi --environment=nanotron-moe --container-workdir=$PWD --pty bash`) -```bash -cd $SCRATCH/$USER/nanotron-multilingual # or wherever you want the venv -mkdir multilingual-venv && cd multilingual-venv -python -m venv --system-site-packages ./moe-venv -source ./moe-venv/bin/activate -git clone https://github.com/triton-lang/triton.git; \ - cd triton; \ - pip install ninja cmake wheel; # build-time dependencies \ - pip install -e python; cd .. -pip install megablocks==0.5.1 -``` - From 328b8c2ae023f949b122f4d66ee7c2f6cd812854 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20H=C3=A4gele?= <56194263+haeggee@users.noreply.github.com> Date: Fri, 6 Sep 2024 17:37:21 +0200 Subject: [PATCH 50/51] Update moe.md --- moe.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moe.md b/moe.md index 7c2bb672..bb67501e 100644 --- a/moe.md +++ b/moe.md @@ -44,7 +44,7 @@ The env `nanotron-moe.toml` with content: image = "/store/swissai/a06/containers/nanotron_moe/nanotron_moe.sqsh" mounts = ["/capstor", "/users", "/store"] -workdir = "/users/$USER/" +workdir = "/workspace/nanotron" writable = true [annotations] From 3bce1f4d93e042327fcb29bfd93ee8dd121f0e05 Mon Sep 17 00:00:00 2001 From: AleHC <36459138+AleHD@users.noreply.github.com> Date: Sat, 19 Oct 2024 16:15:58 +0200 Subject: [PATCH 51/51] Moe converters (#5) * Converters ready * Added xglm transformers implementation --------- Co-authored-by: Negar Foroutan --- examples/xglm/README.md | 5 + examples/xglm/convert_ntmoe2hf.py | 140 +++ examples/xglm/tests/test_moe.py | 182 +++ examples/xglm/transformers_impl/gating.py | 149 +++ examples/xglm/transformers_impl/xglm_model.py | 1119 +++++++++++++++++ src/nanotron/models/moe.py | 2 +- 6 files changed, 1596 insertions(+), 1 deletion(-) create mode 100644 examples/xglm/convert_ntmoe2hf.py create mode 100644 examples/xglm/tests/test_moe.py create mode 100644 examples/xglm/transformers_impl/gating.py create mode 100644 examples/xglm/transformers_impl/xglm_model.py diff --git a/examples/xglm/README.md b/examples/xglm/README.md index 8f62fc57..48447ac2 100644 --- a/examples/xglm/README.md +++ b/examples/xglm/README.md @@ -25,3 +25,8 @@ cd examples/xglm torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=checkpoints/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-8x564M --num-experts=8 ``` Note that this upcycling _drops_ the bias parameters of the MLP because the MegaBlocks implementation does not support bias parameters. While this is a limitation of the current implementation, the performance is quickly recovered after a few training steps. + +To save back to huggingface format use +```bash +torchrun examples/xglm/convert_ntmoe2hf.py --checkpoint-path=$SCRATCH/checkpoints/xglm-8x564M --save-path=$SCRATCH/checkpoints/huggingface/xglm-8x56fM +``` diff --git a/examples/xglm/convert_ntmoe2hf.py b/examples/xglm/convert_ntmoe2hf.py new file mode 100644 index 00000000..d971c96f --- /dev/null +++ b/examples/xglm/convert_ntmoe2hf.py @@ -0,0 +1,140 @@ +""" +Converts a nanotron moe model to HF format +Command: + torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights +""" + +import warnings +from argparse import ArgumentParser +from pathlib import Path +from typing import Optional + +import torch +from transformers import AutoTokenizer +from tqdm import tqdm + +from nanotron.config.models_config import GPT3MoEConfig +from nanotron.models.gpt3_moe import GPT3MoEForTraining, GPT3MoEBlock +from nanotron.models.moe import dMoE, SparseMLP, LearnedRouter + +from examples.xglm.convert_dense2moe import create_nt_moe_model +from examples.xglm.convert_nt2hf import convert_attention +from examples.xglm.convert_utils import convert_generic +from examples.xglm.transformers_impl.xglm_model import XGLMForCausalLM, XGLMDecoderLayer, XGLMmoeConfig, XGLMSparseMoeBlock, XGLMMLP +from examples.xglm.transformers_impl.gating import BasicGate + + +def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig: + if config.embd_pdrop != config.resid_pdrop: + warnings.warn( + f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion." + ) + if config.layer_norm_epsilon != 1e-5: + warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") + if config.moe_z_loss_weight != 0: + warnings.warn(f"transformer implementation does not support z loss") + assert not config.moe_glu, "Transformer implementation does not support glu MLP layers" + + return XGLMmoeConfig( + # Regular xglm config. + activation_function=config.activation_function, + attention_dropout=config.attn_pdrop, + dropout=config.embd_pdrop, + eos_token_id=config.eos_token_id, + d_model=config.hidden_size, + ffn_dim=config.intermediate_size, + max_position_embeddings=config.max_position_embeddings, + attention_heads=config.num_attention_heads, + num_layers=config.num_hidden_layers, + vocab_size=config.vocab_size, + decoder_start_token_id=config.position_embedding_offset, + activation_dropout=config.act_pdrop, + scale_embedding=config.scale_embedding, + # Moe specifics. + num_local_experts=config.moe_num_experts, + num_experts_per_tok=config.num_experts_per_tok, + gate_type="linear", + gate_depth=1, + router_aux_loss_coef=config.moe_loss_weight, + ) + + +def convert_mlp(mlp_hf: XGLMMLP, mlp_nt: SparseMLP): + convert_generic(mlp_hf.fc1, mlp_nt.w1.module) + convert_generic(mlp_hf.fc2, mlp_nt.w2.module) + + +def convert_gate(gate_hf: BasicGate, gate_nt: LearnedRouter): + convert_generic(gate_hf.gate, gate_nt.layer) + + +def convert_ff(ff_hf: XGLMSparseMoeBlock, ff_nt: dMoE): + convert_gate(ff_hf.gate, ff_nt.gate) + int_size = ff_nt.config.intermediate_size + if len(ff_hf.experts) == 1: + assert ff_nt.experts.mlp.w1.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + assert ff_nt.experts.mlp.w2.module.weight.shape == (ff_nt.config.hidden_size, int_size*len(ff_hf.experts)) + else: + assert ff_nt.experts.mlp.w1.module.weight.T.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + assert ff_nt.experts.mlp.w2.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size) + + for i, expert_hf in enumerate(ff_hf.experts): + i0 = i*int_size + i1 = (i + 1)*int_size + with torch.no_grad(): + if len(ff_hf.experts) == 1: + expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight[i0:i1, :].clone()) + expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[:, i0:i1].clone()) + else: + expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight.T[i0:i1, :].clone()) + expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[i0:i1, :].T.clone()) + +def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPT3MoEBlock): + convert_generic(block_hf.self_attn_layer_norm, block_nt.ln_1) + convert_attention(block_hf.self_attn, block_nt.attn) + convert_generic(block_hf.final_layer_norm, block_nt.ln_2) + convert_ff(block_hf.block_sparse_moe, block_nt.ff) + + +def convert(model_hf: XGLMForCausalLM, model_nt: GPT3MoEForTraining): + convert_generic(model_hf.model.embed_tokens, model_nt.model.token_embeddings.pp_block.token_embedding) + for layer_hf, layer_nt in tqdm(zip(model_hf.model.layers, model_nt.model.decoder), desc="Converting layers", + total=model_nt.config.num_hidden_layers): + convert_decoder(layer_hf, layer_nt.pp_block) + convert_generic(model_hf.model.layer_norm, model_nt.model.final_layer_norm.pp_block) + convert_generic(model_hf.lm_head, model_nt.model.lm_head.pp_block) + + +def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): + # Load nanotron model. + model_nt = create_nt_moe_model(checkpoint_path=checkpoint_path) + + # Init huggingface model. + model_config_hf = convert_config(model_nt.config) + model_hf = XGLMForCausalLM._from_config(model_config_hf) + + # Copy weights, initialize tokenizer and save model. + if tokenizer_name is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.save_pretrained(save_path) + states = torch.randn(4, 1, 1024) + convert(model_hf, model_nt), states.cuda().bfloat16() + print("Saving...") + model_hf.save_pretrained(save_path) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument( + "--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model" + ) + parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") + args = parser.parse_args() + ret = main(args.checkpoint_path, args.save_path, args.tokenizer_name) diff --git a/examples/xglm/tests/test_moe.py b/examples/xglm/tests/test_moe.py new file mode 100644 index 00000000..7536b84f --- /dev/null +++ b/examples/xglm/tests/test_moe.py @@ -0,0 +1,182 @@ +import torch +import pytest + +import nanotron +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.config.models_config import GPT3MoEConfig +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.trainer import mark_tied_parameters +from nanotron.models.gpt3_moe import GPT3MoEBlock, GPT3MoEForTraining +from nanotron.models.moe import LearnedRouter, dMoE + +from tests.helpers.utils import init_distributed + +from examples.xglm.convert_ntmoe2hf import convert_config, convert_gate, convert_ff, convert +from examples.xglm.tests.test_implementation import almost_close +from examples.xglm.transformers_impl.xglm_model import XGLMSparseMoeBlock, XGLMForCausalLM +from examples.xglm.transformers_impl.gating import BasicGate + + +MAX_SEQUENCE_LENGTH = 2048 +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. +#TEST_SEQUENCE_LENGTH = MAX_SEQUENCE_LENGTH +BATCH_SIZE = 4 +HIDDEN_SIZE = 1024 +#DTYPE = torch.bfloat16 +DTYPE = torch.float32 +TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" + +CONFIG = GPT3MoEConfig( + attn_pdrop=0.0, + embd_pdrop=0.0, + resid_pdrop=0.0, + act_pdrop=0.0, + eos_token_id=2, + hidden_size=HIDDEN_SIZE, + intermediate_size=4096, + layer_norm_epsilon=1e-05, + max_position_embeddings=MAX_SEQUENCE_LENGTH, + num_attention_heads=16, + num_hidden_layers=24, + scale_attn_weights=True, + vocab_size=256008, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=DTYPE is not torch.bfloat16, + # vvv moe vvv + is_moe=True, + moe_num_experts=8, + num_experts_per_tok=2, + moe_loss_weight=0.01, + moe_z_loss_weight=0.0, + moe_glu=False, +) +PARALLEL_CONFIG = ParallelismArgs(dp=1, pp=1, tp=1, expert_parallel_size=1) #CONFIG.moe_num_experts) + + +@pytest.fixture +def hidden_states() -> torch.Tensor: + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) + + +@pytest.fixture +def input_mask() -> torch.Tensor: + return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) + + +@pytest.fixture +def input_ids() -> torch.Tensor: + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) + + +def _test_nt2hf_gate(parallel_context: ParallelContext, hidden_states: torch.Tensor): + hidden_states = hidden_states.cuda() + + config_hf = convert_config(CONFIG) + gate_nt = LearnedRouter(CONFIG).cuda().to(DTYPE) + gate_hf = BasicGate(config_hf).cuda().to(DTYPE) + convert_gate(gate_hf, gate_nt) + + router_logits_nt, _, _ = gate_nt(hidden_states.view(-1, HIDDEN_SIZE)) + router_logits_hf = gate_hf(hidden_states.permute(1, 0, 2).reshape(-1, HIDDEN_SIZE), "") + + router_logits_nt = router_logits_nt.view(TEST_SEQUENCE_LENGTH, BATCH_SIZE, -1) + router_logits_hf = router_logits_hf.view(BATCH_SIZE, TEST_SEQUENCE_LENGTH, -1).permute(1, 0, 2) + + assert router_logits_nt.size() == router_logits_hf.size() + torch.testing.assert_close(router_logits_nt, router_logits_hf) + + +def test_nt2hf_gate(hidden_states: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_gate)(hidden_states=hidden_states) + + +def _test_nt2hf_ff(parallel_context: ParallelContext, hidden_states: torch.Tensor, + num_experts: int, num_experts_per_tok: int): + hidden_states = hidden_states.cuda() + + config = {**vars(CONFIG)} + config.update({"moe_num_experts": num_experts, "num_experts_per_tok": num_experts_per_tok}) + config = GPT3MoEConfig(**config) + config_hf = convert_config(config) + ff_nt = dMoE(config, parallel_context, PARALLEL_CONFIG).cuda().to(DTYPE) + ff_hf = XGLMSparseMoeBlock(config_hf).cuda().to(DTYPE) + convert_ff(ff_hf, ff_nt) + + out_nt = ff_nt(hidden_states)["hidden_states"] + out_hf, _ = ff_hf(hidden_states.permute(1, 0, 2).contiguous(), "") + out_hf = out_hf.permute(1, 0, 2) + + assert out_nt.size() == out_hf.size() + almost_close(out_nt, out_hf, max_far=0.05, far_atol=0.003) + + +@pytest.mark.parametrize("num_experts,num_experts_per_tok", [(1, 1), (2, 1), (4, 1), (4, 2), (8, 1), (8, 2), (8, 4)]) +def test_nt2hf_ff(hidden_states: torch.Tensor, num_experts: int, num_experts_per_tok: int): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_ff)(hidden_states=hidden_states, num_experts=num_experts, num_experts_per_tok=num_experts_per_tok) + + +def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + + # Get nanotron model. + config_nt = GPT3MoEConfig(**vars(CONFIG)) + if new_dtype not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3MoEForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=new_dtype, + device="cuda", + ).eval() + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + # Create empty model_hf and make conversion. + model_hf = XGLMForCausalLM(convert_config(config_nt)).cuda().to(new_dtype).eval() + convert(model_hf, model_nt) + + # Needed :/ + aux_losses = { + "load_balancing_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + "z_loss": ( + torch.zeros(1, device=input_ids.device) + if not isinstance(input_ids, TensorPointer) + else TensorPointer(self.input_pp_rank) + ), + } + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask, aux_losses)["sharded_logits"].to(new_dtype) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask, output_router_logits=False).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + return out_nt.cpu(), out_hf.cpu() + + +def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.01, far_atol=2.0) # We allow for less than 1% errors, but some of these are very large! + #torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) + + +def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) diff --git a/examples/xglm/transformers_impl/gating.py b/examples/xglm/transformers_impl/gating.py new file mode 100644 index 00000000..efa0e357 --- /dev/null +++ b/examples/xglm/transformers_impl/gating.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import math + +from abc import ABC, abstractmethod + + +class Gate(ABC): + def __init__(self, device): + super(Gate, self).__init__() + self.device = device + + @abstractmethod + def compute(self, x): + """ + Compute the output of the gate. + This method should be implemented by all subclasses. + """ + pass + + +def init_x_embeddings(Xs, x_embedding_dim): + x2embeddings = nn.ParameterDict(dict()) + for x in Xs: + x_embedding = torch.empty(x_embedding_dim) + nn.init.normal_(x_embedding) + x2embeddings[str(x)] = nn.Parameter(x_embedding) + return x2embeddings + + +class BasicGate(nn.Module): + """One or two layer feedforward network as the Gate.""" + + def __init__(self, config) -> None: + super().__init__() + + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.ffn_dim = config.ffn_dim + self.activation = nn.ReLU(self.ffn_dim) + + if config.gate_depth == 1: + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + elif config.gate_depth == 2: + self.gate = nn.Sequential( + nn.Linear(self.hidden_dim, self.ffn_dim), + self.activation, + nn.Linear(self.ffn_dim, self.num_experts, bias=False), + ) + else: + raise ValueError("Invalid gate_depth!") + + def forward(self, x, lang_name): + return self.gate(x) + + +class LanguageAwareGate(nn.Module): + """One or two layer feedforward network as the Gate.""" + + def __init__(self, config) -> None: + super().__init__() + + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.ffn_dim = config.ffn_dim + self.activation = nn.ReLU(self.ffn_dim) + self.language_embedding_dim = ( + config.language_embedding_dim + if config.language_embedding_dim is not None + else config.hidden_size + ) + self.lang_embeddings = init_x_embeddings( + config.languages, self.language_embedding_dim + ) + + if config.gate_depth == 1: + self.gate = nn.Linear( + self.hidden_dim + self.language_embedding_dim, + self.num_experts, + bias=False, + ) + elif config.gate_depth == 2: + self.gate = nn.Sequential( + nn.Linear(self.hidden_dim, self.ffn_dim), + self.activation, + nn.Linear(self.ffn_dim, self.num_experts, bias=False), + ) + else: + raise ValueError("Invalid gate_depth!") + + def forward(self, x, lang_name): + # TODO x needs to be added to the language embedding (we need to pass the language as well) + lang_embedding = self.lang_embeddings[str(lang_name)] + lang_embedding.squeeze(0) + lang_embedding = lang_embedding.expand(x.shape[0], -1) + x = torch.cat((x, lang_embedding), dim=-1) + return self.gate(x) + + +class TopKGate(Gate): + def __init__(self, device, straight_through, k=1): + super(TopKGate, self).__init__(device) + self.k = k + self.device = device + self.straight_through = straight_through + + def compute(self, x): + if self.k > 1: + topk_gate_scores, indices = torch.topk(x, self.k) + topk_gate_scores = F.softmax( + topk_gate_scores, + dim=1, + dtype=torch.float, + ).type_as(x) + mask = F.one_hot(indices, x.shape[-1]).float() + mask_flat = mask.sum(dim=-1) + combine_tensor = ( + topk_gate_scores[..., None, None, None] + * mask_flat[..., None, None, None] + * F.one_hot(indices, x.shape[-1])[..., None, None] + ) + combine_tensor = combine_tensor.sum(1) + return combine_tensor, indices, topk_gate_scores + elif self.k == 1: + x = F.softmax(x, dim=-1) + topk_gate_scores, index = x.topk( + k=self.k, dim=-1 + ) # torch.nn.functional.softmax(x , dim=-1).topk(k=self.k, dim=-1) + if self.straight_through: + index_soft = F.softmax(x, dim=-1) + index = (index - index_soft).detach() + index_soft + index = index[:, 0] + topk_gate_scores, index = map( + lambda x: x.squeeze(dim=-1), (topk_gate_scores, index) + ) + else: + topk_gate_scores, index = map( + lambda x: x.squeeze(dim=-1), (topk_gate_scores, index) + ) + + mask = F.one_hot(index, x.shape[-1]).float() + mask_flat = mask.sum(dim=-1) + combine_tensor = ( + topk_gate_scores[..., None, None, None] + * mask_flat[..., None, None, None] + * F.one_hot(index, x.shape[-1])[..., None, None] + ) + return combine_tensor, index, topk_gate_scores diff --git a/examples/xglm/transformers_impl/xglm_model.py b/examples/xglm/transformers_impl/xglm_model.py new file mode 100644 index 00000000..aa80a7fe --- /dev/null +++ b/examples/xglm/transformers_impl/xglm_model.py @@ -0,0 +1,1119 @@ +from typing import List, Optional, Tuple, Union, Literal +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.nn import CrossEntropyLoss + +from transformers.models.xglm.modeling_xglm import ( + XGLMAttention, + XGLMPreTrainedModel, + XGLMSinusoidalPositionalEmbedding, +) +from transformers.models.xglm.configuration_xglm import XGLMConfig +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) + +# from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + add_code_sample_docstrings, + ModelOutput, +) +from examples.xglm.transformers_impl.gating import BasicGate, LanguageAwareGate + +from accelerate.logging import get_logger + +logger = get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "facebook/xglm-564M" +_CONFIG_FOR_DOC = "XGLMConfig" + + +class XGLMmoeConfig(XGLMConfig): + def __init__(self, + num_local_experts: int = 1, + num_experts_per_tok: int = 1, + gate_type: Literal["linear", "lang_aware_linear"] = "linear", + gate_depth: Literal[1, 2] = 1, + router_aux_loss_coef: float = 1.0, + **kwargs): + + super().__init__(**kwargs) + self.num_local_experts = num_local_experts + self.num_experts_per_tok = num_experts_per_tok + self.gate_type = gate_type + self.gate_depth = gate_depth + self.router_aux_loss_coef = router_aux_loss_coef + + +@dataclass +class MoEModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class MoECausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden + states terms, to train a MoE model. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + z_loss for the sparse modules. + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse + modules. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + z_loss: torch.FloatTensor = None + aux_loss: torch.FloatTensor = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +XGLM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`XGLMConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" +XGLM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of + the decoder. + encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): + Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_layers, attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +def load_balancing_loss_func( + gate_logits: torch.Tensor, + num_experts: torch.Tensor = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return torch.zeros(1) + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand( + (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) + ) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +class XGLMMLP(nn.Module): + + def __init__(self, config: XGLMConfig): + super().__init__() + self.ffn_dim = config.ffn_dim + self.hidden_dim = config.d_model + + self.fc1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.fc2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + + self.activation_fn = ACT2FN[config.activation_function] + self.dropout = config.dropout + self.activation_dropout = config.activation_dropout + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) + hidden_states = self.fc2(hidden_states) + # TODO this dropout can be removed if we use the dropout in the XGLMDecoderLayer + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + return hidden_states + + def set_mlp_weights(self, fc1, fc2, std=0.02): + """Set the weights of the MLP to the given weights""" + + self.fc1.weight.data = fc1.weight.data.clone() + # norm1_pre = self.fc1.weight.data.norm(2) + self.fc1.weight.data.add_( + torch.normal(mean=0, std=std, size=self.fc1.weight.shape) + ) + # norm1_after = self.fc1.weight.data.norm(2) + + self.fc1.bias.data = fc1.bias.data.clone() + self.fc1.bias.data.add_(torch.normal(mean=0, std=std, size=self.fc1.bias.shape)) + + self.fc2.weight.data = fc2.weight.data.clone() + # norm2_pre = self.fc2.weight.data.norm(2) + self.fc2.weight.data.add_( + torch.normal(mean=0, std=std, size=self.fc2.weight.shape) + ) + # norm2_after = self.fc2.weight.data.norm(2) + self.fc2.bias.data = fc2.bias.data.clone() + self.fc2.bias.data.add_(torch.normal(mean=0, std=std, size=self.fc2.bias.shape)) + + # print("********") + # print(f"Norm1: {norm1_pre} -> {norm1_after}") + # print(f"Norm2: {norm2_pre} -> {norm2_after}") + # print("********") + + +class XGLMSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + # self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating network + if config.gate_type == "linear": + self.gate = BasicGate(config) + elif config.gate_type == "lang_aware_linear": + self.gate = LanguageAwareGate(config) + # self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([XGLMMLP(config) for _ in range(self.num_experts)]) + + def forward(self, hidden_states: torch.Tensor, lang_name: str) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states, lang_name) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + if self.top_k == 1: + routing_weights, selected_experts = routing_weights.max(dim=-1, keepdim=True) + else: + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) + * routing_weights[top_x_list, idx_list, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states, router_logits + + def set_mlp_weights(self, fc1, fc2, std=0.02): + """Set the expert (MLP) weights by given weights""" + for i in range(self.num_experts): + self.experts[i].set_mlp_weights(fc1, fc2, std) + + +class XGLMDecoderLayer(nn.Module): + def __init__(self, config: XGLMConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = XGLMAttention( + embed_dim=self.embed_dim, + num_heads=config.attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.dropout = config.dropout + # self.activation_fn = ACT2FN[config.activation_function] + # self.activation_dropout = config.activation_dropout + + if config.add_cross_attention: + self.encoder_attn = XGLMAttention( + embed_dim=self.embed_dim, + num_heads=config.attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + self.block_sparse_moe = XGLMSparseMoeBlock(config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = True, + lang_name: Optional[str] = None, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = ( + past_key_value[-2:] if past_key_value is not None else None + ) + hidden_states, cross_attn_weights, cross_attn_present_key_value = ( + self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + ) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + hidden_states, router_logits = self.block_sparse_moe(hidden_states, lang_name) + # hidden_states = self.activation_fn(self.fc1(hidden_states)) + # hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + # hidden_states = self.fc2(hidden_states) + # hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +@add_start_docstrings( + "The bare XGLM Model transformer outputting raw hidden-states without any specific head on top.", + XGLM_START_DOCSTRING, +) +class XGLMModel(XGLMPreTrainedModel): + """ + Transformer decoder consisting of *config.num_layers* layers. Each layer is a [`XGLMDecoderLayer`] + + Args: + config: XGLMConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, config: XGLMConfig, embed_tokens: Optional[nn.Embedding] = None): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.layerdrop + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + + if embed_tokens is not None: + self.embed_tokens = embed_tokens + else: + self.embed_tokens = nn.Embedding( + config.vocab_size, config.d_model, self.padding_idx + ) + + self.embed_positions = XGLMSinusoidalPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + config.pad_token_id, + ) + self.layers = nn.ModuleList( + [XGLMDecoderLayer(config) for _ in range(config.num_layers)] + ) + self.layer_norm = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MoEModelOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + lang_name: str = None, + ) -> Union[Tuple[torch.Tensor], MoEModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = ( + past_key_values[0][0].shape[2] if past_key_values is not None else 0 + ) + + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, + input_shape[-1] + past_key_values_length, + dtype=torch.long, + device=( + input_ids.device if input_ids is not None else inputs_embeds.device + ), + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + hidden_states = inputs_embeds + self.embed_positions( + position_ids, past_key_values_length + ) + hidden_states = nn.functional.dropout( + hidden_states, p=float(self.dropout), training=self.training + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache =" + " False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + all_cross_attentions = ( + () if (output_attentions and encoder_hidden_states is not None) else None + ) + next_decoder_cache = () if use_cache else None + + # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired + for attn_mask, mask_name in zip( + [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"] + ): + if attn_mask is not None: + if attn_mask.size()[0] != len(self.layers): + raise ValueError( + f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + head_mask[idx] if head_mask is not None else None, + ( + cross_attn_head_mask[idx] + if cross_attn_head_mask is not None + else None + ), + None, + output_attentions, + output_router_logits, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + cross_attn_layer_head_mask=( + cross_attn_head_mask[idx] + if cross_attn_head_mask is not None + else None + ), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + output_router_logits=output_router_logits, + lang_name=lang_name, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.layer_norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + return MoEModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + router_logits=all_router_logits, + ) + + +@add_start_docstrings( + """ + The XGLM Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + XGLM_START_DOCSTRING, +) +class XGLMForCausalLM(XGLMPreTrainedModel): + base_model_prefix = "model" + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = XGLMModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_local_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @add_start_docstrings_to_model_forward(XGLM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MoECausalLMOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + lang_name=None, + ) -> Union[Tuple[torch.Tensor], MoECausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + """ + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + lang_name=lang_name, + ) + + logits = self.lm_head(outputs[0]) + + loss = None + if labels is not None: + # shift labels and add a pad token to the end + shift_labels = labels.new_zeros(labels.shape) + shift_labels[:, :-1] = labels[:, 1:].clone() + shift_labels[:, -1] = self.config.pad_token_id + + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.config.vocab_size), shift_labels.view(-1) + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_local_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithCrossAttentions( + return MoECausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + use_cache=None, + output_router_logits=False, + **kwargs, + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + # first step, decoder_cached_states are empty + return { + "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "output_router_logits": output_router_logits, + } + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +def set_moe_layers(model: nn.Module, config: XGLMConfig): + """Replaces the model's original layer with an MoE one, and sest the weights of the experts (MLPs) to the original MLP weights""" + new_layers = [] + for i in range(config.num_layers): + layer = model.model.layers[i] + moe_layer = XGLMDecoderLayer(config) + moe_layer.block_sparse_moe.set_mlp_weights( + layer.fc1, layer.fc2, config.expert_init_std + ) + new_layers.append(moe_layer) + model.model.layers = nn.ModuleList(new_layers) + + +def freeze_parameters(model, args): + """Freezes paramaters if it's needed.""" + if args.freeze_non_moe_params: + for name, param in model.named_parameters(): + name = name.lower() + if "expert" not in name and "gate" not in name and "lm_head" not in name: + param.requires_grad = False + else: + param.requires_grad = True + + logger.info("non-MoE + lm_head parameters are frozen!") + + model.lm_head.weight.requires_grad = True + # model.lm_head.bias.requires_grad = True + + # for name, param in model.named_parameters(): + # if "block_sparse_moe" not in name: + # print(f"{name}: {param.requires_grad}") + + +def copy_model_weights(pretrained_model: nn.Module, moe_model: XGLMForCausalLM, config: XGLMConfig): + """Replaces the model's original layer with an MoE one, and sest the weights of the experts (MLPs) to the original MLP weights""" + state_dicts = pretrained_model.state_dict() + moe_model.load_state_dict(state_dicts, strict=False) + + # new_layers = [] + for i in range(config.num_layers): + layer = pretrained_model.model.layers[i] + moe_model.model.layers[i].block_sparse_moe.set_mlp_weights( + layer.fc1, layer.fc2, config.init_std + ) + # moe_layer = XGLMDecoderLayer(config) + # moe_layer.block_sparse_moe.set_mlp_weights(layer.fc1, layer.fc2, config.init_std) + # new_layers.append(moe_layer) + # moe_model.model.layers = nn.ModuleList(new_layers) diff --git a/src/nanotron/models/moe.py b/src/nanotron/models/moe.py index f7bb07bd..98add57d 100644 --- a/src/nanotron/models/moe.py +++ b/src/nanotron/models/moe.py @@ -718,4 +718,4 @@ def forward(self, x, topo): x1 = self.sdd(x, self.w1.module.weight, topo) x2 = self.sdd(x, self.w3.module.weight, topo) x = stk.ops.mul(act_fn(x1, self.act), x2) - return self.dsd(x, self.w2.module.weight) \ No newline at end of file + return self.dsd(x, self.w2.module.weight)