Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/NVIDIA/NeMo into aot/qwen-r…
Browse files Browse the repository at this point in the history
…ecipe
  • Loading branch information
suiyoubi committed Oct 25, 2024
2 parents 5077fef + 83eea56 commit 2beabc4
Show file tree
Hide file tree
Showing 42 changed files with 1,152 additions and 88 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4308,6 +4308,21 @@ jobs:
SCRIPT: |
bash tests/collections/llm/bitexact/mixtral/run.sh
L2_NeMo_2_PTQ_Llama2_FP8:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_PTQ_Llama2_FP8') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/test_hf_import.py --hf_model /home/TestData/nlp/megatron_llama/llama-ci-hf --output_path /tmp/nemo2_ckpt
python scripts/llm/ptq.py -nc /tmp/nemo2_ckpt -algo fp8 -out /tmp/nemo2_ptq_engine
AFTER_SCRIPT: |
rm -rf /tmp/nemo2_ckpt
rm -rf /tmp/nemo2_ptq_engine
Nemo_CICD_Test:
needs:
- pre-flight
Expand Down Expand Up @@ -4455,6 +4470,7 @@ jobs:
- L2_Speech_Transcription_Canary_Transcribe_Audio_Dir
- L2_Megatron_GPT_Reranker
- L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact
- L2_NeMo_2_PTQ_Llama2_FP8
if: always()
runs-on: ubuntu-latest
steps:
Expand Down
4 changes: 4 additions & 0 deletions docs/source/asr/ssl/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ NeMo SSL collection API

Model Classes
-------------
.. autoclass:: nemo.collections.asr.models.EncDecDenoiseMaskedTokenPredModel
:show-inheritance:
:members:

.. autoclass:: nemo.collections.asr.models.SpeechEncDecSelfSupervisedModel
:show-inheritance:
:members:
Expand Down
4 changes: 4 additions & 0 deletions docs/source/asr/ssl/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ encoder module of neural ASR models. Here too, majority of SSL effort is focused
While it is common that AM is the focus of SSL in ASR, it can also be utilized in improving other parts of
ASR models (e.g., predictor module in transducer based ASR models).

In NeMo, we provide two types of SSL models, `Wav2Vec-BERT <https://arxiv.org/abs/2108.06209>`_ and `NEST <https://arxiv.org/abs/2408.13106>`_.
The training script for them can be found in `https://github.com/NVIDIA/NeMo/tree/main/examples/asr/speech_pretraining`.


The full documentation tree is as follows:

.. toctree::
Expand Down
4 changes: 2 additions & 2 deletions examples/asr/conf/ssl/nest/nest_fast-conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ model:
mask_position: pre_conv # position to apply masking, before or after conv subsampling, choices in ['pre_conv', 'post_conv']

train_ds:
manifest_filepath: ???
noise_manifest: null
manifest_filepath: ??? # path to training manifest, can be a string or list of strings
noise_manifest: ??? # the manifest for noise data, can be a string or list of strings
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: true
Expand Down
8 changes: 6 additions & 2 deletions examples/asr/run_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def check_missing_values(cfg):
check_missing_values(result)
return result


def check_config_mount_paths(script_config, cluster_config):
# recursively walk all values of the script_config, checking if its a path-like string and if so, check if the path is a mounted path
# if it is not, raise an error
Expand Down Expand Up @@ -154,7 +155,9 @@ def main(cluster_cfg):
if 'exp_manager' in merged_config and 'name' in merged_config['exp_manager']:
exp_name = merged_config['exp_manager']['name']
else:
raise ValueError("Experiment name not provided in the run config file (`exp_name`)) or the cluster config (inside exp_manager.name)")
raise ValueError(
"Experiment name not provided in the run config file (`exp_name`)) or the cluster config (inside exp_manager.name)"
)

with run.Experiment(exp_name) as exp:
cmd = get_execution_script(cluster_script_path, "config.yaml")
Expand All @@ -166,7 +169,8 @@ def main(cluster_cfg):
num_nodes = cluster_cfg.get('num_nodes', merged_config['trainer'].get('num_nodes', 1))
cluster_cfg = OmegaConf.to_object(cluster_cfg)

run_utils.add_task(exp,
run_utils.add_task(
exp,
cmd=cmd,
task_name=job_name,
cluster_config=cluster_cfg,
Expand Down
8 changes: 8 additions & 0 deletions examples/asr/speech_pretraining/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@ This directory contains example scripts to self-supervised speech models.
There are two main types of supported self-supervised learning methods:
- [Wav2vec-BERT](https://arxiv.org/abs/2108.06209): `speech_pre_training.py`
- [NEST](https://arxiv.org/abs/2408.13106): `masked_token_pred_pretrain.py`
- For downstream tasks that use NEST as multi-layer feature extractor, please refer to `./downstream/speech_classification_mfa_train.py`


For their corresponding usage, please refer to the example yaml config:
- Wav2vec-BERT: `examples/asr/conf/ssl/fastconformer/fast-conformer.yaml`
- NEST: `examples/asr/conf/ssl/nest/nest_fast-conformer.yaml`


2 changes: 2 additions & 0 deletions examples/asr/speech_pretraining/masked_token_pred_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
python pretrain_masked_token_pred.py \
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
model.train_ds.manifest_filepath=<path to train manifest> \
model.train_ds.noise_manifest=<path to noise manifest> \
model.validation_ds.manifest_filepath=<path to val/test manifest> \
model.validation_ds.noise_manifest=<path to noise manifest> \
trainer.devices=-1 \
trainer.accelerator="gpu" \
strategy="ddp" \
Expand Down
9 changes: 7 additions & 2 deletions nemo/collections/asr/models/ssl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,12 @@ def training_step(self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int):
return {'loss': loss_value, 'log': tensorboard_logs}

def inference_pass(
self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int, dataloader_idx: int = 0, mode: str = 'val'
self,
batch: ssl_dataset.AudioNoiseBatch,
batch_idx: int,
dataloader_idx: int = 0,
mode: str = 'val',
apply_mask: bool = True,
):
log_probs, encoded_len, masks, tokens = self.forward(
input_signal=batch.audio,
Expand All @@ -1005,7 +1010,7 @@ def inference_pass(
noise_signal_length=batch.noise_len,
noisy_input_signal=batch.noisy_audio,
noisy_input_signal_length=batch.noisy_audio_len,
apply_mask=True,
apply_mask=apply_mask,
)

loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len)
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
MistralConfig7B,
MistralModel,
MistralNeMoConfig12B,
MixtralConfig,
MixtralConfig8x3B,
MixtralConfig8x7B,
MixtralConfig8x22B,
Expand Down Expand Up @@ -104,6 +105,7 @@
gpt_data_step,
gpt_forward_step,
)
from nemo.collections.llm.quantization import Quantizer, get_calib_data_iter
from nemo.collections.llm.t5.model import T5Config, T5Model, t5_data_step, t5_forward_step

__all__ = [
Expand All @@ -120,6 +122,7 @@
"MistralConfig7B",
"MistralNeMoConfig12B",
"MistralModel",
"MixtralConfig",
"MixtralConfig8x3B",
"MixtralConfig8x7B",
"MixtralConfig8x22B",
Expand Down
8 changes: 6 additions & 2 deletions nemo/collections/llm/gpt/data/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,13 @@ def __init__(
self.persistent_workers = persistent_workers
self.create_attention_mask = create_attention_mask or not HAVE_TE

from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
if tokenizer is None:
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

self.tokenizer = get_nmt_tokenizer("megatron", "GPT2BPETokenizer")
else:
self.tokenizer = tokenizer

self.tokenizer = tokenizer or get_nmt_tokenizer("megatron", "GPT2BPETokenizer")
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=micro_batch_size,
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel, MistralNeMoConfig12B
from nemo.collections.llm.gpt.model.mixtral import (
MixtralConfig,
MixtralConfig8x3B,
MixtralConfig8x7B,
MixtralConfig8x22B,
Expand Down Expand Up @@ -105,6 +106,7 @@
"MixtralConfig8x3B",
"MixtralConfig8x7B",
"MixtralConfig8x22B",
"MixtralConfig",
"MixtralModel",
"Starcoder2Config",
"Starcoder2Model",
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def to(self, dtype):

source = ModelState(source)
target = self.init()
trainer = self.nemo_setup(target, ckpt_async_save=False)
trainer = self.nemo_setup(target)
source.to(self.config.params_dtype)
target.to(self.config.params_dtype)
self.convert_state(source, target)
Expand Down
25 changes: 25 additions & 0 deletions nemo/collections/llm/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .quantizer import ExportConfig, QuantizationConfig, Quantizer, create_data_iterator_getter, get_calib_data_iter
from .utils import load_with_modelopt_layer_spec

__all__ = [
"Quantizer",
"QuantizationConfig",
"ExportConfig",
"get_calib_data_iter",
"load_with_modelopt_layer_spec",
"create_data_iterator_getter",
]
Loading

0 comments on commit 2beabc4

Please sign in to comment.