From 9b0ac1a76b02caff2233b1373070bc0ef2dd602d Mon Sep 17 00:00:00 2001 From: zhehuaichen Date: Wed, 8 May 2024 19:47:35 -0700 Subject: [PATCH] include text-only into fprop of training and eval; TODO: text-only predict Signed-off-by: zhehuaichen --- .../speech_llm/models/modular_models.py | 97 +++++++++++-------- 1 file changed, 58 insertions(+), 39 deletions(-) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 8ee548d432c8..9b4c3c611d5e 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -359,6 +359,29 @@ def prepare_llm_input(self, audio_batch): return encoder_input, attention_mask, labels, loss_mask, encoder_length + def _gpt_forward( + self, input_ids, position_ids, encoder_input, attention_mask, labels, checkpoint_activations_all_layers + ): + """Forward pass of the GPT model.""" + if self.mcore_gpt: + output = self.model( + input_ids=input_ids, + position_ids=position_ids, + decoder_input=encoder_input, + attention_mask=attention_mask, + labels=labels, + ) + else: + output = self.model( + input_ids=input_ids, + position_ids=position_ids, + encoder_input=encoder_input, + attention_mask=attention_mask, + labels=labels, + checkpoint_activations_all_layers=checkpoint_activations_all_layers, + ) + return output + def forward( self, batch, checkpoint_activations_all_layers, ): @@ -370,40 +393,26 @@ def forward( output, loss_mask = None, None + multimodal_output = {} if audio_batch: - # TODO: handle the possibility that there might be no audio data in the mini-batch - if 'audio_ratio' in audio_batch: - self.log( - 'local_batch_size', - audio_batch['audio_ratio'].shape[0], - prog_bar=True, - batch_size=1, - rank_zero_only=False, - ) - encoder_input, attention_mask, labels, loss_mask, _ = self.prepare_llm_input(audio_batch) - if self.mcore_gpt: - output = self.model( - input_ids=None, - position_ids=None, - decoder_input=encoder_input, - attention_mask=attention_mask, - labels=labels, - ) - else: - output = self.model( - input_ids=None, - position_ids=None, - encoder_input=encoder_input, - attention_mask=attention_mask, - labels=labels, - checkpoint_activations_all_layers=checkpoint_activations_all_layers, - ) - + output = self._gpt_forward( + None, None, encoder_input, attention_mask, labels, checkpoint_activations_all_layers + ) + multimodal_output['audio_text'] = (output, loss_mask) if text_batch: - pass # TODO: implement text-only mini-batch forward + input_ids = text_batch.text_input_ids[:, :-1] + labels = text_batch.text_input_ids[:, 1:] + attention_mask = self._create_attention_mask(input_ids) + loss_mask = text_batch.text_mask[:, 1:] + output = self._gpt_forward( + input_ids, None, None, attention_mask, labels, checkpoint_activations_all_layers + ) + multimodal_output['text'] = (output, loss_mask) + if not audio_batch and not text_batch: + raise ValueError("No input data found for the model.") - return output, loss_mask + return multimodal_output def get_forward_output_only_func(self): def fwd_output_only_func(dataloader_iter, model): @@ -489,18 +498,26 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if not self.mcore_gpt: batch['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers - output_tensor, loss_mask = self.forward( + multimodal_output = self.forward( batch, checkpoint_activations_all_layers=checkpoint_activations_all_layers ) - batch['loss_mask'] = loss_mask - def loss_func(output_tensor): + def loss_func(multimodal_output): # Loss for a micro-batch (ub) - loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) + loss_for_ub = 0 + + for key, (output, loss_mask) in multimodal_output.items(): + cur_loss = self.loss_func(loss_mask, loss_mask.sum(), output) + loss_for_ub += cur_loss + self.log( + f'{key}_loss', cur_loss.mean(), prog_bar=True, batch_size=1, rank_zero_only=False, + ) + self.log( + f'{key}_batch_size', loss_mask.shape[0], prog_bar=True, batch_size=1, rank_zero_only=False, + ) + cp_size = self.cfg.get('context_parallel_size', 1) - if self.cfg.data.get( - "return_output_tensors", False - ): # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare) + if self.cfg.data.get("return_output_tensors", False): loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) pos_cs = average_losses_across_data_parallel_group([pos_cs]) @@ -540,7 +557,7 @@ def loss_func(output_tensor): reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) return loss_for_ub * cp_size, {'avg': reduced_loss} - return output_tensor, loss_func + return multimodal_output, loss_func return fwd_output_and_loss_func @@ -1044,7 +1061,7 @@ def inference_step(self, dataloader_iter, mode): """ Used for validation and test steps, added postprocessing after calling self.predict_step(). """ - # TODO: support text-only part of mini-batch + # Evaluation of multimodal data follows the same pattern as training except predict_step batch, batch_idx, dataloader_idx = next(dataloader_iter) data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds self._reconfigure_and_process_inference_batch(batch, data_cfg) @@ -1132,7 +1149,9 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int """ Used to get LLM predictions for validation and test steps based on the given inference config. """ + # TODO: we expect only one modality in each batch of inference. In lhotse, can we specify a list of datasets which only have one modality either audio-text or text-only? # TODO: support text-only part of mini-batch + # the following supports STT (audio-text) inference inference_config = self.get_inference_config() if inference_config is not None: