Skip to content

Commit

Permalink
Merge branch 'main' into tfogal/potential-single-gpu-nsys-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumpa authored Nov 25, 2024
2 parents ecc8a1f + 8f779ba commit 1ba7eed
Show file tree
Hide file tree
Showing 127 changed files with 4,272 additions and 1,042 deletions.
128 changes: 128 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2109,6 +2109,121 @@ jobs:
# }
# }

L2_Megatron_LM_To_NeMo_Conversion:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Megatron_LM_To_NeMo_Conversion') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 Megatron-LM/pretrain_gpt.py \
--mock-data \
--distributed-timeout-minutes 60 \
--use-mcore-models \
--no-mmap-bin-files \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--train-samples 80 \
--init-method-std 0.014 \
--position-embedding-type rope \
--rotary-base 1000000 \
--rotary-percent 1.0 \
--squared-relu \
--num-layers 4 \
--hidden-size 384 \
--num-attention-heads 8 \
--group-query-attention \
--num-query-groups 8 \
--ffn-hidden-size 1536 \
--kv-channels 128 \
--normalization RMSNorm \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--exit-duration-in-mins 5750 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--seq-length 8192 \
--max-position-embeddings 8192 \
--micro-batch-size 1 \
--global-batch-size 8 \
--lr 6e-4 \
--min-lr 6e-6 \
--weight-decay 0.1 \
--clip-grad 1.0 \
--lr-decay-style cosine \
--log-interval 1 \
--eval-iters 1 \
--eval-interval 10 \
--tokenizer-type GPT2BPETokenizer \
--tokenizer-model /home/TestData/nlp/gpt2_tokenizer \
--vocab-file /home/TestData/nlp/gpt2_tokenizer/vocab.json \
--merge-file /home/TestData/nlp/gpt2_tokenizer/merges.txt \
--save /tmp/mlm_conversion_ckpt \
--save-interval 10 \
--ckpt-format torch_dist \
--ckpt-fully-parallel-save \
--ckpt-fully-parallel-load \
--async-save \
--ckpt-assume-constant-structure \
--timing-log-option minmax \
--log-params-norm \
--log-num-zeros-in-grad \
--log-throughput \
--bf16 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--use-distributed-optimizer \
--overlap-grad-reduce \
--overlap-param-gather \
--manual-gc \
--num-workers 2
python examples/nlp/language_modeling/megatron_gpt_pretraining.py \
model.data.data_impl=mock \
model.data.data_prefix=[] \
model.skip_train=True \
model.transformer_engine=True \
model.use_flash_attention=False \
model.normalization=rmsnorm \
model.num_layers=4 \
model.hidden_size=384 \
model.ffn_hidden_size=1536 \
model.num_attention_heads=8 \
model.num_query_groups=8 \
model.bias=False \
model.bias_activation_fusion=False \
model.bias_dropout_add_fusion=True \
model.masked_softmax_fusion=True \
model.encoder_seq_length=8192 \
model.max_position_embeddings=8192 \
model.data.seq_length=8192 \
model.activation=squared-relu \
model.transformer_block_type=True \
model.micro_batch_size=1 \
model.global_batch_size=8 \
++model.rotary_base=1000000 \
model.rotary_percentage=1.0 \
model.apply_query_key_layer_scaling=False \
++model.group_query_attention=True \
model.apply_rope_fusion=True \
model.kv_channels=128 \
++model.bert_binary_head=True \
++model.position_embedding_type=rope \
++model.add_position_embedding=True \
trainer.limit_val_batches=1 \
exp_manager.exp_dir=/tmp/nemo_conversion_ckpt
python -m torch.distributed.launch --nproc_per_node=1 examples/nlp/language_modeling/megatron_ckpt_to_nemo.py \
--checkpoint_folder /tmp/mlm_conversion_ckpt \
--checkpoint_name iter_0000010 \
--nemo_file_path /tmp/mlm_to_nemo_test.nemo \
--tensor_model_parallel_size 1 \
--pipeline_model_parallel_size 1 \
--gpus_per_node 1 \
--model_type gpt \
--hparams_file /tmp/nemo_conversion_ckpt/megatron_gpt/version_0/hparams.yaml \
--convert_mlm
L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down Expand Up @@ -4319,7 +4434,18 @@ jobs:
--mbs 1 \
--model mistral \
--dist-opt
L2_NEMO_2_LoRA_MERGE:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NEMO_2_LoRA_MERGE') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/peft/lora_merge.py \
--lora_checkpoint_path=/home/TestData/nemo2_ckpt/llama_lora_ci_checkpoint/ \
--output_path=/tmp/nemo2_lora_merge/${{ github.run_id }}
L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact:
needs: [cicd-test-container-setup]
Expand Down Expand Up @@ -4421,6 +4547,7 @@ jobs:
- L2_RAG_Pipeline_Generating
- L2_Megatron_GPT_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_Skip_Train
- L2_Megatron_LM_To_NeMo_Conversion
- L2_Megatron_GPT_with_Rope_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2
- L2_Megatron_GPT_with_Drop_Optimizer_States_TP2
Expand Down Expand Up @@ -4482,6 +4609,7 @@ jobs:
- L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1
- L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1
- L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1
- L2_NEMO_2_LoRA_MERGE
- L2_NeMo_2_Mixtral_Pretraining
- L2_PTQ_Llama2_FP8
- L2_Community_LLM_Checkpoints_tests_Llama3
Expand Down
3 changes: 3 additions & 0 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
# we will adjust this flag if the model does not support it
compute_langs = cfg.compute_langs

if cfg.timestamps:
cfg.return_hypotheses = True

# Check whether model and decoder type match
if isinstance(asr_model, EncDecCTCModel):
if cfg.decoder_type and cfg.decoder_type != 'ctc':
Expand Down
10 changes: 9 additions & 1 deletion examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def main(cfg: ParallelTranscriptionConfig):
cfg.predict_ds.return_sample_id = True
cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds)

if cfg.predict_ds.use_lhotse:
OmegaConf.set_struct(cfg.predict_ds, False)
cfg.trainer.use_distributed_sampler = False
cfg.predict_ds.force_finite = True
cfg.predict_ds.force_map_dataset = True
cfg.predict_ds.do_transcribe = True
OmegaConf.set_struct(cfg.predict_ds, True)

if isinstance(model, EncDecMultiTaskModel):
cfg.trainer.use_distributed_sampler = False
OmegaConf.set_struct(cfg.predict_ds, False)
Expand All @@ -172,7 +180,7 @@ def main(cfg: ParallelTranscriptionConfig):

trainer = ptl.Trainer(**cfg.trainer)

if isinstance(model, EncDecMultiTaskModel):
if cfg.predict_ds.use_lhotse:
OmegaConf.set_struct(cfg.predict_ds, False)
cfg.predict_ds.global_rank = trainer.global_rank
cfg.predict_ds.world_size = trainer.world_size
Expand Down
23 changes: 12 additions & 11 deletions examples/llm/sft/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated, te_accelerate
from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated
from nemo.lightning.pytorch.callbacks import ModelCallback


Expand Down Expand Up @@ -75,16 +75,17 @@ def squad(tokenizer) -> pl.LightningDataModule:
grad_clip = None
use_dist_samp = False

model = llm.HfAutoModelForCausalLM(args.model)
tokenizer = model.tokenizer
model_accelerator = None
if args.model_accelerator == "te":
from functools import partial
from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate

callbacks = []
if args.model_accelerator:
if args.model_accelerator == "te":
model_transform = ModelCallback(
on_train_start=lambda model: te_accelerate(model, fp8_autocast=args.fp8_autocast)
)
callbacks.append(model_transform)
model_accelerator = partial(te_accelerate, fp8_autocast=args.fp8_autocast)

from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate

model = llm.HfAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator)
tokenizer = model.tokenizer

llm.api.finetune(
model=model,
Expand All @@ -100,7 +101,7 @@ def squad(tokenizer) -> pl.LightningDataModule:
accumulate_grad_batches=10,
gradient_clip_val=grad_clip,
use_distributed_sampler=use_dist_samp,
callbacks=callbacks,
callbacks=[],
logger=wandb,
),
optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)),
Expand Down
9 changes: 8 additions & 1 deletion examples/nlp/language_modeling/megatron_ckpt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def get_args():
choices=['32-true', '16-mixed', 'bf16-mixed'],
help="Precision value for the trainer that matches with precision of the ckpt",
)
parser.add_argument(
"--convert_mlm",
action="store_true",
help="Use this flag to convert megatron-lm checkpoints.",
)

args = parser.parse_args()
return args
Expand Down Expand Up @@ -195,7 +200,9 @@ def convert(local_rank, rank, world_size, args):
)

if args.model_type == 'gpt':
model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer)
model = MegatronGPTModel.load_from_checkpoint(
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer, load_mlm=args.convert_mlm
)
elif args.model_type == 'sft':
model = MegatronGPTSFTModel.load_from_checkpoint(
checkpoint_path, hparams_file=args.hparams_file, trainer=trainer
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,10 +867,16 @@ def write_on_batch_end(
sample = sample_id
if isinstance(sample, lhotse.cut.MixedCut):
sample = sample.first_non_padding_cut
if sample.recording.sources[0].source != '':
item["audio_filepath"] = sample.recording.sources[0].source
else:
item["audio_filepath"] = sample.id
item["audio_filepath"] = sample.recording.sources[0].source
item["offset"] = sample.start
item["duration"] = sample.duration
item["text"] = sample.supervisions[0].text
item["text"] = sample.supervisions[0].text or ''
if hasattr(sample, 'shard_id'):
item["shard_id"] = sample.shard_id
item["pred_text"] = transcribed_text
self.outf.write(json.dumps(item) + "\n")
self.samples_num += 1
Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/asr/data/audio_to_text_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
}

def __init__(self, tokenizer):
def __init__(self, tokenizer, return_cuts=False):
super().__init__()
self.tokenizer = TokenizerWrapper(tokenizer)
self.load_audio = AudioSamples(fault_tolerant=True)
self.return_cuts = return_cuts

def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
audio, audio_lens, cuts = self.load_audio(cuts)
tokens = [
torch.cat(
[
torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language))
torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text or "", s.language))
for s in c.supervisions
],
dim=0,
Expand All @@ -62,6 +63,8 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
tokens = collate_vectors(tokens, padding_value=0)
if self.return_cuts:
return audio, audio_lens, tokens, token_lens, cuts.drop_in_memory_data()
return audio, audio_lens, tokens, token_lens


Expand Down
11 changes: 11 additions & 0 deletions nemo/collections/asr/models/configs/asr_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig):
shard_manifests: bool = False
shuffle_n: int = 0

# lhotse support
use_lhotse: bool = False
tarred_random_access: bool = False
use_bucketing: bool = False
batch_duration: Optional[int] = None
quadratic_duration: Optional[int] = None
bucket_batch_size: Optional[int] = None
bucket_duration_bins: Optional[list] = None
num_buckets: Optional[int] = 0
pin_memory: bool = False

# Optional
int_values: Optional[int] = None
augmentor: Optional[Dict[str, Any]] = None
Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,15 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer),
# During transcription, the model is initially loaded on the CPU.
# To ensure the correct global_rank and world_size are set,
# these values must be passed from the configuration.
global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"),
world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"),
dataset=LhotseSpeechToTextBpeDataset(
tokenizer=self.tokenizer,
return_cuts=config.get("do_transcribe", False),
),
tokenizer=self.tokenizer,
)

Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def transcribe(
A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as
paths2audio_files
"""
timestamps = timestamps or (override_config.timestamps if override_config is not None else None)
if timestamps is not None:
# else retain the decoder state (users can set it using change_decoding_strategy)
if timestamps or (override_config is not None and override_config.timestamps):
Expand Down Expand Up @@ -308,8 +309,11 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
# During transcription, the model is initially loaded on the CPU.
# To ensure the correct global_rank and world_size are set,
# these values must be passed from the configuration.
global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"),
world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"),
dataset=LhotseSpeechToTextBpeDataset(
tokenizer=make_parser(
labels=config.get('labels', None),
Expand All @@ -318,6 +322,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
blank_id=config.get('blank_index', -1),
do_normalize=config.get('normalize_transcripts', False),
),
return_cuts=config.get("do_transcribe", False),
),
)

Expand Down Expand Up @@ -613,7 +618,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
return_hypotheses=False,
)

sample_id = sample_id.cpu().detach().numpy()
if isinstance(sample_id, torch.Tensor):
sample_id = sample_id.cpu().detach().numpy()
return list(zip(sample_id, transcribed_texts))

def validation_pass(self, batch, batch_idx, dataloader_idx=0):
Expand Down
Loading

0 comments on commit 1ba7eed

Please sign in to comment.