From f92c7a55007a506cb74bf71c15a18ba38e95f85f Mon Sep 17 00:00:00 2001 From: yukang Date: Thu, 15 Aug 2024 12:04:10 +0800 Subject: [PATCH] Add long-short sampler and long-video needle test (#160) --- llava/data/dataset.py | 10 +- llava/data/datasets_mixture.py | 16 + .../eval/vision_niah_vila/eval_vision_niah.py | 412 ++++++++++++++++++ .../produce_haystack_embedding.py | 96 ++++ .../produce_needle_embedding.py | 66 +++ .../zigzag_ring_attn/monkey_patch.py | 125 ++++++ .../zigzag_ring_attn/prepare_inputs.py | 53 +++ llava/mm_utils.py | 5 +- llava/train/args.py | 1 + llava/train/llava_trainer.py | 116 ++++- llava/train/train_llm_to_long.py | 380 ++++++++++++++++ .../merge_lora_weights_and_save_hf_model.py | 105 +++++ scripts/deepspeed_inference.yaml | 17 + scripts/v1_5/eval/needle.sh | 22 + scripts/v1_5/longvila/8b/4_extend_llm_256k.sh | 58 +++ scripts/v1_5/longvila/8b/4_extend_llm_64k.sh | 58 +++ .../v1_5/longvila/8b/5_long_sft_1024frames.sh | 59 +++ .../v1_5/longvila/8b/5_long_sft_128frames.sh | 59 +++ .../v1_5/longvila/8b/5_long_sft_256frames.sh | 59 +++ .../v1_5/longvila/8b/5_long_sft_512frames.sh | 59 +++ scripts/zero3_offload_inference.json | 21 + 21 files changed, 1779 insertions(+), 18 deletions(-) create mode 100644 llava/eval/vision_niah_vila/eval_vision_niah.py create mode 100644 llava/eval/vision_niah_vila/produce_haystack_embedding.py create mode 100644 llava/eval/vision_niah_vila/produce_needle_embedding.py create mode 100644 llava/eval/vision_niah_vila/zigzag_ring_attn/monkey_patch.py create mode 100644 llava/eval/vision_niah_vila/zigzag_ring_attn/prepare_inputs.py create mode 100644 llava/train/train_llm_to_long.py create mode 100644 llava/utils/merge_lora_weights_and_save_hf_model.py create mode 100644 scripts/deepspeed_inference.yaml create mode 100644 scripts/v1_5/eval/needle.sh create mode 100644 scripts/v1_5/longvila/8b/4_extend_llm_256k.sh create mode 100644 scripts/v1_5/longvila/8b/4_extend_llm_64k.sh create mode 100644 scripts/v1_5/longvila/8b/5_long_sft_1024frames.sh create mode 100644 scripts/v1_5/longvila/8b/5_long_sft_128frames.sh create mode 100644 scripts/v1_5/longvila/8b/5_long_sft_256frames.sh create mode 100644 scripts/v1_5/longvila/8b/5_long_sft_512frames.sh create mode 100644 scripts/zero3_offload_inference.json diff --git a/llava/data/dataset.py b/llava/data/dataset.py index 330e6bf0..3a7b32f9 100755 --- a/llava/data/dataset.py +++ b/llava/data/dataset.py @@ -2632,15 +2632,16 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: # TODO: Remove the hard coding of NUM_TOKENS_PER_IMAGE NUM_TOKENS_PER_IMAGE = 196 + if hasattr(self.data_args.image_processor, "crop_size"): + crop_size = self.data_args.image_processor.crop_size + else: + crop_size = self.data_args.image_processor.size # Init the padding sample seq_id = 0 while seq_id < len(input_ids): # Skip the samples without images - if len(images[seq_id]) == 0: - seq_id += 1 - continue - dummy_image = torch.ones_like(images[seq_id][:1]) + dummy_image = torch.ones((1, 3, crop_size["height"], crop_size["width"]), device=input_ids[seq_id].device) # dummy input_ids include one bos, one image token, and one eos dummy_input_ids = torch.zeros_like(input_ids[seq_id][:3]) dummy_input_ids[0] = self.tokenizer.bos_token_id @@ -2832,6 +2833,7 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: images=flat_batch_images, position_ids=position_ids, ) + return batch diff --git a/llava/data/datasets_mixture.py b/llava/data/datasets_mixture.py index d04441c1..a2814148 100755 --- a/llava/data/datasets_mixture.py +++ b/llava/data/datasets_mixture.py @@ -976,6 +976,22 @@ def register_datasets_mixtures(): ) add_dataset(shot2story_shotonly) + longvideo_sft = Dataset( + dataset_name="longvideo_sft", + dataset_type="torch", + data_path="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/longvideo_sft/longvideo_sft.json", + image_path="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/shot2story/Shot2Story/data/videos", + ) + add_dataset(longvideo_sft) + + longvideo_sft_deepseek = Dataset( + dataset_name="longvideo_sft_deepseek", + dataset_type="torch", + data_path="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/longvideo_sft/longvideo_sft_deepseek.json", + image_path="/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/shot2story/Shot2Story/data/videos", + ) + add_dataset(longvideo_sft_deepseek) + sharegpt_video = Dataset( dataset_name="sharegpt_video", dataset_type="torch", diff --git a/llava/eval/vision_niah_vila/eval_vision_niah.py b/llava/eval/vision_niah_vila/eval_vision_niah.py new file mode 100644 index 00000000..080fa51b --- /dev/null +++ b/llava/eval/vision_niah_vila/eval_vision_niah.py @@ -0,0 +1,412 @@ +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is modified from https://github.com/EvolvingLMMs-Lab/LongVA + +import argparse +import gc +import glob +import json +import os +import random +import sys +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import torch +from accelerate import Accelerator +from datasets import load_dataset +from matplotlib.colors import LinearSegmentedColormap +from tqdm import tqdm +from transformers import AutoTokenizer, LlamaForCausalLM +from zigzag_ring_attn.monkey_patch import apply_zigzag_ring_attn_monkey_patch_llama +from zigzag_ring_attn.prepare_inputs import prepare_zigzag_ring_attn_inputs + +from llava.mm_utils import get_model_name_from_path +from llava.model.builder import load_pretrained_model + +apply_zigzag_ring_attn_monkey_patch_llama() + + +SEED = 24242424 +torch.manual_seed(SEED) +random.seed(SEED) +np.random.seed(SEED) + +prompt_templates = { + "mistral": {"preprompt": "[INST]", "postprompt": " [/INST]"}, + "vicuna": { + "preprompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER:", + "postprompt": "ASSISTANT:", + }, + "llama_3": { + "preprompt": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n", + "postprompt": "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + }, + "qwen2": { + "preprompt": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n", + "postprompt": "<|im_end|>\n<|im_start|>assistant\n", + }, + "yi": { + "preprompt": "<|im_start|>system\nAnswer the questions.<|im_end|>\n<|im_start|>user\n", + "postprompt": "<|im_end|>\n<|im_start|>assistant\n", + }, +} +# \nAnswer the question using a single word or phrase. +# The color of the bottle cap is +# answer = "Yellow" + + +def safe_tokenize(tokenizer, text): + tokenized = tokenizer.encode(text, return_tensors="pt") + return tokenized + + +# answer = "more bet" +def eval_forward(accelerator, model, input_embeds, answer_embeds, pad_id, answer_ids, tokenizer): + # first append answer_embeds to input_embeds + prompt_length = input_embeds.shape[1] + labels_length = answer_embeds.shape[1] + input_embeds = torch.cat([input_embeds, answer_embeds], dim=1) + # second pad input_embeds to the multiple of accelerator.num_processes + pad_tensor = ( + torch.tensor( + [pad_id] * ((accelerator.num_processes * 2) - input_embeds.shape[1] % (accelerator.num_processes * 2)) + ) + .unsqueeze(0) + .unsqueeze(-1) + .expand(-1, -1, input_embeds.shape[-1]) + .to(accelerator.device) + ) + input_embeds = torch.cat([input_embeds, pad_tensor], dim=1) + position_ids = (torch.arange(input_embeds.shape[1]).unsqueeze(0).expand(input_embeds.shape[0], -1)).to( + accelerator.device + ) + accelerator.print(input_embeds.shape) + + prepared = prepare_zigzag_ring_attn_inputs( + input_embeds, + position_ids, + None, + accelerator.process_index, + accelerator.num_processes, + accelerator.device, + ) + local_input_embeds = prepared["local_input_ids"] + local_position_ids = prepared["local_position_ids"] + with torch.inference_mode(): + logits = model( + inputs_embeds=local_input_embeds, + position_ids=local_position_ids, + use_cache=False, + ).logits + pred = logits.argmax(dim=-1) + + # gather all logits using accelerator.gather + def undo_extract_local(gathered_value, world_size, dim=1): + value_chunks = gathered_value.chunk(2 * world_size, dim=dim) + reordered_chunks = [None] * (2 * world_size) + for i in range(world_size): + reordered_chunks[i] = value_chunks[i * 2] + reordered_chunks[2 * world_size - i - 1] = value_chunks[i * 2 + 1] + return torch.cat(reordered_chunks, dim=dim) + + correct = False + + gathered_logits = accelerator.gather(pred.squeeze(0)).unsqueeze(0) + # undo extract local on the gathered logits + pred = undo_extract_local(gathered_logits, accelerator.num_processes) + pred = pred[:, prompt_length - 1 : prompt_length + labels_length - 1] + # check if the logits are correct, extract argmax id + # compare the predicted_ids with the labels + correct = (pred == answer_ids.to(accelerator.device)).all() + if accelerator.is_main_process: + print( + "Predicted: ", + tokenizer.decode(pred.squeeze().tolist()), + "Answer: ", + tokenizer.decode(answer_ids.squeeze().tolist()), + ) + # print id as well + print( + "Predicted: ", + pred.squeeze().tolist(), + "Answer: ", + answer_ids.squeeze().tolist(), + ) + return int(correct) + + +def load_haystack(args, accelerator): + haystack_embeddings = torch.load(f"{args.haystack_dir}/video_embeddings.pt").to(torch.bfloat16) + return haystack_embeddings + + +def load_text_embeddings(str, tokenizer, model, accelerator, replace_double_newline=False): + token_ids = safe_tokenize(tokenizer, str) + + def replace_double_newline_func(token_ids): + double_newline_loc = (token_ids == 271).nonzero()[:, 1] + double_newline_loc += torch.arange(len(double_newline_loc)) + if len(double_newline_loc) > 0: + for loc in double_newline_loc: + token_ids = torch.cat([token_ids[:, :loc], torch.tensor([[198, 198]]), token_ids[:, loc + 1 :]], dim=1) + return token_ids + + if replace_double_newline: + token_ids = replace_double_newline_func(token_ids) + token_ids = token_ids.to(accelerator.device) + with torch.inference_mode(): + embeddings = model.model.embed_tokens(token_ids) + return embeddings.to(torch.bfloat16) + + +def get_model_name(model): + model_split = [name for name in model.split("/") if len(name) > 0] + model_name = f"{model_split[-2]}_{model_split[-1]}" + return model_name + + +def load_results(results_dir): + results = [] + if os.path.exists(results_dir): + for root, dirs, files in os.walk(results_dir): + for file in files: + if "json" in file: + print("file", file) + results.append(json.load(open(os.path.join(root, file)))) + else: + os.system("mkdir -p %s" % results_dir) + return results + + +def inference(args): + tokenizer = AutoTokenizer.from_pretrained( + os.path.join(args.model, "llm"), + model_max_length=sys.maxsize, + trust_remote_code=True, + ) + + tokenizer.pad_token = tokenizer.eos_token + + accelerator = Accelerator( + mixed_precision="bf16", + ) + kwargs = {"rope_theta": args.rope_theta} if args.rope_theta is not None else {} + if "qwen2" in args.model.lower() or "longva" in args.model.lower(): + model = Qwen2ForCausalLM_RingAttn.from_pretrained( + args.model, + torch_dtype=torch.bfloat16, + _attn_implementation="flash_attention_2", + device_map=accelerator.device, + **kwargs, + ) + else: + model = LlamaForCausalLM.from_pretrained( + os.path.join(args.model, "llm"), + torch_dtype=torch.bfloat16, + _attn_implementation="flash_attention_2", + device_map=accelerator.device, + **kwargs, + ) + tokenizer.pad_token = tokenizer.eos_token + # remember to remove + accelerator.print("Preparing Haystack...") + haystack_embeddings = load_haystack(args, accelerator) + assert ( + len(haystack_embeddings) >= args.max_frame_num + ), f"Haystack embeddings are not enough. Max frame {args.max_frame_num} is not found. Currently only {len(haystack_embeddings)} frames." + haystack_embeddings = haystack_embeddings[: args.max_frame_num].to(accelerator.device) + prompt = prompt_templates[args.prompt_template] + preprompt_embeddings = load_text_embeddings( + prompt["preprompt"], tokenizer, model, accelerator, args.replace_double_newline + ) + postprompt_embeddings = load_text_embeddings( + prompt["postprompt"], tokenizer, model, accelerator, args.replace_double_newline + ) + + needle_dataset = load_dataset(args.needle_dataset)["test"] + answer_embedding_list = [] + answer_id_list = [] + needle_embedding_list = [] + question_embeding_list = [] + for index, instance in enumerate(needle_dataset): + answer = instance["answer"] + question = instance["question"] + needle_embedding_list.append( + torch.load(args.needle_embedding_dir + f"/{index}.pt", map_location="cpu") + .to(torch.bfloat16) + .to(accelerator.device) + ) + answer_embedding_list.append(load_text_embeddings(answer, tokenizer, model, accelerator)) + answer_id_list.append(safe_tokenize(tokenizer, answer)) + question_embeding_list.append(load_text_embeddings(question, tokenizer, model, accelerator)) + + accelerator.print("Starting Evaluation...") + model = accelerator.prepare(model) + model.gradient_checkpointing_enable() + + model_name = get_model_name(args.model) + results_dir = "results/%s" % model_name + all_accuries = load_results(results_dir) + + for num_frames in tqdm(range(args.min_frame_num, args.max_frame_num + 1, args.frame_interval)): + context_depths = [result["Frame Depth"] for result in all_accuries if result["Num. Frame"] == num_frames] + for depth in np.arange(0, 1 + args.depth_interval, args.depth_interval): + if round(depth * 100, -1) in context_depths: + print("Context %d, depth %d already done." % (num_frames, round(depth * 100, -1))) + continue + accuracies = [] + for question_embedding, needle_embedding, answer_embedding, answer_id in zip( + question_embeding_list, needle_embedding_list, answer_embedding_list, answer_id_list + ): + query_frame_idx = int(depth * num_frames) + input_frames = ( + torch.cat( + [ + haystack_embeddings[:query_frame_idx], + needle_embedding.unsqueeze(0), + haystack_embeddings[query_frame_idx:num_frames], + ], + dim=0, + ) + .view(-1, haystack_embeddings.shape[-1]) + .unsqueeze(0) + ) + input_emebds = torch.cat( + [preprompt_embeddings, input_frames, question_embedding, postprompt_embeddings], dim=1 + ) + correct = eval_forward( + accelerator, model, input_emebds, answer_embedding, tokenizer.pad_token_id, answer_id, tokenizer + ) + gc.collect() + torch.cuda.empty_cache() + if accelerator.is_main_process: + accuracies.append(correct) + if accelerator.is_main_process: + result = { + "Num. Frame": num_frames, + "Frame Depth": round(depth * 100, -1), + "Score": sum(accuracies) / len(accuracies), + } + accelerator.print(result) + all_accuries.append(result) + json.dump( + result, + open(os.path.join(results_dir, "frame_%d_depth_%d.json" % (num_frames, int(depth * 100))), "w"), + ) + + if accelerator.is_main_process: + model_name = args.model.split("/")[-1] + os.makedirs(f"{args.output_path}/{model_name}", exist_ok=True) + with open(f"{args.output_path}/{model_name}/all_accuracies.json", "w") as f: + json.dump(all_accuries, f, indent=4) + return all_accuries, accelerator + + +def plot(args, all_accuries): + df = pd.DataFrame(all_accuries) + cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#F0496E", "#EBB839", "#9ad5b3"]) + + pivot_table = pd.pivot_table( + df, + values="Score", + index=["Frame Depth", "Num. Frame"], + aggfunc="mean", + ).reset_index() # This will aggregate + pivot_table = pivot_table.pivot(index="Frame Depth", columns="Num. Frame", values="Score") + # Create the heatmap with better aesthetics + plt.figure(figsize=(17.5, 8)) # Can adjust these dimensions as needed + ax = sns.heatmap( + pivot_table, + # annot=True, + fmt="g", + vmin=0, + vmax=1, + linecolor="white", + linewidths=1.5, + cmap=cmap, + cbar_kws={"label": "Score"}, + ) + + # Set the color bar label font size + cbar = ax.collections[0].colorbar + cbar.ax.yaxis.label.set_size(14) + cbar.ax.tick_params(labelsize=14) + + # Define the formatter function + def thousands_formatter(x, pos): + if x >= 1000: + return f"{x/1000:.1f}K" + return f"{x}" + + context_lengths = pivot_table.columns + formatted_context_lengths = [thousands_formatter(x, None) for x in context_lengths] + + # More aesthetics + plt.xlabel("Num. of Frames", fontsize=14) # X-axis label + plt.ylabel("Depth Percent", fontsize=14) # Y-axis label + plt.xticks( + ticks=[i + 0.5 for i in range(len(context_lengths))], labels=formatted_context_lengths, rotation=45, fontsize=14 + ) + # plt.xticks(rotation=45, fontsize=14) # Rotates the x-axis labels to prevent overlap + plt.yticks(rotation=0, fontsize=14) # Ensures the y-axis labels are horizontal + plt.tight_layout() # Fits everything neatly into the figure area + # save + model_name = args.model.split("/")[-1] + + plt.savefig(f"{args.output_path}/{model_name}/heatmap.png") + # calculate average accuracy + average_accuracy = df["Score"].mean() + print(f"Average Accuracy: {average_accuracy}") + # save as txt + with open(f"{args.output_path}/{model_name}/avg_accuracy.txt", "w") as f: + f.write(f"Average Accuracy: {average_accuracy}\n") + + +def main(args): + if args.plot_only: + # load all_accuracies from json + model_name = args.model.split("/")[-1] + with open(f"{args.output_path}/{model_name}/all_accuracies.json") as f: + all_accuracies = json.load(f) + plot(args, all_accuracies) + else: + all_accuracies, accelerator = inference(args) + if accelerator.is_main_process: + plot(args, all_accuracies) + + +if __name__ == "__main__": + args = argparse.ArgumentParser() + args.add_argument("--model", type=str, default="output/LLaVA-NeXT-Video-7B-32K") + args.add_argument("--max_frame_num", type=int, default=300) + args.add_argument("--needle_dataset", type=str, default="lmms-lab/v_niah_needles") + args.add_argument("--min_frame_num", type=int, default=20) + args.add_argument("--frame_interval", type=int, default=20) + args.add_argument("--output_path", type=str, default="vision_niah/niah_output") + args.add_argument("--depth_interval", type=float, default=0.1) + args.add_argument("--num_samples", type=int, default=1) + args.add_argument("--rope_theta", type=float, default=None) + args.add_argument("--haystack_dir", type=str, default="video_needle_haystack/data/haystack_embeddings") + args.add_argument("--needle_embedding_dir", type=str, default="vision_niah/data/needle_embeddings") + args.add_argument("--prompt_template", type=str) + args.add_argument("--replace_double_newline", action="store_true") + args.add_argument("--plot_only", action="store_true") + + main(args.parse_args()) diff --git a/llava/eval/vision_niah_vila/produce_haystack_embedding.py b/llava/eval/vision_niah_vila/produce_haystack_embedding.py new file mode 100644 index 00000000..bf512eb1 --- /dev/null +++ b/llava/eval/vision_niah_vila/produce_haystack_embedding.py @@ -0,0 +1,96 @@ +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is adopted from https://github.com/EvolvingLMMs-Lab/LongVA + + +import argparse +import math + +import numpy as np +import torch +from decord import VideoReader, cpu +from PIL import Image +from tqdm import tqdm + +from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token +from llava.model.builder import load_pretrained_model + + +def load_video_batches(video_path, batch_size): + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + fps = round(vr.get_avg_fps()) + frame_idx = [i for i in range(0, len(vr), fps)] + for start_idx in range(0, len(frame_idx), batch_size): + end_idx = min(start_idx + batch_size, total_frame_num) + frame_indices = frame_idx[start_idx:end_idx] + batch_frames = vr.get_batch(frame_indices).asnumpy() + yield batch_frames + + +def main(args): + video_path = args.video_path + model_path = args.model + model_name = get_model_name_from_path(model_path) + + tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_name, None) + model.config.image_aspect_ratio = "pad" + model.config.mm_patch_merge_type = "flat" + # Process video in batches + batch_size = 32 + total_batches = (args.sampled_frames_num + batch_size - 1) // batch_size + image_feature_list = [] + if args.add_newline_token: + newline_token_embeddong = model.model.image_newline + with torch.inference_mode(): + for i, video_batch in tqdm( + enumerate(load_video_batches(video_path, batch_size)), total=total_batches, desc="Processing Video Batches" + ): + images = [Image.fromarray(frame).convert("RGB") for frame in video_batch] + processed_images = process_images(images, image_processor, model.config).half() + image_features = model.encode_images(processed_images) + print(image_features.shape) + if args.pooling_size != 0: + B, _, F = image_features.shape + + image_features_spatial = image_features.view(B, int(math.sqrt(_)), int(math.sqrt(_)), F).permute( + 0, 3, 1, 2 + ) # B, F, 24, 24 + image_features_spatial_pool = torch.nn.functional.avg_pool2d( + image_features_spatial, args.pooling_size, args.pooling_size + ) # B, F, 12, 12 + image_features = image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() # B, 144, F + if args.add_newline_token: + image_features = torch.cat( + [image_features, newline_token_embeddong.unsqueeze(0).expand(image_features.shape[0], 1, -1)], dim=1 + ) + image_feature_list.append(image_features.to(torch.bfloat16).to("cpu")) + if i > total_batches: + break + image_feature_list = torch.cat(image_feature_list, dim=0) + torch.save(image_feature_list, f"{args.output_dir}/video_embeddings.pt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="output/LLaVA-NeXT-Video-7B-Vicuna") + parser.add_argument("--video_path", type=str, default="/home/yukangc/movie.mp4") + parser.add_argument("--sampled_frames_num", type=int, default=7200) + parser.add_argument("--output_dir", type=str, default="video_needle_haystack/data/haystack_vicuna_embeddings") + parser.add_argument("--pooling_size", type=int, default=0) + parser.add_argument("--add_newline_token", action="store_true") + args = parser.parse_args() + main(args) diff --git a/llava/eval/vision_niah_vila/produce_needle_embedding.py b/llava/eval/vision_niah_vila/produce_needle_embedding.py new file mode 100644 index 00000000..8c7285b0 --- /dev/null +++ b/llava/eval/vision_niah_vila/produce_needle_embedding.py @@ -0,0 +1,66 @@ +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is adopted from https://github.com/EvolvingLMMs-Lab/LongVA + + +import argparse +import json +import math +from pathlib import Path + +import numpy as np +import torch +from datasets import load_dataset +from PIL import Image +from tqdm import tqdm + +from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token +from llava.model.builder import load_pretrained_model + + +def main(args): + model_path = args.model + model_name = get_model_name_from_path(model_path) + + tokenizer, model, image_processor, context_len = load_pretrained_model(args.model, model_name, None) + model.config.image_aspect_ratio = "pad" + model.config.mm_patch_merge_type = "flat" + dataset = load_dataset(args.needle_dataset)["test"] + for index, instance in enumerate(dataset): + image = instance["image"].convert("RGB") + image = process_images([image], image_processor, model.config).half() + image_features = model.encode_images(image) + if args.pooling_size != 0: + B, _, F = image_features.shape + image_features_spatial = image_features.view(B, int(math.sqrt(_)), int(math.sqrt(_)), F).permute( + 0, 3, 1, 2 + ) # B, F, 24, 24 + image_features_spatial_pool = torch.nn.functional.avg_pool2d( + image_features_spatial, args.pooling_size, args.pooling_size + ) # B, F, 12, 12 + image_features = image_features_spatial_pool.flatten(2).transpose(1, 2).contiguous() # B, 144, F + image_features = image_features.squeeze(0) + torch.save(image_features, f"{args.output_dir}/{index}.pt") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="output/LLaVA-NeXT-Video-7B-Vicuna") + parser.add_argument("--needle_dataset", type=str, default="lmms-lab/v_niah_needles") + parser.add_argument("--output_dir", type=str, default="video_needle_haystack/data/needle_vicuna_embeddings") + parser.add_argument("--pooling_size", type=int, default=0) + args = parser.parse_args() + main(args) diff --git a/llava/eval/vision_niah_vila/zigzag_ring_attn/monkey_patch.py b/llava/eval/vision_niah_vila/zigzag_ring_attn/monkey_patch.py new file mode 100644 index 00000000..57dea659 --- /dev/null +++ b/llava/eval/vision_niah_vila/zigzag_ring_attn/monkey_patch.py @@ -0,0 +1,125 @@ +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is adopted from https://github.com/EvolvingLMMs-Lab/LongVA + +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +import transformers +from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func +from transformers.utils import is_flash_attn_greater_or_equal_2_10 + + +def new_flash_attn_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + seqlens_in_batch=None, +): + if is_flash_attn_greater_or_equal_2_10(): + causal = self.is_causal + else: + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + assert attention_mask is None + assert causal is True + assert use_sliding_windows is False + attn_output = zigzag_ring_flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale, + causal=causal, + ) + + return attn_output + + +def new_decoder_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + assert isinstance(self.self_attn, transformers.models.llama.modeling_llama.LlamaFlashAttention2) or isinstance( + self.self_attn, + transformers.models.mistral.modeling_mistral.MistralFlashAttention2, + ), "Please toggle on the Flash Attention 2 implementation when using zigzag ring attention monkey patch." + + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +def apply_zigzag_ring_attn_monkey_patch_llama(): + transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward = new_flash_attn_forward + transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward = new_decoder_forward + + +def apply_zigzag_ring_attn_monkey_patch_mistral(): + transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward = ( + new_flash_attn_forward + ) + transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = new_decoder_forward diff --git a/llava/eval/vision_niah_vila/zigzag_ring_attn/prepare_inputs.py b/llava/eval/vision_niah_vila/zigzag_ring_attn/prepare_inputs.py new file mode 100644 index 00000000..54018edf --- /dev/null +++ b/llava/eval/vision_niah_vila/zigzag_ring_attn/prepare_inputs.py @@ -0,0 +1,53 @@ +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is adopted from https://github.com/EvolvingLMMs-Lab/LongVA + +import torch + + +def extract_local(value, rank, world_size, device, dim=1): + value_chunks = value.chunk(2 * world_size, dim=dim) + local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim) + return local_value.to(device) + + +def prepare_zigzag_ring_attn_inputs(input_ids, position_ids, target_ids, rank, world_size, device): + local_input_ids = extract_local( + input_ids, + rank, + world_size, + device, + ) + local_position_ids = extract_local( + position_ids, + rank, + world_size, + device, + ) + if target_ids is not None: + local_target_ids = extract_local( + target_ids, + rank, + world_size, + device, + ) + else: + local_target_ids = None + return { + "local_input_ids": local_input_ids, + "local_position_ids": local_position_ids, + "local_target_ids": local_target_ids, + } diff --git a/llava/mm_utils.py b/llava/mm_utils.py index af617c3e..87a3cb10 100755 --- a/llava/mm_utils.py +++ b/llava/mm_utils.py @@ -132,8 +132,9 @@ def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, f while success: if frame_count >= num_frames: - success, frame = vidcap.read() + # success, frame = vidcap.read() if count in frame_indices: + success, frame = vidcap.read() try: img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) im_pil = Image.fromarray(img) @@ -143,6 +144,8 @@ def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, f continue if len(images) >= num_frames: return images, num_frames + else: + success = vidcap.grab() count += 1 else: # Left padding frames if the video is not long enough diff --git a/llava/train/args.py b/llava/train/args.py index d256ffb7..8b888da8 100755 --- a/llava/train/args.py +++ b/llava/train/args.py @@ -90,6 +90,7 @@ class TrainingArguments(transformers.TrainingArguments): lora_llm: bool = False lora_vt: bool = False dpo: bool = False + longvila_sampler: bool = False dpo_beta: float = field(default=0.1) mm_projector_lr: Optional[float] = None group_by_modality_length: bool = field(default=False) diff --git a/llava/train/llava_trainer.py b/llava/train/llava_trainer.py index 5c40766f..ea436af4 100755 --- a/llava/train/llava_trainer.py +++ b/llava/train/llava_trainer.py @@ -162,7 +162,7 @@ def __init__( self.epoch = 0 self.drop_last = True # always True self.sp_degree = max(1, sp_degree) - assert batch_size % self.sp_degree == 0, "Batch size must be divisible by sequence parallelism degree." + self.bs_divisible_by_sp = batch_size % self.sp_degree == 0 # Consider sequence parallelism if self.sp_degree > 1: # Sequence Parallelism is enabled @@ -218,7 +218,9 @@ def __iter__(self): self.total_size, ) - if self.sp_degree > 1: # Sequence Parallelism is enabled, to ensure the same behavior as data parallelism + if ( + self.sp_degree > 1 and self.bs_divisible_by_sp + ): # Sequence Parallelism is enabled, to ensure the same behavior as data parallelism dp_indices_dict = {} # {rank: indices_list} all_indices_dict = {} # {rank: all_indices} @@ -292,18 +294,104 @@ def __iter__(self): assert -1 not in all_indices return iter(all_indices) - # # (Qinghao): Implementation for validating accuracy of SP - # def __iter__(self): - # iterator = super().__iter__() - # indices = list(iterator) - # indices = indices[self.start_index :] - # return iter(indices) - # def __len__(self) -> int: - # return self.num_samples - self.start_index +class LongVILADistributedSampler(VILADistributedSampler): + """This class is implemented by Yukang Chen.""" + + def __iter__(self): + def batch_shuffle(indices): + batch_indices = list(range(indices[0] // self.batch_size, indices[-1] // self.batch_size + 1)) + random.shuffle(batch_indices) + indices_shuffled = [ + batch_indices[i // self.batch_size] * self.batch_size + index % self.batch_size + for i, index in enumerate(indices) + ] + return indices_shuffled + + indices = list(range(len(self.dataset))) + + # 1. split the full indices first (note: without drop last at this moment) + indices_list = [] + for i in range(len(self.org_sample_len_list)): + indices_list.append( + indices[sum(self.org_sample_len_list[:i]) : sum(self.org_sample_len_list[:i]) + self.total_samples[i]] + ) + + assert sum([len(indices) for indices in indices_list]) == self.total_size, ( + sum([len(indices) for indices in indices_list]), + self.total_size, + ) + + if self.sp_degree > 1: # Sequence Parallelism is enabled, to ensure the same behavior as data parallelism + dp_indices_dict = {} # {rank: indices_list} + all_indices_dict = {} # {rank: all_indices} + + for i in self.corresponding_ranks: + dp_indices_list = [] + for idx, indices in enumerate(indices_list): + dp_indices_list.append( + indices[i * self.per_replica_samples[idx] : (i + 1) * self.per_replica_samples[idx]] + ) + + random.seed(self.seed + self.epoch) + for indice in range(len(dp_indices_list)): + batch_shuffle(dp_indices_list[indice]) + + dp_indices_dict[i] = dp_indices_list.copy() + + for rank, dp_indices_list in dp_indices_dict.items(): + dp_indices_list = sorted(dp_indices_list, key=lambda x: -len(x)) + dp_all_indices = [-1] * self.num_samples + indices_available = list(range(self.num_samples)) - # def set_start_index(self, start_index: int) -> None: - # self.start_index = start_index + for indice in dp_indices_list: + + original_indices = range(len(indice)) + transformed_indices = [idx * len(indices_available) // len(indice) for idx in original_indices] + + mapped_indices = [indices_available[idx] for idx in transformed_indices] + # update indices_available + for idx in reversed(transformed_indices): + del indices_available[idx] + for i, idx in enumerate(mapped_indices): + dp_all_indices[idx] = indice[i] + + all_indices_dict[rank] = dp_all_indices + + # Interleaving Merge + merged_indices = [] + interleaved_indices = [] + for item_idx in range(len(all_indices_dict[self.corresponding_ranks[0]])): + for rank in self.corresponding_ranks: + interleaved_indices.append(all_indices_dict[rank][item_idx]) + merged_indices.append(interleaved_indices) + + all_indices = merged_indices[0] + else: + # let's first do subsample + for idx, indices in enumerate(indices_list): + indices_list[idx] = indices[ + self.rank * self.per_replica_samples[idx] : (self.rank + 1) * self.per_replica_samples[idx] + ] + + random.seed(self.seed + self.epoch) + for indice in range(len(indices_list)): + batch_shuffle(indices_list[indice]) + + indices_list = sorted(indices_list, key=lambda x: -len(x)) + all_indices = [-1] * self.num_samples + indices_available = list(range(self.num_samples)) + for indice in indices_list: + original_indices = range(len(indice)) + transformed_indices = [idx * len(indices_available) // len(indice) for idx in original_indices] + mapped_indices = [indices_available[idx] for idx in transformed_indices] + # update indices_available + for idx in reversed(transformed_indices): + del indices_available[idx] + for i, idx in enumerate(mapped_indices): + all_indices[idx] = indice[i] + assert -1 not in all_indices + return iter(all_indices) class LengthGroupedSampler(Sampler): @@ -501,6 +589,8 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed num_replicas = self.args.world_size rank = self.args.process_index + longvila_sampler = self.args.longvila_sampler + sampler = LongVILADistributedSampler if longvila_sampler else VILADistributedSampler # # Consider sequence parallelism # sp_degree = self.args.seq_parallel_size @@ -510,7 +600,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: # rank = PROCESS_GROUP_MANAGER.dp_rank # # rank = dist.get_rank() // sp_degree - return VILADistributedSampler( + return sampler( self.train_dataset, num_replicas=num_replicas, rank=rank, diff --git a/llava/train/train_llm_to_long.py b/llava/train/train_llm_to_long.py new file mode 100644 index 00000000..37dcda17 --- /dev/null +++ b/llava/train/train_llm_to_long.py @@ -0,0 +1,380 @@ +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is modified from https://github.com/dvlab-research/LongLoRA + +import copy +import math +import os +from dataclasses import dataclass, field +from functools import partial +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import datasets +import torch +import torch.distributed as dist +import transformers +from datasets import load_dataset, load_from_disk +from peft import LoraConfig, PeftModel, get_peft_model +from ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func +from torch.distributed import barrier +from torch.utils.data import DataLoader, Dataset +from transformers import DataCollatorForLanguageModeling, LlamaForCausalLM, Trainer +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LlamaFlashAttention2 +from transformers.trainer_utils import seed_worker +from transformers.utils import is_datasets_available + + +def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + + return DataLoader(train_dataset, **dataloader_params) + + +Trainer.get_train_dataloader = get_train_dataloader + + +forward_llama_ori = copy.deepcopy(LlamaForCausalLM.forward) + + +def extract_local(value, rank, world_size, device, dim=1): + value_chunks = value.chunk(2 * world_size, dim=dim) + local_value = torch.cat([value_chunks[rank], value_chunks[2 * world_size - rank - 1]], dim=dim) + return local_value.to(device) + + +def forward_llama( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + + seq_len = input_ids.shape[-1] + rank = dist.get_rank() + num_processes = dist.get_world_size() + input_ids = extract_local(input_ids, rank, num_processes, input_ids.device) + labels = extract_local(labels, rank, num_processes, labels.device) + position_ids = ( + torch.arange(seq_len, device=input_ids.device, dtype=torch.long).unsqueeze(0).expand(input_ids.shape[0], -1) + ) + position_ids = extract_local(position_ids, rank, num_processes, position_ids.device) + + return forward_llama_ori( + self=self, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + seqlens_in_batch=seqlens_in_batch, + ) + + +LlamaForCausalLM.forward = forward_llama + + +def ring_flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + seqlens_in_batch=None, +): + attn_output = zigzag_ring_flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + return attn_output + + +LlamaFlashAttention2._flash_attention_forward = ring_flash_attention_forward + + +def judge_dir(resume_dir): + is_checkpoint_dir = False + if os.path.exists(resume_dir) == False: + return False + for _dir in os.listdir(resume_dir): + if "checkpoint" in _dir: + is_checkpoint_dir = True + if "pth" in _dir: + is_checkpoint_dir = True + return is_checkpoint_dir + + +IGNORE_INDEX = -100 +DEFAULT_PAD_TOKEN = "[PAD]" +DEFAULT_EOS_TOKEN = "" +DEFAULT_BOS_TOKEN = "" + + +@dataclass +class ModelArguments: + model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") + model_type: Optional[str] = field(default="llama") + + +@dataclass +class TrainingArguments(transformers.TrainingArguments): + cache_dir: Optional[str] = field(default=None) + optim: str = field(default="adamw_torch") + model_max_length: int = field( + default=8192 * 4, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + data_max_length: int = field( + default=80000, + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, + ) + use_flash_attn: bool = field( + default=True, + metadata={"help": "Whether use flash attention for training."}, + ) + low_rank_training: bool = field( + default=True, + metadata={"help": "Whether use low rank adaptation for training."}, + ) + trainable_params: str = field( + default="embed,norm", + metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."}, + ) + resume_from_checkpoint: bool = field( + default=True, + metadata={"help": "Whether use flash attention for training."}, + ) + scaling_type: str = field( + default="linear", + metadata={"help": "Whether use flash attention for training."}, + ) + scaling_factor: int = field( + default=1.0, + metadata={"help": "Whether use flash attention for training."}, + ) + rope_theta: int = field( + default=500000.0, + metadata={"help": "Whether use flash attention for training."}, + ) + data_file: str = field(default="linear", metadata={"help": "Whether use flash attention for training."}) + peft_model: str = field(default=None, metadata={"help": "Whether use flash attention for training."}) + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + print("len(tokenizer)", len(tokenizer)) + # model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def tokenize_fn(tokenizer, example): + context_length = tokenizer.model_max_length + outputs = tokenizer( + tokenizer.eos_token.join(example["text"]), + truncation=True, + return_tensors="pt", + pad_to_multiple_of=context_length, + padding=True, + ) + return {"input_ids": outputs["input_ids"].view(-1, context_length)} + + +def chunk_fn(tokenizer, example): + input_ids = torch.tensor(example["text"], dtype=torch.int64) + # world_size = 8 # dist.get_world_size() + # input_ids = input_ids.unsqueeze(0).repeat(world_size, 1, 1).permute(1, 0, 2).reshape(-1, input_ids.shape[-1]) + return {"input_ids": input_ids} + + +def train(): + parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) + model_args, training_args = parser.parse_args_into_dataclasses() + + # Set RoPE scaling factor + config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + ) + + config.rope_theta = training_args.rope_theta + config.max_position_embeddings = training_args.model_max_length + + dataset = load_dataset("json", data_files=training_args.data_file, cache_dir=training_args.cache_dir) + dataset = dataset.map(partial(chunk_fn, None), batched=True, num_proc=1, remove_columns=["text"]) + + # Load model and tokenizer + model = LlamaForCausalLM.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=training_args.cache_dir, + torch_dtype=torch.bfloat16, + ) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + cache_dir=training_args.cache_dir, + model_max_length=training_args.data_max_length, + padding_side="right", + use_fast=True, + ) + + special_tokens_dict = dict() + if tokenizer.pad_token is None: + special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN + if tokenizer.eos_token is None: + special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN + if tokenizer.bos_token is None: + special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN + + smart_tokenizer_and_embedding_resize( + special_tokens_dict=special_tokens_dict, + tokenizer=tokenizer, + model=model, + ) + + rank = int(os.environ.get("RANK", -1)) + if rank > 0: + barrier() + + # dataset = load_dataset("yaofu/slimpajama-per-source-length-upsample", cache_dir=training_args.cache_dir) + # dataset = dataset.map(partial(chunk_fn,tokenizer),batched=True, num_proc=16, remove_columns=["labels", "source"]) + # dataset = load_dataset('json', data_files=training_args.data_file, cache_dir=training_args.cache_dir) + # dataset = dataset.map(partial(chunk_fn,tokenizer),batched=True, num_proc=2, remove_columns=["text"]) + # from IPython import embed; embed() + + if rank == 0: + barrier() + + print(dataset) + + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + if training_args.low_rank_training: + if training_args.peft_model is None: + if model_args.model_type == "gpt-neox": + # added `dense` to match with llama as the basic LoRA would only target 'query_key_value' + targets = ["query_key_value", "dense"] + else: + targets = ["q_proj", "k_proj", "v_proj", "o_proj"] + + config = LoraConfig( + r=8, + lora_alpha=16, + target_modules=targets, + lora_dropout=0, + bias="none", + modules_to_save=training_args.trainable_params.split(","), + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + # enable trainable params + # [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] + else: + model = PeftModel.from_pretrained( + model, + training_args.peft_model, + torch_dtype=torch.bfloat16, + ) + for n, p in model.named_parameters(): + if "lora" in n or any([k in n for k in training_args.trainable_params.split(",")]): + if not "original_module" in n: + p.requires_grad_() + if p.requires_grad: + print(n) + + model.config.use_cache = False # required for gradient checkpointing + model.enable_input_require_grads() # required for gradient checkpointing + model.gradient_checkpointing_enable() # enable gradient checkpointing + trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=dataset["train"], + eval_dataset=None, + data_collator=data_collator, + ) + + if training_args.resume_from_checkpoint and judge_dir(training_args.output_dir): + trainer.train(resume_from_checkpoint=True) + else: + trainer.train() + trainer.save_state() + trainer.save_model(output_dir=training_args.output_dir) + + +if __name__ == "__main__": + train() diff --git a/llava/utils/merge_lora_weights_and_save_hf_model.py b/llava/utils/merge_lora_weights_and_save_hf_model.py new file mode 100644 index 00000000..e80a3f09 --- /dev/null +++ b/llava/utils/merge_lora_weights_and_save_hf_model.py @@ -0,0 +1,105 @@ +# Copyright 2024 NVIDIA CORPORATION & AFFILIATES +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# This file is modified from https://github.com/dvlab-research/LongLoRA + +import argparse +import os +from typing import Dict + +import torch +import transformers +from peft import PeftModel + + +def parse_config(): + parser = argparse.ArgumentParser(description="arg parser") + parser.add_argument("--base_model", type=str, default="/data/pretrained-models/llama-7b-hf") + parser.add_argument("--peft_model", type=str, default=None, help="") + parser.add_argument("--save_path", type=str, default=None, help="") + parser.add_argument("--cache_dir", type=str, default=None, help="./cache_dir") + parser.add_argument("--rope_theta", type=int, default=15300000, help="") + parser.add_argument("--max_position_embeddings", type=int, default=65536, help="") + args = parser.parse_args() + return args + + +def smart_tokenizer_and_embedding_resize( + special_tokens_dict: Dict, + tokenizer: transformers.PreTrainedTokenizer, + model: transformers.PreTrainedModel, +): + """Resize tokenizer and embedding. + + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. + """ + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) + model.resize_token_embeddings(len(tokenizer)) + + if num_new_tokens > 0: + input_embeddings = model.get_input_embeddings().weight.data + output_embeddings = model.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + +def main(): + args = parse_config() + device = "cuda:0" + torch.cuda.set_device(device) + + print("base model", args.base_model) + print("peft model", args.peft_model) + + config = transformers.AutoConfig.from_pretrained( + args.base_model, + cache_dir=args.cache_dir, + ) + + config.rope_theta = args.rope_theta + config.max_position_embeddings = args.max_position_embeddings + config.model_max_length = args.max_position_embeddings + config.tokenizer_model_max_length = args.max_position_embeddings + + # Load model and tokenizer + model = transformers.AutoModelForCausalLM.from_pretrained( + args.base_model, + config=config, + cache_dir=args.cache_dir, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + + tokenizer = transformers.AutoTokenizer.from_pretrained( + args.base_model, + ) + + model = PeftModel.from_pretrained( + model, + args.peft_model, + device_map="auto", + torch_dtype=torch.bfloat16, + ) + model = model.merge_and_unload() + model.save_pretrained(args.save_path) + tokenizer.save_pretrained(args.save_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/deepspeed_inference.yaml b/scripts/deepspeed_inference.yaml new file mode 100644 index 00000000..1a5fe8ac --- /dev/null +++ b/scripts/deepspeed_inference.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_config_file: scripts/zero3_offload_inference.json + zero3_init_flag: false +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/scripts/v1_5/eval/needle.sh b/scripts/v1_5/eval/needle.sh new file mode 100644 index 00000000..e9b7fded --- /dev/null +++ b/scripts/v1_5/eval/needle.sh @@ -0,0 +1,22 @@ +MODEL_NAME=$1 +MODEL_PATH=$2 +prompt_template=${3:"llama_3"} +max_frame_num=${4:500} +frame_interval=${5:50} + +eval_path=llava/eval/vision_niah_vila +mkdir -p $eval_path/data/haystack_embeddings/$MODEL_NAME +mkdir -p $eval_path/data/needle_embeddings/$MODEL_NAME +python $eval_path/produce_haystack_embedding.py --model $MODEL_PATH --output_dir $eval_path/data/haystack_embeddings/$MODEL_NAME --sampled_frames_num $max_frame_num --pooling_size 0 +python $eval_path/produce_needle_embedding.py --model $MODEL_PATH --output_dir $eval_path/data/needle_embeddings/$MODEL_NAME --pooling_size 0 --needle_dataset LongVa/v_niah_needles + +accelerate launch --num_processes 8 --config_file scripts/deepspeed_inference.yaml --main_process_port 6000 $eval_path/eval_vision_niah.py \ + --model $MODEL_PATH \ + --needle_embedding_dir $eval_path/data/needle_embeddings/$MODEL_NAME \ + --haystack_dir $eval_path/data/haystack_embeddings/$MODEL_NAME \ + --needle_dataset lmms-lab/v_niah_needles \ + --prompt_template $prompt_template \ + --max_frame_num $max_frame_num \ + --min_frame_num $frame_interval \ + --frame_interval $frame_interval \ + --depth_interval 0.2 diff --git a/scripts/v1_5/longvila/8b/4_extend_llm_256k.sh b/scripts/v1_5/longvila/8b/4_extend_llm_256k.sh new file mode 100644 index 00000000..3657f170 --- /dev/null +++ b/scripts/v1_5/longvila/8b/4_extend_llm_256k.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=${master_addr:-"127.0.0.1"} +export CURRENT_RANK=${SLURM_PROCID:-"0"} +worker_list=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ') +n_node=${SLURM_JOB_NUM_NODES:-1} + +echo "MASTER_ADDR="$MASTER_ADDR +echo "JobID: $SLURM_JOB_ID | Full list: $worker_list" + +EXTENDED_64k_PATH=$1 +OUTPUT=$2 +DATA_FILE=$3 # /lustre/fs2/portfolios/nvr/users/yukangc/datasets/SlimPajama-encode-llama3/encode_words_llama3_256k_first10k.jsonl + +model_max_length=262144 +rope_theta=207112184 + +mkdir -p $OUTPUT + +torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \ + --master_addr $MASTER_ADDR --node_rank=$CURRENT_RANK \ + llava/train/train_llm_to_long.py \ + --model_name_or_path $EXTENDED_64k_PATH/llm \ + --bf16 True \ + --data_file $DATA_FILE \ + --output_dir $OUTPUT \ + --cache_dir ./cache-256k \ + --model_max_length $model_max_length \ + --data_max_length $model_max_length \ + --rope_theta $rope_theta \ + --use_flash_attn True \ + --low_rank_training True \ + --max_steps 80 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 40 \ + --save_total_limit 2 \ + --learning_rate 2e-5 \ + --weight_decay 0.0 \ + --warmup_steps 2 \ + --lr_scheduler_type "constant_with_warmup" \ + --logging_steps 1 \ + --deepspeed "./scripts/zero2.json" \ + --tf32 True + +cp -r $EXTENDED_64k_PATH/vision_tower $OUTPUT +cp $EXTENDED_64k_PATH/config.json $OUTPUT +cp -r $EXTENDED_64k_PATH/mm_projector $OUTPUT + +python3 llava/utils/merge_lora_weights_and_save_hf_model.py --base_model $EXTENDED_64k_PATH/llm \ + --peft_model $OUTPUT \ + --save_path $OUTPUT/llm \ + --rope_theta $rope_theta \ + --max_position_embeddings $model_max_length diff --git a/scripts/v1_5/longvila/8b/4_extend_llm_64k.sh b/scripts/v1_5/longvila/8b/4_extend_llm_64k.sh new file mode 100644 index 00000000..8b44cc7b --- /dev/null +++ b/scripts/v1_5/longvila/8b/4_extend_llm_64k.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=${master_addr:-"127.0.0.1"} +export CURRENT_RANK=${SLURM_PROCID:-"0"} +worker_list=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ') +n_node=${SLURM_JOB_NUM_NODES:-1} + +echo "MASTER_ADDR="$MASTER_ADDR +echo "JobID: $SLURM_JOB_ID | Full list: $worker_list" + +STAGE3_PATH=$1 +OUTPUT=$2 +DATA_FILE=$3 # /lustre/fs2/portfolios/nvr/users/yukangc/datasets/SlimPajama-encode-llama3/encode_words_llama3_64k_first10k.jsonl + +model_max_length=65536 +rope_theta=15300000 + +mkdir -p $OUTPUT + +torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \ + --master_addr $MASTER_ADDR --node_rank=$CURRENT_RANK \ + llava/train/train_llm_to_long.py \ + --model_name_or_path $STAGE3_PATH/llm \ + --bf16 True \ + --data_file $DATA_FILE \ + --output_dir $OUTPUT \ + --cache_dir ./cache-64k \ + --model_max_length $model_max_length \ + --data_max_length $model_max_length \ + --rope_theta $rope_theta \ + --use_flash_attn True \ + --low_rank_training True \ + --max_steps 80 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 32 \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 40 \ + --save_total_limit 2 \ + --learning_rate 2e-5 \ + --weight_decay 0.0 \ + --warmup_steps 2 \ + --lr_scheduler_type "constant_with_warmup" \ + --logging_steps 1 \ + --deepspeed "./scripts/zero2.json" \ + --tf32 True + +cp -r $STAGE3_PATH/vision_tower $OUTPUT +cp $STAGE3_PATH/config.json $OUTPUT +cp -r $STAGE3_PATH/mm_projector $OUTPUT + +python3 llava/utils/merge_lora_weights_and_save_hf_model.py --base_model $STAGE3_PATH/llm \ + --peft_model $OUTPUT \ + --save_path $OUTPUT/llm \ + --rope_theta $rope_theta \ + --max_position_embeddings $model_max_length diff --git a/scripts/v1_5/longvila/8b/5_long_sft_1024frames.sh b/scripts/v1_5/longvila/8b/5_long_sft_1024frames.sh new file mode 100644 index 00000000..add41aa6 --- /dev/null +++ b/scripts/v1_5/longvila/8b/5_long_sft_1024frames.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=${master_addr:-"127.0.0.1"} +export CURRENT_RANK=${SLURM_PROCID:-"0"} +worker_list=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ') +n_node=${SLURM_JOB_NUM_NODES:-1} + +echo "MASTER_ADDR="$MASTER_ADDR +echo "JobID: $SLURM_JOB_ID | Full list: $worker_list" + +n_node=$SLURM_JOB_NUM_NODES +gradient_accumulation_steps=$((64 / n_node)) +EXTENDED_256k_PATH=$1 +OUTPUT=$2 + +torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \ + --master_addr $MASTER_ADDR --node_rank=$CURRENT_RANK \ + llava/train/train_longvila_hybrid.py \ + --longvila_sampler True \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path $EXTENDED_256k_PATH \ + --version llama_3 \ + --data_mixture longvideo_sft_deepseek \ + --vision_tower google/siglip-so400m-patch14-384 \ + --mm_vision_select_feature cls_patch \ + --mm_projector mlp_downsample \ + --num_video_frames 1024 \ + --tune_vision_tower True \ + --tune_mm_projector True \ + --tune_language_model True \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio resize \ + --bf16 True \ + --output_dir $OUTPUT \ + --num_train_epochs 6 \ + --per_device_train_batch_size 4 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps $gradient_accumulation_steps \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50 \ + --fps 2.0 \ + --save_total_limit 10 \ + --learning_rate 5e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --seq_parallel_size 32 \ + --logging_steps 1 \ + --tf32 True \ + --ddp_timeout 72000 \ + --model_max_length 65536 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/v1_5/longvila/8b/5_long_sft_128frames.sh b/scripts/v1_5/longvila/8b/5_long_sft_128frames.sh new file mode 100644 index 00000000..532956ad --- /dev/null +++ b/scripts/v1_5/longvila/8b/5_long_sft_128frames.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=${master_addr:-"127.0.0.1"} +export CURRENT_RANK=${SLURM_PROCID:-"0"} +worker_list=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ') +n_node=${SLURM_JOB_NUM_NODES:-1} + +echo "MASTER_ADDR="$MASTER_ADDR +echo "JobID: $SLURM_JOB_ID | Full list: $worker_list" + +n_node=$SLURM_JOB_NUM_NODES +gradient_accumulation_steps=$((32 / n_node)) +EXTENDED_64k_PATH=$1 +OUTPUT=$2 + +torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \ + --master_addr $MASTER_ADDR --node_rank=$CURRENT_RANK \ + llava/train/train_hybrid.py \ + --longvila_sampler True \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path $EXTENDED_64k_PATH \ + --version llama_3 \ + --data_mixture longvideo_sft_deepseek \ + --vision_tower google/siglip-so400m-patch14-384 \ + --mm_vision_select_feature cls_patch \ + --mm_projector mlp_downsample \ + --num_video_frames 128 \ + --tune_vision_tower True \ + --tune_mm_projector True \ + --tune_language_model True \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio resize \ + --bf16 True \ + --output_dir $OUTPUT \ + --num_train_epochs 6 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps $gradient_accumulation_steps \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50 \ + --fps 2.0 \ + --save_total_limit 10 \ + --learning_rate 5e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --seq_parallel_size 8 \ + --logging_steps 1 \ + --tf32 True \ + --ddp_timeout 72000 \ + --model_max_length 65536 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/v1_5/longvila/8b/5_long_sft_256frames.sh b/scripts/v1_5/longvila/8b/5_long_sft_256frames.sh new file mode 100644 index 00000000..6fd58ea5 --- /dev/null +++ b/scripts/v1_5/longvila/8b/5_long_sft_256frames.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=${master_addr:-"127.0.0.1"} +export CURRENT_RANK=${SLURM_PROCID:-"0"} +worker_list=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ') +n_node=${SLURM_JOB_NUM_NODES:-1} + +echo "MASTER_ADDR="$MASTER_ADDR +echo "JobID: $SLURM_JOB_ID | Full list: $worker_list" + +n_node=$SLURM_JOB_NUM_NODES +gradient_accumulation_steps=$((32 / n_node)) +EXTENDED_64k_PATH=$1 +OUTPUT=$2 + +torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \ + --master_addr $MASTER_ADDR --node_rank=$CURRENT_RANK \ + llava/train/train_longvila_hybrid.py \ + --longvila_sampler True \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path $EXTENDED_64k_PATH \ + --version llama_3 \ + --data_mixture longvideo_sft_deepseek \ + --vision_tower google/siglip-so400m-patch14-384 \ + --mm_vision_select_feature cls_patch \ + --mm_projector mlp_downsample \ + --num_video_frames 256 \ + --tune_vision_tower True \ + --tune_mm_projector True \ + --tune_language_model True \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio resize \ + --bf16 True \ + --output_dir $OUTPUT \ + --num_train_epochs 6 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps $gradient_accumulation_steps \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50 \ + --fps 2.0 \ + --save_total_limit 10 \ + --learning_rate 5e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --seq_parallel_size 16 \ + --logging_steps 1 \ + --tf32 True \ + --ddp_timeout 72000 \ + --model_max_length 65536 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/v1_5/longvila/8b/5_long_sft_512frames.sh b/scripts/v1_5/longvila/8b/5_long_sft_512frames.sh new file mode 100644 index 00000000..0bc9e299 --- /dev/null +++ b/scripts/v1_5/longvila/8b/5_long_sft_512frames.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=${master_addr:-"127.0.0.1"} +export CURRENT_RANK=${SLURM_PROCID:-"0"} +worker_list=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | tr '\n' ' ') +n_node=${SLURM_JOB_NUM_NODES:-1} + +echo "MASTER_ADDR="$MASTER_ADDR +echo "JobID: $SLURM_JOB_ID | Full list: $worker_list" + +n_node=$SLURM_JOB_NUM_NODES +gradient_accumulation_steps=$((32 / n_node)) +EXTENDED_256k_PATH=$1 +OUTPUT=$2 + +torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \ + --master_addr $MASTER_ADDR --node_rank=$CURRENT_RANK \ + llava/train/train_longvila_hybrid.py \ + --longvila_sampler True \ + --deepspeed ./scripts/zero3.json \ + --model_name_or_path $EXTENDED_256k_PATH \ + --version llama_3 \ + --data_mixture longvideo_sft_deepseek \ + --vision_tower google/siglip-so400m-patch14-384 \ + --mm_vision_select_feature cls_patch \ + --mm_projector mlp_downsample \ + --num_video_frames 512 \ + --tune_vision_tower True \ + --tune_mm_projector True \ + --tune_language_model True \ + --mm_vision_select_layer -2 \ + --mm_use_im_start_end False \ + --mm_use_im_patch_token False \ + --image_aspect_ratio resize \ + --bf16 True \ + --output_dir $OUTPUT \ + --num_train_epochs 6 \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 4 \ + --gradient_accumulation_steps $gradient_accumulation_steps \ + --evaluation_strategy "no" \ + --save_strategy "steps" \ + --save_steps 50 \ + --fps 2.0 \ + --save_total_limit 10 \ + --learning_rate 5e-5 \ + --weight_decay 0. \ + --warmup_ratio 0.03 \ + --lr_scheduler_type "cosine" \ + --seq_parallel_size 32 \ + --logging_steps 1 \ + --tf32 True \ + --ddp_timeout 72000 \ + --model_max_length 65536 \ + --gradient_checkpointing True \ + --dataloader_num_workers 4 \ + --lazy_preprocess True \ + --report_to wandb diff --git a/scripts/zero3_offload_inference.json b/scripts/zero3_offload_inference.json new file mode 100644 index 00000000..5a0bbc39 --- /dev/null +++ b/scripts/zero3_offload_inference.json @@ -0,0 +1,21 @@ +{ + "bf16": { + "enabled": "auto" + }, + "fp16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 3, + "stage3_prefetch_bucket_size": 33554432, + "stage3_param_persistence_threshold": 4096, + "stage3_max_live_parameters":33554432, + "offload_param": { + "device": "cpu", + "pin_memory": true + } + }, + "train_batch_size": 8, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": false +}