Skip to content

Commit

Permalink
Add joint text/audio dataloading capability to speechllm
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko committed May 8, 2024
1 parent b30b99c commit b87acb6
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def build_salm_dataloader(dataset, data_cfg, consumed_samples=0, is_predict=Fals
global_rank=parallel_state.get_data_parallel_rank(),
world_size=parallel_state.get_data_parallel_world_size(),
dataset=dataset,
tokenizer=dataset.text_processor.tokenizer,
)
else:
dls = []
Expand Down
123 changes: 69 additions & 54 deletions nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import copy
import random
from dataclasses import dataclass
from typing import Optional

import numpy as np
import torch.utils.data
from lhotse.cut import Cut, CutSet
from lhotse.dataset.collation import collate_vectors as collate_vectors_lhotse

from nemo.collections.common.data.lhotse.text_adapters import NeMoSFTExample
from nemo.utils import logging


Expand Down Expand Up @@ -301,7 +305,6 @@ def __init__(
tokens_to_generate: int,
pad_to_max_length: bool,
max_seq_length: int,
noise_cuts: Optional = None,
canary_processor: Optional = None,
convert_canary_prompt_to_text: bool = False,
prepend_to_exist_question: Optional = None,
Expand All @@ -314,9 +317,6 @@ def __init__(
super().__init__()
self.text_processor = text_processor
self.load_audio = AudioSamples(fault_tolerant=True)
self.maybe_mix_noise = (
_identity if noise_cuts is None else CutMix(noise_cuts, pad_to_longest=False, random_mix_offset=True)
)
self.tokens_to_generate = tokens_to_generate
self.pad_to_max_length = pad_to_max_length
self.max_seq_length = max_seq_length
Expand Down Expand Up @@ -348,62 +348,77 @@ def _inject_random_context_into_question(self, cut, random_context_num=8, random
cut.question = context + cut.question
self.random_context = current_words

def __getitem__(self, cuts) -> dict[str, torch.Tensor | list[str] | dict]:
cuts = cuts.sort_by_duration()
cuts = self.maybe_mix_noise(cuts)
def __getitem__(self, all_cuts: CutSet) -> dict:
ans = {}

audio, audio_lens, cuts = self.load_audio(cuts)
# convert audio cuts to mini-batch
cuts = all_cuts.filter(lambda c: isinstance(c, Cut))
if cuts:
audio, audio_lens, cuts = self.load_audio(cuts)

return_batch = {}
audio_ratio = []
for id, cut in enumerate(cuts):
if hasattr(cut, "is_text_only") and cut.is_text_only:
audio_ratio.append(0.0)
else:
audio_ratio.append(1.0)

if self.canary_processor != None:
is_canary_tokens_augment = torch.rand(1) < self.canary_tokens_augment_ratio
_, _, _, _, canary_tokens, canary_token_lens = self.canary_processor.__getitem__(cuts)
return_batch = {}
audio_ratio = []
for id, cut in enumerate(cuts):
canary_text = self.canary_processor.tokenizer._tokenizer.ids_to_text(canary_tokens[id].tolist())
if audio_ratio[id] == 0.0:
assert hasattr(cut, "question")
elif self.prepend_to_exist_question and hasattr(cut, "question"):
cut.question = self.prepend_to_exist_question + cut.question
elif self.convert_canary_prompt_to_text:
cut.question = convert_canary_prompt_to_text(canary_text, is_canary_tokens_augment)
elif hasattr(cut, "question"):
pass
if hasattr(cut, "is_text_only") and cut.is_text_only:
audio_ratio.append(0.0)
else:
cut.question = self.question + ' ' + canary_text
metadata = []
for id, cut in enumerate(cuts):
self._inject_random_context_into_question(
cut, random_context_positive_percent=self.random_context_positive_percent
audio_ratio.append(1.0)

if self.canary_processor != None:
is_canary_tokens_augment = torch.rand(1) < self.canary_tokens_augment_ratio
_, _, _, _, canary_tokens, canary_token_lens = self.canary_processor.__getitem__(cuts)
for id, cut in enumerate(cuts):
canary_text = self.canary_processor.tokenizer._tokenizer.ids_to_text(canary_tokens[id].tolist())
if audio_ratio[id] == 0.0:
assert hasattr(cut, "question")
elif self.prepend_to_exist_question and hasattr(cut, "question"):
cut.question = self.prepend_to_exist_question + cut.question
elif self.convert_canary_prompt_to_text:
cut.question = convert_canary_prompt_to_text(canary_text, is_canary_tokens_augment)
elif hasattr(cut, "question"):
pass
else:
cut.question = self.question + ' ' + canary_text
metadata = []
for id, cut in enumerate(cuts):
self._inject_random_context_into_question(
cut, random_context_positive_percent=self.random_context_positive_percent
)
metadata.append({'audio_filepath': cut.id + '.wav'})

collated_text_data = collate_text_data(
cuts=cuts,
default_question=self.question,
text_processor=self.text_processor,
tokens_to_generate=self.tokens_to_generate,
pad_to_max_length=self.pad_to_max_length,
max_seq_length=self.max_seq_length,
)
metadata.append({'audio_filepath': cut.id + '.wav'})

collated_text_data = collate_text_data(
cuts=cuts,
default_question=self.question,
text_processor=self.text_processor,
tokens_to_generate=self.tokens_to_generate,
pad_to_max_length=self.pad_to_max_length,
max_seq_length=self.max_seq_length,
)
return_batch.update(
{
"sample_ids": list(cuts.ids),
"audio_signal": audio,
"audio_signal_length": audio_lens,
"audio_ratio": torch.FloatTensor(audio_ratio),
"metadata": metadata,
**collated_text_data,
}
)
return_batch.update(
{
"sample_ids": list(cuts.ids),
"audio_signal": audio,
"audio_signal_length": audio_lens,
"audio_ratio": torch.FloatTensor(audio_ratio),
"metadata": metadata,
**collated_text_data,
}
)
ans.update(return_batch)

# convert text examples to tensors
text_examples = all_cuts.filter(lambda c: isinstance(c, NeMoSFTExample))
if text_examples:
pad_id = self.text_processor.pad_id
text_minibatch = dict(
text_input_ids=collate_vectors_lhotse([e.input_ids for e in text_examples], padding_value=pad_id),
text_answer_ids=collate_vectors_lhotse([e.answer_ids for e in text_examples], padding_value=pad_id),
text_context_ids=collate_vectors_lhotse([e.context_ids for e in text_examples], padding_value=pad_id),
text_masks=collate_vectors_lhotse([e.mask for e in text_examples], padding_value=pad_id),
)
ans.update(text_minibatch)

return return_batch
return ans


def collate_text_data(
Expand Down
67 changes: 40 additions & 27 deletions nemo/collections/multimodal/speech_llm/models/modular_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,38 +360,48 @@ def prepare_llm_input(self, audio_batch):
return encoder_input, attention_mask, labels, loss_mask, encoder_length

def forward(
self, audio_batch, checkpoint_activations_all_layers,
self, batch, checkpoint_activations_all_layers,
):
"""
Forward pass of the model. We prepend audio embeddings to the instruction and label text tokens as the LLM input.
"""
if 'audio_ratio' in audio_batch:
self.log(
'local_batch_size',
audio_batch['audio_ratio'].shape[0],
prog_bar=True,
batch_size=1,
rank_zero_only=False,
)
audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")}
text_batch = {k: v for k, v in batch.items() if k.startswith("text_")}

output, loss_mask = None, None

if audio_batch:
# TODO: handle the possibility that there might be no audio data in the mini-batch
if 'audio_ratio' in audio_batch:
self.log(
'local_batch_size',
audio_batch['audio_ratio'].shape[0],
prog_bar=True,
batch_size=1,
rank_zero_only=False,
)

encoder_input, attention_mask, labels, loss_mask, _ = self.prepare_llm_input(audio_batch)
if self.mcore_gpt:
output = self.model(
input_ids=None,
position_ids=None,
decoder_input=encoder_input,
attention_mask=attention_mask,
labels=labels,
)
else:
output = self.model(
input_ids=None,
position_ids=None,
encoder_input=encoder_input,
attention_mask=attention_mask,
labels=labels,
checkpoint_activations_all_layers=checkpoint_activations_all_layers,
)
encoder_input, attention_mask, labels, loss_mask, _ = self.prepare_llm_input(audio_batch)
if self.mcore_gpt:
output = self.model(
input_ids=None,
position_ids=None,
decoder_input=encoder_input,
attention_mask=attention_mask,
labels=labels,
)
else:
output = self.model(
input_ids=None,
position_ids=None,
encoder_input=encoder_input,
attention_mask=attention_mask,
labels=labels,
checkpoint_activations_all_layers=checkpoint_activations_all_layers,
)

if text_batch:
pass # TODO: implement text-only mini-batch forward

return output, loss_mask

Expand Down Expand Up @@ -1034,6 +1044,7 @@ def inference_step(self, dataloader_iter, mode):
"""
Used for validation and test steps, added postprocessing after calling self.predict_step().
"""
# TODO: support text-only part of mini-batch
batch, batch_idx, dataloader_idx = next(dataloader_iter)
data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds
self._reconfigure_and_process_inference_batch(batch, data_cfg)
Expand Down Expand Up @@ -1121,6 +1132,8 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int
"""
Used to get LLM predictions for validation and test steps based on the given inference config.
"""
# TODO: support text-only part of mini-batch

inference_config = self.get_inference_config()
if inference_config is not None:
# need to overwrite some configuration, make it immutable
Expand Down

0 comments on commit b87acb6

Please sign in to comment.