Skip to content

Commit

Permalink
Add long-short sampler and long-video needle test (NVlabs#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
yukang2017 authored Aug 15, 2024
1 parent 4956922 commit f92c7a5
Show file tree
Hide file tree
Showing 21 changed files with 1,779 additions and 18 deletions.
10 changes: 6 additions & 4 deletions llava/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2632,15 +2632,16 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:

# TODO: Remove the hard coding of NUM_TOKENS_PER_IMAGE
NUM_TOKENS_PER_IMAGE = 196
if hasattr(self.data_args.image_processor, "crop_size"):
crop_size = self.data_args.image_processor.crop_size
else:
crop_size = self.data_args.image_processor.size

# Init the padding sample
seq_id = 0
while seq_id < len(input_ids):
# Skip the samples without images
if len(images[seq_id]) == 0:
seq_id += 1
continue
dummy_image = torch.ones_like(images[seq_id][:1])
dummy_image = torch.ones((1, 3, crop_size["height"], crop_size["width"]), device=input_ids[seq_id].device)
# dummy input_ids include one bos, one image token, and one eos
dummy_input_ids = torch.zeros_like(input_ids[seq_id][:3])
dummy_input_ids[0] = self.tokenizer.bos_token_id
Expand Down Expand Up @@ -2832,6 +2833,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
images=flat_batch_images,
position_ids=position_ids,
)

return batch


Expand Down
16 changes: 16 additions & 0 deletions llava/data/datasets_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,22 @@ def register_datasets_mixtures():
)
add_dataset(shot2story_shotonly)

longvideo_sft = Dataset(
dataset_name="longvideo_sft",
dataset_type="torch",
data_path="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/longvideo_sft/longvideo_sft.json",
image_path="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/shot2story/Shot2Story/data/videos",
)
add_dataset(longvideo_sft)

longvideo_sft_deepseek = Dataset(
dataset_name="longvideo_sft_deepseek",
dataset_type="torch",
data_path="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/longvideo_sft/longvideo_sft_deepseek.json",
image_path="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/shot2story/Shot2Story/data/videos",
)
add_dataset(longvideo_sft_deepseek)

sharegpt_video = Dataset(
dataset_name="sharegpt_video",
dataset_type="torch",
Expand Down
Loading

0 comments on commit f92c7a5

Please sign in to comment.