diff --git a/src/liger_kernel/transformers/model/jamba.py b/src/liger_kernel/transformers/model/jamba.py new file mode 100644 index 000000000..a10b4b43e --- /dev/null +++ b/src/liger_kernel/transformers/model/jamba.py @@ -0,0 +1,168 @@ +from typing import Optional, Tuple, Union + +import torch +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import MoeCausalLMOutputWithPast +from transformers.models.jamba.modeling_jamba import ( + _CONFIG_FOR_DOC, + JAMBA_INPUTS_DOCSTRING, + HybridMambaAttentionDynamicCache, + load_balancing_loss_func, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + + +@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: Optional[Union[int, None]] = None, +) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int` or `None`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all + `input_ids`. Only last token logits are needed for generation, and calculating them only for that token + can save memory, which becomes pretty significant for long sequences. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, JambaForCausalLM + + >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + if num_logits_to_keep is None: + logits = self.lm_head(hidden_states) + else: + logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :]) + logits = logits.float() + + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( + loss.device + ) # make sure to reside in the same device + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 375a6a28d..fbb2b147a 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -169,3 +169,41 @@ def apply_liger_kernel_to_qwen2( modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward if swiglu: modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP + + +def apply_liger_kernel_to_jamba( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Jamba models + to make GPU go burrr. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused lienar cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True. + """ + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + + from transformers.models.jamba import modeling_jamba + + if rope: + modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + # https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109 + modeling_jamba.JambaRMSNorm = LigerRMSNorm + if cross_entropy: + modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss + if swiglu: + modeling_jamba.JambaMLP = LigerSwiGLUMLP diff --git a/src/liger_kernel/transformers/trainer_integration.py b/src/liger_kernel/transformers/trainer_integration.py index 4caf03173..00726f077 100644 --- a/src/liger_kernel/transformers/trainer_integration.py +++ b/src/liger_kernel/transformers/trainer_integration.py @@ -2,6 +2,7 @@ from liger_kernel.transformers.monkey_patch import ( apply_liger_kernel_to_gemma, + apply_liger_kernel_to_jamba, apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, @@ -12,6 +13,7 @@ # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py MODEL_TYPE_TO_APPLY_LIGER_FN = { "gemma": apply_liger_kernel_to_gemma, + "jamba": apply_liger_kernel_to_jamba, "llama": apply_liger_kernel_to_llama, "mistral": apply_liger_kernel_to_mistral, "mixtral": apply_liger_kernel_to_mixtral,