Skip to content

Commit

Permalink
Merge pull request #2 from zhehuaichen/canary_speechllm1_speech_text_…
Browse files Browse the repository at this point in the history
…lhotse

include text-only into fprop of training and eval; TODO: text-only inference
  • Loading branch information
pzelasko authored May 9, 2024
2 parents b87acb6 + 9b0ac1a commit d40a430
Showing 1 changed file with 58 additions and 39 deletions.
97 changes: 58 additions & 39 deletions nemo/collections/multimodal/speech_llm/models/modular_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,29 @@ def prepare_llm_input(self, audio_batch):

return encoder_input, attention_mask, labels, loss_mask, encoder_length

def _gpt_forward(
self, input_ids, position_ids, encoder_input, attention_mask, labels, checkpoint_activations_all_layers
):
"""Forward pass of the GPT model."""
if self.mcore_gpt:
output = self.model(
input_ids=input_ids,
position_ids=position_ids,
decoder_input=encoder_input,
attention_mask=attention_mask,
labels=labels,
)
else:
output = self.model(
input_ids=input_ids,
position_ids=position_ids,
encoder_input=encoder_input,
attention_mask=attention_mask,
labels=labels,
checkpoint_activations_all_layers=checkpoint_activations_all_layers,
)
return output

def forward(
self, batch, checkpoint_activations_all_layers,
):
Expand All @@ -370,40 +393,26 @@ def forward(

output, loss_mask = None, None

multimodal_output = {}
if audio_batch:
# TODO: handle the possibility that there might be no audio data in the mini-batch
if 'audio_ratio' in audio_batch:
self.log(
'local_batch_size',
audio_batch['audio_ratio'].shape[0],
prog_bar=True,
batch_size=1,
rank_zero_only=False,
)

encoder_input, attention_mask, labels, loss_mask, _ = self.prepare_llm_input(audio_batch)
if self.mcore_gpt:
output = self.model(
input_ids=None,
position_ids=None,
decoder_input=encoder_input,
attention_mask=attention_mask,
labels=labels,
)
else:
output = self.model(
input_ids=None,
position_ids=None,
encoder_input=encoder_input,
attention_mask=attention_mask,
labels=labels,
checkpoint_activations_all_layers=checkpoint_activations_all_layers,
)

output = self._gpt_forward(
None, None, encoder_input, attention_mask, labels, checkpoint_activations_all_layers
)
multimodal_output['audio_text'] = (output, loss_mask)
if text_batch:
pass # TODO: implement text-only mini-batch forward
input_ids = text_batch.text_input_ids[:, :-1]
labels = text_batch.text_input_ids[:, 1:]
attention_mask = self._create_attention_mask(input_ids)
loss_mask = text_batch.text_mask[:, 1:]
output = self._gpt_forward(
input_ids, None, None, attention_mask, labels, checkpoint_activations_all_layers
)
multimodal_output['text'] = (output, loss_mask)
if not audio_batch and not text_batch:
raise ValueError("No input data found for the model.")

return output, loss_mask
return multimodal_output

def get_forward_output_only_func(self):
def fwd_output_only_func(dataloader_iter, model):
Expand Down Expand Up @@ -489,18 +498,26 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
if not self.mcore_gpt:
batch['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers

output_tensor, loss_mask = self.forward(
multimodal_output = self.forward(
batch, checkpoint_activations_all_layers=checkpoint_activations_all_layers
)
batch['loss_mask'] = loss_mask

def loss_func(output_tensor):
def loss_func(multimodal_output):
# Loss for a micro-batch (ub)
loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor)
loss_for_ub = 0

for key, (output, loss_mask) in multimodal_output.items():
cur_loss = self.loss_func(loss_mask, loss_mask.sum(), output)
loss_for_ub += cur_loss
self.log(
f'{key}_loss', cur_loss.mean(), prog_bar=True, batch_size=1, rank_zero_only=False,
)
self.log(
f'{key}_batch_size', loss_mask.shape[0], prog_bar=True, batch_size=1, rank_zero_only=False,
)

cp_size = self.cfg.get('context_parallel_size', 1)
if self.cfg.data.get(
"return_output_tensors", False
): # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare)
if self.cfg.data.get("return_output_tensors", False):
loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
pos_cs = average_losses_across_data_parallel_group([pos_cs])
Expand Down Expand Up @@ -540,7 +557,7 @@ def loss_func(output_tensor):
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
return loss_for_ub * cp_size, {'avg': reduced_loss}

return output_tensor, loss_func
return multimodal_output, loss_func

return fwd_output_and_loss_func

Expand Down Expand Up @@ -1044,7 +1061,7 @@ def inference_step(self, dataloader_iter, mode):
"""
Used for validation and test steps, added postprocessing after calling self.predict_step().
"""
# TODO: support text-only part of mini-batch
# Evaluation of multimodal data follows the same pattern as training except predict_step
batch, batch_idx, dataloader_idx = next(dataloader_iter)
data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds
self._reconfigure_and_process_inference_batch(batch, data_cfg)
Expand Down Expand Up @@ -1132,7 +1149,9 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int
"""
Used to get LLM predictions for validation and test steps based on the given inference config.
"""
# TODO: we expect only one modality in each batch of inference. In lhotse, can we specify a list of datasets which only have one modality either audio-text or text-only?
# TODO: support text-only part of mini-batch
# the following supports STT (audio-text) inference

inference_config = self.get_inference_config()
if inference_config is not None:
Expand Down

0 comments on commit d40a430

Please sign in to comment.