diff --git a/src/nanotron/models/llama_sft.py b/src/nanotron/models/llama_sft.py index f66555f9..2cd12eb7 100644 --- a/src/nanotron/models/llama_sft.py +++ b/src/nanotron/models/llama_sft.py @@ -26,7 +26,6 @@ from nanotron.config.models_config import RandomInit, SpectralMupInit from nanotron.generation.generate_store import AttachableStore from nanotron.kernels.rope import liger_rotary_pos_emb -from nanotron.kernels.swiglu import LigerSiLUMulFunction from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN @@ -157,24 +156,19 @@ def __init__( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - self.gate_proj = TensorParallelColumnLinear( - config.hidden_size, - config.intermediate_size, - pg=tp_pg, - mode=tp_mode, - bias=False, - async_communication=tp_linear_async_communication, + gate_up_contiguous_chunks = ( + config.intermediate_size, # shape of gate_linear + config.intermediate_size, # shape of up_linear ) - - self.up_proj = TensorParallelColumnLinear( + self.gate_up_proj = TensorParallelColumnLinear( config.hidden_size, - config.intermediate_size, + 2 * config.intermediate_size, pg=tp_pg, mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, + contiguous_chunks=gate_up_contiguous_chunks, ) - self.down_proj = TensorParallelRowLinear( config.intermediate_size, config.hidden_size, @@ -183,10 +177,13 @@ def __init__( bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, ) + # TODO @nouamane: why can't we torch.jit.script GLUActivation? + self.split_silu_mul = GLUActivation(config.hidden_act) def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] - - return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(hidden_states), self.up_proj(hidden_states))) + merged_states = self.gate_up_proj(hidden_states) + hidden_states = self.down_proj(self.split_silu_mul(merged_states)) + return hidden_states class CausalSelfAttention(nn.Module, AttachableStore):