Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jamba liger fused linear+xentropy #102

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions src/liger_kernel/transformers/model/jamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import Optional, Tuple, Union

import torch
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.models.jamba.modeling_jamba import (
_CONFIG_FOR_DOC,
JAMBA_INPUTS_DOCSTRING,
HybridMambaAttentionDynamicCache,
load_balancing_loss_func,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)


@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: Optional[Union[int, None]] = None,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

num_logits_to_keep (`int` or `None`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
`input_ids`. Only last token logits are needed for generation, and calculating them only for that token
can save memory, which becomes pretty significant for long sequences.

Returns:

Example:

```python
>>> from transformers import AutoTokenizer, JambaForCausalLM

>>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""

output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_router_logits = (
output_router_logits
if output_router_logits is not None
else self.config.output_router_logits
)

output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
return_dict=return_dict,
)

hidden_states = outputs[0]

loss = None
logits = None

if self.training:
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)

lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
if num_logits_to_keep is None:
logits = self.lm_head(hidden_states)
else:
logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
logits = logits.float()

if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits if return_dict else outputs[-1],
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(
loss.device
) # make sure to reside in the same device

if not return_dict:
output = (logits,) + outputs[1:]
if output_router_logits:
output = (aux_loss,) + output
return (loss,) + output if loss is not None else output

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
38 changes: 38 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,41 @@ def apply_liger_kernel_to_qwen2(
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
if swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP


def apply_liger_kernel_to_jamba(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Jamba models
to make GPU go burrr.

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused lienar cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
geglu (bool): Whether to apply Liger's GeGLU MLP. Default is True.
"""
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."

from transformers.models.jamba import modeling_jamba

if rope:
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the comment is wrong

modeling_jamba.JambaRMSNorm = LigerRMSNorm
if cross_entropy:
modeling_jamba.CrossEntropyLoss = LigerCrossEntropyLoss
if swiglu:
modeling_jamba.JambaMLP = LigerSwiGLUMLP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is lce_forward?

2 changes: 2 additions & 0 deletions src/liger_kernel/transformers/trainer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from liger_kernel.transformers.monkey_patch import (
apply_liger_kernel_to_gemma,
apply_liger_kernel_to_jamba,
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
Expand All @@ -12,6 +13,7 @@
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
MODEL_TYPE_TO_APPLY_LIGER_FN = {
"gemma": apply_liger_kernel_to_gemma,
"jamba": apply_liger_kernel_to_jamba,
"llama": apply_liger_kernel_to_llama,
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
Expand Down