diff --git a/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse.yaml b/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse.yaml new file mode 100644 index 0000000000000..6145a1a4c462a --- /dev/null +++ b/examples/multimodal/speech_llm/conf/bestow/modular_audio_gpt_config_cross_llama_lhotse.yaml @@ -0,0 +1,329 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: megatron_audio_gpt_bestow_lhotse + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 1000000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + limit_train_batches : 1000 + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 1000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +model_target: nemo.collections.multimodal.speech_llm.models.modular_models.CrossAttendModularAudioGPTModel + +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + pretrained_audio_model: stt_en_fastconformer_transducer_large + freeze_llm: True + freeze_audio_encoder: False + freeze_modality_adapter: False + load_audio_encoder: True + + ## Legacy batch_size configuration + # When used with lhotse, the batch composition is decided by dataloader configs + # and batch size here is only used for deciding gradient accumulation. + # gradient accumulation = global_batch_size / micro_batch_size / data_parallel_size + # where data_parallel_size = num_nodes * num_gpus / TP_size + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + # use_am_tokenizer: True + # override_vocab_size: 1024 + + peft: + peft_scheme: "lora" # can be either lora, adapter, ia3 or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['attention_qkv','attention_dense','mlp_fc1','mlp_fc2'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: ${model.peft.lora_tuning.adapter_dim} + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + perception: + target: nemo.collections.multimodal.speech_llm.modules.perception_modules.AudioPerceptionModule + use_multi_layer_feat: false + xattn: + target: nemo.collections.multimodal.speech_llm.modules.perception_modules.TransformerCrossAttention + num_attention_heads: 8 + attn_score_dropout: 0.1 + attn_layer_dropout: 0.1 + ffn_dropout: 0.1 + hidden_act: "relu" + pre_ln: true + pre_ln_final_layer_norm: true + + multi_layer_feat: + layer_idx_list: [0,16] # layer indices to extract features from + aggregator: + mode: "cat" # ways to combine features from different layers, choices=['cat','sum','mean', 'max', 'min'], default to concat ('cat') + pooling: "avg" # ways to pool features if they have different temporal lengths and align_mode=min, choices=['mean', 'max', 'min'] + align_mode: "min" # if features have different temporal lengths, set `min` to pool to the shortest length or `max` to repeat to the longest. + + modality_adapter: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained AM: + # output_dim: null + # encoder: null + # preprocessor: null + + data: + end_string: "[EOG]" + train_ds: + # Example of how to specify paths to multiple datasets + # manifest_filepath: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'audio_filepath': 'audio1.wav', 'offset': 0.0, 'duration': 12.3, 'question': 'transcribe this audio', 'answer': 'I have a dream...'} + # the 'answer' field can also be 'text', and a default 'question' field is added if missing in manigests, so as to work with ASR manifests + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type (tar and + # non-tar). + # See audio_text_qa_dataset.py for details. + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'context' + answer_key: 'answer' + add_eos: True + # add_eos: False + end_string: ${model.data.end_string} + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "[INST]\n<>\nPlease answer the following based on the previous speech feature.\n<>\n\n{context}[/INST] {answer}" + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + use_lhotse: True + text_field : "text" + batch_duration : 80 # 0 + quadratic_duration : 30 + num_buckets : 30 + buffer_size : 10000 + shuffle_buffer_size : 10000 + duration_bins: null + + validation_ds: + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + end_string: ${model.data.end_string} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 128 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + log_every_n_steps: 10 + metric: + name: "wer" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml index e2ef61a8046d8..62b9030b47082 100644 --- a/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml +++ b/examples/multimodal/speech_llm/conf/modular_audio_gpt_config_eval.yaml @@ -81,7 +81,6 @@ model: data: test_ds: - manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. names: null # Names of the corresponding datasets used to log metrics. global_batch_size: 1 micro_batch_size: 1 diff --git a/examples/multimodal/speech_llm/conf/salm/modular_audio_gpt_config_llama_lhotse.yaml b/examples/multimodal/speech_llm/conf/salm/modular_audio_gpt_config_llama_lhotse.yaml new file mode 100644 index 0000000000000..cc848562f70e4 --- /dev/null +++ b/examples/multimodal/speech_llm/conf/salm/modular_audio_gpt_config_llama_lhotse.yaml @@ -0,0 +1,317 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: megatron_audio_gpt_salm_lhotse + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 1000000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + limit_train_batches : 1000 + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 1000 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + pretrained_audio_model: stt_en_fastconformer_transducer_large + freeze_llm: True + freeze_audio_encoder: False + freeze_modality_adapter: False + load_audio_encoder: True + + ## Legacy batch_size configuration + # When used with lhotse, the batch composition is decided by dataloader configs + # and batch size here is only used for deciding gradient accumulation. + # gradient accumulation = global_batch_size / micro_batch_size / data_parallel_size + # where data_parallel_size = num_nodes * num_gpus / TP_size + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + # use_am_tokenizer: True + # override_vocab_size: 1024 + + peft: + peft_scheme: "lora" # can be either lora, adapter, ia3 or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['attention_qkv','attention_dense','mlp_fc1','mlp_fc2'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) + adapter_dim: 32 + alpha: ${model.peft.lora_tuning.adapter_dim} + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + perception: + target: nemo.collections.multimodal.speech_llm.modules.perception_modules.AudioPerceptionModule + use_multi_layer_feat: false + multi_layer_feat: + layer_idx_list: [0,16] # layer indices to extract features from + aggregator: + mode: "cat" # ways to combine features from different layers, choices=['cat','sum','mean', 'max', 'min'], default to concat ('cat') + pooling: "avg" # ways to pool features if they have different temporal lengths and align_mode=min, choices=['mean', 'max', 'min'] + align_mode: "min" # if features have different temporal lengths, set `min` to pool to the shortest length or `max` to repeat to the longest. + + modality_adapter: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained AM: + # output_dim: null + # encoder: null + # preprocessor: null + + data: + end_string: "[EOG]" + train_ds: + # Example of how to specify paths to multiple datasets + # manifest_filepath: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'audio_filepath': 'audio1.wav', 'offset': 0.0, 'duration': 12.3, 'question': 'transcribe this audio', 'answer': 'I have a dream...'} + # the 'answer' field can also be 'text', and a default 'question' field is added if missing in manigests, so as to work with ASR manifests + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type (tar and + # non-tar). + # See audio_text_qa_dataset.py for details. + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'context' + answer_key: 'answer' + add_eos: True + # add_eos: False + end_string: ${model.data.end_string} + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "[INST]\n<>\nPlease answer the following based on the previous speech feature.\n<>\n\n{context}[/INST] {answer}" + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + use_lhotse: True + text_field : "text" + batch_duration : 80 # 0 + quadratic_duration : 30 + num_buckets : 30 + buffer_size : 10000 + shuffle_buffer_size : 10000 + duration_bins: null + + validation_ds: + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + end_string: ${model.data.end_string} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 128 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + log_every_n_steps: 10 + metric: + name: "wer" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_config.yaml b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_config.yaml new file mode 100644 index 0000000000000..a76de9e312e22 --- /dev/null +++ b/examples/multimodal/speech_llm/conf/salm/modular_audio_t5_config.yaml @@ -0,0 +1,334 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: megatron_audio_t5_salm_lhotse + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 9999 + max_steps: 1000000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + limit_train_batches : 1000 + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: 1.0 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: 1.0 + accumulate_grad_batches: 1 + +model_target: nemo.collections.multimodal.speech_llm.models.modular_t5_models.ModularizedAudioT5Model +exp_manager: + # explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + +model: + virtual_prompt_style: 'no-prompts' # make cls happy + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + pretrained_audio_model: stt_en_fastconformer_transducer_large + freeze_llm: True + freeze_audio_encoder: False + freeze_modality_adapter: False + load_audio_encoder: True + + global_batch_size: 128 + micro_batch_size: 4 + language_model_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: False + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + answer_only_loss: True + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + + # use_am_tokenizer: True + # override_vocab_size: 1024 + + lora_tuning: + kqv_adapter_dim: 128 + kv_adapter_dim: 64 + q_adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + + peft: + peft_scheme: "adapter" # can be either adapter,ia3, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre' or 'post', 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + perception: + target: nemo.collections.multimodal.speech_llm.modules.perception_modules.AudioPerceptionModule + use_multi_layer_feat: false + + modality_adapter: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 1024 + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 2 + d_model: 512 + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + + # Reduction parameters: Can be used to add another subsampling layer at a given position. + # Having a 2x reduction will speedup the training and inference speech while keeping similar WER. + # Adding it at the end will give the best WER while adding it at the beginning will give the best speedup. + reduction: null # pooling, striding, or null + reduction_position: null # Encoder block index or -1 for subsampling at the end of encoder + reduction_factor: 1 + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + # the following are read from the pretrained AM: + # output_dim: null + # encoder: null + # preprocessor: null + + data: + train_ds: + # Example of how to specify paths to multiple datasets + # manifest_filepath: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'audio_filepath': 'audio1.wav', 'offset': 0.0, 'duration': 12.3, 'question': 'transcribe this audio', 'answer': 'I have a dream...'} + # the 'answer' field can also be 'text', and a default 'question' field is added if missing in manigests, so as to work with ASR manifests + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type (tar and + # non-tar). + # See audio_text_qa_dataset.py for details. + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'context' + answer_key: 'answer' + add_eos: True + # add_eos: False + add_sep: True + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + # sample_alpha: 0.1 + use_lhotse: True + text_field : "text" + batch_duration : 80 # 0 + quadratic_duration : 30 + max_open_streams: 50 + num_buckets : 30 + buffer_size : 10000 + shuffle_buffer_size : 10000 + duration_bins: [2.92,3.474,3.924,4.335,4.728,5.11,5.487,5.872,6.288,6.696,7.128,7.62,8.208,8.934,9.883,10.56,11.22,11.88,12.51,13.05,13.59,14.13,14.64,15.17875,15.81,16.54,17.37,18.241,19.18] + # sample_alpha: 0.1 + + validation_ds: + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: False + context_key: ${model.data.train_ds.context_key} + answer_key: ${model.data.train_ds.answer_key} + add_eos: ${model.data.train_ds.add_eos} + add_sep: ${model.data.train_ds.add_sep} + add_bos: ${model.data.train_ds.add_bos} + separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + write_predictions_to_file: False + output_file_path_prefix: null # Prefix of the file to write predictions to. + truncation_field: "context" # Options: ['context', 'answer'] + index_mapping_dir: null # Path to a directory to write index mapping files. + prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + tokens_to_generate: 128 + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + log_every_n_steps: 1 + metric: + name: "wer" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + # make model init happy + num_workers: 0 + # test_ds: + # manifest_filepath: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + # names: null # Names of the corresponding datasets used to log metrics. + # global_batch_size: ${model.global_batch_size} + # micro_batch_size: ${model.micro_batch_size} + # shuffle: False + # num_workers: 4 + # pin_memory: True + # max_seq_length: 2048 + # min_seq_length: 1 + # drop_last: False + # context_key: 'input' + # label_key: 'output' + # add_eos: ${model.data.train_ds.add_eos} + # add_sep: ${model.data.train_ds.add_sep} + # add_bos: ${model.data.train_ds.add_bos} + # separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline} + # write_predictions_to_file: False + # output_file_path_prefix: null # Prefix of the file to write predictions to. + # truncation_field: "context" # Options: ['context', 'answer'] + # index_mapping_dir: null # Path to a directory to write index mapping files. + # prompt_template: ${model.data.train_ds.prompt_template} + # # ASR configs + # sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + + # metric: + # name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + # average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + # num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/multimodal/speech_llm/modular_audio_gpt_train.py b/examples/multimodal/speech_llm/modular_audio_gpt_train.py index 04bff37e7a3f7..ad8aacef2af2b 100644 --- a/examples/multimodal/speech_llm/modular_audio_gpt_train.py +++ b/examples/multimodal/speech_llm/modular_audio_gpt_train.py @@ -18,7 +18,7 @@ from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder from nemo.core.config import hydra_runner -from nemo.utils import logging +from nemo.utils import logging, model_utils from nemo.utils.exp_manager import exp_manager mp.set_start_method("spawn", force=True) @@ -61,7 +61,11 @@ def main(cfg) -> None: # update resume from checkpoint found by exp_manager logging.info(f'Resuming training from checkpoint: {trainer.ckpt_path}') - model = ModularAudioGPTModel.restore_from_pretrained_models(cfg, trainer=trainer) + if hasattr(cfg, 'model_target'): + imported_cls = model_utils.import_class_by_path(cfg.model_target) + else: + imported_cls = ModularAudioGPTModel + model = imported_cls.restore_from_pretrained_models(cfg, trainer=trainer) trainer.fit(model) diff --git a/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py b/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py index 7d0ee6afbfa2b..94d2cd50a240b 100644 --- a/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py +++ b/nemo/collections/multimodal/speech_llm/data/audio_text_dataset.py @@ -32,6 +32,8 @@ from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType from nemo.collections.common.parts.preprocessing import collections from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import ( + TextProcessing, + build_loss_mask, ceil_to_nearest, get_num_samples_from_files, maybe_cast_to_list, @@ -90,19 +92,6 @@ def _audio_collate_fn(audio_signals, audio_lengths): return audio_signals_padded, audio_lengths -def _build_loss_mask(processed_example: Dict, answer_only_loss: bool = True): - """Pad input_ids in batch to max batch length while building loss mask""" - # function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py - input_ids = processed_example['input_ids'] - answer_start_idx = processed_example['answer_start_idx'] - if answer_only_loss: - loss_mask = [float(idx >= answer_start_idx) for idx in range(len(input_ids))] - else: - loss_mask = [1.0] * len(input_ids) - - return loss_mask - - def _collate_item(item: Union[torch.Tensor, np.ndarray, List], max_length: int, pad_id: int = 0): # function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py item = maybe_cast_to_list(item) @@ -132,7 +121,7 @@ def _speechllm_audio_text_collate_fn( context_lengths = torch.LongTensor([item['context_length'] for item in batch]) answers = [item['answer_ids'] for item in batch] - loss_mask = [_build_loss_mask(item)[1:] for item in batch] + loss_mask = [build_loss_mask(item)[1:] for item in batch] max_length = max([len(x) for x in input_ids]) + tokens_to_generate # increase max length to nearest multiple of 4 or 8 @@ -205,197 +194,6 @@ def _speechllm_multi_audio_text_collate_fn( return batch -class TextProcessing(object): - """ - Text processing pipeline for AudioTextDataset and TarredAudioTextDataset. - This class is adapted from the one used in nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py - """ - - def __init__( - self, - tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec', - max_seq_length: int = 1024, - min_seq_length: int = 1, - add_bos: bool = False, - add_eos: bool = True, - add_sep: bool = False, - sep_id: Optional[int] = None, - seed: int = 1234, - separate_prompt_and_response_with_newline: bool = False, - answer_only_loss: bool = True, - truncation_field: str = "answer", - pad_to_max_length: bool = False, # (@adithyare) allows for much faster training especially in PEFT settings. - prompt_template: str = None, - virtual_tokens: int = 0, - tokens_to_generate: int = 0, - context_key: str = 'context', - answer_key: str = 'answer', - end_string: Optional[str] = None, - sample_alpha: Optional[float] = None, - audio_locator: Optional[str] = None, - ): - self.context_key = context_key - self.answer_key = answer_key - self.tokenizer = tokenizer - self.max_seq_length = max_seq_length - self.min_seq_length = min_seq_length - self.seed = seed - self.separate_prompt_and_response_with_newline = separate_prompt_and_response_with_newline - self.answer_only_loss = answer_only_loss - self.truncation_field = truncation_field - self.pad_to_max_length = pad_to_max_length - self.prompt_template = prompt_template - self.virtual_tokens = virtual_tokens - self.tokens_to_generate = tokens_to_generate - self.add_bos = add_bos - self.add_eos = add_eos - self.add_sep = add_sep - self.end_string = end_string - self.sample_alpha = sample_alpha - self.audio_locator = audio_locator - - if add_bos and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0: - self.bos_id = tokenizer.bos_id - else: - self.bos_id = None - - if add_eos and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0: - self.eos_id = tokenizer.eos_id - else: - self.eos_id = None - - if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0: - self.pad_id = tokenizer.pad_id - else: - self.pad_id = self.eos_id if self.eos_id is not None else 0 - - self.sep_id = sep_id if add_sep else None - - if self.prompt_template is not None: - # When providing things like newlines in the prompt template via the CLI, they are escaped. This line unescapes them. - self.prompt_template = self.prompt_template.encode('utf-8').decode('unicode_escape') - assert self.truncation_field in ["answer", "context"] - - def _process_example(self, context: str, output: str): - """ - Create an example by concatenating text and answer. - Truncation is carried out when needed, but it is performed only on the prompt side. - BOS, EOS, and SEP, are added if specified. - - function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py - """ - if self.prompt_template is not None: - if self.context_key not in self.prompt_template or self.answer_key not in self.prompt_template: - if "input" in self.prompt_template and "output" in self.prompt_template: - logging.warning( - f"Using 'input' and 'output' as context and answer keys, since given ones ({self.context_key}, {self.answer_key}) are not found in the prompt template: {self.prompt_template}.", - mode=logging_mode.ONCE, - ) - self.context_key = "input" - self.answer_key = "output" - assert f'{{{self.context_key}}}' in self.prompt_template - assert f'{{{self.answer_key}}}' in self.prompt_template - # Make sure that '{output}' always occurs at the end of the prompt template string - assert self.prompt_template.index(f'{{{self.answer_key}}}') == len(self.prompt_template) - len( - f'{{{self.answer_key}}}' - ) - # Get the context by replacing only the input - original_context = context - context = ( - self.prompt_template.replace(f'{{{self.context_key}}}', context) - .replace(f'{{{self.answer_key}}}', '') - .strip(' ') - ) - # Replace the input and output placeholders with the actual input and output - text = self.prompt_template.replace(f'{{{self.context_key}}}', original_context).replace( - f'{{{self.answer_key}}}', output - ) - - elif self.separate_prompt_and_response_with_newline: - text = context + '\n' + output - else: - text = context + ' ' + output - - if self.virtual_tokens: - # (@adithyare) we are going to insert "pad/eos" tokens in the beginning of the text and context - # these pad/eos tokens are placeholders for virtual tokens - pre_pad = [self.tokenizer.eos_id] * self.virtual_tokens - else: - pre_pad = [] - answer_text = text[len(context) :] - answer_ids = pre_pad + self.tokenizer.text_to_ids(answer_text, self.sample_alpha) - if self.end_string: - answer_ids += self.tokenizer.text_to_ids(self.end_string) - - if self.audio_locator is None: - # signle audio case - context_ids = self.tokenizer.text_to_ids(context) - context_start_idx = [0] - else: - # multiple audio case - context_ids = [] - context_start_idx = [] - for context_seg in context.split(self.audio_locator): - context_start_idx.append(len(context_ids)) - context_ids.extend(self.tokenizer.text_to_ids(context_seg)) - context_ids = pre_pad + context_ids - context_start_idx = [x + len(pre_pad) for x in context_start_idx] - - # for the long context cases, collate_fn includes self.tokens_to_generate for padding - total_ids = len(context_ids) + max(len(answer_ids), self.tokens_to_generate) - if self.add_bos: - total_ids += 1 - if self.add_sep: - total_ids += 1 - # Only training need to consider eos token - if self.add_eos and self.tokens_to_generate == 0: - total_ids += 1 - - # If the total number of token is greater than the max, we will try to truncate the answer - if total_ids > self.max_seq_length: - truncation_length = total_ids - self.max_seq_length - if self.truncation_field == "answer": - answer_ids = answer_ids[: -min(truncation_length, len(answer_ids))] - elif self.truncation_field == "context": - context_ids = context_ids[: -min(truncation_length, len(context_ids))] - - input_ids = context_ids - answer_start_idx = len(input_ids) - - # Adds bos token in the start - if self.add_bos: - context_ids = [self.tokenizer.bos_id] + context_ids - input_ids = [self.tokenizer.bos_id] + input_ids - answer_start_idx += 1 - - # Adds sep token between text/prompt and answer - if self.add_sep: - context_ids = context_ids + [self.sep_id] - input_ids = input_ids + [self.sep_id] - answer_start_idx += 1 - - input_ids = input_ids + answer_ids - - # Only training need to consider eos token - if self.add_eos and self.tokens_to_generate == 0: - input_ids = input_ids + [self.tokenizer.eos_id] - - if len(input_ids) > self.max_seq_length: - logging.warning(f'Input ids length {len(input_ids)} exceed max sequence length {self.max_seq_length}') - input_ids = input_ids[: self.max_seq_length] - - processed_example = { - 'input_ids': input_ids, - 'answer_start_idx': answer_start_idx, - 'context_ids': context_ids, - 'context_length': len(context_ids), - 'answer_ids': answer_ids, - 'context_start_idx': context_start_idx, - } - - return processed_example - - class AudioTextDataset(TextProcessing, Dataset): """ Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds). diff --git a/nemo/collections/multimodal/speech_llm/data/build_dataset.py b/nemo/collections/multimodal/speech_llm/data/build_dataset.py new file mode 100644 index 0000000000000..b042386cea3b3 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/data/build_dataset.py @@ -0,0 +1,229 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from pathlib import Path + +import torch +from megatron.core import parallel_state +from omegaconf.omegaconf import OmegaConf + +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.multimodal.speech_llm.data.audio_text_dataset import ( + get_audio_text_dataset_from_config, + get_tarred_audio_text_dataset_from_config, +) +from nemo.collections.multimodal.speech_llm.data.lhotse_dataset import LhotseAudioQuestionAnswerDataset +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import TextProcessing +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( + MegatronPretrainingBatchSampler, +) +from nemo.utils import logging + + +def build_speechllm_dataset(model_instance, data_cfg, is_train): + if 'augmentor' in data_cfg: + augmentor = process_augmentations( + data_cfg['augmentor'], global_rank=model_instance.global_rank, world_size=model_instance.world_size + ) + else: + augmentor = None + + # Check dataset max_seq_legnth and max_position_embeddings size + if ( + model_instance.cfg.get('position_embedding_type', None) in [None, 'learned_absolute'] + and data_cfg.max_seq_length > model_instance.cfg.max_position_embeddings + ): + logging.warning( + f"Set dataset max_seq_length to max_position_embeddings {model_instance.cfg.max_position_embeddings} if using learned_absolute position embedding" + ) + data_cfg.max_seq_length = model_instance.cfg.max_position_embeddings + + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type. + if data_cfg.get("use_lhotse"): + tp = TextProcessing( + model_instance.tokenizer, + max_seq_length=data_cfg["max_seq_length"], + min_seq_length=data_cfg["min_seq_length"], + add_bos=data_cfg.get('add_bos', False), + add_eos=data_cfg.get('add_eos', False), + add_sep=data_cfg.get('add_sep', False), + sep_id=model_instance.sep_id, + seed=data_cfg.get('seed', 1234), + separate_prompt_and_response_with_newline=data_cfg.get('separate_prompt_and_response_with_newline', True), + answer_only_loss=model_instance.cfg.get('answer_only_loss', True), + truncation_field=data_cfg.get('truncation_field', 'context'), + pad_to_max_length=data_cfg.get('pad_to_max_length', False), + prompt_template=data_cfg.get('prompt_template', None), + virtual_tokens=model_instance.virtual_tokens, + tokens_to_generate=data_cfg.get( + 'tokens_to_generate', 0 + ), # used at inference time to allocate tensor positions for tokens that will be generated by inf procedure. + context_key=data_cfg.get('context_key', 'context'), + answer_key=data_cfg.get('answer_key', 'answer'), + end_string=data_cfg.get('end_string', None), + sample_alpha=data_cfg.get('sample_alpha', None), + ) + return LhotseAudioQuestionAnswerDataset( + tp, + default_context="answer the question according to the previous audio", + tokens_to_generate=data_cfg.get('tokens_to_generate', 0), + pad_to_max_length=data_cfg.get('pad_to_max_length', False), + max_seq_length=data_cfg["max_seq_length"], + context_key=data_cfg.get('context_key', "context"), + default_context_key=data_cfg.get('default_context_key', "default_context"), + ) + + # Notably, the data weights are controlled by either bucketing_weights + # or concat_sampling_probabilities depending on the dataset type. + if data_cfg.get('is_tarred', False): + return get_tarred_audio_text_dataset_from_config( + config=data_cfg, + tokenizer=model_instance.tokenizer, + augmentor=augmentor, + sep_id=model_instance.sep_id, + answer_only_loss=model_instance.cfg.get('answer_only_loss', True), + virtual_tokens=model_instance.virtual_tokens, + global_rank=parallel_state.get_data_parallel_rank(), + world_size=parallel_state.get_data_parallel_world_size(), + ) + else: + return get_audio_text_dataset_from_config( + manifest_filepath=data_cfg.manifest_filepath, + config=data_cfg, + tokenizer=model_instance.tokenizer, + augmentor=augmentor, + is_train=is_train, + sep_id=model_instance.sep_id, + answer_only_loss=model_instance.cfg.get('answer_only_loss', True), + virtual_tokens=model_instance.virtual_tokens, + ) + + +def build_speechllm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict=False, is_eval=False): + """Buld dataloader given an input dataset.""" + if data_cfg.get("use_lhotse"): + if is_eval == False and is_predict == False: + return get_lhotse_dataloader_from_config( + data_cfg, + global_rank=parallel_state.get_data_parallel_rank(), + world_size=parallel_state.get_data_parallel_world_size(), + dataset=dataset, + ) + # for eval, we need to create separate dataset so as to report splitted numbers + else: + dls = [] + if hasattr(data_cfg, 'manifest_filepath'): + manifest_filepath = data_cfg.manifest_filepath + for cur_manifest_filepath in manifest_filepath: + conf = copy.deepcopy(data_cfg) + conf['manifest_filepath'] = cur_manifest_filepath + dls.append( + get_lhotse_dataloader_from_config( + conf, + global_rank=parallel_state.get_data_parallel_rank(), + world_size=parallel_state.get_data_parallel_world_size(), + dataset=dataset, + ) + ) + else: + input_cfg = data_cfg.input_cfg + if isinstance(input_cfg, (str, Path)): + # Resolve /path/to/input_cfg.yaml into config contents if needed. + input_cfg = OmegaConf.load(input_cfg) + assert len(input_cfg) == 1, "Only one dataset with multiple manifest paths is supported for eval" + data_cfg.input_cfg = input_cfg + # for getting names + manifest_filepath = [ic.manifest_filepath for ic in input_cfg[0].input_cfg] + for cur_input_cfg in input_cfg[0].input_cfg: + conf = copy.deepcopy(data_cfg) + conf.input_cfg[0].input_cfg = [cur_input_cfg] + dls.append( + get_lhotse_dataloader_from_config( + conf, + global_rank=parallel_state.get_data_parallel_rank(), + world_size=parallel_state.get_data_parallel_world_size(), + dataset=dataset, + ) + ) + + if 'names' not in data_cfg: + names = [] + for cur_manifest_filepath in manifest_filepath: + names.append(Path(cur_manifest_filepath).stem) + OmegaConf.update(data_cfg, 'names', names, force_add=True) + logging.info(f'Update dataset names as {names}') + return dls + + logging.info(f'Building dataloader with consumed samples: {consumed_samples}') + if isinstance(dataset, BlendableDataset): + collate_fn = dataset.datasets[0].collate_fn + elif hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + elif hasattr(dataset.datasets[0], 'collate_fn'): + # support datasets that are lists of entries + collate_fn = dataset.datasets[0].collate_fn + else: + # support datasets that are lists of lists + collate_fn = dataset.datasets[0].datasets[0].collate_fn + + if isinstance(dataset, torch.utils.data.IterableDataset): + data_parallel_size = parallel_state.get_data_parallel_world_size() + num_micro_batches = data_cfg.global_batch_size // (data_cfg.micro_batch_size * data_parallel_size) + global_batch_size_on_this_data_parallel_rank = num_micro_batches * data_cfg.micro_batch_size + + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=collate_fn, + shuffle=False, + batch_size=global_batch_size_on_this_data_parallel_rank, + drop_last=True, + num_workers=data_cfg.num_workers, + pin_memory=data_cfg.pin_memory, + ) + return dataloader + + if is_predict: + # MegatronPretrainingBatchSampler doesn't work with trainer.predict() + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=collate_fn, + batch_size=data_cfg.micro_batch_size, + num_workers=data_cfg.num_workers, + pin_memory=data_cfg.pin_memory, + ) + return dataloader + + batch_sampler = MegatronPretrainingBatchSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=data_cfg.micro_batch_size, + global_batch_size=data_cfg.global_batch_size, + data_parallel_rank=parallel_state.get_data_parallel_rank(), + data_parallel_size=parallel_state.get_data_parallel_world_size(), + drop_last=data_cfg.drop_last, + pad_samples_to_global_batch_size=not data_cfg.drop_last, + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=data_cfg.num_workers, + pin_memory=data_cfg.pin_memory, + persistent_workers=True if data_cfg.num_workers > 0 else False, + ) + return dataloader diff --git a/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py new file mode 100644 index 0000000000000..d3e70343d5071 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py @@ -0,0 +1,166 @@ +import torch.utils.data +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_vectors as collate_vectors_lhotse + +from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import ( + TextProcessing, + build_loss_mask, + ceil_to_nearest, +) + + +def collate_vectors(items, max_length: int, padding_value): + vectors = collate_vectors_lhotse(items, padding_value=padding_value) + if max_length > vectors.size(1): + vectors = torch.cat( + [vectors, padding_value * torch.ones(vectors.size(0), max_length - vectors.size(1), dtype=vectors.dtype)], + dim=1, + ) + if items[0].shape[0] < 1: + vectors = vectors.long() + return vectors + + +class LhotseAudioQuestionAnswerDataset(torch.utils.data.Dataset): + """ + This dataset is based on Lhotse ASR dataset from ``audio_to_text_lhotse.py`` + and ``TarredAudioQuestionAnswerDataset`` from ``audio_text_qa_dataset.py``. + + Unlike native NeMo datasets, Lhotse dataset defines only the mapping from + a CutSet (meta-data) to a mini-batch with PyTorch tensors. + Specifically, it performs tokenization, I/O, augmentation, and feature extraction (if any). + Managing data, sampling, de-duplication across workers/nodes etc. is all handled + by Lhotse samplers instead. + + Args: + text_processor: TextProcessing object + default_context: Default question to use if no question is provided + tokens_to_generate: Number of tokens to generate during inference + pad_to_max_length: Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch. + max_seq_length: Maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. + context_key: Key to use for the context in your JSONL file + default_context_key: Key to use for the default context in lhotse yaml + """ + + def __init__( + self, + text_processor: TextProcessing, + default_context: str, + tokens_to_generate: int, + pad_to_max_length: bool, + max_seq_length: int, + context_key: str = "context", + default_context_key: str = "default_context", + ): + super().__init__() + self.text_processor = text_processor + self.load_audio = AudioSamples(fault_tolerant=True) + self.tokens_to_generate = tokens_to_generate + self.pad_to_max_length = pad_to_max_length + self.max_seq_length = max_seq_length + + self.default_context = default_context + self.context_key = context_key + self.default_context_key = default_context_key + + def __getitem__(self, cuts) -> dict[str, torch.Tensor | list[str] | dict]: + cuts = cuts.sort_by_duration() + + audio, audio_lens, cuts = self.load_audio(cuts) + + return_batch = {} + audio_ratio = [] + for id, cut in enumerate(cuts): + audio_ratio.append(1.0) + + for _, cut in enumerate(cuts): + if hasattr(cut, self.context_key): + cut.context = getattr(cut, self.context_key) + elif hasattr(cut, self.default_context_key): + cut.context = getattr(cut, self.default_context_key) + else: + cut.context = self.default_context + + metadata = [] + for id, cut in enumerate(cuts): + metadata.append({'audio_filepath': cut.id + '.wav'}) + + collated_text_data = collate_text_data( + cuts=cuts, + default_context=self.default_context, + text_processor=self.text_processor, + tokens_to_generate=self.tokens_to_generate, + pad_to_max_length=self.pad_to_max_length, + max_seq_length=self.max_seq_length, + ) + return_batch.update( + { + "sample_ids": list(cuts.ids), + "audio_signal": audio, + "audio_signal_length": audio_lens, + "audio_ratio": torch.FloatTensor(audio_ratio), + "metadata": metadata, + **collated_text_data, + } + ) + + return return_batch + + +def collate_text_data( + cuts, + default_context: str, + text_processor: TextProcessing, + tokens_to_generate: int, + pad_to_max_length: bool, + max_seq_length: int, +) -> dict: + """Perform text collation equivalent to nemo/collections/multimodal/data/audio_text_qa_dataset.py:121""" + batch_size = len(cuts) + pad_id = text_processor.pad_id + examples = [ + { + k: torch.as_tensor(v) + for k, v in text_processor._process_example( + context=cut.context, + output=cut.supervisions[0].text, + ).items() + } + for cut in cuts + ] + fields = as_dict(examples) + + def get_max_len(input_list): + return max([len(x) for x in input_list]) + + max_length = tokens_to_generate + max( + get_max_len(fields["input_ids"]), get_max_len(fields["context_ids"]), get_max_len(fields["answer_ids"]) + ) + # increase max length to nearest multiple of 4 or 8 + if pad_to_max_length: + max_length = max_seq_length + else: + max_length = min(max_seq_length, ceil_to_nearest(max_length, 8)) + + all_tokens = collate_vectors(fields["input_ids"], max_length=max_length, padding_value=pad_id) + full_lengths = torch.LongTensor([len(item) for item in fields["input_ids"]]) + + assert max_length <= max_seq_length, f"{max_length=} <= {max_seq_length=}" + + return { + "tokens": all_tokens[:, :-1], + "tokens_length": full_lengths - 1, + "labels": all_tokens[:, 1:], + "loss_mask": collate_vectors( + [torch.as_tensor(build_loss_mask(item)) for item in examples], max_length=max_length, padding_value=0 + )[:, 1:], + "position_ids": torch.arange(max_length, dtype=torch.long).repeat(batch_size, 1), + "contexts": collate_vectors(fields["context_ids"], max_length=max_length, padding_value=pad_id), + "context_lengths": torch.LongTensor([len(seq) for seq in fields["context_ids"]]), + "answers": collate_vectors(fields["answer_ids"], max_length=max_length, padding_value=pad_id), + "max_length": torch.LongTensor([max_length] * batch_size), + } + + +def as_dict(arg: list[dict]) -> dict[str, list]: + return {k: [item[k] for item in arg] for k in arg[0].keys()} diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 39bc37c33e56e..cce74e7b6a1df 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -29,12 +29,11 @@ from nemo.collections.asr.models import ASRModel, EncDecSpeakerLabelModel from nemo.collections.asr.parts.mixins.transcription import move_to_device -from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.collections.asr.parts.utils.eval_utils import remove_punctuations from nemo.collections.common.metrics import MetricStringToTorchMetric, TextMetricsSet -from nemo.collections.multimodal.speech_llm.data.audio_text_dataset import ( - get_audio_text_dataset_from_config, - get_tarred_audio_text_dataset_from_config, +from nemo.collections.multimodal.speech_llm.data.build_dataset import ( + build_speechllm_dataloader, + build_speechllm_dataset, ) from nemo.collections.multimodal.speech_llm.modules.common.audio_text_generation_utils import generate from nemo.collections.multimodal.speech_llm.modules.perception_modules import ( @@ -43,10 +42,6 @@ ) from nemo.collections.multimodal.speech_llm.parts.mixins.adapter_mixin import SpeechLLMAdapterMixin from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import get_nested_dict_value -from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset -from nemo.collections.nlp.data.language_modeling.megatron.megatron_batch_samplers import ( - MegatronPretrainingBatchSampler, -) from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel from nemo.collections.nlp.modules.common.megatron.utils import ( @@ -59,7 +54,7 @@ from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo from nemo.core.classes.mixins import adapter_mixins -from nemo.utils import AppState, logging +from nemo.utils import AppState, logging, model_utils from nemo.utils.model_utils import inject_model_parallel_rank try: @@ -88,15 +83,24 @@ class ModularAudioGPTModel(SpeechLLMAdapterMixin, MegatronGPTSFTModel): """Modularized speech GPT model.""" + def setup_perception_modules(self, cfg): + if 'target' in cfg.perception: + imported_cls = model_utils.import_class_by_path(cfg.perception.target) + self.perception = imported_cls(cfg=cfg.perception) + else: + self.perception = ( + AudioPerceptionModule(cfg=cfg.perception) + if "encoders" not in cfg.perception + else MultiAudioPerceptionModule(cfg=cfg.perception) + ) + def __init__(self, cfg: DictConfig, trainer: Trainer): self.cfg = cfg super().__init__(cfg, trainer) + # handle the case where the batch size from dynamic bucketting is not divisible in lhotse + self.enforce_divisible_batch = False + self.setup_perception_modules(cfg) - self.perception = ( - AudioPerceptionModule(cfg=cfg.perception) - if "encoders" not in cfg.perception - else MultiAudioPerceptionModule(cfg=cfg.perception) - ) # print out params in more details self.summarize(max_depth=2) @@ -121,11 +125,14 @@ def setup_optimizer_param_groups(self): Override parent method to setup optimizer groups for training/freezing different parts of the model. """ known_groups = [] - if self.cfg.get('freeze_llm', True): - for param in self.model.parameters(): - param.requires_grad = False + self.unfreeze() + freeze_llm = self.cfg.get('freeze_llm', True) + if freeze_llm: known_groups.append('model.') + for param in self.model.parameters(): + param.requires_grad = not freeze_llm + if self.cfg.get('freeze_audio_encoder', False): # freeze speaker model if there is any if self.cfg.perception.get("speaker_model", None) is not None: @@ -362,6 +369,15 @@ def forward( """ Forward pass of the model. We prepend audio embeddings to the instruction and label text tokens as the LLM input. """ + if 'audio_ratio' in audio_batch: + self.log( + 'local_batch_size', + audio_batch['audio_ratio'].shape[0], + prog_bar=True, + batch_size=1, + rank_zero_only=False, + ) + encoder_input, attention_mask, labels, loss_mask, _ = self.prepare_llm_input(audio_batch) if self.mcore_gpt: output = self.model( @@ -523,109 +539,10 @@ def loss_func(output_tensor): return fwd_output_and_loss_func def _build_dataset(self, data_cfg, is_train=True): - if 'augmentor' in data_cfg: - augmentor = process_augmentations( - data_cfg['augmentor'], global_rank=self.global_rank, world_size=self.world_size - ) - else: - augmentor = None + return build_speechllm_dataset(self, data_cfg, is_train) - # Check dataset max_seq_legnth and max_position_embeddings size - if ( - self.cfg.get('position_embedding_type', None) in [None, 'learned_absolute'] - and data_cfg.max_seq_length > self.cfg.max_position_embeddings - ): - logging.warning( - f"Set dataset max_seq_length to max_position_embeddings {self.cfg.max_position_embeddings} if using learned_absolute position embedding" - ) - data_cfg.max_seq_length = self.cfg.max_position_embeddings - - # Notably, the data weights are controlled by either bucketing_weights - # or concat_sampling_probabilities depending on the dataset type. - if data_cfg.get('is_tarred', False): - return get_tarred_audio_text_dataset_from_config( - config=data_cfg, - tokenizer=self.tokenizer, - augmentor=augmentor, - sep_id=self.sep_id, - answer_only_loss=self.cfg.get('answer_only_loss', True), - virtual_tokens=self.virtual_tokens, - global_rank=parallel_state.get_data_parallel_rank(), - world_size=parallel_state.get_data_parallel_world_size(), - ) - else: - return get_audio_text_dataset_from_config( - manifest_filepath=data_cfg.manifest_filepath, - config=data_cfg, - tokenizer=self.tokenizer, - augmentor=augmentor, - is_train=is_train, - sep_id=self.sep_id, - answer_only_loss=self.cfg.get('answer_only_loss', True), - virtual_tokens=self.virtual_tokens, - ) - - def build_data_loader(self, dataset, data_cfg, consumed_samples=0, is_predict=False): - """Buld dataloader given an input dataset.""" - logging.info(f'Building dataloader with consumed samples: {consumed_samples}') - if isinstance(dataset, BlendableDataset): - collate_fn = dataset.datasets[0].collate_fn - elif hasattr(dataset, 'collate_fn'): - collate_fn = dataset.collate_fn - elif hasattr(dataset.datasets[0], 'collate_fn'): - # support datasets that are lists of entries - collate_fn = dataset.datasets[0].collate_fn - else: - # support datasets that are lists of lists - collate_fn = dataset.datasets[0].datasets[0].collate_fn - - if isinstance(dataset, torch.utils.data.IterableDataset): - data_parallel_size = parallel_state.get_data_parallel_world_size() - num_micro_batches = data_cfg.global_batch_size // (data_cfg.micro_batch_size * data_parallel_size) - global_batch_size_on_this_data_parallel_rank = num_micro_batches * data_cfg.micro_batch_size - - dataloader = torch.utils.data.DataLoader( - dataset, - collate_fn=collate_fn, - shuffle=False, - batch_size=global_batch_size_on_this_data_parallel_rank, - drop_last=True, - num_workers=data_cfg.num_workers, - pin_memory=data_cfg.pin_memory, - ) - return dataloader - - if is_predict: - # MegatronPretrainingBatchSampler doesn't work with trainer.predict() - dataloader = torch.utils.data.DataLoader( - dataset, - collate_fn=collate_fn, - batch_size=data_cfg.micro_batch_size, - num_workers=data_cfg.num_workers, - pin_memory=data_cfg.pin_memory, - ) - return dataloader - - batch_sampler = MegatronPretrainingBatchSampler( - total_samples=len(dataset), - consumed_samples=consumed_samples, - micro_batch_size=data_cfg.micro_batch_size, - global_batch_size=data_cfg.global_batch_size, - data_parallel_rank=parallel_state.get_data_parallel_rank(), - data_parallel_size=parallel_state.get_data_parallel_world_size(), - drop_last=data_cfg.drop_last, - pad_samples_to_global_batch_size=not data_cfg.drop_last, - ) - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn, - num_workers=data_cfg.num_workers, - pin_memory=data_cfg.pin_memory, - persistent_workers=True if data_cfg.num_workers > 0 else False, - ) - return dataloader + def build_data_loader(self, dataset, data_cfg, consumed_samples=0, is_predict=False, is_eval=False): + return build_speechllm_dataloader(dataset, data_cfg, consumed_samples, is_predict=is_predict, is_eval=is_eval) @classmethod def _modify_audio_encoder_config(cls, gpt_cfg, audio_cfg, speaker_cfg=None): @@ -789,6 +706,7 @@ def get_audio_encoder_models_and_configs(cls, cfg): def load_pretrained_audio_weights( cls, cfg, model, audio_model, speaker_model: Optional[EncDecSpeakerLabelModel] = None ): + model.perception.tokenizer = audio_model.tokenizer use_multi_encoder = cfg.model.perception.get("encoders", None) is not None if not use_multi_encoder: if cfg.model.perception.get("use_multi_layer_feat", False): @@ -932,7 +850,9 @@ def merge_inference_cfg( trainer=trainer, return_config=True, ) - + # overwrite pretrained_audio_model if there + if hasattr(cfg.model, "pretrained_audio_model"): + model_cfg.pretrained_audio_model = cfg.model.pretrained_audio_model if hasattr(model_cfg, 'peft') and model_cfg.peft.peft_scheme not in [None, 'none']: # before PEFT migrates to distributed ckpt, eval must use same TP/PP as training for p in ['tensor_model_parallel_size', 'pipeline_model_parallel_size']: @@ -966,11 +886,12 @@ def load_adapters_for_inference(cls, cfg: DictConfig, model_cfg: DictConfig, mod if cfg.model.peft.restore_from_path: if '\\' in cfg.model.peft.restore_from_path: cfg.model.peft.restore_from_path = cfg.model.peft.restore_from_path.replace('\\', '') - if "peft" in model_cfg: + if "peft" in model_cfg and 'peft_scheme' in model_cfg.peft: peft_cfg_cls = PEFT_CONFIG_MAP[model_cfg.peft.peft_scheme] model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg), map_location="cpu") else: - model.load_state_dict(torch.load(cfg.model.peft.restore_from_path), strict=False) + torch_state_dict = torch.load(cfg.model.peft.restore_from_path)['state_dict'] + model.load_state_dict(torch_state_dict, strict=False) elif cfg.model.peft.restore_from_ckpt.checkpoint_dir and cfg.model.peft.restore_from_ckpt.checkpoint_name: checkpoint_path = os.path.join( cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name @@ -1486,9 +1407,9 @@ def write_predictions_to_file(self, outputs, output_file_path_prefix, output_dir def setup_eval_dataloader(self, datasets, data_cfg): dataloaders = [] if not isinstance(datasets, list): - return self.build_data_loader(dataset=datasets, data_cfg=data_cfg, consumed_samples=0) + return self.build_data_loader(dataset=datasets, data_cfg=data_cfg, consumed_samples=0, is_eval=True) for dataset in datasets: - eval_dl = self.build_data_loader(dataset=dataset, data_cfg=data_cfg, consumed_samples=0) + eval_dl = self.build_data_loader(dataset=dataset, data_cfg=data_cfg, consumed_samples=0, is_eval=True) dataloaders.append(eval_dl) return dataloaders @@ -1517,8 +1438,6 @@ def maybe_build_test(self): logging.info('Building test datasets...') # Wrap this in a list since the general finetuning parent class supports multi-validation. self._test_ds = self._build_dataset(self.cfg.data.test_ds, is_train=False) - lengths = [len(x) for x in self._test_ds] - logging.info(f'Length of test datasets: {lengths}, total: {sum(lengths)}') return def maybe_setup_test(self): @@ -1532,8 +1451,6 @@ def build_train_valid_test_datasets(self, stage): logging.info('Building validation datasets.') # Wrap this in a list since the general finetuning parent class supports multi-validation. self._validation_ds = self._build_dataset(self.cfg.data.validation_ds, is_train=False) - lengths = [len(x) for x in self._validation_ds] - logging.info(f'Length of validation datasets: {lengths}, total: {sum(lengths)}') if stage != 'validate': self.maybe_build_test() @@ -1542,7 +1459,6 @@ def build_train_valid_test_datasets(self, stage): return logging.info('Building training datasets.') self._train_ds = self._build_dataset(self.cfg.data.train_ds) - logging.info(f'Length training datasets: {len(self._train_ds)}') @classmethod def list_available_models(cls) -> Optional[PretrainedModelInfo]: @@ -1561,3 +1477,76 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]: ) results.append(model) return results + + +class CrossAttendModularAudioGPTModel(ModularAudioGPTModel): + """Modularized speech GPT model.""" + + def prepare_llm_input(self, audio_batch): + + input_signal = audio_batch['audio_signal'] + input_signal_length = audio_batch['audio_signal_length'] + + input_ids, input_length, labels, loss_mask = ( + audio_batch['tokens'], + audio_batch['tokens_length'], + audio_batch['labels'], + audio_batch['loss_mask'], + ) + + num_audios = audio_batch.get("num_audios", None) + if num_audios is not None: + raise ValueError("num_audios is not supported.") + + if self.cfg.get('megatron_amp_O2', False): + base_module = self.model.module + else: + base_module = self.model + lm_embedding = ( + base_module.language_model.embedding if hasattr(base_module, 'language_model') else base_module.embedding + ) + # [b, t, c] + encoded, encoded_len = self.perception( + input_signal=input_signal, + input_signal_length=input_signal_length, + processed_signal=None, + processed_signal_length=None, + ) + input_embeds = self._get_text_embeddings(input_ids, None).transpose(0, 1) + encoder_input, extra_outputs = self.perception_cross_attn( + encoded, encoded_len, input_embeds, input_lengths=input_length, return_mems=True + ) + # TODO: need separate speech and text methods for inference + if 'audio_ratio' in audio_batch: + audio_ratio = audio_batch['audio_ratio'][..., None, None] + encoder_input = encoder_input * audio_ratio + input_embeds * (1 - audio_ratio) + if 'alpha_xattn' in extra_outputs: + alpha_xattn = extra_outputs['alpha_xattn'] + self.log( + 'alpha_xattn', + alpha_xattn.mean(), + prog_bar=True, + batch_size=1, + rank_zero_only=True, + ) + attention_mask = self._create_attention_mask(encoder_input) + + if not hasattr(lm_embedding, 'transpose_batch_sequence') or lm_embedding.transpose_batch_sequence: + encoder_input = encoder_input.transpose(0, 1).contiguous() + if self.cfg.get("sequence_parallel", False): + encoder_input = tensor_parallel.mappings.scatter_to_sequence_parallel_region(encoder_input) + return encoder_input, attention_mask, labels, loss_mask, (encoded, encoded_len, extra_outputs) + + def setup_perception_modules(self, cfg): + super().setup_perception_modules(cfg) + imported_cls = model_utils.import_class_by_path(cfg.perception.xattn.target) + self.perception_cross_attn = imported_cls(cfg=cfg.perception) + + def state_dict(self, destination=None, prefix=None, keep_vars=False): + if self.setup_complete: + return_state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + state_dict = self.perception_cross_attn.state_dict(prefix="perception_cross_attn.") + return_state_dict.update(state_dict) + return return_state_dict + else: + return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py new file mode 100644 index 0000000000000..a96ee823e1977 --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -0,0 +1,1367 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import json +import os +from functools import partial +from typing import Any, Optional, Union + +import sacrebleu +import torch +from omegaconf import ListConfig +from omegaconf.dictconfig import DictConfig +from omegaconf.omegaconf import OmegaConf, open_dict +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.asr.models import ASRModel, SpeechEncDecSelfSupervisedModel +from nemo.collections.asr.parts.mixins.transcription import move_to_device +from nemo.collections.common.metrics import MetricStringToTorchMetric, TextMetricsSet +from nemo.collections.multimodal.speech_llm.data.build_dataset import ( + build_speechllm_dataloader, + build_speechllm_dataset, +) +from nemo.collections.multimodal.speech_llm.modules.perception_modules import ( + AudioPerceptionModule, + MultiAudioPerceptionModule, +) +from nemo.collections.nlp.models.language_modeling.megatron_t5_adapter_model import MegatronT5LoraModel +from nemo.collections.nlp.models.language_modeling.megatron_t5_sft_model import MegatronT5SFTModel +from nemo.collections.nlp.models.nlp_model import NLPModel +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + build_position_ids, + get_iterator_k_split, +) +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.core.classes.mixins import adapter_mixins +from nemo.utils import AppState, logging, model_utils + +try: + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator, + get_current_global_batch_size, + get_micro_batch_size, + get_num_microbatches, + ) + + HAVE_APEX = True +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model + +try: + from megatron.core import parallel_state, tensor_parallel + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + HAVE_MEGATRON_CORE = False + + +__all__ = ["ModularizedAudioT5Model"] + + +default_inference_config = {'tokens_to_generate': 30} + + +class ModularizedAudioT5Model(MegatronT5LoraModel): + """Modularized speech GPT model.""" + + def setup_perception_modules(self, cfg): + if 'target' in cfg.perception: + imported_cls = model_utils.import_class_by_path(cfg.perception.target) + self.perception = imported_cls(cfg=cfg.perception) + else: + self.perception = ( + AudioPerceptionModule(cfg=cfg.perception) + if "encoders" not in cfg.perception + else MultiAudioPerceptionModule(cfg=cfg.perception) + ) + + def __init__(self, cfg: DictConfig, trainer: Trainer): + self.cfg = cfg + super().__init__(cfg, trainer) + self.val_metric, self.val_metric_name = self.setup_metric(self.cfg.data.validation_ds) + self.val_metric = torch.nn.ModuleList(self.val_metric) + if hasattr(self.cfg.data, "test_ds"): + self.test_metric, self.test_metric_name = self.setup_metric(self.cfg.data.test_ds) + self.test_metric = torch.nn.ModuleList(self.test_metric) + # Used other keys from metadata to calulate metrics + if hasattr(self.cfg.data, "test_ds") and hasattr(self.cfg.data.test_ds, "metric"): + self.test_metric_label_key = self.cfg.data.test_ds.metric.get('label_key', 'labels') + if hasattr(self.cfg.data, "validation_ds") and hasattr(self.cfg.data.validation_ds, "metric"): + self.val_metric_label_key = self.cfg.data.validation_ds.metric.get('label_key', 'labels') + self.setup_perception_modules(cfg) + self.setup_optimizer_param_groups() + # self.configure_optimizers() + self.summarize(max_depth=3) + # follow gpt + self.setup_complete = False + self.sep_id = cfg.get('sep_id', self.tokenizer.bos_id) + self.virtual_tokens = 0 + self.model = self.frozen_model.enc_dec_model + + def load_frozen_model(self, cfg, trainer): + self.megatron_amp_O2 = cfg.get('megatron_amp_O2', False) + t5_cfg_base = MegatronT5Model.restore_from(cfg.get('language_model_path'), trainer=trainer, return_config=True) + # use the incoming cfg updated by _modify_config + t5_cfg = copy.deepcopy(cfg) + t5_cfg.target = t5_cfg_base.target + self.frozen_model = MegatronT5Model.restore_from( + cfg.get('language_model_path'), + trainer=trainer, + override_config_path=t5_cfg, + save_restore_connector=NLPSaveRestoreConnector(), + ) + logging.info(f"self.frozen_model.cfg: {self.frozen_model.cfg}") + + def init_model(self, cfg: DictConfig, trainer: Trainer): + self.cfg = cfg + + self.load_frozen_model(cfg, trainer) + self.prompt_encoder = None + if self.frozen_model.tokenizer is not None: + self.tokenizer = self.frozen_model.tokenizer + + if hasattr(self.frozen_model.cfg, "encoder") and hasattr(self.frozen_model.cfg, "decoder"): + self.hidden_size = ( + self.frozen_model.cfg.encoder.hidden_size + ) # Encoder and decoder need to have the same hidden size and we check for this in the frozen enc-dec model. + else: + self.hidden_size = self.frozen_model.cfg.hidden_size + + # Handle this when moving GPT prompt learning to the base class. + self.word_embeddings = self.frozen_model.enc_dec_model.encoder_embedding.word_embeddings + + self._reduced_loss_buffer = [] + self._inference_config = None + + self.tokenizer.legacy = cfg.get('legacy_tokenizer', False) + self.bos_id = self.tokenizer.bos_id + self.decoder_seq_length = cfg.get('decoder_seq_length', 40) + + # make sure the default pytorch lightning gradient clipping in the basemodel + self.grad_clip_pl_default = False # make distributed_fused_adam happy + self.lowest_val_loss = None + self.prompt_encoder = None + + self.enable_autocast = ( + True if (not self.megatron_amp_O2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False + ) + + def parameters(self): + # override the same method in MegatronGPT model to include parameters ouside of LM + all_names = [] + all_params = [] + for name, param in self.named_parameters(recurse=True): + all_names.append(name) + all_params.append(param) + + if isinstance(self.frozen_model, list): + for module in self.frozen_model: + for name, param in module.named_parameters(recurse=True): + all_names.append(name) + all_params.append(param) + + return itertools.chain(all_params) + + def setup_optimizer_param_groups(self): + """ + ModelPT override. Optimizer will get self._optimizer_param_groups. + Makes two optimizer param groups, one for the frozen model params + and one for the prompt-table/prompt-encoder params. The learning + rate for the frozen model's params will always be zero effectively + freezing the model's params but still allowing for the needed gradients + to be passed around in pipeline parallel models. The prompt-encoder + and/or prompt table will use the learning rate set by the user. + """ + self.unfreeze() + known_groups = [] + if self.cfg.get('freeze_llm', True): + for param in self.frozen_model.parameters(): + param.requires_grad = False + known_groups.append('model.') + else: + if self.cfg.get('freeze_encoder', False): + for param in self.frozen_model.enc_dec_model.enc_dec_model.encoder.parameters(): + param.requires_grad = False + known_groups.append('enc_dec_model.encoder.') + if self.cfg.get('freeze_decoder', False): + for param in self.frozen_model.enc_dec_model.enc_dec_model.decoder.parameters(): + param.requires_grad = False + known_groups.append('enc_dec_model.decoder.') + if self.cfg.get('freeze_word_emb', False): + names = [ + 'encoder_embedding', + 'encoder_relative_position_embedding', + 'decoder_relative_position_embedding', + 'decoder_embedding', + ] + for pname in names: + for param in getattr(self.frozen_model.enc_dec_model, pname).parameters(): + param.requires_grad = False + known_groups.append('enc_dec_model.word_embeddings.') + known_groups.append('enc_dec_model.relative_position_embedding.') + if self.cfg.get('freeze_modality_adapter', False): + self.perception.modality_adapter.freeze() + known_groups.append('modality_adapter.') + if self.cfg.get('freeze_audio_encoder', False): + self.perception.encoder.freeze() + known_groups.append('audio_encoder.') + + opt_params = [] + for _, module in self.named_modules(): + if isinstance(module, adapter_mixins.AdapterModuleMixin) and module.is_adapter_available(): + module.set_enabled_adapters(enabled=True) + module.unfreeze_enabled_adapters() # selectively unfreeze the adapter modules. + opt_params += [p for p in module.parameters()] + + param_groups = [] + if "optim_param_groups" in self.cfg: + param_groups_cfg = self.cfg.optim_param_groups + for group, group_cfg in param_groups_cfg.items(): + module = getattr(self, group, None) + if module is None: + raise ValueError(f"{group} not found in model.") + elif hasattr(module, "parameters"): + known_groups.append(f"{group}.") + new_group = {"params": module.parameters()} + for k, v in group_cfg.items(): + new_group[k] = v + param_groups.append(new_group) + else: + raise ValueError(f"{group} does not have parameters.") + + for n, p in self.named_parameters(): + is_unknown = True + for group in known_groups: + if n.startswith(group): + is_unknown = False + if is_unknown: + opt_params.append(p) + + param_groups = [{"params": opt_params}] + param_groups + + self._optimizer_param_groups = param_groups + logging.info(f"Optimizer groups set:\n{self.summarize(max_depth=2)}") + + def inject_perception_input(self, encoded, encoded_len, input_ids, input_length): + def _concat_embs(embs1, emb1_lens, embs2, emb2_lens): + concat_emb = [] + concat_len = [] + for emb1, emb1_len, emb2, emb2_len in zip(embs1, emb1_lens, embs2, emb2_lens): + if self.cfg.get('ignore_dummy_audio', False) and emb1_len <= 1: # TODO: ignore the dummy audio emb + new_len = emb2_len + new_emb = emb2[:emb2_len] + else: + new_len = emb1_len + emb2_len + new_emb = torch.concat([emb1[:emb1_len], emb2[:emb2_len]], axis=0) + padded_new_emb = torch.zeros(emb1.shape[0] + emb2.shape[0], emb1.shape[-1], device=emb1.device) + padded_new_emb[:new_len, ...] = new_emb + concat_emb.append(padded_new_emb) + concat_len.append(new_len) + concat_emb = torch.stack(concat_emb, dim=0) + concat_len = torch.stack(concat_len, dim=0) + return concat_emb, concat_len + + # [b, t, c] + lm_embedding = self.frozen_model.enc_dec_model.encoder_embedding + input_embeds = lm_embedding.word_embeddings(input_ids) + if self.cfg.audio_prompt_first: + encoder_input, encoder_length = _concat_embs(encoded, encoded_len, input_embeds, input_length) + else: # more streaming friendly + encoder_input, encoder_length = _concat_embs(input_embeds, input_length, encoded, encoded_len) + + b = encoder_input.shape[0] + max_len = encoder_input.shape[1] + + # Using causal attention mask for whole input + # TODO(zhehuai): use prefixlm instead for the audio embeddings + attention_mask = torch.tril(torch.ones((b, max_len, max_len), device=encoder_input.device)).view( + b, 1, max_len, max_len + ) + # Convert attention mask from float to bool + attention_mask = attention_mask < 0.5 + position_ids = build_position_ids(encoder_input[:, :, 0]) + + # Add position embeddings + if hasattr(lm_embedding, "position_embeddings"): + position_embeddings = lm_embedding.position_embeddings(position_ids) + encoder_input = encoder_input + position_embeddings + else: + pass + encoder_max_length = encoder_input.shape[1] + if lm_embedding.transpose_batch_sequence: + encoder_input = encoder_input.contiguous() + if self.cfg.get("sequence_parallel", False): + encoder_input = tensor_parallel.mappings.scatter_to_sequence_parallel_region(encoder_input) + return encoder_input, attention_mask, encoder_length, position_ids, encoder_max_length + + def _shift_labels_by_emb_len(self, labels, label_lens, emb_lens, max_len, pad_token=0): + shifted_labels = [] + for label, label_len, emb_len in zip(labels, label_lens, emb_lens): + shifted_label = torch.full([max_len], pad_token, device=label.device) + shifted_label[emb_len : emb_len + label_len] = label[:label_len] + shifted_labels.append(shifted_label) + shifted_labels = torch.stack(shifted_labels, dim=0) + return shifted_labels + + def _get_text_embeddings(self, text_tokens, position_ids): + lm_embedding = self.frozen_model.enc_dec_model.encoder_embedding + text_embeddings = lm_embedding.word_embeddings(text_tokens) # (batch_size, seq_len, hidden_size) + if hasattr(lm_embedding, 'position_embeddings'): + position_embeddings = lm_embedding.position_embeddings(position_ids) + text_embeddings = text_embeddings + position_embeddings + return text_embeddings + + def prepare_llm_input(self, audio_batch): + + input_signal = audio_batch['audio_signal'] + input_signal_length = audio_batch['audio_signal_length'] + + input_ids, input_length, labels, loss_mask = ( + audio_batch['contexts'], + audio_batch['context_lengths'], + audio_batch['labels'], + audio_batch['loss_mask'], + ) + + # [b, t, c] + encoded, encoded_len = self.perception( + input_signal=input_signal, + input_signal_length=input_signal_length, + processed_signal=None, + processed_signal_length=None, + ) + encoder_input, attention_mask, encoder_length, _, encoder_max_length = self.inject_perception_input( + encoded, encoded_len, input_ids, input_length + ) + # generate encoder_mask from encoder_length + enc_mask = torch.arange(encoder_input.shape[1], device=encoder_input.device)[None, :] < encoder_length[:, None] + return encoder_input, attention_mask, enc_mask + + def forward( + self, + audio_batch, + checkpoint_activations_all_layers, + ): + """Forward pass of the model. + + We prepend audio embeddings to the instruction and label text tokens + as the LLM input. + """ + if 'audio_ratio' in audio_batch: + self.log( + 'audio_ratio', audio_batch['audio_ratio'].mean(), prog_bar=True, batch_size=1, rank_zero_only=False + ) + self.log( + 'local_batch_size', + audio_batch['audio_ratio'].shape[0], + prog_bar=True, + batch_size=1, + rank_zero_only=False, + ) + + encoder_input, attention_mask, enc_mask = self.prepare_llm_input(audio_batch) + # enc_input = speech and text prompt + # dec_input and label = text output label + b = audio_batch['answers'].shape[0] + device = audio_batch['answers'].device + dec_input = audio_batch['masked_answer_ids'] if 'masked_answer_ids' in audio_batch else audio_batch['answers'] + dec_input = torch.cat([torch.full([b, 1], self.bos_id, device=device), dec_input[:, :-1]], dim=-1) + labels = audio_batch['answers'] + dec_mask = (dec_input != self.tokenizer.pad_id).long().contiguous() + output = self.frozen_model.enc_dec_model( + enc_input_ids=None, + enc_attn_mask=enc_mask, + dec_input_ids=dec_input, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=encoder_input, + ) + loss_mask = dec_mask + return output, loss_mask + + def get_forward_output_only_func(self): + def fwd_output_only_func(dataloader_iter, model): + batch = next(dataloader_iter) + extra_arg = {} + # take the batch produced by prepare_batch_at_step + ( + _, + input_embeddings, + attention_mask, + _, + set_inference_key_value_memory, + inference_max_sequence_len, + ) = batch + if attention_mask is not None: + attention_mask = attention_mask.cuda() + attention_mask = attention_mask[0:1] + extra_arg['set_inference_key_value_memory'] = set_inference_key_value_memory[0].item() + extra_arg['inference_max_sequence_len'] = inference_max_sequence_len[0].item() + output_tensor = model( + input_ids=None, + position_ids=None, + encoder_input=input_embeddings, + attention_mask=attention_mask, + **extra_arg, + ) + + if isinstance(output_tensor, tuple): + output_tensor = output_tensor[1] # get logits only + + def id_func(output_tensor): + return output_tensor, {'logits': output_tensor} + + return output_tensor, id_func + + return fwd_output_only_func + + def get_forward_output_and_loss_func(self, validation_step=False): + def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): + batch = next(dataloader_iter) + batch = {key: val.cuda(non_blocking=True) for key, val in batch.items()} + output_tensor, loss_mask = self.forward( + batch, checkpoint_activations_all_layers=checkpoint_activations_all_layers + ) + + def loss_func(output_tensor): + # Loss for a micro-batch (ub) + if 'audio_ratio' in batch: + text_loss_weight = self.cfg.get('text_loss_weight', 1.0) + audio_ratio = batch['audio_ratio'] + scaled_loss_mask = loss_mask * torch.unsqueeze( + (1 * audio_ratio + text_loss_weight * (1 - audio_ratio)), 1 + ) + loss_for_ub = self.loss_func(scaled_loss_mask, output_tensor) + else: + loss_for_ub = self.loss_func(loss_mask, output_tensor) + if validation_step and not self.cfg.data.get('validation_drop_last', True): + num_valid_tokens_in_ub = batch['loss_mask'].sum() + if loss_for_ub.isnan(): + assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input' + loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub) + else: + loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub + + loss_sum_and_ub_size_all_gpu = torch.cat( + [ + loss_sum_for_ub.clone().detach().view(1), + torch.tensor([num_valid_tokens_in_ub]).cuda().clone().detach(), + ] + ) + # Could potentially reduce num_valid_samples_in_microbatch and use that to aggregate instead of len(self._validation_ds) + torch.distributed.all_reduce( + loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group() + ) + return loss_for_ub, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} + else: + reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) + return loss_for_ub, {'avg': reduced_loss} + + return output_tensor, loss_func + + return fwd_output_and_loss_func + + def _build_dataset(self, data_cfg, is_train=True): + return build_speechllm_dataset(self, data_cfg, is_train) + + def build_data_loader(self, dataset, data_cfg, consumed_samples=0, is_eval=False): + return build_speechllm_dataloader(dataset, data_cfg, consumed_samples, is_eval=is_eval) + + @classmethod + def _modify_config(cls, gpt_cfg, cfg, audio_cfg, add_cfg_to_tree=False): + """ + This function modifies the original gpt pre-training config (gpt_cfg) with attributes from the finetuning config (cfg). + The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. + """ + OmegaConf.set_struct(gpt_cfg, True) + OmegaConf.resolve(cfg) + with open_dict(gpt_cfg): + if 'vocab_file' in cfg.model: + gpt_cfg.tokenizer.vocab_file = cfg.model.vocab_file + gpt_cfg.legacy_tokenizer = cfg.model.get('legacy_tokenizer', False) + gpt_cfg.audio_prompt_first = cfg.model.get('audio_prompt_first', True) + gpt_cfg.ignore_dummy_audio = cfg.model.get('ignore_dummy_audio', False) + gpt_cfg.freeze_llm = cfg.model.get('freeze_llm', True) + gpt_cfg.freeze_word_emb = cfg.model.get('freeze_word_emb', False) + gpt_cfg.freeze_encoder = cfg.model.get('freeze_encoder', False) + gpt_cfg.freeze_decoder = cfg.model.get('freeze_decoder', False) + gpt_cfg.text_loss_weight = cfg.model.get('text_loss_weight', 1.0) + gpt_cfg.freeze_audio_encoder = cfg.model.get('freeze_audio_encoder', False) + gpt_cfg.freeze_modality_adapter = cfg.model.get('freeze_modality_adapter', False) + gpt_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) + gpt_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size + gpt_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size + gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False) + gpt_cfg.tensor_model_parallel_size = cfg.model.get( + "tensor_model_parallel_size", + gpt_cfg.tensor_model_parallel_size if hasattr(gpt_cfg, "tensor_model_parallel_size") else 1, + ) + gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None) + gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None) + gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None) + gpt_cfg.data = cfg.model.data + gpt_cfg.optim = cfg.model.optim + gpt_cfg.precision = cfg.trainer.precision + gpt_cfg.answer_only_loss = cfg.model.answer_only_loss + gpt_cfg.language_model_path = cfg.model.language_model_path + gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint + gpt_cfg.save_nemo_on_validation_end = cfg.model.save_nemo_on_validation_end + gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view + # set dropout + hidden_dropout = cfg.model.get('hidden_dropout', 0.0) + attention_dropout = cfg.model.get('attention_dropout', 0.0) + ffn_dropout = cfg.model.get('ffn_dropout', 0.0) + gpt_cfg.encoder.hidden_dropout = hidden_dropout + gpt_cfg.decoder.hidden_dropout = hidden_dropout + gpt_cfg.encoder.attention_dropout = attention_dropout + gpt_cfg.decoder.attention_dropout = attention_dropout + gpt_cfg.encoder.ffn_dropout = ffn_dropout + gpt_cfg.decoder.ffn_dropout = ffn_dropout + if hasattr(gpt_cfg, 'embedding_dropout'): + gpt_cfg.embedding_dropout = hidden_dropout + # set label_smoothing + if hasattr(gpt_cfg, 'label_smoothing'): + gpt_cfg.label_smoothing = cfg.model.get('label_smoothing', gpt_cfg.label_smoothing) + gpt_cfg.virtual_prompt_style = cfg.model.virtual_prompt_style + gpt_cfg.lora_tuning = cfg.model.lora_tuning + # for AudioGPTLoRAModel + gpt_cfg.target = f"{cls.__module__}.{cls.__name__}" + gpt_cfg.perception = cfg.model.perception + gpt_cfg.pretrained_audio_model = cfg.model.get('pretrained_audio_model', None) + gpt_cfg.perception.preprocessor = audio_cfg.preprocessor + gpt_cfg.perception.encoder = audio_cfg.encoder + modality_adapter_cfg = gpt_cfg.perception.modality_adapter + modality_adapter_cfg.feat_in = audio_cfg.encoder.d_model + gpt_cfg.perception.output_dim = gpt_cfg.encoder.hidden_size + override_vocab_size = cfg.model.get('override_vocab_size', None) + if override_vocab_size is not None: + gpt_cfg.override_vocab_size = override_vocab_size + if not hasattr(gpt_cfg, 'tokenizer'): + gpt_cfg.tokenizer = gpt_cfg.decoder_tokenizer + # This is needed when modifying a hparam file directly to load `.ckpt` files. + # This is not needed to modify the cfg in `.nemo` files. + if add_cfg_to_tree: + OmegaConf.resolve(gpt_cfg) + gpt_cfg.cfg = gpt_cfg + + return gpt_cfg + + @classmethod + def load_audio_model(cls, pretrained_audio_model): + try: + if pretrained_audio_model.endswith('.nemo'): + logging.info(f'Loading pretrained audio model from local file: {pretrained_audio_model}') + audio_model = ASRModel.restore_from(pretrained_audio_model, map_location='cpu') + else: + logging.info(f'Loading pretrained audio model from NGC: {pretrained_audio_model}') + audio_model = ASRModel.from_pretrained(pretrained_audio_model, map_location='cpu') + except: + logging.info(f'Fail in loading it with ASRModel. Try again with SpeechEncDecSelfSupervisedModel.') + if pretrained_audio_model.endswith('.nemo'): + logging.info(f'Loading pretrained audio model from local file: {pretrained_audio_model}') + audio_model = SpeechEncDecSelfSupervisedModel.restore_from(pretrained_audio_model, map_location='cpu') + else: + logging.info(f'Loading pretrained audio model from NGC: {pretrained_audio_model}') + audio_model = SpeechEncDecSelfSupervisedModel.from_pretrained( + pretrained_audio_model, map_location='cpu' + ) + return audio_model + + @classmethod + def restore_from_pretrained_models( + cls, + cfg: Optional[Union[OmegaConf, str]] = None, + trainer: Optional[Trainer] = None, + ): + if not cfg.model.pretrained_audio_model: + raise RuntimeError("PEFT training needs a pretrained audio model present.") + + if not cfg.model.language_model_path: + raise RuntimeError("PEFT training needs a trained base model present.") + + base_model_save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.model.language_model_path): + base_model_save_restore_connector.model_extracted_dir = cfg.model.language_model_path + base_model_cfg = cls.restore_from( + restore_path=cfg.model.language_model_path, + trainer=trainer, + return_config=True, + save_restore_connector=base_model_save_restore_connector, + ) + audio_model = cls.load_audio_model(cfg.model.pretrained_audio_model) + + model_cfg = cls._modify_config(base_model_cfg, cfg, audio_model.cfg, add_cfg_to_tree=False) + + # load llm + model = cls.restore_from( + restore_path=cfg.model.language_model_path, + trainer=trainer, + override_config_path=model_cfg, + strict=False, + ) + # load am + model.perception.tokenizer = audio_model.tokenizer + if cfg.model.get('load_audio_encoder', True): + model.perception.encoder.load_state_dict( + audio_model.encoder.state_dict(), strict='adapter' not in cfg.model.perception + ) + logging.info(f'Loaded pretrained audio model from {cfg.model.pretrained_audio_model}') + else: + logging.info(f'Not load pretrained audio model from {cfg.model.pretrained_audio_model}') + if cfg.model.get('use_am_tokenizer', False): + model.tokenizer = audio_model.tokenizer + logging.info(f'Use AM tokenizer: {audio_model.tokenizer}') + if 'inference' in cfg: + inference_cfg = OmegaConf.to_container(cfg.inference, resolve=True) + model.set_inference_config(inference_cfg) + return model + + def _build_vocab(self): + """ + Manipulate vocabulary (e.g., pad vocabulary for increased performance)/ + """ + if self._cfg.get('override_vocab_size', None) is not None: + self.padded_vocab_size = self._cfg.override_vocab_size + else: + self.padded_vocab_size = self._vocab_size_with_padding( + orig_vocab_size=self.tokenizer.vocab_size, + make_vocab_size_divisible_by=self._cfg.get('make_vocab_size_divisible_by', 128), + tensor_model_parallel_size=self._cfg.get('tensor_model_parallel_size', 1), + ) + + def state_dict(self, destination=None, prefix=None, keep_vars=False): + if self.setup_complete: + # save adapter + return_state_dict = super().state_dict(destination, prefix, keep_vars) + # save perception + if not self.cfg.get('freeze_audio_encoder', False): + perception_state_dict = self.perception.state_dict(prefix="perception.") + return_state_dict.update(perception_state_dict) + # store llm if not freezing it + if not self.cfg.get('freeze_llm', True): + llm_state_dict = self.frozen_model.state_dict(prefix="frozen_model.") + return_state_dict.update(llm_state_dict) + else: + return_state_dict = self.frozen_model.state_dict(prefix="frozen_model.") + return return_state_dict + + def load_state_dict(self, state_dict, strict: bool = True): + """ + Loads a state_dict expecting the state_dict to contain key,values + only for the adapter parameters. + """ + if self.setup_complete: + # load adapters + super().load_state_dict(state_dict, strict) + # load perception + print(f"loading state_dict {self.setup_complete}: {state_dict.keys()}") + super(NLPModel, self).load_state_dict(state_dict, strict=False) + else: + if len([i for i in state_dict.keys() if 'lora' in i]) > 0: + # load adapters + super().load_state_dict(state_dict, strict) + # load frozen llm and maybe perception model + print(f"loading state_dict {self.setup_complete}: {state_dict.keys()}") + super(NLPModel, self).load_state_dict(state_dict, strict=False) + + def build_train_valid_test_datasets(self, stage): + if stage != 'test': + logging.info('Building GPT SFT validation datasets.') + # Wrap this in a list since the general finetuning parent class supports multi-validation. + self._validation_ds = self._build_dataset(self.cfg.data.validation_ds, is_train=False) + + if stage != 'validate': + if hasattr(self.cfg.data, 'test_ds'): + logging.info('Building GPT SFT test datasets.') + # Wrap this in a list since the general finetuning parent class supports multi-validation. + self._test_ds = self._build_dataset(self.cfg.data.test_ds, is_train=False) + + if stage == 'validate' or stage == 'test': + return + logging.info('Building GPT SFT traing datasets.') + self._train_ds = self._build_dataset(self.cfg.data.train_ds) + + def setup_training_data(self, training_data_config=None): + return + + def setup_validation_data(self, validation_data_config=None): + return + + def setup_test_data(self, test_data_config=None): + return + + def setup_training_dataloader(self): + if hasattr(self, '_train_ds'): + consumed_samples = self.compute_consumed_samples(0) + self._train_dl = self.build_data_loader( + dataset=self._train_ds, + data_cfg=self.cfg.data.train_ds, + consumed_samples=consumed_samples, + ) + + def setup(self, stage=None): + self.init_consumed_samples = 0 + + if stage == 'predict': + return + + # If the user wants to manually override train and validation dataloaders before calling `.fit()` + if self._train_dl is not None and self._validation_dl is not None: + return + self.build_train_valid_test_datasets(stage=stage) + if hasattr(self, '_train_ds'): + self.setup_training_dataloader() + if hasattr(self, '_validation_ds'): + self._validation_dl = self.setup_eval_dataloader(self._validation_ds, self.cfg.data.validation_ds) + if hasattr(self.cfg.data, 'test_ds'): + self._test_dl = self.setup_eval_dataloader(self._test_ds, self.cfg.data.test_ds) + + # when using pipeline model parallel the final stage need to initialize word embeddings + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + if isinstance(self.frozen_model, list): + for i, module in enumerate(self.frozen_model): + parallel_state.set_virtual_pipeline_model_parallel_rank(i) + module.sync_initial_word_embeddings() + parallel_state.set_virtual_pipeline_model_parallel_rank(0) + else: + self.frozen_model.sync_initial_word_embeddings() + + if self.cfg.get('transformer_engine', False): + self.setup_transformer_engine_tp_groups() + self.setup_complete = True + + @property + def _metrics_require_string2category_map(self): + return set(["f1", "accuracy", "average_precision"]) + + def setup_metric(self, data_cfg): + metric_name = "exact_string_match" + if not hasattr(data_cfg, "metric"): + metric = MetricStringToTorchMetric["exact_string_match"] + else: + if not hasattr(data_cfg.metric, "name"): + raise ValueError("Metric name is not provided in the metric config.") + if data_cfg.metric.name == "loss": + return None, "loss" + if data_cfg.metric.name not in MetricStringToTorchMetric: + raise KeyError( + f"{data_cfg.metric.name} is not supported. List of supported metrics: {MetricStringToTorchMetric.keys()}" + ) + if data_cfg.metric.name in self._metrics_require_string2category_map: + if data_cfg.metric.average is None: + raise ValueError( + f"{data_cfg.metric.name} requires specifying whether you want to compute a micro or macro average. Found None." + ) + if ( + data_cfg.metric.get('labels_are_strings', False) + and data_cfg.metric.name in self._metrics_require_string2category_map + ): + if data_cfg.metric.num_classes is None: + raise ValueError( + "Number of classes is not provided in the metric section within the data config. " + f"Please provide the number of classes in the data config to use the {data_cfg.metric.name} metric." + ) + if data_cfg.metric.get('class_labels', None) is None or not isinstance( + data_cfg.metric.get('class_labels', None), ListConfig + ): + raise ValueError( + "Class labels are not provided properly in the metric section witnin the data config. " + f"Please provide the class labels as a list of strings in the data config to use the {data_cfg.metric.name} metric." + ) + if len(data_cfg.metric.get('class_labels', None)) != data_cfg.metric.num_classes: + raise ValueError( + f"Number of class labels {len(data_cfg.metric.get('class_labels', None))} does not match `num_classes` : {data_cfg.metric.num_classes}" + ) + + metric_name = data_cfg.metric.name + metric_cls = MetricStringToTorchMetric[metric_name] + if metric_name not in TextMetricsSet: + metric = [metric_cls(**data_cfg.metric)] + else: + metric = [metric_cls()] + return metric, metric_name + + # Override the parent batch reconfiguring logic. + def _reconfigure_and_process_inference_batch(self, batch, data_cfg): + global_batch_size_per_gpu = batch['tokens'].size(0) + # This should happen only on the last batch of the dataset. + if ( + global_batch_size_per_gpu + != get_current_global_batch_size() // parallel_state.get_data_parallel_world_size() + ): + # NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches. + if ( + global_batch_size_per_gpu + != data_cfg.global_batch_size // parallel_state.get_data_parallel_world_size() + ): + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_size_per_gpu, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + # NOTE: need to explicitly handle resetting for multi-validation + else: + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=data_cfg.global_batch_size, + micro_batch_size=data_cfg.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + def validation_step(self, dataloader_iter, inference=False): + return self.inference_step(dataloader_iter, 'validation') + + def _validation_step_internal( + self, dataloader_iter, batch_idx, dataloader_idx=0, inference=False, result_mode='validation' + ): + """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + """ + mode = self.training + self.eval() + loss = self.fwd_bwd_step(dataloader_iter, 0, True) + self.train(mode=mode) + self.frozen_model.eval() + + if result_mode == 'validation': + if type(self._validation_dl) == list and len(self._validation_dl) > 1: + self.validation_step_outputs[dataloader_idx].append(loss) + else: + self.validation_step_outputs.append(loss) + else: + if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx].append(loss) + else: + self.test_step_outputs.append(loss) + return loss + + def inference_step(self, dataloader_iter, mode, dataloader_idx=0): + batch, batch_idx, dataloader_idx = next(dataloader_iter) + data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds + self._reconfigure_and_process_inference_batch(batch, data_cfg) + # Meta data from dataset + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + loss = self._validation_step_internal(itertools.chain([batch]), batch_idx, dataloader_idx, result_mode=mode) + + # We need _inference_config to get generation params + # add_BOS and tokens_to_generate are set in dataset + if self.get_inference_config() is None: + logging.warning(f'inference_config is not set. Use default: {default_inference_config}') + self.set_inference_config(inference_config=default_inference_config) + self._inference_config['add_BOS'] = data_cfg.add_bos + self._inference_config['tokens_to_generate'] = data_cfg.get('tokens_to_generate') + + output = self.predict_step(batch, batch_idx, dataloader_idx) + + inputs_text = [self.tokenizer.ids_to_text(c.tolist()) for c in batch['contexts']] + labels_text = [self.tokenizer.ids_to_text(a.tolist()) for a in batch['answers']] + preds_text = output['preds_text'] + if data_cfg.get("log_every_n_steps", None) is not None: + if batch_idx % data_cfg.log_every_n_steps == 0: + logging.info(f"Input: `{inputs_text[0]}`") + logging.info(f"Label: `{labels_text[0]}`") + logging.info(f"Pred: `{preds_text[0]}`") + + outputs = { + 'loss': loss, + 'preds': preds_text, # [str] + 'labels': labels_text, # [str] + 'inputs': inputs_text, # [str] + 'metadata': metadata, # [dict] + } + + if mode == 'validation': + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + # super().validation_step appends just loss to self.validation_step_outputs, replace the last appended loss with the outputs dict + self.validation_step_outputs[dataloader_idx][-1] = outputs + else: + # super().validation_step appends just loss to self.validation_step_outputs, replace the last appended loss with the outputs dict + self.validation_step_outputs[-1] = outputs + else: + if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx][-1] = outputs + else: + self.test_step_outputs[-1] = outputs + return outputs + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + + batch = move_to_device(batch, device=self.device) + encoder_input, attention_mask, enc_mask = self.prepare_llm_input(batch) + # enc_input = speech and text prompt + # dec_input and label = text output label + predicted_token_ids, log_probs = self.frozen_model.decode( + tokens_enc=None, + enc_mask=enc_mask, + num_tokens_to_generate=self._inference_config['tokens_to_generate'], + encoder_input=encoder_input, + tokenizer=self.tokenizer, + bos_id=self.bos_id, + ) + + # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. + input_text = batch['contexts'] + preds_text = MegatronT5SFTModel.ids_to_text(predicted_token_ids, self.tokenizer) + input_text = MegatronT5SFTModel.ids_to_text(input_text, self.tokenizer) + labels = batch['answers'] + + if labels is not None: + labels_text = MegatronT5SFTModel.ids_to_text(labels, self.tokenizer) + else: + labels_text = [None] * len(preds_text) + + return { + 'input_text': input_text, + 'preds_text': preds_text, + 'labels_text': labels_text, + } + + def on_test_epoch_end(self): + _ = self.inference_epoch_end(self.test_step_outputs, 'test', self.cfg.data.test_ds) + # Commenting as on_test_epoch_end was a no-op in PTL 1.9 + # return super().on_test_epoch_end() + + def on_validation_epoch_end(self): + _ = self.inference_epoch_end(self.validation_step_outputs, 'validation', self.cfg.data.validation_ds) + # Commenting as on_validation_epoch_end was a no-op in PTL 1.9 + # return super().on_validation_epoch_end() + + def inference_epoch_end(self, outputs, mode, data_cfg): + # Parent class will handle logging of the loss. + if not outputs: + # Handle case where no metrics. This can break checkpoint save/load. + app_state = AppState() + monitor_mode = app_state.checkpoint_callback_params.mode + assert monitor_mode in ['min', 'max'] + averaged_metric = 0.0 if monitor_mode == 'max' else 1e2 + logging.warning(f"No outputs to log for {mode} epoch") + return torch.Tensor([1e2]), torch.Tensor([averaged_metric]) + + if isinstance(outputs[0], dict): + outputs = [outputs] + + averaged_loss = [] + averaged_metric = [] + # Log metrics for each provided validation/test dataset. + for dataloader_idx, output in enumerate(outputs): + if len(output) == 0: + logging.warning(f"Empty output for dataloader_idx: {dataloader_idx}") + continue + # Expand on_validation_epoch_end from parent class MegatronGPTModel as on_validation_epoch_end doesnt take outputs arg + loss_vals = [x['loss'] for x in output] + if parallel_state.is_pipeline_last_stage(): + # only the last pipeline parallel stages return loss with their batch size + if self.cfg.data.get('validation_drop_last', True): + loss = torch.stack(loss_vals).mean() + else: + # Compute the avg loss by total_loss across all samples / total number of samples + total_loss_and_total_samples = torch.vstack(loss_vals).sum(axis=0) + avg_loss = total_loss_and_total_samples[0] / total_loss_and_total_samples[1] + loss = avg_loss.type(torch.float32).cuda() + else: + loss = torch.tensor(0.0, dtype=torch.float32).cuda() + + # we can only log on one rank if it is rank zero so we broadcast from last rank + torch.distributed.broadcast(loss, get_last_rank()) + + self.log('val_loss', loss, prog_bar=True, rank_zero_only=True, batch_size=1, sync_dist=True) + + # Determine the key used to log the loss based on the user provided name of the dataset or the dataloader index. + loss_log_key = self._determine_log_key(data_cfg, dataloader_idx, "loss", mode) + self.log(loss_log_key, loss, batch_size=1) + averaged_loss.append(loss) + + # Gather the outputs object from all data parallel ranks since we are using the DistributedSampler which splits data across DDP ranks. + gathered_outputs = [None for _ in range(parallel_state.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + gathered_outputs, + [ + {'preds': x['preds'], 'labels': x['labels'], 'inputs': x['inputs'], 'metadata': x['metadata']} + for x in output + ], + group=parallel_state.get_data_parallel_group(), + ) + + # Remove duplicate examples due to distributed sampler. + inp_label_set = set() + deduplicated_outputs = { + 'preds': [], + 'labels': [], + 'inputs': [], + 'metadata': [], + } + total_size = 0 + for rank in range(0, parallel_state.get_data_parallel_world_size()): + for batch in gathered_outputs[rank]: + for pred, label, input, metadata in zip( + batch['preds'], batch['labels'], batch['inputs'], batch['metadata'] + ): + key = input + label + total_size += 1 + dedup = data_cfg.get('deduplicate', True) + if (not dedup) or key not in inp_label_set: + inp_label_set.add(key) + deduplicated_outputs['preds'].append(pred) + deduplicated_outputs['labels'].append(label) + deduplicated_outputs['inputs'].append(input) + deduplicated_outputs['metadata'].append(metadata) + + # Compute metric score + metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name + metric_label_key = self.val_metric_label_key if mode == 'validation' else self.test_metric_label_key + if metric_name != 'loss': + metric_log_key = self._determine_log_key(data_cfg, dataloader_idx, metric_name, mode) + metric_fn = self.val_metric[0] if mode == 'validation' else self.test_metric[0] + if metric_label_key in deduplicated_outputs['metadata'][0]: + labels = [m[metric_label_key] for m in deduplicated_outputs['metadata']] + else: + labels = deduplicated_outputs['labels'] + + # sacrebleu.corpus_bleu is commonly used which does not share + # the same interface as other metrics. We handle it separately. + if metric_name == 'bleu': + metric_result = torch.Tensor( + [sacrebleu.corpus_bleu(deduplicated_outputs['preds'], [labels]).score] + ).to(self.device) + else: + for pred, label in zip(deduplicated_outputs['preds'], labels): + _ = metric_fn(pred, label) + + metric_result = metric_fn.compute() + + if metric_name == 'rouge': + for k, v in metric_result.items(): + if 'fmeasure' in k: + self.log(metric_log_key + f'_{k}', v.item(), sync_dist=True) + logging.info(f"{mode} {metric_name} {k}: {v.item()}") + metric_result = metric_result['rouge1_fmeasure'] + else: + self.log(metric_log_key, metric_result.item(), sync_dist=True) + logging.info(f"{mode} {metric_name}: {metric_result.item()}") + + metric_fn.reset() + averaged_metric.append(metric_result) + + # Write predictions to file + if self.global_rank == 0 and data_cfg.get("write_predictions_to_file", False): + logging.info( + f"Total deduplicated inference data size: {total_size} to {len(deduplicated_outputs['inputs'])}" + ) + + # Check if the user provided a prefix path to the file(s) they want to write. + if not hasattr(data_cfg, "output_file_path_prefix") or data_cfg.output_file_path_prefix is None: + raise ValueError( + f"Cannot write predictions to file when output_file_path_prefix is not set or present in the yaml config file." + ) + filename_log_key = self._determine_log_key(data_cfg, dataloader_idx, None, mode) + output_dir = data_cfg.get("output_dir", "./") + self.write_predictions_to_file( + deduplicated_outputs, f"{data_cfg.output_file_path_prefix}_{filename_log_key}", output_dir + ) + + torch.distributed.barrier(group=parallel_state.get_data_parallel_group()) + outputs[dataloader_idx].clear() # free memory + + # Logging of the averaged metrics: + averaged_loss = sum(averaged_loss) / len(averaged_loss) + averaged_metric = sum(averaged_metric) / len(averaged_metric) if len(averaged_metric) > 0 else None + + # Handle case where metrics can be nan or inf. This can break checkpoint save/load. + if averaged_metric is not None and (torch.isinf(averaged_metric) or torch.isnan(averaged_metric)): + app_state = AppState() + monitor_mode = app_state.checkpoint_callback_params.mode + assert monitor_mode in ['min', 'max'] + averaged_metric = 0.0 if monitor_mode == 'max' else 1e5 + + if mode == 'validation': + self.log("validation_loss", averaged_loss, batch_size=1, sync_dist=True) + if averaged_metric is not None: + self.log(f"validation_{self.val_metric_name}", averaged_metric, sync_dist=True) + elif mode == 'test': + self.log("test_loss", averaged_loss, batch_size=1, sync_dist=True) + if averaged_metric is not None: + self.log(f"test_{self.test_metric_name}", averaged_metric, sync_dist=True) + + # Merge the functionality of previous on_inference_epoch_end() within inference_epoch_end() func here + app_state = AppState() + # TODO(zhehuai): add _restore_sequence_parallelism_args after sync to HEAD + if hasattr(self, "_train_ds"): + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.data.train_ds.global_batch_size, + micro_batch_size=self.cfg.data.train_ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + # When running `trainer.validate()`, the training dataset is not available. + else: + logging.warning('No training data found, reconfiguring microbatches based on validation batch sizes.') + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=data_cfg.global_batch_size, + micro_batch_size=data_cfg.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + return averaged_loss, averaged_metric + + # consistent with speech models + def write_predictions_to_file(self, outputs, output_file_path_prefix, output_dir): + os.makedirs(output_dir, exist_ok=True) + output_file_path = output_file_path_prefix + "_inputs_preds_labels.jsonl" + output_file_path = os.path.join(output_dir, output_file_path) + with open(output_file_path, "w") as f_json: + assert ( + len(outputs['inputs']) == len(outputs['preds']) == len(outputs['labels']) == len(outputs['metadata']) + ) + for i, p, l, m in zip(outputs['inputs'], outputs['preds'], outputs['labels'], outputs['metadata']): + json_string = {'input': i, 'pred_text': p, 'text': l} + for k, v in m.items(): + if k not in json_string: + json_string[k] = v + f_json.write(json.dumps(json_string) + '\n') + + logging.info(f'Predictions saved to {output_file_path}') + + def setup_eval_dataloader(self, datasets, data_cfg): + dataloaders = [] + if not isinstance(datasets, list): + return self.build_data_loader(dataset=datasets, data_cfg=data_cfg, consumed_samples=0, is_eval=True) + for dataset in datasets: + eval_dl = self.build_data_loader(dataset=dataset, data_cfg=data_cfg, consumed_samples=0, is_eval=True) + dataloaders.append(eval_dl) + return dataloaders + + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + batch = next(dataloader_iter) + # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() + batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} + _, seq_length = batch['tokens'].shape + # handle the case where the batch size from dynamic bucketting is not divisible in lhotse + data_iter = get_iterator_k_split(batch, get_num_microbatches(), enforce_divisible_batch=False) + + # handle asynchronous grad reduction + no_sync_func = None + grad_sync_func = None + param_sync_func = None + if not forward_only and self.with_distributed_adam: + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) + grad_sync_func = self.reduce_overlap_gradients + param_sync_func = self.sync_overlap_parameters + + self.model.config.no_sync_func = no_sync_func + self.model.config.grad_sync_func = grad_sync_func + self.model.config.param_sync_func = param_sync_func + + fwd_bwd_function = get_forward_backward_func() + + dec_seq_length = batch['answers'].shape[1] + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=data_iter, + model=[self.model], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=get_micro_batch_size(), + decoder_seq_length=dec_seq_length, + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # Get the total loss since micro batches sizes are not uniform + loss_sum_tensors_list = [ + loss_sum['loss_sum_and_ub_size'] + for loss_sum in losses_reduced_per_micro_batch + if loss_sum['loss_sum_and_ub_size'][1] > 0 + ] + loss_sum = ( + torch.vstack(loss_sum_tensors_list).sum(axis=0) + if len(loss_sum_tensors_list) > 0 + else torch.tensor([0.0, 0.0]).cuda() + ) + return loss_sum + else: + # we're not on the last pipeline stage so no losses + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0).cuda() + + return loss_mean + + def loss_func(self, loss_mask, output_tensor): + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll + return loss + + def _determine_log_key(self, data_config, dataloader_idx, metric_name, mode): + # Function that determines whether to log based on the user provided name of the dataset or the dataloader index. + base_key = f"{mode}_{metric_name}_" if metric_name is not None else f"{mode}_" + # If the user provided names for each validation/test dataset, use those. + if hasattr(data_config, "names") and data_config.names is not None: + # With only a single validation/test dataset, the name is not a list. + if not isinstance(data_config.names, ListConfig): + name = data_config.names + else: + name = data_config.names[dataloader_idx] + return base_key + name + else: + return base_key + f"dataloader{dataloader_idx}" + + def test_step(self, dataloader_iter, dataloader_idx=0): + return self.inference_step(dataloader_iter, 'test') + + def training_step(self, dataloader_iter): + batch, batch_idx, dataloader_idx = next(dataloader_iter) + return super().training_step(itertools.chain([batch]), batch_idx=batch_idx) + + def setup_mcore_distributed_parallel(self): + """Set up mcore distributed data parallel called by configure_ddp in nlp_overrides.""" + if self.with_distributed_adam and self.use_mcore_dist_optim: + raise ValueError("T5 does not support both distributed adam and mcore distributed data parallel.") + + +class DecoderTextPromptModularizedAudioT5Model(ModularizedAudioT5Model): + """Modularized speech GPT model.""" + + def prepare_llm_input(self, audio_batch): + + input_signal = audio_batch['audio_signal'] + input_signal_length = audio_batch['audio_signal_length'] + + # [b, t, c] + encoded, encoded_len = self.perception( + input_signal=input_signal, + input_signal_length=input_signal_length, + processed_signal=None, + processed_signal_length=None, + ) + encoder_input, attention_mask, encoder_length = encoded, None, encoded_len + # generate encoder_mask from encoder_length + enc_mask = torch.arange(encoder_input.shape[1], device=encoder_input.device)[None, :] < encoder_length[:, None] + return encoder_input, attention_mask, enc_mask + + def forward( + self, + audio_batch, + checkpoint_activations_all_layers, + ): + """Forward pass of the model. + + We prepend audio embeddings to the instruction and label text tokens + as the LLM input. + """ + if 'audio_ratio' in audio_batch: + self.log( + 'local_batch_size', + audio_batch['audio_ratio'].shape[0], + prog_bar=True, + batch_size=1, + rank_zero_only=False, + ) + + encoder_input, _, enc_mask = self.prepare_llm_input(audio_batch) + # enc_input = speech prompt + # dec_input and label = text prompt and text output label + dec_input = audio_batch['tokens'] + labels = audio_batch['labels'] + dec_mask = (dec_input != self.tokenizer.eos_id) * (dec_input != self.tokenizer.pad_id).long().contiguous() + output = self.frozen_model.enc_dec_model( + enc_input_ids=None, + enc_attn_mask=enc_mask, + dec_input_ids=dec_input, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=encoder_input, + ) + loss_mask = audio_batch['loss_mask'] + return output, loss_mask + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + + batch = move_to_device(batch, device=self.device) + encoder_input, _, enc_mask = self.prepare_llm_input(batch) + # enc_input = speech prompt + # dec_input and label = text prompt and text output label + + predicted_token_ids, log_probs = self.frozen_model.decode( + tokens_enc=None, + enc_mask=enc_mask, + num_tokens_to_generate=self._inference_config['tokens_to_generate'], + encoder_input=encoder_input, + tokenizer=self.tokenizer, + bos_id=self.bos_id, + predicted_tokens_dec=torch.cat( + [ + batch['contexts'], + torch.full_like(batch['contexts'][:, :1], self.sep_id, device=batch['contexts'].device), + ], + dim=1, + ), + ) + predicted_token_ids = predicted_token_ids[:, batch['contexts'].shape[1] + 1 :] + + # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. + input_text = batch['contexts'] + preds_text = MegatronT5SFTModel.ids_to_text(predicted_token_ids, self.tokenizer) + input_text = MegatronT5SFTModel.ids_to_text(input_text, self.tokenizer) + labels = batch['answers'] + + if labels is not None: + labels_text = MegatronT5SFTModel.ids_to_text(labels, self.tokenizer) + else: + labels_text = [None] * len(preds_text) + + return { + 'input_text': input_text, + 'preds_text': preds_text, + 'labels_text': labels_text, + } + + def _build_dataset(self, data_cfg, is_train=True): + # this is crucial so as to tell the decoder when to start generate answer after context and paddings + assert data_cfg.add_sep == True + return super()._build_dataset(data_cfg, is_train) diff --git a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py index 0cd48502bb84f..763e03b699cd2 100644 --- a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py +++ b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_strategy.py @@ -18,7 +18,7 @@ import nemo.collections.nlp.modules.common.text_generation_strategy as text_generation_strategy from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import shift_tokens_by_multi_audios - +from nemo.collections.nlp.modules.common.megatron.utils import build_position_ids # the text representation of eos_id, it applies for all tokenizers END_OF_SEQ = '<|endoftext|>' @@ -166,10 +166,121 @@ def end_of_generation_condition( return torch.tensor(conditions, dtype=torch.bool, device=tokens.device) +class CrossAttendAudioToTextGenerationStrategy(AudioToTextGenerationStrategy): + def init_batch( + self, + context_tokens: torch.Tensor, + context_lengths: torch.Tensor, + audio_signal: torch.Tensor, + audio_length: torch.Tensor, + compute_attention_mask: bool, + num_audios: Optional[torch.Tensor] = None, + context_start_idx: Optional[List[List[int]]] = None, + ): + """initialize the batch data before the inference steps.""" + # Move to GPU. + batch = { + 'audio_signal': audio_signal, + 'audio_signal_length': audio_length, + 'tokens': context_tokens, + 'tokens_length': context_lengths, + 'labels': context_tokens, + 'loss_mask': None, + } + if self.model.perception.cfg.get('combine_return', True): + ( + encoder_input, + self.attention_mask, + context_tokens, + _, + (speech_encoded, speech_encoded_len, extra_outputs), + ) = self.model.prepare_llm_input(batch) + self.position_ids = build_position_ids(encoder_input[:, :, 0].transpose(0, 1)) + self.extra_outputs = extra_outputs + return ( + context_tokens, + (encoder_input, speech_encoded, speech_encoded_len), + torch.zeros_like(context_lengths), + ) + else: + ( + encoder_input, + self.attention_mask, + context_tokens, + _, + (speech_encoded, speech_encoded_len, llm_encoded_len, extra_outputs), + ) = self.model.prepare_llm_input(batch) + self.position_ids = build_position_ids(encoder_input[:, :, 0].transpose(0, 1)) + self.extra_outputs = extra_outputs + return context_tokens, (encoder_input, speech_encoded, speech_encoded_len), llm_encoded_len + + def prepare_batch_at_step( + self, + tokens: torch.Tensor, + input_embeddings: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + maxlen: int, + micro_batch_size: int, + step: int, + context_lengths: torch.Tensor, + curr_context_length: int, + compute_attention_mask: bool, + ) -> Tuple[List[torch.Tensor], List[int]]: + # types2use = None + self.input_embeds_hidden = self.extra_outputs.get('input_embeds_hidden', None) + input_embeddings, speech_encoded, speech_encoded_len = input_embeddings + if step == 0: + # Allocate memory for the entire context. + set_inference_key_value_memory = True + tokens2use = tokens[:, :curr_context_length] + positions2use = self.position_ids[:, :curr_context_length] + embeddings2use = input_embeddings[:curr_context_length] + else: + # Set this to false so the memory is not reallocated. + set_inference_key_value_memory = False + tokens2use = tokens[:, curr_context_length - 1].view(micro_batch_size, -1) + positions2use = self.position_ids[:, curr_context_length - 1].view(micro_batch_size, -1) + embeddings2use = self.model._get_text_embeddings(tokens2use, positions2use).transpose(0, 1) + started = context_lengths <= curr_context_length + # for seq started, first get embeddings2use, and then run cross attend, after that replace embeddings2use with the cross attended embed + # use speech_encoded; rerun cross attend + # [1, b, d] + decoder_mems_list = self.extra_outputs.get('decoder_mems_list', None) + if decoder_mems_list is not None: + decoder_mems_list = decoder_mems_list[:, :, : curr_context_length - 1] + # need to use audio_ratio field if to support text-only decoding + embeddings2use, self.extra_outputs = self.model.perception_cross_attn( + speech_encoded, + speech_encoded_len, + embeddings2use, + input_lengths=tokens2use.squeeze(-1) != self.model.tokenizer.eos_id, + decoder_mems_list=decoder_mems_list, + return_mems=True, + ) + self.input_embeds_hidden = self.extra_outputs.get('input_embeds_hidden', None) + embeddings2use = switch( + input_embeddings[curr_context_length - 1].unsqueeze(0), embeddings2use.transpose(0, 1), started + ) + + """Prepare batch for each of the inference steps""" + setkey_value_array = torch.tensor( + [set_inference_key_value_memory] * micro_batch_size, device=torch.cuda.current_device() + ) + len_array = torch.tensor([maxlen] * micro_batch_size, device=torch.cuda.current_device()) + + batch = [tokens2use, embeddings2use, self.attention_mask, positions2use, setkey_value_array, len_array] + tensor_shape = [tokens2use.shape[1], micro_batch_size, self.model.cfg.hidden_size] + return batch, tensor_shape + + def model_inference_strategy_dispatcher(model, **args): - from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel + from nemo.collections.multimodal.speech_llm.models.modular_models import ( + CrossAttendModularAudioGPTModel, + ModularAudioGPTModel, + ) - if isinstance(model, ModularAudioGPTModel): + if isinstance(model, CrossAttendModularAudioGPTModel): + return CrossAttendAudioToTextGenerationStrategy(model, **args) + elif isinstance(model, ModularAudioGPTModel): return AudioToTextGenerationStrategy(model, **args) else: return text_generation_strategy.model_inference_strategy_dispatcher(model, **args) diff --git a/nemo/collections/multimodal/speech_llm/modules/modality_adapters.py b/nemo/collections/multimodal/speech_llm/modules/modality_adapters.py index 408231adcc6d9..9138845c73bdf 100644 --- a/nemo/collections/multimodal/speech_llm/modules/modality_adapters.py +++ b/nemo/collections/multimodal/speech_llm/modules/modality_adapters.py @@ -132,3 +132,15 @@ def forward(self, audio_signal, length=None): outputs = self.mlp(outputs) outputs_len = torch.div(length, self.pooling_factor, rounding_mode='floor') return outputs.transpose(1, 2), outputs_len + + +class IdentityConnectors(NeuralModule, Exportable, AccessMixin): + def __init__( + self, + *args, + **kwargs, + ): + super().__init__() + + def forward(self, audio_signal, length=None, *args, **kwargs): + return audio_signal, length diff --git a/nemo/collections/multimodal/speech_llm/modules/perception_modules.py b/nemo/collections/multimodal/speech_llm/modules/perception_modules.py index 2f0565982941d..a42c7d06cba0a 100644 --- a/nemo/collections/multimodal/speech_llm/modules/perception_modules.py +++ b/nemo/collections/multimodal/speech_llm/modules/perception_modules.py @@ -23,12 +23,12 @@ from nemo.collections.asr.models import EncDecSpeakerLabelModel from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder, ConformerMultiLayerFeatureExtractor from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import align_feat_seq_list +from nemo.collections.nlp.modules.common.transformer.transformer_decoders import TransformerDecoder from nemo.core.classes import Exportable, NeuralModule from nemo.core.classes.common import typecheck from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType from nemo.utils.decorators import experimental - __all__ = ["AudioPerceptionModule", "MultiAudioPerceptionModule"] @@ -70,6 +70,7 @@ def output_types(self): def __init__(self, cfg: DictConfig): super().__init__() # Initialize components + self.cfg = cfg self.preprocessor = self.from_config_dict(cfg.preprocessor) self.encoder = self.from_config_dict(cfg.encoder) @@ -429,3 +430,76 @@ def forward( # b, c, t -> b, t, c encoded = self.proj(encoded.transpose(1, 2)) return encoded, encoded_len + + +def lens_to_mask(lens, max_length): + batch_size = lens.shape[0] + mask = torch.arange(max_length).repeat(batch_size, 1).to(lens.device) < lens[:, None] + return mask + + +class TransformerCrossAttention(NeuralModule, Exportable): + """Transformer module for cross-attention between speech and text embeddings. + The module allows optional projection from the input embeddings to a lower dimension before feeding them to the transformer. + Args: + cfg: DictConfig, configuration object for the module which should include: + xattn: DictConfig, configuration object for the transformer decoder + """ + + def __init__(self, cfg: DictConfig, *args, **kwargs): + super().__init__() + xformer_num_layers = cfg.xattn.get('xformer_num_layers', 2) + xformer_dims = cfg.xattn.get('xformer_dims', cfg.output_dim) + self.cfg = cfg + cross_attn_cfg = cfg.xattn + if xformer_dims != cfg.output_dim: + self.input_proj1 = nn.Linear(cfg.output_dim, xformer_dims) + self.input_proj2 = nn.Linear(cfg.output_dim, xformer_dims) + self.output_proj = nn.Linear(xformer_dims, cfg.output_dim) + else: + self.input_proj1 = nn.Identity() + self.input_proj2 = nn.Identity() + self.output_proj = nn.Identity() + # causal attention decoder by default + self.xattn_decoder = TransformerDecoder( + hidden_size=xformer_dims, + num_layers=xformer_num_layers, + inner_size=1 * xformer_dims, + num_attention_heads=cross_attn_cfg.num_attention_heads, + ffn_dropout=cross_attn_cfg.ffn_dropout, + attn_score_dropout=cross_attn_cfg.attn_score_dropout, + attn_layer_dropout=cross_attn_cfg.attn_layer_dropout, + hidden_act=cross_attn_cfg.hidden_act, + pre_ln=cross_attn_cfg.pre_ln, + pre_ln_final_layer_norm=cross_attn_cfg.pre_ln_final_layer_norm, + ) + + def forward( + self, + encoder_states, + encoded_len, + input_embeds, + input_lengths, + decoder_mems_list=None, + return_mems=False, + ): + assert input_embeds.shape[-1] == encoder_states.shape[-1] + enc_mask = lens_to_mask(encoded_len, encoder_states.shape[1]).to(encoder_states.dtype) + dec_mask = lens_to_mask(input_lengths, input_embeds.shape[1]).to(input_lengths.dtype) + y = self.xattn_decoder( + decoder_states=self.input_proj1(input_embeds), + decoder_mask=dec_mask, + encoder_states=self.input_proj2(encoder_states), + encoder_mask=enc_mask, + decoder_mems_list=decoder_mems_list, + return_mems=return_mems, + return_mems_as_list=False, + ) + if return_mems: + extra_outpus = {'decoder_mems_list': y} + y = y[-1][:, -input_embeds.shape[1] :] + else: + extra_outpus = {} + y = self.output_proj(y) + input_embeds + assert y.shape == input_embeds.shape + return y, extra_outpus diff --git a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py index 92a3548f9337f..d638281950b46 100644 --- a/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py +++ b/nemo/collections/multimodal/speech_llm/parts/utils/data_utils.py @@ -16,6 +16,7 @@ import numpy as np import torch +from nemo.utils import logging, logging_mode def maybe_cast_to_list(x): @@ -155,3 +156,227 @@ def align_feat_seq_list( new_seq_list.append(new_seq) new_seq_len_list.append(new_seq_len) return new_seq_list, new_seq_len_list + + +def build_loss_mask(processed_example: dict, answer_only_loss: bool = True): + """Pad input_ids in batch to max batch length while building loss mask""" + # function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py + input_ids = processed_example['input_ids'] + answer_start_idx = processed_example['answer_start_idx'] + if answer_only_loss: + loss_mask = [float(idx >= answer_start_idx) for idx in range(len(input_ids))] + else: + loss_mask = [1.0] * len(input_ids) + + return loss_mask + + +class TextProcessing: + """ + Text processing pipeline for speech_llm data loader. + This class is adapted from the one used in nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py + The class follows the interface of _process_example which takes in a context and an output + and processes them into a formatted training example. + + Args: + tokenizer: text tokenizer object + max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. + min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. + add_bos (bool): Whether to add a beginning of sentence token to each data example + add_eos (bool): Whether to add an end of sentence token to each data example + add_sep (bool): Whether to add a separation token to each data example (goes between prompt and answer) + sep_id (int): The id of the separation token + separate_prompt_and_response_with_newline (bool): Whether to separate the prompt and response with a newline character + answer_only_loss (bool): Whether to compute the loss only on the answer part of the input + truncation_field (str): Field to use for truncation. (Options: "answer", "context"). Field to be used for truncation if the combined length exceeds the max sequence length. + pad_to_max_length (bool): Whether to pad the input to the max sequence length. If False, will pad to the max length of the current batch. + prompt_template (str): Prompt template to inject via an fstring. Formatted like Q: {input}\n\nA: {output} + virtual_tokens (int): Number of virtual tokens to add to the beginning of the input + tokens_to_generate (int): Number of tokens to generate during inference + context_key (str): Key to use for the context in your JSONL file + answer_key (str): Key to use for the label in your JSONL file + end_string (Optional[str]): If not None, add this string to the end of the answer. + sample_alpha (Optional[float]): For SPE subword sampling + input_text_mask_ratio (Optional[float]): If not None, will mask the input text at this ratio. + """ + + def __init__( + self, + tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec', + max_seq_length: int = 1024, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + add_sep: bool = False, + sep_id: Optional[int] = None, + seed: int = 1234, + separate_prompt_and_response_with_newline: bool = False, + answer_only_loss: bool = True, + truncation_field: str = "answer", + pad_to_max_length: bool = False, # (@adithyare) allows for much faster training especially in PEFT settings. + prompt_template: str = None, + virtual_tokens: int = 0, + tokens_to_generate: int = 0, + context_key: str = 'context', + answer_key: str = 'answer', + end_string: Optional[str] = None, + sample_alpha: Optional[float] = None, + audio_locator: Optional[str] = None, + ): + self.context_key = context_key + self.answer_key = answer_key + self.tokenizer = tokenizer + self.max_seq_length = max_seq_length + self.min_seq_length = min_seq_length + self.seed = seed + self.separate_prompt_and_response_with_newline = separate_prompt_and_response_with_newline + self.answer_only_loss = answer_only_loss + self.truncation_field = truncation_field + self.pad_to_max_length = pad_to_max_length + self.prompt_template = prompt_template + self.virtual_tokens = virtual_tokens + self.tokens_to_generate = tokens_to_generate + self.add_bos = add_bos + self.add_eos = add_eos + self.add_sep = add_sep + self.end_string = end_string + self.sample_alpha = sample_alpha + self.audio_locator = audio_locator + + if add_bos and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0: + self.bos_id = tokenizer.bos_id + else: + self.bos_id = None + + if add_eos and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0: + self.eos_id = tokenizer.eos_id + else: + self.eos_id = None + + if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0: + self.pad_id = tokenizer.pad_id + else: + self.pad_id = self.eos_id if self.eos_id is not None else 0 + + self.sep_id = sep_id if add_sep else None + + if self.prompt_template is not None: + # When providing things like newlines in the prompt template via the CLI, they are escaped. This line unescapes them. + self.prompt_template = self.prompt_template.encode('utf-8').decode('unicode_escape') + assert self.truncation_field in ["answer", "context"] + + def _process_example(self, context: str, output: str): + """ + Create an example by concatenating text and answer. + Truncation is carried out when needed, but it is performed only on the prompt side. + BOS, EOS, and SEP, are added if specified. + + function copied from nemo/collections/nlp/data/language_modelling/megatron/gpt_sft_dataset.py + """ + if self.prompt_template is not None: + if self.context_key not in self.prompt_template or self.answer_key not in self.prompt_template: + if "input" in self.prompt_template and "output" in self.prompt_template: + logging.warning( + f"Using 'input' and 'output' as context and answer keys, since given ones ({self.context_key}, {self.answer_key}) are not found in the prompt template: {self.prompt_template}.", + mode=logging_mode.ONCE, + ) + self.context_key = "input" + self.answer_key = "output" + assert f'{{{self.context_key}}}' in self.prompt_template + assert f'{{{self.answer_key}}}' in self.prompt_template + # Make sure that '{output}' always occurs at the end of the prompt template string + assert self.prompt_template.index(f'{{{self.answer_key}}}') == len(self.prompt_template) - len( + f'{{{self.answer_key}}}' + ) + # Get the context by replacing only the input + original_context = context + context = ( + self.prompt_template.replace(f'{{{self.context_key}}}', context) + .replace(f'{{{self.answer_key}}}', '') + .strip(' ') + ) + # Replace the input and output placeholders with the actual input and output + text = self.prompt_template.replace(f'{{{self.context_key}}}', original_context).replace( + f'{{{self.answer_key}}}', output + ) + + elif self.separate_prompt_and_response_with_newline: + text = context + '\n' + output + else: + text = context + ' ' + output + + if self.virtual_tokens: + # (@adithyare) we are going to insert "pad/eos" tokens in the beginning of the text and context + # these pad/eos tokens are placeholders for virtual tokens + pre_pad = [self.tokenizer.eos_id] * self.virtual_tokens + else: + pre_pad = [] + answer_text = text[len(context) :] + answer_ids = pre_pad + self.tokenizer.text_to_ids(answer_text, self.sample_alpha) + if self.end_string: + answer_ids += self.tokenizer.text_to_ids(self.end_string) + + if self.audio_locator is None: + # signle audio case + context_ids = self.tokenizer.text_to_ids(context) + context_start_idx = [0] + else: + # multiple audio case + context_ids = [] + context_start_idx = [] + for context_seg in context.split(self.audio_locator): + context_start_idx.append(len(context_ids)) + context_ids.extend(self.tokenizer.text_to_ids(context_seg)) + context_ids = pre_pad + context_ids + context_start_idx = [x + len(pre_pad) for x in context_start_idx] + + # for the long context cases, collate_fn includes self.tokens_to_generate for padding + total_ids = len(context_ids) + max(len(answer_ids), self.tokens_to_generate) + if self.add_bos: + total_ids += 1 + if self.add_sep: + total_ids += 1 + if self.add_eos: + total_ids += 1 + + # If the total number of token is greater than the max, we will try to truncate the answer + if total_ids > self.max_seq_length: + truncation_length = total_ids - self.max_seq_length + answer_ids = answer_ids[: -min(truncation_length, len(answer_ids))] + context_ids = context_ids[: -min(truncation_length, len(context_ids))] + + input_ids = context_ids + answer_start_idx = len(input_ids) + + # Adds bos token in the start + if self.add_bos: + context_ids = [self.bos_id] + context_ids + input_ids = [self.bos_id] + input_ids + answer_start_idx += 1 + + # Adds sep token between text/prompt and answer + if self.add_sep: + context_ids = context_ids + [self.sep_id] + input_ids = input_ids + [self.sep_id] + answer_start_idx += 1 + + input_ids = input_ids + answer_ids + + if self.add_eos: + input_ids = input_ids + [self.tokenizer.eos_id] + answer_ids = answer_ids + [self.tokenizer.eos_id] + + if len(input_ids) > self.max_seq_length: + logging.warning(f'Input ids length {len(input_ids)} exceed max sequence length {self.max_seq_length}') + input_ids = input_ids[: self.max_seq_length] + + processed_example = { + 'input_ids': (input_ids), + 'answer_start_idx': (answer_start_idx), + 'context_ids': (context_ids), + 'context_length': len(context_ids), + 'answer_ids': (answer_ids), + 'context_start_idx': context_start_idx, + } + + return processed_example diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index b2594731d1777..29f3e8905f913 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -421,7 +421,7 @@ def _build_tokenizer(self): legacy = True if self._cfg.tokenizer.library == 'sentencepiece' else False self.tokenizer = get_nmt_tokenizer( library=self._cfg.tokenizer.library, - model_name=self._cfg.tokenizer.type, + model_name=self._cfg.tokenizer.get("type", None), tokenizer_model=self.register_artifact("tokenizer.model", self._cfg.tokenizer.get('model', None)), vocab_file=self.register_artifact("tokenizer.vocab_file", self._cfg.tokenizer.get('vocab_file', None)), merges_file=self.register_artifact("tokenizer.merge_file", self._cfg.tokenizer.get('merge_file', None)), diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py index 4d4cc09d0751e..d151925635ab6 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py @@ -63,26 +63,29 @@ class MegatronBasePromptLearningModel(MegatronBaseModel, TextGeneration): """ - Model class for prompt-tuning or p-tuning a pretrained Megatron model. + Model class for prompt-tuning or p-tuning a pretrained Megatron model. Prompt Tuning initalizes virtual prompt embeddings directly from a copy of certain token embeddings from the the pretrained model's vocabulary - and directly tunes these embedding weights. The token embeddings used in - initalization are specified by the user in the config file. The model can - be prompt-tuned for multiple tasks at once. virtual prompts are stored in a - prompt table and can be added or deleted without disrupting virtual prompts - for other tasks. + and directly tunes these embedding weights. The token embeddings used in + initalization are specified by the user in the config file. The model can + be prompt-tuned for multiple tasks at once. virtual prompts are stored in a + prompt table and can be added or deleted without disrupting virtual prompts + for other tasks. P-tuning initializes an LSTM encoder model that generates virtual prompt embeddings for every task. Each task shares the same encoder. After ptuning is compelete, the learned virtual prompts can be saved to the prompt table - using add_ptuned_prompts_to_prompt_table(). Thus, if a user wants to add a - new virtual prompt via p-tuning, they do not need to retrain on all previous + using add_ptuned_prompts_to_prompt_table(). Thus, if a user wants to add a + new virtual prompt via p-tuning, they do not need to retrain on all previous tasks. This gives p-tuning the same task flexiblity as prompt-tuning. """ def __init__(self, cfg: DictConfig, trainer: Trainer): super().__init__(cfg, trainer) + self.init_model(cfg, trainer) + + def init_model(self, cfg: DictConfig, trainer: Trainer): self.config: ModelParallelConfig = self.model_parallel_config @@ -156,10 +159,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): def load_task_templates(self, task_templates): """ - Takes in the task template portion of the config and turns - it into a table where each task's prompt template and - the number of virtual tokens to insert in a given part of - the prompt template are specified. + Takes in the task template portion of the config and turns + it into a table where each task's prompt template and + the number of virtual tokens to insert in a given part of + the prompt template are specified. """ self.task_templates = {} self.task_id_num_to_name = {} @@ -215,18 +218,17 @@ def init_prompt_encoder(self): ) def freeze_existing_word_embeddings(self): - """Freeze params of existing virtual prompts that should not be tuned further - """ + """Freeze params of existing virtual prompts that should not be tuned further""" # Make sure word embeddings are frozen for params in self.word_embeddings.parameters(): params.requires_grad = False def state_dict(self): """ - Custom state dict that only contains prompt table and prompt encoder parameters. - No frozen model parameters are stored in the state dict. Prompt encoder parameters + Custom state dict that only contains prompt table and prompt encoder parameters. + No frozen model parameters are stored in the state dict. Prompt encoder parameters are only in state dict for intermediate checkpoints saved during training. Final - nemo checkpoints at the end of training will contain prompt table parameters only. + nemo checkpoints at the end of training will contain prompt table parameters only. """ state_dict_ = {} @@ -241,7 +243,7 @@ def state_dict(self): def load_state_dict(self, state_dict, strict: bool = True): """ Custom load state dict method that only loads prompt table and prompt encoder - parameters. Matching load method for this class' custom state dict method. + parameters. Matching load method for this class' custom state dict method. """ if self.first_stage_of_pipeline(): if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: @@ -253,7 +255,7 @@ def load_state_dict(self, state_dict, strict: bool = True): def setup_optimizer_param_groups(self): """ - ModelPT override. Optimizer will get self._optimizer_param_groups. + ModelPT override. Optimizer will get self._optimizer_param_groups. Only want virtual prompt params to be passed to the optimizer. """ ## Freeze frozen model @@ -272,8 +274,8 @@ def setup_optimizer_param_groups(self): def embed_input(self, input_ids: Tensor, taskname_ids: Tensor, use_cached_reps: bool): """ - Replaces the virtual tokens in the input_ids with embeddings - calculated from either the 'prompt_table' or 'prompt_encoder'. + Replaces the virtual tokens in the input_ids with embeddings + calculated from either the 'prompt_table' or 'prompt_encoder'. The virtual token placeholders have token_ids listed in `self.pseudo_token_ids`. @@ -422,7 +424,7 @@ def load_frozen_model(self, cfg, trainer): def get_pseudo_tokens(num_virtual_tokens): """ Takes in an integer and returns a list of strings where each string - is a numbered virtual token placeholder. If + is a numbered virtual token placeholder. If num_virtual_tokens = 3, then this function returns: ["", "", ""] @@ -430,7 +432,7 @@ def get_pseudo_tokens(num_virtual_tokens): Args: num_virtual_tokens: (int) Number of virtual token strings you want to make - returns a list of string. + returns a list of string. """ pseudo_tokens = [ diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 44a08e163c914..28bcbf22ac33e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -100,6 +100,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.virtual_tokens = 0 self.init_global_step = 0 + self.enforce_divisible_batch = True # used for gradient accumulation def setup_metric(self, data_cfg): metric_name = "exact_string_match" @@ -356,7 +357,7 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() batch = {k: v for k, v in batch.items() if isinstance(v, (torch.Tensor, list))} _, seq_length = batch['tokens'].shape - data_iter = get_iterator_k_split(batch, get_num_microbatches()) + data_iter = get_iterator_k_split(batch, get_num_microbatches(), self.enforce_divisible_batch) if log_token_counts: self.log('seq_length_padded', seq_length, prog_bar=True, batch_size=1) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 90c6a40b1d403..8fe215bcc9af6 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -1206,6 +1206,10 @@ def dummy(): global_batch_per_gpu = tokens_enc.size(0) device = tokens_enc.device encoder_seq_length = tokens_enc.size(1) + elif encoder_input is not None: + global_batch_per_gpu = encoder_input.size(0) + device = encoder_input.device + encoder_seq_length = encoder_input.size(1) else: global_batch_per_gpu = enc_output.size(0) device = enc_output.device diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index 75c50146bfabd..5aaac6755601f 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -15,11 +15,10 @@ """Utilities for models.""" import itertools import math -from typing import Dict, Iterator, List, Tuple, Union +from typing import Dict, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn - from torch import Tensor from nemo.utils import logging, logging_mode @@ -413,16 +412,19 @@ def get_all_params_for_weight_decay_optimization( return tuple(filter(lambda g: len(g['params']) > 0, param_groups)) -def split_list(inputs, num_chunks): +def split_list(inputs, num_chunks, enforce_divisible_batch: Optional[bool] = True): """ Split a list into equal sized chunks """ chunk_size = len(inputs) // num_chunks - assert len(inputs) % chunk_size == 0, "Issue with batch size configuration!" + if enforce_divisible_batch: + assert len(inputs) % chunk_size == 0, "Issue with batch size configuration!" return [inputs[i : i + chunk_size] for i in range(0, len(inputs), chunk_size)] -def get_iterator_k_split(batch: Union[Dict, List[torch.Tensor]], num_microbatches: int) -> Iterator: +def get_iterator_k_split( + batch: Union[Dict, List[torch.Tensor]], num_microbatches: int, enforce_divisible_batch: Optional[bool] = True +) -> Iterator: """ Split a batch into k microbatches, where the batch size is divisible by k. Batch could be a dictionary of tensors or a list of tensors. A dictionary batch could also have items of List type, @@ -442,8 +444,13 @@ def get_iterator_k_split(batch: Union[Dict, List[torch.Tensor]], num_microbatche # Split tensor items items = list(tensor_items.items()) - assert items[0][1].shape[0] % num_microbatches == 0, "Issue with batch size configuration!" + if enforce_divisible_batch: + assert items[0][1].shape[0] % num_microbatches == 0, "Issue with batch size configuration!" split_batch = [torch.tensor_split(item[1], num_microbatches, dim=0) for item in items] + # handle the case where the batch size from dynamic bucketting is not divisible + if items[0][1].shape[0] % num_microbatches != 0: + chunk_size = split_batch[0][-1].shape[0] + split_batch = [[j[:chunk_size] for j in i] for i in split_batch] if len(list_items) == 0: # Only have tensor items @@ -453,7 +460,10 @@ def get_iterator_k_split(batch: Union[Dict, List[torch.Tensor]], num_microbatche else: # Split list items list_items = list(list_items.items()) - split_list_batch = [split_list(item[1], num_microbatches) for item in list_items] + split_list_batch = [ + split_list(item[1], num_microbatches, enforce_divisible_batch=enforce_divisible_batch) + for item in list_items + ] # Merge tensor and list items all_keys = [item[0] for item in items] + [item[0] for item in list_items] all_split_batch = split_batch + split_list_batch