From fafeb22b75079871e0038a138c3e71ffebd978d5 Mon Sep 17 00:00:00 2001 From: Slyne Deng Date: Thu, 11 Jul 2024 16:18:08 -0700 Subject: [PATCH] Cherry pick: LITA Integration (#9684) Signed-off-by: Slyne Deng Co-authored-by: Slyne Deng --- .github/workflows/cicd-main.yml | 24 +- .../multimodal_llm/neva/conf/lita_config.yaml | 242 +++++++ .../multimodal_llm/neva/conf/vita_config.yaml | 231 +++++++ ...va_to_neva.py => convert_llava_to_neva.py} | 142 +++- .../neva/eval/eval_video_rtl.py | 196 ++++++ .../multimodal_llm/neva/eval/eval_vqa.py | 207 ++++++ .../multimodal_llm/neva/neva_evaluation.py | 202 ++++-- .../multimodal/data/neva/conversation.py | 4 + .../multimodal/data/neva/neva_dataset.py | 105 ++- .../models/multimodal_llm/neva/neva_model.py | 175 ++++- nemo/collections/multimodal/parts/utils.py | 31 +- .../common/text_generation_strategy.py | 17 + .../modules/common/text_generation_utils.py | 95 ++- .../convert_dvc_dataset_for_evaluation.py | 160 +++++ .../convert_dvc_dataset_for_training.py | 322 +++++++++ .../convert_video_qa_dataset.py | 184 ++++++ .../generate_qa_data.py | 369 +++++++++++ .../prepare_youmakeup.py | 325 +++++++++ tutorials/multimodal/LITA Tutorial.ipynb | 621 ++++++++++++++++++ tutorials/multimodal/NeVA Tutorial.ipynb | 4 +- tutorials/multimodal/README.md | 1 + tutorials/multimodal/images/LITA_arch.png | Bin 0 -> 268131 bytes 22 files changed, 3547 insertions(+), 110 deletions(-) create mode 100644 examples/multimodal/multimodal_llm/neva/conf/lita_config.yaml create mode 100644 examples/multimodal/multimodal_llm/neva/conf/vita_config.yaml rename examples/multimodal/multimodal_llm/neva/{convert_hf_llava_to_neva.py => convert_llava_to_neva.py} (73%) create mode 100644 examples/multimodal/multimodal_llm/neva/eval/eval_video_rtl.py create mode 100644 examples/multimodal/multimodal_llm/neva/eval/eval_vqa.py create mode 100644 scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_evaluation.py create mode 100644 scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_training.py create mode 100644 scripts/multimodal_dataset_conversion/convert_video_qa_dataset.py create mode 100644 scripts/multimodal_dataset_conversion/generate_qa_data.py create mode 100644 scripts/multimodal_dataset_conversion/prepare_youmakeup.py create mode 100644 tutorials/multimodal/LITA Tutorial.ipynb create mode 100644 tutorials/multimodal/images/LITA_arch.png diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index d225ee3ab429..414516d81f18 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -179,7 +179,28 @@ jobs: rm -f /home/TestData/nlp/megatron_gpt/falcon-ci-hf/falcon_ci.nemo AFTER_SCRIPT: | rm -rf /home/TestData/nlp/megatron_gpt/falcon-ci-hf/model_weights - + + # L2: Community llava multimodal Checkpoints tests + L2_Community_vita_Checkpoints_tests_Llama3: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + export PYTHONPATH=/home/TestData/multimodal/video_neva/LLaVA:$PYTHONPATH + CUDA_VISIBLE_DEVICES=0 python examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py \ + --in-file /home/TestData/multimodal/video_neva/Llama-3-VILA1.5-8B/llm \ + --mm-projector-ckpt-dir /home/TestData/multimodal/video_neva/Llama-3-VILA1.5-8B/mm_projector \ + --mm-vision-tower /home/TestData/multimodal/video_neva/Llama-3-VILA1.5-8B/vision_tower \ + --tokenizer-model /home/TestData/multimodal/video_neva/vita-tokenizer/ \ + --config-file vita_config.yaml \ + --out-file=/home/TestData/multimodal/video_neva/llama3-ci-hf/llama3_ci.nemo \ + --model-type VITA \ + --conv-template llama_3 + AFTER_SCRIPT: | + rm -f /home/TestData/multimodal/video_neva/llama3-ci-hf/llama3_ci.nemo + rm -rf /home/TestData/multimodal/video_neva/llama3-ci-hf/model_weights + # this test is using a 7B model which is too large for GitHub CI # replace the model in this test with a toy model or move the test # to the nightly CI @@ -4437,6 +4458,7 @@ jobs: - L2_Community_LLM_Checkpoints_tests_Llama - L2_Community_LLM_Checkpoints_tests_StarCoder - L2_Community_LLM_Checkpoints_tests_Falcon + - L2_Community_vita_Checkpoints_tests_Llama3 #- OPTIONAL_L2_Community_LLM_Checkpoints_tests_Baichuan2 - ASR_dev_run_Speech_to_Text - ASR_dev_run_Speech_to_Text_WPE_-_CitriNet diff --git a/examples/multimodal/multimodal_llm/neva/conf/lita_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/lita_config.yaml new file mode 100644 index 000000000000..591f528810fc --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/conf/lita_config.yaml @@ -0,0 +1,242 @@ +name: nemo_video_lita_neva +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 10000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + check_val_every_n_epoch: null + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: nemo_video_neva_lita + create_wandb_logger: True + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 5 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_clip--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + +model: + precision: ${trainer.precision} + + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + + # Batch size guideline for different types of dataset + micro_batch_size: 1 # limited by GPU memory + global_batch_size: 2 # will use more micro batches to reach global batch size + + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + context_parallel_size: 1 # kqv model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + restore_from_path: null # used in fine-tuning + + # Multimodal configs + mm_cfg: + llm: + from_pretrained: null #path to nemo checkpoint + freeze: False + model_type: llama_2 # `nvgpt` or `llama_2` supported + vision_encoder: + from_pretrained: "Lin-Chen/ShareGPT4V-13B_Pretrained_vit-large336-l12" # huggingface path or name + from_hf: True + crop_size: [336, 336] + patch_dim: 14 + hidden_size: 1024 # could be found from model but tricky in code + vision_select_layer: -2 # default to the last layer + class_token_length: 1 + freeze: True + lita: + lita_video_arch: 'temporal_all_resolution' # ['temporal_spatial_pool', 'temporal_spatial', 'temporal_all_resolution'] 'temporal_spatial_pool' is used in lita1.0 + visual_token_format: 'im_vid_start_end' # ["v1", "im_vid_start_end"] v1 means do nothing, im_vid_start_end means add image and video start and end tokens around spatial and temporal tokens + sample_frames: 4 # for lita 1.5 sample_frames are used for spatial tokens, and spatial tokens will no longer do pooling and instead, it will use full tokens + use_lita: True + pretrain_mm_mlp_adapter: null # path to pretrained mm adapter + mm_mlp_adapter_type: mlp2x_gelu # ['linear', 'mlp2x_gelu', 'mlp_downsample'] + use_im_start_end: False + + # ========LORA configs start======= + #peft: + # peft_scheme: "lora" + # restore_from_path: null + # lora_tuning: + # adapter_dim: 128 + # alpha: 256 + # target_modules: ['all'] + # adapter_dropout: 0.0 + # column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + # row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + # layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + # weight_tying: False + # position_embedding_strategy: null # used only when weight_tying is True + # =======LORA configs end======= + + # LLM configs + # use GPTModel from megatron.core + mcore_gpt: True + + # model architecture + encoder_seq_length: 4096 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: rope + num_layers: 32 + hidden_size: 4096 + ffn_hidden_size: 11008 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 32 + init_method_std: 0.014 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0.0 # Dropout probability for hidden state transformer. + attention_dropout: 0.0 # Dropout probability for attention + ffn_dropout: 0.0 # Dropout probability in the feed-forward layer. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: 'rmsnorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm' + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + make_vocab_size_divisible_by: 16 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + bias: False # Whether to use bias terms in all weight matrices. + activation: 'fast-swiglu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu'] + headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head. + transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] + normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. + rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this. + attention_type: 'multihead' # Attention type. Options ['multihead'] + share_embeddings_and_output_weights: False # Share embedding and output layer weights. + overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595. + num_query_groups: null # Number of query groups for group query attention. If None, normal attention is used. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + num_micro_batches_with_partial_activation_checkpoints: null + activations_checkpoint_layers_per_pipeline: null + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: False + bias_activation_fusion: False + megatron_legacy: False + + transformer_engine: True + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + async_grad_allreduce: False + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'sentencepiece' + type: null + model: /ws/converted_nemo_model/tokenizer_1_5.model + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + additional_special_tokens: null # ["", "", "", "", "", ""] + + data: + packed_sequence: False + num_workers: 8 + dataloader_type: cyclic + data_path: null + lazy_preprocess: True + is_multimodal: True + media_type: video # currently supported: image or video + splice_single_frame: null # 'first', 'middle', 'last' will represent video as first / middle / last frame only, all other frames discarded. + num_frames: 256 # selects the number of frames to use from the video + sep_token_between_frames: False # TODO: allow usage of separator tokens between frames + sep_image_conv_front: False + image_token_len: 576 #lita 1.0 uses 256 + conv_template: v1 # check `nemo/collections/multimodal/data/neva/conversation.py` + image_folder: null + video_folder: null + image_aspect_ratio: 'pad' # lita 1.0 uses 'square' + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: fused_adam + lr: 2e-5 + weight_decay: 0. + betas: + - 0.9 + - 0.95 + sched: + name: CosineAnnealing + warmup_steps: 140 + constant_steps: 0 + min_lr: 2e-7 diff --git a/examples/multimodal/multimodal_llm/neva/conf/vita_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/vita_config.yaml new file mode 100644 index 000000000000..7be99308a280 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/conf/vita_config.yaml @@ -0,0 +1,231 @@ +name: nemo_video_lita_neva +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 8 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 10000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + check_val_every_n_epoch: null + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: nemo_video_neva_lita + create_wandb_logger: True + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 5 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_clip--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + +model: + precision: ${trainer.precision} + + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + + # Batch size guideline for different types of dataset + micro_batch_size: 1 # limited by GPU memory + global_batch_size: 128 # will use more micro batches to reach global batch size + + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + context_parallel_size: 1 # kqv model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + restore_from_path: null # used in fine-tuning + + # Multimodal configs + mm_cfg: + llm: + from_pretrained: null #path to nemo checkpoint + freeze: False + model_type: vita + vision_encoder: + from_pretrained: null # path or name + model_type: null + from_hf: True + crop_size: [384, 384] + patch_dim: 14 + hidden_size: 1152 # could be found from model but tricky in code + vision_select_layer: -2 # default to the last layer + vision_select_feature: 'cls_patch' # default is patch + class_token_length: 1 + freeze: True + lita: + lita_video_arch: 'temporal_all_resolution' # ['temporal_spatial_pool', 'temporal_spatial', 'temporal_all_resolution'] + visual_token_format: 'im_vid_start_end' # ["v1", "im_vid_start_end"] v1 means do nothing, im_vid_start_end means add image and video start and end tokens around spatial and temporal tokens + sample_frames: 4 # for lita 1.5 sample_frames are used for spatial tokens, and spatial tokens will no longer do pooling and instead, it will use full tokens + use_lita: True + pretrain_mm_mlp_adapter: null # path to pretrained mm adapter + mm_mlp_adapter_type: mlp_downsample # ['linear', 'mlp2x_gelu', 'mlp_downsample'] + + use_im_start_end: False + + + # LLM configs + # use GPTModel from megatron.core + mcore_gpt: True + + # model architecture + encoder_seq_length: 8192 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: rope + num_layers: 32 + hidden_size: 4096 + ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 32 + init_method_std: 0.014 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0.0 # Dropout probability for hidden state transformer. + attention_dropout: 0.0 # Dropout probability for attention + ffn_dropout: 0.0 # Dropout probability in the feed-forward layer. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: 'rmsnorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm' + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + make_vocab_size_divisible_by: 16 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + bias: False # Whether to use bias terms in all weight matrices. + activation: 'fast-swiglu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu'] + headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head. + transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] + normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. + rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this. + rotary_base: 500000.0 # default is 10000 + attention_type: 'multihead' # Attention type. Options ['multihead'] + share_embeddings_and_output_weights: False # Share embedding and output layer weights. + overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595. + num_query_groups: 8 # Number of query groups for group query attention. If None, normal attention is used. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + num_micro_batches_with_partial_activation_checkpoints: null + activations_checkpoint_layers_per_pipeline: null + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: False + bias_activation_fusion: False + megatron_legacy: False + + transformer_engine: True + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + async_grad_allreduce: False + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'huggingface' + type: /ws/converted_models/tokenizer # set huggingface tokenizer here; And check `LITA Tutorial.ipynb` for how to add time tokens to tokenizer + model: null # set sentencepiece model path here if tokenizer is sentencepiece + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + additional_special_tokens: null # ["", "", "", "", "", ""] + + data: + packed_sequence: False + num_workers: 8 + dataloader_type: cyclic + data_path: null + lazy_preprocess: True + is_multimodal: True + media_type: video # currently supported: image or video + splice_single_frame: null # 'first', 'middle', 'last' will represent video as first / middle / last frame only, all other frames discarded. + num_frames: 256 # selects the number of frames to use from the video + sep_token_between_frames: False # TODO: allow usage of separator tokens between frames + sep_image_conv_front: False + image_token_len: 784 # 28x28 + conv_template: llama_3 # check `nemo/collections/multimodal/data/neva/conversation.py` + image_folder: null + video_folder: null + image_aspect_ratio: 'pad' # in vila, it's `resize` + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: fused_adam + lr: 2e-5 + weight_decay: 0. + betas: + - 0.9 + - 0.95 + sched: + name: CosineAnnealing + warmup_steps: 140 + constant_steps: 0 + min_lr: 2e-7 diff --git a/examples/multimodal/multimodal_llm/neva/convert_hf_llava_to_neva.py b/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py similarity index 73% rename from examples/multimodal/multimodal_llm/neva/convert_hf_llava_to_neva.py rename to examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py index 2cbb4c2b3b82..d02b737c750a 100644 --- a/examples/multimodal/multimodal_llm/neva/convert_hf_llava_to_neva.py +++ b/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py @@ -13,15 +13,22 @@ # limitations under the License. r""" -Script to convert HuggingFace LLaVA checkpoints into .nemo file. - Example to run this conversion script: - python convert_hf_llava_to_neva.py \ - --in-file \ - --out-file \ - --tokenizer-model \ - --conv-template llama_2 # nvgpt, llama_2, v1 (vicuna) +Script to convert LLaVA checkpoints into .nemo file. +This script depend on llava github project: +https://github.com/haotian-liu/LLaVA/tree/main + +If you want to convert huggingface LLaVA checkpoint such as llava-hf/llava-1.5-7b-hf, +you should check `NeMo/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py` + +Example to run this conversion script: + python convert_hf_llava_to_neva.py \ + --in-file \ + --out-file \ + --tokenizer-model \ + --conv-template llama_2 # nvgpt, llama_2, v1, llama_3 (vicuna) """ +import json import os from argparse import ArgumentParser from collections import OrderedDict @@ -31,6 +38,7 @@ from omegaconf import OmegaConf from pytorch_lightning.core.saving import _load_state as ptl_load_state from pytorch_lightning.trainer.trainer import Trainer +from safetensors import safe_open from transformers import LlamaTokenizer from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel @@ -47,7 +55,11 @@ def get_args(): parser = ArgumentParser() parser.add_argument( - "--in-file", type=str, default=None, required=True, help="Path to Huggingface LLaMA checkpoints", + "--in-file", + type=str, + default=None, + required=True, + help="Path to LLaVA checkpoints", ) parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output .nemo file.") parser.add_argument( @@ -61,6 +73,16 @@ def get_args(): "--tokenizer-model", type=str, default=None, required=False, help="Path to sentencepiece tokenizer model." ) parser.add_argument("--precision", type=str, default="32", help="Model precision") + parser.add_argument("--config-file", type=str, default="llava_config.yaml") + parser.add_argument( + "--mm-projector-ckpt-dir", + type=str, + default=None, + help="Path to multimodal projector checkpoint directory \ + This will overlap the projector weights in in-file hf checkpoint", + ) + parser.add_argument("--mm-vision-tower", type=str, default=None) + parser.add_argument("--model-type", type=str, default=None) args = parser.parse_args() return args @@ -110,13 +132,32 @@ def load_model(cls, checkpoint, strict, **kwargs): def load_config(args, llava_config): - nemo_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), 'conf/llava_config.yaml')).model + nemo_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), 'conf', args.config_file)).model nemo_config.mm_cfg.mm_mlp_adapter_type = llava_config.get('mm_projector_type', 'linear') - nemo_config.mm_cfg.vision_encoder.from_pretrained = llava_config.get( - 'mm_vision_tower', 'openai/clip-vit-large-patch14' - ) - if '336' in nemo_config.mm_cfg.vision_encoder.from_pretrained: - nemo_config.data.image_token_len = 576 + + mm_vision_tower = llava_config.get('mm_vision_tower', 'openai/clip-vit-large-patch14') + + if args.mm_vision_tower is not None: + mm_vision_tower = args.mm_vision_tower + + nemo_config.mm_cfg.vision_encoder.from_pretrained = mm_vision_tower + if args.mm_vision_tower is not None: + config_file = os.path.join(args.mm_vision_tower, "config.json") + if os.path.exists(config_file): + with open(config_file, "r") as f: + vision_model_config = json.load(f) + nemo_config.mm_cfg.vision_encoder["model_type"] = vision_model_config.get("model_type", 'clip') + crop_size = vision_model_config.get("image_size", 224) + nemo_config.mm_cfg.vision_encoder.crop_size = [crop_size, crop_size] + else: + if '336' in mm_vision_tower: + nemo_config.data.image_token_len = 576 + nemo_config.mm_cfg.vision_encoder.crop_size = [336, 336] + else: + nemo_config.data.image_token_len = 256 + nemo_config.mm_cfg.vision_encoder.crop_size = [224, 224] + nemo_config.mm_cfg.vision_encoder.patch_dim = 14 + nemo_config.encoder_seq_length = llava_config['max_position_embeddings'] nemo_config.num_layers = int(llava_config['num_hidden_layers']) nemo_config.hidden_size = llava_config['hidden_size'] @@ -130,16 +171,34 @@ def load_config(args, llava_config): nemo_config.use_cpu_initialization = True nemo_config.activation = 'fast-swiglu' nemo_config.data.conv_template = args.conv_template - nemo_config.mm_cfg.model_type = args.conv_template + nemo_config.data.image_aspect_ratio = llava_config.get('image_aspect_ratio', 'square') + if args.model_type is None: + nemo_config.mm_cfg.model_type = args.conv_template + else: + nemo_config.mm_cfg.model_type = args.model_type if args.tokenizer_model is None: - nemo_config.tokenizer.model = llava_config['tokenizer_model'] + if 'tokenizer_model' in llava_config: + nemo_config.tokenizer.library = 'sentencepiece' + nemo_config.tokenizer.model = llava_config['tokenizer_model'] + else: + # Llama3 uses converted TikToken Tokenizer + tokenizer_dict = {'library': 'huggingface', 'type': args.in_file, 'use_fast': True, 'model': None} + nemo_config.tokenizer.update(tokenizer_dict) else: - nemo_config.tokenizer.model = args.tokenizer_model + # if tokenizer_model is directory + if os.path.isdir(args.tokenizer_model): + tokenizer_dict = {'library': 'huggingface', 'type': args.tokenizer_model, 'use_fast': True, 'model': None} + nemo_config.tokenizer.update(tokenizer_dict) + else: + nemo_config.tokenizer.library = 'sentencepiece' + nemo_config.tokenizer.model = args.tokenizer_model if llava_config['rope_scaling'] is not None: if llava_config['rope_scaling']['type'] == 'linear': nemo_config['seq_len_interpolation_factor'] = llava_config['rope_scaling']['factor'] else: raise ValueError("Only linear rope scaling type is supported now") + if llava_config.get('rope_theta', None): + nemo_config['rotary_base'] = llava_config['rope_theta'] base = 128 while llava_config['vocab_size'] % base != 0: @@ -152,16 +211,15 @@ def load_config(args, llava_config): def convert(args): logging.info(f"loading checkpoint {args.in_file}") model = LlavaLlamaForCausalLM.from_pretrained(args.in_file) - tokenizer = LlamaTokenizer.from_pretrained(args.in_file) hf_config = vars(model.config) - hf_config['tokenizer_model'] = str(tokenizer.vocab_file) - print(f"hf_config: {hf_config}") - print("named parameters:") + if os.path.exists(f'{args.in_file}/tokenizer.model'): + tokenizer = LlamaTokenizer.from_pretrained(args.in_file) + hf_config['tokenizer_model'] = str(tokenizer.vocab_file) + for name, param in model.named_parameters(): print(f"- {name}") nemo_config = load_config(args, hf_config) - print(nemo_config) if args.precision in ["32", "16"]: precision = int(float(args.precision)) @@ -179,7 +237,7 @@ def convert(args): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) @@ -235,10 +293,42 @@ def convert(args): for key in model.state_dict(): if 'mm_projector' in key: mm_projection_layer_suffix = key.split('mm_projector')[1] - checkpoint['state_dict'][ - f'{mm_projection_layer_base_name}{mm_projection_layer_suffix}' - ] = param_to_weights(model.state_dict()[key]) + checkpoint['state_dict'][f'{mm_projection_layer_base_name}{mm_projection_layer_suffix}'] = ( + param_to_weights(model.state_dict()[key]) + ) + # Replace or add the projection weights + proj_ckpt = None + if args.mm_projector_ckpt_dir is not None: + if os.path.exists(args.mm_projector_ckpt_dir): + ckpt_path = os.path.join(args.mm_projector_ckpt_dir, "mm_projector.bin") + if os.path.exists(ckpt_path): + proj_ckpt = torch.load(ckpt_path) + else: + ckpt_path = os.path.join(args.mm_projector_ckpt_dir, "model.safetensors") + proj_ckpt = {} + with safe_open(ckpt_path, framework="pt", device="cuda") as f: + for key in f.keys(): + new_key = key.replace("layers.", "mm_projector.") + proj_ckpt[new_key] = f.get_tensor(key) + else: + raise FileNotFoundError(f"mm_projector_ckpt_dir {args.mm_projector_ckpt_dir} does not exist.") + for key in proj_ckpt.keys(): + if 'mm_projector' in key: + mm_projection_layer_suffix = key.split('mm_projector')[1] + checkpoint['state_dict'][f'{mm_projection_layer_base_name}{mm_projection_layer_suffix}'] = ( + param_to_weights(proj_ckpt[key]) + ) + + proj_conf_file = open(os.path.join(args.mm_projector_ckpt_dir, "config.json")) + + proj_conf = json.load(proj_conf_file) + if proj_conf['mm_projector_type'] != nemo_config.mm_cfg.mm_mlp_adapter_type: + logging.warning( + f"Overriding mm_projector_type from {nemo_config.mm_cfg.mm_mlp_adapter_type} to {proj_conf['mm_projector_type']}" + ) + nemo_config.mm_cfg.mm_mlp_adapter_type = proj_conf['mm_projector_type'] + proj_conf_file.close() embed_weight = model.state_dict()[f'model.embed_tokens.weight'] if mcore_gpt: embed_weights_base_name = f'model.embedding.word_embeddings.weight' diff --git a/examples/multimodal/multimodal_llm/neva/eval/eval_video_rtl.py b/examples/multimodal/multimodal_llm/neva/eval/eval_video_rtl.py new file mode 100644 index 000000000000..3567cf431d87 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/eval/eval_video_rtl.py @@ -0,0 +1,196 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This script is used for evaluating RTL (Reasoning Temporal Localization) task. +It accepts one JSON file. The JSON file should have the following structure: +[ + { + "video": "rY7eLyJF31M_6.mp4", + "question_id": "rY7eLyJF31M_6_0", + "question": "When is \"Apply mascara , false lashes on the lashes \" depicted in the video? Convey your answer using start and end timestamps exclusively.", + "ref_answer": "<0> <53> Apply mascara , false lashes on the lashes ", + "duration": 102.002002002002, + "pred_answer": "<1> <53> Apply mascara , false lashes on the lashes ", + }, + { + "video": "rY7eLyJF31M_6.mp4", + "question_id": "rY7eLyJF31M_6_1", + "question": "When is \"Apply foundation on the face with a brush\" depicted in the video? Provide a response using only start and end timestamps.", + "ref_answer": "<56> <97> Apply foundation on the face with a brush", + "duration": 102.002002002002, + "pred_answer": "<50> <97> Apply foundation on the face with a brush", + }, +] + +The `xxx_answer` field should contain the start and end timestamps such as `<56>` and `<97>` of the event along with the sentence. +If not, the [0, duration] will be used as the predicted timestamps. + +USAGE: +python eval_rtl.py --input_file \ + --output_dir \ + --save_mid_result +""" +import argparse +import json +import os +import re +from collections import defaultdict + + +def iou(seg1, seg2): + """Compute the intersection over union (IoU) between two segments. + + Args: + seg1 (list): [start, end] + seg2 (list): [start, end] + + Returns: + float: IoU value + """ + assert seg1[1] >= seg1[0] and seg2[1] >= seg2[0] + + x1 = max(seg1[0], seg2[0]) + x2 = min(seg1[1], seg2[1]) + inter = max(x2 - x1, 0) + + len1 = max(seg1[1] - seg1[0], 0) + len2 = max(seg2[1] - seg2[0], 0) + + union = len1 + len2 - inter + + if union == 0: + return 0.0 + else: + return inter / union + + +def precision_func(thres): + """calculate the precision based on the threshold. + If the IoU value is greater than or equal to the threshold, \ + the precision is 1.0, otherwise 0.0. + + Args: + thres (float): threshold value [0.0, 1.0] + """ + + def precision(seg1, seg2): + return float(iou(seg1, seg2) >= thres) + + return precision + + +def parse_start_end_timestamps(outputs, duration, strict=False): + timestamp_pattern = '\<(?: (?: \d* \.? \d+ ) | (?: \d+ \.? ) )\>' + rx = re.compile(timestamp_pattern, re.VERBOSE) + matches = list(rx.finditer(outputs)) + if strict: + assert len(list(matches)) >= 2, "cannot find timestamps" + elif len(list(matches)) < 2: + return outputs, [0, duration] + + prev_end = 0 + sentence = "" + timestamps = [] + for i in range(2): + m = matches[i] + start = m.start(0) + end = m.end(0) + timestamp = float(m.group(0)[1:-1]) + timestamp = min(max(timestamp, 0), duration) + timestamps.append(timestamp) + sentence += outputs[prev_end:start] + prev_end = end + sentence += outputs[prev_end:] + sentence = sentence.strip() + + return sentence, [min(timestamps), max(timestamps)] + + +def eval(pred_file, output_dir, save_mid_result=True): + """Evaluate the predictions against the ground truth. + + Args: + pred_file (str): path to the predictions JSON file + output_dir (str): path to the output directory, + where the `answers.json` and `metrics.json` result will be saved. + """ + metric_func = {'iou': iou, 'precision@0.5': precision_func(0.5)} + metrics = {} + for metric in metric_func: + metrics[metric] = defaultdict(list) + + with open(pred_file, 'r') as f: + pred_data = json.load(f) + + out_list = [] + for pred in pred_data: + assert "pred_answer" in pred, "pred_answer field is missing" + assert "ref_answer" in pred, "answer field is missing" + duration = pred['duration'] + pred_answer, pred_timestamps = parse_start_end_timestamps(pred['pred_answer'], duration, strict=False) + ref_answer, ref_timestamps = parse_start_end_timestamps(pred['ref_answer'], duration, strict=False) + + for metric in metric_func: + metrics[metric][pred['video']].append(metric_func[metric](pred_timestamps, ref_timestamps)) + + out_list.append( + { + 'video': pred['video'], + 'question_id': pred['question_id'], + 'question': pred['question'], + 'pred_answer': pred_answer, + 'ref_answer': ref_answer, + 'pred_timestamps': pred_timestamps, + 'ref_timestamps': ref_timestamps, + } + ) + # save result + os.makedirs(output_dir, exist_ok=True) + if save_mid_result: + output_file = os.path.join(output_dir, 'answers.json') + print(f"Saving intermediate result to {output_file}") + with open(output_file, 'w') as f: + json.dump(out_list, f, indent=2) + + final_result = {} + for metric in metrics: + values = [] + for vid in metrics[metric]: + # get single video metric value + cur_metric_values = metrics[metric][vid] + values.append(sum(cur_metric_values) / len(cur_metric_values)) + # get global average video metric value + values = sum(values) / len(values) + final_result[metric] = values + + print(final_result) + output_file = os.path.join(output_dir, 'metrics.json') + with open(output_file, 'w') as f: + json.dump(final_result, f, indent=2) + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate the predictions against the ground truth") + parser.add_argument("--input_file", help="Path to the input JSON file", required=True) + parser.add_argument("--output_dir", help="Path to the output directory", required=True) + parser.add_argument("--save_mid_result", action="store_true", help="Save intermediate result") + args = parser.parse_args() + + eval(args.input_file, args.output_dir, args.save_mid_result) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/multimodal_llm/neva/eval/eval_vqa.py b/examples/multimodal/multimodal_llm/neva/eval/eval_vqa.py new file mode 100644 index 000000000000..8929648a3f97 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/eval/eval_vqa.py @@ -0,0 +1,207 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This script is used for evaluating Video Question Answering task by leveraging LLM API as a judge. +It accepts one JSON file. The JSON file should have the following structure: +[ + { + "video": "YRvBOLRgZNc_2".mp4", + "question_id": "v_yVgL8sJQxYo_2_5", + "question": "What tools are used to apply foundation on the skin between <5s> and <60s>?", + "ref_answer": "A brush and blender.", + "duration": 102.002002002002, + "pred_answer": "A brush", + }, + { + "video": "yVgL8sJQxYo_2.mp4", # not a must-to-have field + "question": "How long does the action of applying foundation take?", + "question_id": "v_yVgL8sJQxYo_2_5" + "ref_answer": "The action takes around 55 seconds (<60s> - <5s>)." + "duration": 102.002002002002, # not a must-to-have field + "pred_answer": "This action takes around 50 seconds.", + } + + ... +] + +`video` and `duration` are two optional fields. If not provided, the script will ignore them. + +Notice that the time token here is represented as '<%ss>'.format(time_in_seconds). + +For the external LLM API, we use `meta/llama3-70b-instruct"` as an example. +You can go to: https://build.nvidia.com/explore/discover to choose the one that fits your needs. +Notice the API might be a little bit different. + +You also need an `API_TOKEN` from here: https://build.nvidia.com/explore/discover#llama3-70b +Click the `Get API Key` and save your key in the environment variable `API_TOKEN`. + +USAGE: +API_TOKEN= python eval_qa.py --input_file --output_dir --save_mid_result +""" + +import argparse +import ast +import json +import os +import re + +import requests + + +def parse_args(): + parser = argparse.ArgumentParser(description="Evaluate Video Question Answering task.") + parser.add_argument("--input_file", type=str, required=True, help="Path to the prediction file. json list file") + parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory.") + parser.add_argument("--save_mid_result", action="store_true", help="Whether to save the intermediate results.") + return parser.parse_args() + + +INVOKE_URL = "https://integrate.api.nvidia.com/v1/chat/completions" +# MODEL="mistralai/mixtral-8x22b-instruct-v0.1" # no `system` role +MODEL = "meta/llama3-70b-instruct" + + +def request_nvidia_api(messages): + API_TOKEN = os.getenv("API_TOKEN", "") # ADD NGC API TOKEN HERE + if not API_TOKEN: + raise ValueError("Please provide the API_TOKEN in the environment variable.") + headers = { + "Authorization": f"Bearer {API_TOKEN}", + "accept": "text/event-stream", + "content-type": "application/json", + } + payload = { + "model": MODEL, + "messages": messages, + "temperature": 0.5, + "top_p": 1.0, + "max_tokens": 2048, + "seed": 42, + "stream": True, + } + invoke_url = INVOKE_URL + response = requests.post(invoke_url, headers=headers, json=payload, stream=True) + output = "" + for line in response.iter_lines(): + if line == b'data: [DONE]': + break + if line: + res = json.loads(line.decode("utf-8").split("data: ")[1]) + if 'content' in res['choices'][0]['delta']: + output += res['choices'][0]['delta']['content'] + return output.lstrip().strip() + + +def convert_time_token(text): + # use regular expression to convert <12> <56> to <12s> <56s> + return re.sub(r'<(\d+)>', r'<\1s>', text) + + +def get_result(question, answer, pred, key, output_dir, save_mid_result=False): + messages = [ + { + "role": "system", + "content": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer.", + }, + { + "role": "user", + "content": "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " + "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}.", + }, + ] + try: + response_message = request_nvidia_api(messages) + response_dict = ast.literal_eval(response_message) + except Exception as e: + print(f"Error processing file {key}: {e}") + return [] + qa_set = {"question": question, "ref_answer": answer, "pred_answer": pred} + result_qa_pair = [response_dict, qa_set] + if save_mid_result: + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(result_qa_pair, f) + return result_qa_pair + + +def main(): + args = parse_args() + input_file = args.input_file + output_dir = args.output_dir + save_mid_result = args.save_mid_result + with open(input_file, "r") as f: + data = json.load(f) + + tasks = [] + key = 0 + for item in data: + question = item["question"] + item["ref_answer"] = convert_time_token(item["ref_answer"]) + tasks.append((question, item["ref_answer"], item["pred_answer"], key, output_dir, save_mid_result)) + key += 1 + + # TODO: parallelize the requests + results = [] + while len(tasks) > 0: + task = tasks.pop() + key = task[3] + cur_result = get_result(*task) + if cur_result == []: + tasks.append(task) + continue + results.append((key, cur_result)) + + score_sum = count = yes_count = no_count = 0 + for key, result in results: + try: + count += 1 + score_sum += int(result[0]["score"]) + + if "yes" in result[0]["pred"].lower(): + yes_count += 1 + elif "no" in result[0]["pred"].lower(): + no_count += 1 + except Exception as e: + print(f"Error processing file {key}") + + average_score = score_sum / count + accuracy = yes_count / (yes_count + no_count) + result_file = os.path.join(output_dir, "metrics.json") + metrics = { + "average_score": average_score, + "accuracy": accuracy, + "no_count": no_count, + "yes_count": yes_count, + "model": MODEL, + } + print("Metrics: ", metrics) + with open(result_file, "w") as f: + json.dump(metrics, f, indent=2) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/multimodal_llm/neva/neva_evaluation.py b/examples/multimodal/multimodal_llm/neva/neva_evaluation.py index dcc79029463c..75d8a907b796 100644 --- a/examples/multimodal/multimodal_llm/neva/neva_evaluation.py +++ b/examples/multimodal/multimodal_llm/neva/neva_evaluation.py @@ -15,7 +15,7 @@ import json import os import torch -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from nemo.collections.multimodal.parts.utils import create_neva_model_and_processor from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam @@ -36,24 +36,109 @@ raise EnvironmentError("GPU is needed for the inference") -class RequestDataSet(Dataset): - def __init__(self, sentences): - super().__init__() - self.sentences = sentences - - def __len__( +class TemporalNevaDataset(Dataset): + def __init__( self, + prompt_dicts, + media_base_path, + media_token, + insert_media_token=None, + image_processor=None, + video_processor=None, + add_media_sep=False, ): - return len(self.sentences) + self.prompt_dicts = prompt_dicts + self.media_token = media_token + self.insert_media_token = insert_media_token + self.media_base_path = media_base_path + self.image_processor = image_processor + self.video_processor = video_processor + self.add_media_sep = add_media_sep + # [(media_name, [prompt_dict, prompt_dict, ...]), ...} + self.media_prompt_list = [] + self.group_by_media(media_token) + + def group_by_media(self, media_token): + """ + This function groups the prompt dicts by the media/video/image file name + """ + media_dict = {} + media = media_token.lstrip('<').rstrip('>') + for prompt_dict in self.prompt_dicts: + media_name = prompt_dict[media] # video or image file name + if media_name not in media_dict: + media_dict[media_name] = [] + media_dict[media_name].append(prompt_dict) + self.media_prompt_list = list(media_dict.items()) + + def __len__(self) -> int: + return len(self.media_prompt_list) + + def __getitem__(self, idx) -> dict: + """ + Return a list of prompt dicts for the idx-th media + For a single media file, only one media feature is returned + This would help improve performance as well as save GPU memory + """ + prompt_dict_list = self.media_prompt_list[idx][1] + cur_item = [] + cur_media_feature = None + for prompt_dict in prompt_dict_list: + if 'prompt' not in prompt_dict: + prompt_dict['prompt'] = prompt_dict['text'] if 'text' in prompt_dict else prompt_dict['question'] + if self.insert_media_token == 'left': + if self.add_media_sep: + prompt_dict['prompt'] = self.media_token + " \n" + prompt_dict['prompt'] + else: + prompt_dict['prompt'] = self.media_token + prompt_dict['prompt'] + elif self.insert_media_token == 'right': + if self.add_media_sep: + prompt_dict['prompt'] = prompt_dict['prompt'] + self.media_token + " \n" + else: + prompt_dict['prompt'] = prompt_dict['prompt'] + self.media_token + if 'image' in prompt_dict: + prompt_dict['image_path'] = prompt_dict['image'] + image_path = os.path.join(self.media_base_path, prompt_dict['image']) + if cur_media_feature is None: + cur_media_feature = ("image", self.image_processor(image_path)) + if 'video' in prompt_dict: + prompt_dict['video_path'] = prompt_dict['video'] + video_path = os.path.join(self.media_base_path, prompt_dict['video']) + if cur_media_feature is None: + cur_media_feature = ("video", self.video_processor(video_path)) + cur_item.append(prompt_dict) + return cur_media_feature, cur_item + - def __getitem__(self, idx): - return self.sentences[idx] +def collate_function(batch): + # do nothing + return batch + + +def do_inference(dataloader, model, length_params, sampling_params, cfg): + responses = [] + all_prompts = [] + for idx, batch_media_prompts in enumerate(dataloader): + if idx % 10 == 0: + print(f"Processed {idx} batch media") + for media_media_feature, prompts in batch_media_prompts: + media, media_feature = media_media_feature + all_prompts.extend(prompts.copy()) + for prompt in prompts: + prompt[media] = media_feature + cur_batch_responses = model.generate( + input_prompts=prompts, + length_params=length_params, + sampling_params=sampling_params, + inference_config=cfg, + ) + responses.extend(cur_batch_responses) + return responses, all_prompts @hydra_runner(config_path="conf", config_name="neva_inference") def main(cfg) -> None: model, image_processor, video_processor = create_neva_model_and_processor(cfg) - length_params: LengthParam = { "max_length": cfg.inference.tokens_to_generate, "min_length": cfg.inference.min_tokens_to_generate, @@ -71,35 +156,43 @@ def main(cfg) -> None: "end_strings": cfg.inference.end_strings, } - with open(cfg.prompt_file, 'r') as f: - lines = f.readlines() + prompt_dicts = [] + if cfg.prompt_file.endswith('.json'): + with open(cfg.prompt_file, 'r') as f: + prompt_dicts = json.load(f) + elif cfg.prompt_file.endswith('.jsonl'): + with open(cfg.prompt_file, 'r') as f: + lines = f.readlines() + for line in lines: + prompt_dicts.append(json.loads(line)) + else: + raise ValueError(f"Unsupported prompt file format: {cfg.prompt_file}") media_type_token = cfg.inference.get("media_type", "image") media_token = f"<{media_type_token}>" insert_media_token = cfg.inference.get("insert_media_token", None) - final_prompts = [] - for line in lines: - prompt_dict = json.loads(line) - assert 'prompt' in prompt_dict or 'text' in prompt_dict - if 'prompt' not in prompt_dict: - prompt_dict['prompt'] = prompt_dict['text'] - if insert_media_token == 'left': - prompt_dict['prompt'] = media_token + prompt_dict['prompt'] - elif insert_media_token == 'right': - prompt_dict['prompt'] = prompt_dict['prompt'] + media_token - if 'image' in prompt_dict: - prompt_dict['image_path'] = prompt_dict['image'] - prompt_dict['image'] = image_processor(os.path.join(cfg.inference.media_base_path, prompt_dict['image'])) - if 'video' in prompt_dict: - prompt_dict['video_path'] = prompt_dict['video'] - prompt_dict['video'] = video_processor(os.path.join(cfg.inference.media_base_path, prompt_dict['video'])) - final_prompts.append(prompt_dict) - - responses = model.generate( - input_prompts=final_prompts, length_params=length_params, sampling_params=sampling_params, inference_config=cfg + dataset = TemporalNevaDataset( + prompt_dicts, + cfg.inference.media_base_path, + media_token, + insert_media_token, + image_processor, + video_processor, + cfg.get("add_media_sep", False), ) + num_workers = 2 + dataloader = DataLoader( + dataset, + batch_size=cfg.inference.get("batch_size", 1), + shuffle=False, + collate_fn=collate_function, + num_workers=num_workers, + persistent_workers=True, + ) + responses, final_prompts = do_inference(dataloader, model, length_params, sampling_params, cfg) + # =================== Start Quantization ==================== if HAVE_MODELOPT and cfg.quantization.enable == True: print(f"Using quantization algorithm: {cfg.quantization.algorithm}") @@ -113,21 +206,33 @@ def main(cfg) -> None: raise ValueError(f"Unsupported quantization algorithm: {cfg.quantization.algorithm}") def forward_loop(): - model.generate( - input_prompts=final_prompts, - length_params=length_params, - sampling_params=sampling_params, - inference_config=cfg, + num_samples = cfg.quantization.get("num_samples", 100) + if num_samples == -1: + cur_prompt_dicts = prompt_dicts + else: + cur_prompt_dicts = prompt_dicts[:num_samples] + cur_dataset = TemporalNevaDataset( + cur_prompt_dicts, + cfg.inference.media_base_path, + media_token, + insert_media_token, + image_processor, + video_processor, + cfg.get("add_media_sep", False), ) + cur_dataloader = DataLoader( + cur_dataset, + batch_size=cfg.inference.get("batch_size", 1), + shuffle=False, + collate_fn=collate_function, + num_workers=num_workers, + ) + _, _ = do_inference(cur_dataloader, model, length_params, sampling_params, cfg) mtq.quantize(model, mtq_config, forward_loop) - responses = model.generate( - input_prompts=final_prompts, - length_params=length_params, - sampling_params=sampling_params, - inference_config=cfg, - ) + responses, final_prompts = do_inference(dataloader, model, length_params, sampling_params, cfg) + # ============== Quantization End ========================= # PP middle stages do not yield any responses @@ -138,7 +243,7 @@ def forward_loop(): results = [] for response, prompt in zip(responses, final_prompts): prompt['full_text'] = response["clean_text"] - prompt['text'] = response["clean_response"] + prompt['pred_answer'] = response["clean_response"] prompt['model_id'] = cfg.neva_model_file if 'image_path' in prompt: prompt['image'] = prompt.pop('image_path') @@ -151,8 +256,11 @@ def forward_loop(): results.append(prompt) with open(cfg.output_file, 'w') as f: - for result in results: - f.write(json.dumps(result) + '\n') + if cfg.output_file.endswith('.json'): + json.dump(results, f, indent=2) + else: + for result in results: + f.write(json.dumps(result) + '\n') if __name__ == '__main__': diff --git a/nemo/collections/multimodal/data/neva/conversation.py b/nemo/collections/multimodal/data/neva/conversation.py index 10a6c9e7283d..2e110eebe9e6 100644 --- a/nemo/collections/multimodal/data/neva/conversation.py +++ b/nemo/collections/multimodal/data/neva/conversation.py @@ -34,6 +34,10 @@ DEFAULT_IM_START_TOKEN["llama_3"] = "<|reserved_special_token_4|>" DEFAULT_IM_END_TOKEN["llama_3"] = "<|reserved_special_token_5|>" +DEFAULT_VID_START_TOKEN = "" +DEFAULT_VID_END_TOKEN = "" +TIME_TOKEN_TEMPLATE = "" + class SeparatorStyle(Enum): """Different separator style.""" diff --git a/nemo/collections/multimodal/data/neva/neva_dataset.py b/nemo/collections/multimodal/data/neva/neva_dataset.py index 7eef677e13a8..b56c42fff274 100644 --- a/nemo/collections/multimodal/data/neva/neva_dataset.py +++ b/nemo/collections/multimodal/data/neva/neva_dataset.py @@ -34,11 +34,15 @@ import nemo.collections.multimodal.data.neva.conversation as conversation_lib from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform from nemo.collections.multimodal.data.neva.conversation import ( + DEFAULT_BOS_TOKEN, + DEFAULT_EOS_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_LABELS_TOKEN, + DEFAULT_VID_END_TOKEN, + DEFAULT_VID_START_TOKEN, DEFAULT_VIDEO_TOKEN, ) from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids @@ -145,7 +149,7 @@ def open_video(self, file_name): cap = decord.VideoReader(f) return self.flatten_frames(cap) else: - decord.bridge.set_bridge("torch") + # decord.bridge.set_bridge("torch") cap = decord.VideoReader(os.path.join(self.video_folder, file_name)) return self.flatten_frames(cap) return None @@ -171,9 +175,7 @@ def flatten_frames(self, cap): else: num_frames = min(len(cap), self.data_cfg['num_frames']) indices = np.linspace(0, len(cap) - 1, num_frames, dtype=int) - frames = [] - frames = cap.get_batch(indices) - + frames = [Image.fromarray(cap[i].asnumpy()).convert('RGB') for i in indices] while len(frames) < self.data_cfg['num_frames']: frames.append(frames[-1]) return frames @@ -226,6 +228,25 @@ def tokenize( return result +def get_tokens_ids(tokenizer, tokens): + """ + Returns the token id for a given token. + + Parameters + ---------- + tokenizer : nemo tokenizer + A tokenizer to be used for tokenization. + tokens : list + A list of tokens to get the token id for. + + Returns + ------- + List + The token ids. + """ + return [tokenizer.token_to_id(token) for token in tokens] + + def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: int, use_plain: bool = False) -> Dict: """ Preprocesses multimodal sources based on the provided configuration. @@ -259,13 +280,15 @@ def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: in if not is_multimodal: return sources - num_patches = image_token_len + num_frames = multimodal_cfg['num_frames'] + # vila + if multimodal_cfg['mm_mlp_adapter_type'] == 'mlp_downsample': + image_token_len //= 4 + num_patches = image_token_len + # TO DO: to support multiple images if media_type == 'video': - num_patches *= multimodal_cfg['num_frames'] - - if multimodal_cfg['mm_mlp_adapter_type'] == 'mlp_downsample': - num_patches //= 4 + num_patches *= num_frames if multimodal_cfg['use_im_start_end']: replace_token = DEFAULT_IMAGE_PATCH_TOKEN[model_type] * num_patches @@ -273,6 +296,44 @@ def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: in replace_token = DEFAULT_IMAGE_PATCH_TOKEN[model_type] * (num_patches - 2) replace_token = DEFAULT_IM_START_TOKEN[model_type] + replace_token + DEFAULT_IM_END_TOKEN[model_type] + if media_type == 'video' and multimodal_cfg.get("use_lita", False): + if not multimodal_cfg.get('lita', None): + raise ValueError("LITA config is missing") + lita_video_arch = multimodal_cfg['lita']['lita_video_arch'] + num_temporal_tokens, num_spatial_tokens = num_frames, 0 + if lita_video_arch == 'temporal_all_resolution': + sample_frames = min(multimodal_cfg['lita']['sample_frames'], num_frames) + # num_frames for temporal tokens, sample_frames * num_patches for spatial tokens + num_spatial_tokens = sample_frames * image_token_len + else: + # num_frames for temporal tokens and num_patches for spatial tokens + num_spatial_tokens = image_token_len + num_tokens = num_temporal_tokens + num_spatial_tokens + + visual_token_format = multimodal_cfg['lita'].get('visual_token_format', 'v1') + media_start = DEFAULT_IM_START_TOKEN[model_type] + media_end = DEFAULT_IM_END_TOKEN[model_type] + image_patch = DEFAULT_IMAGE_PATCH_TOKEN[model_type] + if visual_token_format == 'im_vid_start_end': + image_start, image_end = DEFAULT_IM_START_TOKEN[model_type], DEFAULT_IM_END_TOKEN[model_type] + vid_start, vid_end = DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN + if multimodal_cfg['use_im_start_end']: + replace_token_list = [image_start + image_patch * image_token_len + image_end] * sample_frames + replace_token_list += [vid_start + image_patch * num_temporal_tokens + vid_end] + replace_token = "".join(replace_token_list) + else: + replace_token_list = [image_start + image_patch * (image_token_len - 1) + image_end] + replace_token_list += [image_start + image_patch * image_token_len + image_end] * (sample_frames - 1) + replace_token_list += [vid_start + image_patch * (num_temporal_tokens - 1) + vid_end] + replace_token = "".join(replace_token_list) + replace_token = media_start + replace_token + media_end + else: + if multimodal_cfg['use_im_start_end']: + replace_token = image_patch * num_tokens + else: + replace_token = image_patch * (num_tokens - 2) + replace_token = media_start + replace_token + media_end + for source in sources: conversation = source['conversations'] if multimodal_cfg['sep_image_conv_front']: @@ -290,7 +351,6 @@ def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: in conversation[0]['value'] = default_token for turn in conversation: turn["value"] = turn["value"].replace(default_token, replace_token) - return sources @@ -475,9 +535,13 @@ def preprocess_llama_2( ) # llama tricks - tokens[tokens == 32003] = 0 # DEFAULT_IMAGE_PATCH_TOKEN - tokens[tokens == 32006] = 1 # - tokens[tokens == 32007] = 2 # + # 32003, 32006, 32007 + image_patch_token = DEFAULT_IMAGE_PATCH_TOKEN["llama_2"] + DEFAULT_TOKENS = [image_patch_token, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN] + img_patch_id, bos_id, eos_id = get_tokens_ids(tokenizer, DEFAULT_TOKENS) + tokens[tokens == img_patch_id] = 0 # DEFAULT_IMAGE_PATCH_TOKEN + tokens[tokens == bos_id] = 1 # + tokens[tokens == eos_id] = 2 # labels = tokens.clone().detach() # Mask labels @@ -577,9 +641,14 @@ def preprocess_v1( ) # llama tricks - tokens[tokens == 32003] = 0 # DEFAULT_IMAGE_PATCH_TOKEN - tokens[tokens == 32006] = 1 # - tokens[tokens == 32007] = 2 # + # 32003, 32006, 32007 + image_patch_token = DEFAULT_IMAGE_PATCH_TOKEN["llama_2"] + DEFAULT_TOKENS = [image_patch_token, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN] + img_patch_id, bos_id, eos_id = get_tokens_ids(tokenizer, DEFAULT_TOKENS) + tokens[tokens == img_patch_id] = 0 # DEFAULT_IMAGE_PATCH_TOKEN + tokens[tokens == bos_id] = 1 # + tokens[tokens == eos_id] = 2 # + # tokens = torch.concat((torch.tensor([[1]]), tokens), axis=1) #lita 1.5 legacy labels = tokens.clone().detach() # Mask labels @@ -977,7 +1046,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: frames = self.video_loader.open_video(video_file) if frames is None: logging.warning(f"Video {video_file} could not be found!") - if isinstance(self.processor, CLIPImageProcessor): + if isinstance(self.processor, CLIPImageProcessor) or isinstance(self.processor, SiglipImageProcessor): # image processor from HF if self.multimodal_cfg['image_aspect_ratio'] == 'keep': max_hw, min_hw = max(frames.size), min(frames.size) @@ -1268,6 +1337,8 @@ def make_supervised_data_module(tokenizer, image_processor, model_cfg) -> Dict: context_length=model_cfg.encoder_seq_length, media_type=data_cfg.get('media_type', 'image'), num_frames=data_cfg.get('num_frames', -1), + use_lita=getattr(model_cfg.mm_cfg, 'use_lita', False), + lita=getattr(model_cfg.mm_cfg, 'lita', {}), mm_mlp_adapter_type=model_cfg.mm_cfg.get('mm_mlp_adapter_type', 'linear'), ), data_cfg=dict( diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 376237e89ecc..92f13c28c287 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -17,9 +17,10 @@ from itertools import chain from typing import Any, Optional +import numpy as np import torch import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange, reduce, repeat from omegaconf.dictconfig import DictConfig from pkg_resources import packaging from pytorch_lightning.trainer.trainer import Trainer @@ -137,6 +138,7 @@ def init_vision( media_start_id, media_end_id, vision_select_layer=-1, + vision_select_feature="patch", class_token_length=1, use_im_start_end=False, ): @@ -147,6 +149,7 @@ def init_vision( self.class_token_length = class_token_length self.use_im_start_end = use_im_start_end self.vision_select_layer = vision_select_layer + self.vision_select_feature = vision_select_feature self.media = None self.set_accepted_adapter_types([MultimodalProjectorAdapterConfig._target_]) @@ -208,7 +211,10 @@ def encode_vision_x(self, vision_x: torch.Tensor): self.vision_encoder.backbone.transformer.return_select_layer = self.vision_select_layer vision_x = self.vision_encoder(vision_x) vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) - vision_x = vision_x[:, :, :, self.class_token_length :] + if self.vision_select_feature == "patch": + vision_x = vision_x[:, :, :, self.class_token_length :] + elif self.vision_select_feature != "cls_patch": + raise ValueError(f"Unsupported vision_select_feature {self.vision_select_feature}") assert self.is_adapter_available(), "Cannot find multimodal vision adapter!" vision_connector = self.get_adapter_module(AdapterName.MULTIMODAL_PROJECTOR_ADAPTER) vision_x = vision_connector(vision_x) @@ -273,6 +279,147 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), **kw return sharded_state_dict +class LitaWordEmbeddingMixin(NevaWordEmbeddingMixin): + def init_lita( + self, + lita_video_arch: str, + visual_token_format: str = "v1", + use_media_start_end: bool = False, + sample_frames: int = 4, + ): + """_summary_ + + Args: + lita_video_arch (str): ['temporal_spatial_pool', 'temporal_spatial', 'temporal_all_resolution'] + visual_token_format (str, optional): default to 'v1', other option ["v1", "im_vid_start_end"] + v1: no video_start_id and video_end_id, video tokens are inserted between fast/slow (temporal/spatial) tokens + im_vid_start_end: video start and end tokens are inserted before and after temporal tokens + image start and end tokens are inserted before and after spatial tokens + use_media_start_end (bool, optional): + whether media start and media end is used in input_ids, Defaults to False. + Notice, when it is false, the media_start_id and media_end_id will play as an placeholder + input_ids = [..., media_start_id, t1, t2, t3...., media_end_id, ...] + use_media_start_end = False + we will replace the tokens including and between: [media_start_id, ... media_end_id] + use_media_start_end = True + we will replace the tokens between: (media_start_id, ... media_end_id) + num_frames (int, optional): number of frames to sample from the video, default to 4 + """ + self.lita_video_arch = lita_video_arch + self.visual_token_format = visual_token_format + self.use_media_start_end = use_media_start_end + self.sample_frames = sample_frames + + def add_lita_layer(self, media_features): + """_summary_ + + Args: + media_features (torch.Tensor): + feature after encoded by vision encoder + shape: Batch, T (number of images), S (num patches), H (hidden size) + Returns: + tokens (torch.Tensor): + shape: Batch, T + M, D (hidden size) + """ + + b, T, S, H = media_features.shape + tokens = media_features + if self.lita_video_arch == 'temporal_spatial_pool': + pool_size = 2 + h = w = int(np.sqrt(S)) + selected_frames = np.round(np.linspace(0, tokens.shape[1] - 1, pool_size * pool_size)).astype(int) + s_tokens = tokens[:, selected_frames, ...] + s_tokens = rearrange(s_tokens, 'b t (h w) d -> (b t) d h w', h=h, w=w) + s_tokens = F.avg_pool2d(s_tokens, kernel_size=pool_size) + s_tokens = rearrange(s_tokens, '(b t) d h w -> b (t h w) d', b=b) # B, M, D + t_tokens = reduce(tokens, 'b t s d -> b t d', 'mean') + # tokens = torch.cat([t_tokens, s_tokens], dim=1) # B, T + M, D + return t_tokens, s_tokens + elif self.lita_video_arch == 'temporal_spatial': + t_tokens = reduce(tokens, 'b t s d -> b t d', 'mean') + s_tokens = reduce(tokens, 'b t s d -> b s d', 'mean') + # tokens = torch.cat([t_tokens, s_tokens], dim=1) # B, T + M, D + return t_tokens, s_tokens + elif self.lita_video_arch == 'temporal_all_resolution': + idx = np.round(np.linspace(0, tokens.shape[1] - 1, self.sample_frames)).astype(int) + im_features = tokens[:, idx, ...] # B, num_frames, S, D + # im_tokens = im_features.view(b, -1, H) # flatten the B, num_frames * S, D + im_tokens = im_features + vid_tokens = reduce(tokens, 'b t s d -> b t d', 'mean') + # s and t tokens have been changed position + return im_tokens, vid_tokens + else: + raise ValueError(f"Unknown video architecture: {self.lita_video_arch}") + + def replace_media_embeddings(self, input_ids, inputs_embeds, media): + """_summary_ + + Args: + input_ids (torch.tensor): The input token ids [B, T] + words_embeddings (torch.tensor): The input embeddings [B, T, D] + media (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) + """ + if input_ids.shape[1] == 1: + return inputs_embeds + + if media is None: + return inputs_embeds + if type(media) is list: + raise NotImplementedError("dynamic length of videos not supported yet, only fixed length of videos now") + # 1, 1, num_frames, 3, 244, 244 + media_features = self.encode_vision_x(media) # B T F S(eq) H(idden) + B, T, F, S, H = media_features.shape + assert T == 1, "multiple videos per sample not supported yet" + media_features = media_features.squeeze(1) + t_tokens, s_tokens = self.add_lita_layer(media_features) # B, T, D & B, M, D + T = t_tokens.shape[1] + M = s_tokens.shape[1] + inputs_embeds = inputs_embeds.clone() + for idx, input_id in enumerate(input_ids): + media_start_position = torch.where(input_id == self.media_start_id)[0] + media_end_position = torch.where(input_id == self.media_end_id)[0] + if self.visual_token_format != 'im_vid_start_end': + assert len(media_start_position) == 1, "Only 1 video per sample supported" + assert len(media_end_position) == 1, "Only 1 video per sample supported" + + media_start_position = media_start_position[0] + media_end_position = media_end_position[-1] + if self.use_media_start_end: + # replace the tokens between media_start_id and media_end_id + start, end = media_start_position + 1, media_end_position - 1 + else: + # replace the tokens including and between media_start_id and media_end_id + start, end = media_start_position, media_end_position + + if self.visual_token_format == 'v1': + t_token_start, t_token_end = start, start + T + s_token_start, s_token_end = start + T, start + T + M + assert s_token_end == end + 1, "Token replacement error" + inputs_embeds[idx, t_token_start:t_token_end] = temporal_tokens[idx] + inputs_embeds[idx, s_token_start:s_token_end] = spatial_tokens[idx] + elif self.visual_token_format == 'im_vid_start_end': # v1.5 lita + if not self.use_media_start_end: + # replace the media start and media end embedding with + # img_start and vid_end token embedding + inputs_embeds[idx, start] = inputs_embeds[idx, start + 1] + inputs_embeds[idx, end] = inputs_embeds[idx, end - 1] + # TO DO: To optimize the below codes + im_features, vid_features = t_tokens[idx], s_tokens[idx] + # im_feature: num_frames * S, D + emb_start = start + 1 # skip the img_start token + num_frames, S, D = im_features.shape + for i in range(num_frames): + inputs_embeds[idx, emb_start : emb_start + S] = im_features[i] + emb_start = emb_start + S + 2 # skip the img_end token and img_start token + T = vid_features.shape[0] + inputs_embeds[idx, emb_start : emb_start + T] = vid_features + assert emb_start + T == end + else: + raise ValueError(f"Unsupported visual_token_format {self.visual_token_format}") + return inputs_embeds + + class NevaBaseModel: """ Base class for a multimedia model integrating vision and language models. @@ -307,12 +454,24 @@ def __init__( # Monkey patch embedding if kwargs.get("pre_process", True): - extend_instance(self.embedding.word_embeddings, NevaWordEmbeddingMixin) + if not mm_cfg.get("use_lita", False): + extend_instance(self.embedding.word_embeddings, NevaWordEmbeddingMixin) + else: + extend_instance(self.embedding.word_embeddings, LitaWordEmbeddingMixin) + lita_conf = mm_cfg.get('lita', {}) + self.embedding.word_embeddings.init_lita( + lita_video_arch=lita_conf.get('lita_video_arch', 'temporal_spatial_pool'), + visual_token_format=lita_conf.get('visual_token_format', 'v1'), + use_media_start_end=mm_cfg.get('use_im_start_end', False), # we need to make this clear + sample_frames=lita_conf.get('sample_frames', 4), + ) + self.embedding.word_embeddings.init_vision( vision_encoder, media_start_id, media_end_id, vision_select_layer=mm_cfg.vision_encoder.get("vision_select_layer", -2), + vision_select_feature=mm_cfg.vision_encoder.get("vision_select_feature", "patch"), class_token_length=mm_cfg.vision_encoder.get("class_token_length", 1), use_im_start_end=mm_cfg.get("use_im_start_end", False), ) @@ -320,7 +479,11 @@ def __init__( def create_vision_encoder_and_processor(self, mm_cfg): # Initialize vision encoder and freeze it if mm_cfg.vision_encoder.get("from_hf", False): - if "clip" in mm_cfg.vision_encoder.from_pretrained: + if ( + "clip" in mm_cfg.vision_encoder.from_pretrained + or "vit" in mm_cfg.vision_encoder.from_pretrained + or "clip" in mm_cfg.vision_encoder.get("model_type", "") + ): vision_encoder = CLIPVisionModel.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16, @@ -330,7 +493,9 @@ def create_vision_encoder_and_processor(self, mm_cfg): for param in vision_encoder.parameters(): param.requires_grad = False vision_encoder = vision_encoder.eval() - elif "siglip" in mm_cfg.vision_encoder.from_pretrained: + elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get( + "model_type", "" + ): vision_encoder = SiglipVisionModel.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16, diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 7eb72b38d0f0..232955817e16 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -451,7 +451,6 @@ def image_processor(maybe_image_path): def video_processor(maybe_video_path): if isinstance(maybe_video_path, str): - decord.bridge.set_bridge("torch") vr = decord.VideoReader(maybe_video_path) if neva_cfg.data.splice_single_frame == 'first': frames = [Image.fromarray(vr[0].asnumpy()).convert('RGB')] @@ -465,19 +464,23 @@ def video_processor(maybe_video_path): else: num_frames = min(len(vr), neva_cfg.data.num_frames) indices = np.linspace(0, len(vr) - 1, num_frames, dtype=int) - frames = vr.get_batch(indices) - + frames = [Image.fromarray(vr[i].asnumpy()).convert('RGB') for i in indices] while len(frames) < neva_cfg.data.num_frames: frames.append(frames[-1]) else: frames = maybe_video_path - if neva_cfg.mm_cfg.vision_encoder.from_hf: - processor = CLIPImageProcessor.from_pretrained( - neva_cfg.mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 - ) + if neva_cfg.mm_cfg.vision_encoder.get("from_hf", False): + if ( + "siglip" in neva_cfg.mm_cfg.vision_encoder.from_pretrained + or "siglip" in neva_cfg.mm_cfg.vision_encoder.get("model_type", "") + ): + processor = SiglipImageProcessor.from_pretrained(neva_cfg.mm_cfg.vision_encoder.from_pretrained) + else: + # for clip and vit model + processor = CLIPImageProcessor.from_pretrained(neva_cfg.mm_cfg.vision_encoder.from_pretrained) else: - processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16) + processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") # support single video inference if neva_cfg.data.image_aspect_ratio == 'keep': @@ -503,7 +506,7 @@ def expand2square(pil_img, background_color): result.paste(pil_img, ((height - width) // 2, 0)) return result - frames = [expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean)) for frame in frames] + frames = [expand2square(frame, tuple(int(x * 255) for x in processor.image_mean)) for frame in frames] frames = processor.preprocess(frames, return_tensors='pt')['pixel_values'] else: frames = processor.preprocess(frames, return_tensors='pt')['pixel_values'] @@ -516,11 +519,17 @@ def expand2square(pil_img, background_color): def create_image_processor(mm_cfg): if mm_cfg.vision_encoder.get("from_hf", False): - if "clip" in mm_cfg.vision_encoder.from_pretrained: + if ( + "clip" in mm_cfg.vision_encoder.from_pretrained + or "vit" in mm_cfg.vision_encoder.from_pretrained + or "clip" in mm_cfg.vision_encoder.get("model_type", "") + ): image_processor = CLIPImageProcessor.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 ) - elif "siglip" in mm_cfg.vision_encoder.from_pretrained: + elif "siglip" in mm_cfg.vision_encoder.from_pretrained or "siglip" in mm_cfg.vision_encoder.get( + "model_type", "" + ): image_processor = SiglipImageProcessor.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 ) diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index 8f8fe313a5e3..3b57b3988310 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -584,6 +584,7 @@ def __init__(self, model): media_type=getattr(self.data_cfg, 'media_type', 'image'), num_frames=getattr(self.data_cfg, 'num_frames', 1), mm_mlp_adapter_type=getattr(self.cfg.mm_cfg, 'mm_mlp_adapter_type', 'linear'), + use_lita=getattr(self.cfg.mm_cfg, 'use_lita', False), ) if self.multimodal_cfg['crop_size'] is None: image_processor = CLIPImageProcessor.from_pretrained( @@ -605,6 +606,21 @@ def __init__(self, model): width_num_patches += 1 self.num_media_latents = height_num_patches * width_num_patches + # add config for lita + if self.multimodal_cfg['use_lita']: + if self.cfg.mm_cfg.get('lita'): + lita = { + 'lita_video_arch': getattr(self.cfg.mm_cfg.lita, 'lita_video_arch', 'temporal_spatial_pool'), + 'visual_token_format': getattr(self.cfg.mm_cfg.lita, 'visual_token_format', 'v1'), + 'sample_frames': getattr(self.cfg.mm_cfg.lita, 'sample_frames', 1), + } + self.multimodal_cfg['lita'] = lita + else: + self.multimodal_cfg['use_lita'] = False + raise Warning( + 'Use lita has been set True but Lita config not found in the config file' + 'LITA will be disabled for this run.' + ) def clip_max_len(self, maxlen: int) -> int: """clip the max len based on the LM model max sequence length""" @@ -687,6 +703,7 @@ def prepare_batch_at_step( # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, context_length - 1].view(batch_size, -1) + media = None """Prepare batch for each of the inference steps""" attention_mask_repeat = None diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index cd02f5409679..1bd5b618de35 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -31,6 +31,8 @@ DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, + DEFAULT_VID_END_TOKEN, + DEFAULT_VID_START_TOKEN, ) from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids from nemo.collections.nlp.modules.common.text_generation_strategy import model_inference_strategy_dispatcher @@ -144,7 +146,75 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para return output +def decode_time_tokens(tokenizer, text: str, duration: float, time_tokens: list[str], time_token_ids: list[int]): + """Decode the time tokens .... in the text to the actual time in seconds. + TO DO: to do time decoding on output ids instead of text + + Args: + text (str): _description_ + duration (float): the total length of the video in seconds + time_tokens (list[str]): list of time tokens [, , , ..] + time_token_ids (list[str]): list of time token ids [32004, 32005, ....] + """ + output_ids = tokenizer.text_to_ids(text) + num_time_tokens = len(time_token_ids) + # the original code is len(output_ids) - 1 + indices = [j for j in range(len(output_ids)) if output_ids[j] in time_token_ids] + last_processed = -1 + new_output_ids = [] + for j in range(len(indices)): + pred_seq = [int(output_ids[k]) for k in range(last_processed + 1, indices[j])] + new_output_ids.extend(pred_seq) + max_offset = num_time_tokens - 1 + time_token = tokenizer.ids_to_tokens([output_ids[indices[j]]])[0] + time_idx = time_tokens.index(time_token) + time = float(time_idx) * duration / max_offset + time = min(max(time, 0), duration) + time = round(time, 2) + # time_str = '<' + str(time) + '>' + time_str = '<%s>' % str(time) + new_output_ids.extend(tokenizer.text_to_ids(time_str)) + + last_processed = indices[j] + pred_seq = [int(x) for x in output_ids[last_processed + 1 :]] + new_output_ids.extend(pred_seq) + output_ids = new_output_ids + decoded_text = tokenizer.ids_to_text(output_ids) + return decoded_text + + +def encode_time_str(text: str, duration: float, num_time_tokens: int = 100, time_token_template: str = ""): + """ + Encode the common time expression to its time token expression + """ + + def time_to_string(time): + # time is normalized in [0, 1] + max_offset = float(num_time_tokens - 1) + time = int(np.round(max_offset * time)) + return time_token_template.format(t=time) + + def repl(match): + value = float(match.group(1)) / duration + return time_to_string(value) + f"" + + text = re.sub(r"<([\d.]{1,20})s>", repl, text) + text = re.sub(r"\s([\d.]{1,20})s[\s|\.|,|>]", repl, text) + text = re.sub(r"\s([\d.]{1,20}) seconds", repl, text) + text = re.sub(r"\s([\d.]{1,20}) second", repl, text) + + # This is to remove the timestamps from the text + text = re.sub(r"", "", text) + return text.strip() + + def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_params, inference_config, **strategy_args): + use_lita = model.cfg.mm_cfg.get('use_lita', False) + if use_lita: + num_time_tokens = model.cfg.data.get('num_time_tokens', 100) + TIME_TOKEN_TEMPLATE = "" + time_tokens = [TIME_TOKEN_TEMPLATE.format(t=i) for i in range(num_time_tokens)] + time_token_ids = model.tokenizer.tokens_to_ids(time_tokens) model_type = model.cfg.mm_cfg.llm.get("model_type", "nvgpt") conv_template = model.cfg.data.get("conv_template", "nvgpt") @@ -152,6 +222,14 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para for idx, prompt_dict in enumerate(prompt_dict_list): # determine the media type in the prompt_dict media_type_token = inference_config.inference.get("media_type", "image") + if use_lita: + if prompt_dict.get("duration") is not None: + duration = prompt_dict.get("duration") + prompt_dict['prompt'] = encode_time_str( + prompt_dict['prompt'], duration, num_time_tokens, TIME_TOKEN_TEMPLATE + ) + else: + print("duration field is not in prompt file, skipping time encoding.") response = generate( model, inputs=prompt_dict.get('prompt'), @@ -184,7 +262,12 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para r'|', r'\|' ) ) - combined_pattern = re.compile(f'{pattern.pattern}|{pattern_nvgpt.pattern}') + + if use_lita: + pattern_lita = re.compile(rf'{DEFAULT_IM_START_TOKEN[model_type]}(.)+{DEFAULT_IM_END_TOKEN[model_type]}') + combined_pattern = re.compile(f'{pattern_lita.pattern}') + else: + combined_pattern = re.compile(f'{pattern.pattern}|{pattern_nvgpt.pattern}') clean_text = re.sub(combined_pattern, f"<{media_type_token}>", response['sentences'][0]) clean_response = clean_text @@ -204,10 +287,18 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para clean_response = clean_response.rsplit("[/INST] ", 1)[-1] elif conv_template == "llama_3": clean_response = clean_response.rsplit("assistant<|end_header_id|>\n\n", 1)[-1] - clean_response = clean_response.rstrip("<|eot_id|>") + clean_response = re.sub(r"(<\|eot_id\|>)+$", "", clean_response) elif conv_template == "v1": clean_response = clean_response.rsplit("ASSISTANT: ", 1)[-1] + if use_lita: + if prompt_dict.get("duration", None) is not None: + duration = prompt_dict.get("duration") + clean_response = decode_time_tokens( + model.tokenizer, clean_response, duration, time_tokens, time_token_ids + ) + else: + print("duration field is not in prompt file, skipping time decoding.") clean_response = clean_response.strip() response["clean_text"] = clean_text response["clean_response"] = clean_response diff --git a/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_evaluation.py b/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_evaluation.py new file mode 100644 index 000000000000..1427e0983b24 --- /dev/null +++ b/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_evaluation.py @@ -0,0 +1,160 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +""" +This script is used to convert the DVC dataset to the format required by the model evaluation for RTL task. +The DVC dataset should have the below structure: +{ + "-4RXOT_UfpM_3": { # video_name is the unique video file name, extention is .mp4 + "duration": 118.01801801801803, + "timestamps": [ + [5, 58], + [66, 82], + [82, 96] + ], + "sentences": [ + "Apply eyeshadow on the lower area then crease with brush", + "Apply eyeshadow on the outer corner of eyes with brush", + "Apply eyeshadow on the outer half of eyes with brush", + ] + }, + ... +} + +The converted format will be as follows: +[ + { + "video": "-4RXOT_UfpM_3.mp4", + "question_id": "-4RXOT_UfpM_3_0", + "question": "When does \"Apply eyeshadow on the lower area then crease with brush\" happen in the video? Provide a response using only start and end timestamps.", + "ref_answer": "<5> <58> Apply eyeshadow on the lower area then crease with brush", + "duration": 118.01801801801803 + }, + { + "video": "-4RXOT_UfpM_3.mp4", + "question_id": "-4RXOT_UfpM_3_1", + "question": "When is \"Apply eyeshadow on the outer corner of eyes with brush\" depicted in the video? Convey your answer using start and end timestamps exclusively.", + "ref_answer": "<66> <82> Apply eyeshadow on the outer corner of eyes with brush", + "duration": 118.01801801801803 + }, + { + "video": "-4RXOT_UfpM_3.mp4", + "question_id": "-4RXOT_UfpM_3_2", + "question": "When does \"Apply eyeshadow on the outer half of eyes with brush\" happen in the video? Provide a response using only start and end timestamps.", + "ref_answer": "<82> <96> Apply eyeshadow on the outer half of eyes with brush", + "duration": 118.01801801801803 + }, + ..... +] + +For each sentence in the sentences list, we will generate one question for it and the answer will be the sentence itself with the timestamps. +USAGE: +python convert_dvc_dataset_for_evaluation.py --input --output_file --ratio + +""" + +import argparse +import json +import os +import random + + +class RTLConverter: + def __init__(self, input_file, output_file, sample_ratio, ext): + self.input_file = input_file + self.output_file = output_file + self.sample_ratio = sample_ratio + self.desc_prompts = [ + "When does \"%s\" happen in the video?", + "At what point in the video does \"%s\" happen?", + "When is \"%s\" depicted in the video?", + "At what time in the video does \"%s\" take place?", + ] + self.time_prompts = [ + "Answer the question only using start and end timestamps.", + "Provide a response using only start and end timestamps.", + "Convey your answer using start and end timestamps exclusively.", + ] + self.ext = ext + + def convert(self): + converted_data = [] + + # Load JSON data + with open(self.input_file, 'r') as file: + data = json.load(file) + + # Fix random seed for reproducibility + random.seed(42) + + # Randomly sample entries based on the sample ratio + vid_list = list(data.keys()) + sampled_vids = random.sample(vid_list, k=int(len(vid_list) * self.sample_ratio)) + + # Iterate through sampled entries + for vid in sampled_vids: + details = data[vid] + duration = details['duration'] + timestamps = details['timestamps'] + sentences = details['sentences'] + + # Iterate through sentences + for i, sentence in enumerate(sentences): + question_id = f"{vid}_{i}" + desc_prompt = random.choice(self.desc_prompts) + time_prompt = random.choice(self.time_prompts) + start_time, end_time = timestamps[i] + answer = f"<{start_time}> <{end_time}> {sentence}" + + # Construct question + question = (desc_prompt % sentence) + ' ' + time_prompt + + # Create entry in converted data + converted_data.append( + { + "video": vid + self.ext, + "question_id": question_id, + "question": question, + "ref_answer": answer, + "duration": duration, + } + ) + + # Ensure the output directory exists + os.makedirs(os.path.dirname(self.output_file), exist_ok=True) + + # Write converted data to output file + with open(self.output_file, 'w') as file: + json.dump(converted_data, file, indent=2) + + +def main(): + parser = argparse.ArgumentParser(description="Convert makeup QA JSON format") + parser.add_argument("--input", help="Input DVC JSON file", required=True) + parser.add_argument("--output_file", help="Output file", default="rtl_eval.json", required=True) + parser.add_argument("--ratio", help="Sampling ratio between 0 and 1", type=float, default=1.0, required=False) + parser.add_argument("--ext", help="Extension of the video files", default=".mp4", required=False) + args = parser.parse_args() + + if args.ratio < 0 or args.ratio > 1: + raise ValueError("Sampling ratio must be between 0 and 1") + + converter = RTLConverter(args.input, args.output_file, args.ratio, args.ext) + converter.convert() + + +if __name__ == "__main__": + main() diff --git a/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_training.py b/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_training.py new file mode 100644 index 000000000000..a80900e30004 --- /dev/null +++ b/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_training.py @@ -0,0 +1,322 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +""" +This script is used to convert the DVC dataset to the format required by the model training script. +The DVC dataset should have the below structure: +{ + "1043215450": { # video_name is the unique video file name (the extension should be .mp4) + "duration": 125.0, + "timestamps": [ + [0, 5], + [3, 9] + ], + "sentences": [ # For custom caption or event localization task + "Here is your caption 1", + "Here is your caption 2", + ], + "events": [ # For custom event task + "Event 1", + "Event 2", + ] + }, + ... +} + +The converted dataset format is as follows: +[ + # 1st example: dense video captioning (custom event or custom caption task) + { + "id": "xxxx", + "video: "xxxx.mp4", + "conversations": + [ + {"from": "human", "value": "