Skip to content

Commit

Permalink
Revert "Remove transformer_replace" (NVlabs#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijian-liu authored Aug 14, 2024
1 parent 4acbe97 commit 4956922
Show file tree
Hide file tree
Showing 19 changed files with 15,463 additions and 74 deletions.
1 change: 1 addition & 0 deletions environment_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ pip install -e ".[eval]"
# Install HF's Transformers
pip install git+https://github.com/huggingface/[email protected]
site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
cp -rv ./llava/train/transformers_replace/* $site_pkg_path/transformers/
cp -rv ./llava/train/deepspeed_replace/* $site_pkg_path/deepspeed/
106 changes: 58 additions & 48 deletions llava/model/language_model/llava_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,28 @@
# This file is modified from https://github.com/haotian-liu/LLaVA/


import inspect
import os
from functools import partial
from typing import List, Optional, Tuple, Union
from unittest.mock import patch

import torch
from torch.nn import functional as F
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast

from llava.utils import distributed as dist

from ...train.utils import calculate_loss_weight
from ..configuration_llava import LlavaConfig
from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel


def _get_unpad_data(attention_mask, _seqlens_in_batch, *args, **kwargs):
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = _seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(_seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (indices, cu_seqlens, max_seqlen_in_batch)


class LlavaLlamaConfig(LlavaConfig):
model_type = "llava_llama"


# FIXME we will follow the convention to add a new class for CausalLM in the future
## FIXME we will follow the convention to add a new class for CausalLM in the future
class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel):
config_class = LlavaLlamaConfig
main_input_name = "input_embeds"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True

def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None:
super().__init__(config)
Expand Down Expand Up @@ -104,15 +92,16 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
seqlens_in_batch: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
packing: bool = True,
seqlens_in_batch: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
dpo_forward: bool = False,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
self.freezed_module_patch()

if inputs_embeds is None:
(
input_ids,
Expand All @@ -125,50 +114,71 @@ def forward(
input_ids, position_ids, attention_mask, past_key_values, labels, images
)

if packing and self.training and not dpo_forward:
support_packing = "seqlens_in_batch" in inspect.signature(self.llm.forward).parameters

if self.training and support_packing and not dpo_forward:
(
_,
new_position_ids,
new_attention_mask,
_,
new_inputs_embeds,
new_labels,
sorted_seqlens_in_batch,
) = self.repack_multimodal_data(
input_ids,
position_ids,
attention_mask,
_,
past_key_values,
inputs_embeds,
labels,
_seqlens_in_batch,
) = self.repack_multimodal_data(
input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels
)
if _seqlens_in_batch is None:
_seqlens_in_batch = seqlens_in_batch

wraps = partial(_get_unpad_data, _seqlens_in_batch=_seqlens_in_batch)
with patch(self.llm.__module__ + "._get_unpad_data", wraps=wraps) as m:
outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
labels=labels,
**kwargs,
)
if 0 in attention_mask and m.call_count == 0:
raise ValueError("The forward function of the model should call `_get_unpad_data` function.")
if sorted_seqlens_in_batch is None:
sorted_seqlens_in_batch = seqlens_in_batch
new_input_ids = None
past_key_values = None
else:
outputs = self.llm(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
new_attention_mask = attention_mask
new_position_ids = position_ids
new_inputs_embeds = inputs_embeds
new_labels = labels
sorted_seqlens_in_batch = attention_mask.sum(-1).int()
new_input_ids = input_ids

if support_packing:
outputs = self.llm.forward(
input_ids=new_input_ids,
attention_mask=new_attention_mask,
position_ids=new_position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
**kwargs,
inputs_embeds=new_inputs_embeds,
labels=new_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
seqlens_in_batch=sorted_seqlens_in_batch,
)
else:
outputs = self.llm.forward(
input_ids=new_input_ids,
attention_mask=new_attention_mask,
position_ids=new_position_ids,
past_key_values=past_key_values,
inputs_embeds=new_inputs_embeds,
labels=new_labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

# Loss rescale for SP & DP loss match
if dist.size() > 1:
loss_weight = calculate_loss_weight(labels)
outputs.loss = outputs.loss * loss_weight
loss_weight = calculate_loss_weight(new_labels)
outputs.loss = outputs.loss * loss_weight

if dpo_forward:
return outputs.logits, labels
return outputs.logits, new_labels
return outputs


Expand Down
Loading

0 comments on commit 4956922

Please sign in to comment.