Skip to content

Commit

Permalink
Merge branch 'main' into vchen/neva-blend-data
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanzic authored Aug 2, 2024
2 parents 9415477 + fdf07a9 commit 002d8f9
Show file tree
Hide file tree
Showing 59 changed files with 1,354 additions and 450 deletions.
16 changes: 15 additions & 1 deletion .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,19 @@ jobs:
## - name: L2: Multimodal Imagen Train

# L2: Community LLM Checkpoints tests
L2_Community_LLM_Checkpoints_tests_Bert:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
with:
RUNNER: self-hosted-azure
SCRIPT: |
python scripts/checkpoint_converters/convert_bert_hf_to_nemo.py \
--input_name_or_path /home/TestData/nlp/megatron_ir/sbert/hf_model/bert-base-uncased \
--output_path /home/TestData/nlp/megatron_ir/sbert/sbert.nemo
AFTER_SCRIPT: |
rm -f /home/TestData/nlp/megatron_ir/sbert/sbert.nemo
rm -rf /home/TestData/nlp/megatron_ir/sbert/model_weights
L2_Community_LLM_Checkpoints_tests_Llama:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -200,7 +213,7 @@ jobs:
AFTER_SCRIPT: |
rm -f /home/TestData/multimodal/video_neva/llama3-ci-hf/llama3_ci.nemo
rm -rf /home/TestData/multimodal/video_neva/llama3-ci-hf/model_weights
# this test is using a 7B model which is too large for GitHub CI
# replace the model in this test with a toy model or move the test
# to the nightly CI
Expand Down Expand Up @@ -4457,6 +4470,7 @@ jobs:
- cicd-test-container-setup
- L0_Unit_Tests_GPU
- L0_Unit_Tests_CPU
- L2_Community_LLM_Checkpoints_tests_Bert
- L2_Community_LLM_Checkpoints_tests_Llama
- L2_Community_LLM_Checkpoints_tests_StarCoder
- L2_Community_LLM_Checkpoints_tests_Falcon
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ model:
resume_from_checkpoint: null # manually set the checkpoint file to load from
apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
ddp_overlap: False # True for using PyTorch DDP overlap.

optim:
name: fused_adam
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def model_cfg_modifier(model_cfg):
model_cfg.precision = cfg.trainer.precision
model_cfg.ckpt_path = None
model_cfg.inductor = False
model_cfg.unet_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt"
model_cfg.unet_config.from_NeMo = True
model_cfg.first_stage_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt"
model_cfg.first_stage_config.from_NeMo = True
# model_cfg.unet_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt"
# model_cfg.unet_config.from_NeMo = True
# model_cfg.first_stage_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt"
# model_cfg.first_stage_config.from_NeMo = True
model_cfg.first_stage_config._target_ = 'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper'
# model_cfg.fsdp = True

Expand Down
78 changes: 54 additions & 24 deletions nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Sequence

import torch.utils.data
Expand All @@ -25,6 +26,26 @@
from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER


@dataclass
class PromptedAudioToTextMiniBatch:
audio: torch.Tensor
audio_lens: torch.Tensor
transcript: torch.Tensor
transcript_lens: torch.Tensor
prompt: torch.Tensor
prompt_lens: torch.Tensor
prompted_transcript: torch.Tensor
prompted_transcript_lens: torch.Tensor

def get_decoder_inputs_outputs(self) -> tuple[torch.Tensor, torch.Tensor]:
"""
Returns the inputs and outputs of transformer decoder for training.
The input is ``prompted_transcript`` (minus last token),
and the output is ``prompted_transcript`` (minus first token).
"""
return self.prompted_transcript[:, :-1], self.prompted_transcript[:, 1:]


class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
"""
This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`.
Expand All @@ -45,41 +66,46 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset):
def __init__(
self,
tokenizer: TokenizerSpec,
prompt_format_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]],
inference: bool = False,
prompt_format_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]],
):
super().__init__()
self.tokenizer = TokenizerWrapper(tokenizer)
self.load_audio = AudioSamples(fault_tolerant=True)
self.padding_value = self.tokenizer._tokenizer.pad_id
self.prompt_format_fn = prompt_format_fn
self.inference = inference

def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch:
audio, audio_lens, cuts = self.load_audio(cuts)

prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference)

prompts_with_answers = [torch.as_tensor(t) for t in prompts_with_answers]
prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long)
prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value)

if self.inference:
prompts = [torch.as_tensor(t) for t in prompts]
prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long)
prompts = collate_vectors(prompts, padding_value=self.padding_value)
else:
prompts = None
prompts_lens = None
prompts_with_answers, prompts, answers = self.prompt_format_fn(cuts, self.tokenizer)

transcript, transcript_lens = self._collate_tokens(answers)
prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers)
prompts, prompt_lens = self._collate_tokens(prompts)

return PromptedAudioToTextMiniBatch(
audio=audio,
audio_lens=audio_lens,
transcript=transcript,
transcript_lens=transcript_lens,
prompt=prompts,
prompt_lens=prompt_lens,
prompted_transcript=prompts_with_answers,
prompted_transcript_lens=prompts_with_answers_lens,
)

return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens
def _collate_tokens(self, tokens: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
tokens = [torch.as_tensor(t) for t in tokens]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
tokens = collate_vectors(tokens, padding_value=self.padding_value)
return tokens, token_lens


# Mapping from a string name to a known prompt formatter function.
PROMPT_FORMAT_FNS = {}


def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]):
def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]):
"""
Decorator for registering prompt functions under a name.
Expand All @@ -97,7 +123,7 @@ def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, b
return prompt_fn


def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]:
def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]:
if name not in PROMPT_FORMAT_FNS:
raise ValueError(
f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}"
Expand All @@ -107,8 +133,8 @@ def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool]

@registered_prompt_format_fn
def canary(
cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
cuts: CutSet, tokenizer: TokenizerWrapper
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
"""
Prepend and append control tokens to the token sequence as per Canary format.
Expand Down Expand Up @@ -137,7 +163,7 @@ def canary(
), "To use 'canary' prompt format, you must use the CanaryTokenizer."
formatter = CanaryPromptFormatter(tokenizer._tokenizer)

prompts_with_answers, prompts = [], []
prompts_with_answers, prompts, answers = [], [], []
for cut in cuts:
if isinstance(cut, MixedCut):
cut = cut._first_non_padding_cut
Expand Down Expand Up @@ -180,8 +206,12 @@ def canary(
)
prompts_with_answers.append(encoded["input_ids"])
prompts.append(encoded["context_ids"])
assert (
encoded["answer_ids"][-1].item() == formatter.tokenizer.eos
), f"Expected the last token in answer_ids to be EOS, but we got {encoded['answer_ids']=}"
answers.append(encoded["answer_ids"][:-1]) # Strip Canary's EOS

return prompts_with_answers, prompts
return prompts_with_answers, prompts, answers


class ProbablyIncorrectLanguageKeyError(RuntimeError):
Expand Down
Loading

0 comments on commit 002d8f9

Please sign in to comment.