-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PTQ example for NeMo 2.0 #10642
Merged
Merged
PTQ example for NeMo 2.0 #10642
Changes from 10 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
f02f36c
initial commit
b17fe3c
create Quantizer for NeMo 2.0
6676ef6
refactor
9df6ac0
Call quantize on an unwrapped mcore model
30f5a97
Apply isort and black reformatting
Laplasjan107 afd3129
Merge branch 'NVIDIA:main' into nemo2_ptq_example
Laplasjan107 aa142f2
Merge branch 'NVIDIA:main' into nemo2_ptq_example
Laplasjan107 f7235f6
Merge branch 'NVIDIA:main' into nemo2_ptq_example
Laplasjan107 d228ef7
Add tests, adjust unwrapping
0ea0e0c
Apply isort and black reformatting
Laplasjan107 aeed59f
fix export
f3e6b30
Merge branch 'main' into nemo2_ptq_example
Laplasjan107 ab5beb3
Apply isort and black reformatting
Laplasjan107 332a6dc
Apply isort and black reformatting
artbataev c5a2a4d
Fix output_path argument for HF import
Laplasjan107 4c56058
Merge branch 'main' into nemo2_ptq_example
Laplasjan107 5300716
Merge branch 'main' into nemo2_ptq_example
Laplasjan107 20db14b
fix fabric ckpt loading
dc493c2
Apply isort and black reformatting
Laplasjan107 23fc9c2
code review suggestions
40ebf78
Apply isort and black reformatting
Laplasjan107 93d7d66
remove unused import
cf99e13
use cnn dataset in github ci
c3e6296
applied code review
b8a530c
code review changes
371b422
Apply isort and black reformatting
Laplasjan107 d7e54c2
simplify interface for data iterator
6fbab58
Apply isort and black reformatting
Laplasjan107 0daac30
Merge branch 'NVIDIA:main' into nemo2_ptq_example
Laplasjan107 5626ec3
(partial) PP fix
2925392
Apply isort and black reformatting
Laplasjan107 007f7bb
Merge branch 'NVIDIA:main' into nemo2_ptq_example
Laplasjan107 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import sys | ||
|
||
import torch | ||
from tqdm import tqdm | ||
|
||
from nemo.collections.llm.quantization import Quantizer, get_calib_data_iter | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
description="NeMo PTQ argument parser", | ||
) | ||
parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source NeMo 2.0 checkpoint") | ||
parser.add_argument("--decoder_type", type=str, help="Decoder type for TensorRT-Model-Optimizer") | ||
parser.add_argument("-ctp", "--calib_tp", type=int, default=1) | ||
parser.add_argument("-cpp", "--calib_pp", type=int, default=1) | ||
parser.add_argument("-tps", "--tensor_parallelism_size", type=int, default=1) | ||
parser.add_argument("-pps", "--pipeline_parallelism_size", type=int, default=1) | ||
parser.add_argument('-out', '--output_path', type=str, help='Path for the exported engine') | ||
parser.add_argument( | ||
'-algo', | ||
'--algorithm', | ||
type=str, | ||
default="no_quant", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's maybe use "fp8" by default? |
||
choices=["no_quant", "int8", "int8_sq", "fp8", "int4_awq", "w4a8_awq", "int4"], | ||
help='TensorRT-Model-Optimizer quantization algorithm', | ||
) | ||
parser.add_argument( | ||
'-awq_bs', '--awq_block_size', type=int, default=128, help='Block size for AWQ quantization algorithms' | ||
) | ||
parser.add_argument('--sq_alpha', type=float, default=0.5, help='Smooth-Quant alpha parameter') | ||
parser.add_argument('--enable_kv_cache', type=bool, help='Enables KV-cache quantization') | ||
Laplasjan107 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
parser.add_argument( | ||
'-dt', '--dtype', default="bf16", choices=["16", "bf16"], help='Default precision for non-quantized layers' | ||
) | ||
parser.add_argument('-bs', '--batch_size', default=64, type=int, help='Calibration batch size') | ||
parser.add_argument('-sl', '--seq_len', default=128, type=int, help='Length of the tokenized text') | ||
parser.add_argument( | ||
'-calib_size', '--calibration_dataset_size', default=512, type=int, help='Size of calibration dataset' | ||
) | ||
parser.add_argument( | ||
'-calib_ds', | ||
'--calibration_dataset', | ||
default="cnn_dailymail", | ||
choices=["wikitext", "cnn_dailymail"], | ||
Laplasjan107 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
type=str, | ||
help='Calibration dataset to be used', | ||
) | ||
|
||
return parser.parse_args(sys.argv[1:]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to pass |
||
|
||
|
||
def get_quantizer_config(args): | ||
if args.output_path is None: | ||
args.output_path = ( | ||
f"./qnemo_{args.algorithm}_tp{args.tensor_parallelism_size}_pp{args.pipeline_parallelism_size}" | ||
) | ||
|
||
quantization_config = { | ||
"algorithm": None if args.algorithm == "no_quant" else args.algorithm, | ||
"awq_block_size": args.awq_block_size, | ||
"sq_alpha": args.sq_alpha, | ||
"enable_kv_cache": args.enable_kv_cache, | ||
} | ||
export_config = { | ||
"path": args.output_path, | ||
"decoder_type": args.decoder_type, | ||
"inference_tensor_parallel": args.tensor_parallelism_size, | ||
"inference_pipeline_parallel": args.pipeline_parallelism_size, | ||
"dtype": args.dtype, | ||
} | ||
|
||
return quantization_config, export_config | ||
|
||
|
||
def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): | ||
def _iterator(): | ||
CHARACTERS_PER_TOKEN = 4 | ||
|
||
dataloader = get_calib_data_iter( | ||
data=dataset, | ||
max_sequence_length=CHARACTERS_PER_TOKEN * seq_len, | ||
batch_size=batch_size, | ||
calib_size=calibration_size, | ||
) | ||
for batch in dataloader: | ||
batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch] | ||
batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch] | ||
yield torch.tensor(batch, device=model.device) | ||
|
||
def _iterator_getter(): | ||
dataloader = _iterator() | ||
dataloader = [data for data in dataloader] | ||
return iter(tqdm(dataloader)) | ||
|
||
return _iterator_getter | ||
|
||
|
||
def main(): | ||
params = get_args() | ||
quantization_config, export_config = get_quantizer_config(params) | ||
quantizer = Quantizer(quantization_config, export_config) | ||
model = quantizer.load_quantizable_model(params.nemo_checkpoint, params.calib_tp, params.calib_pp) | ||
|
||
get_dataloader = create_data_iterator_getter( | ||
model, | ||
dataset=params.calibration_dataset, | ||
seq_len=params.seq_len, | ||
batch_size=params.batch_size, | ||
calibration_size=params.calibration_dataset_size, | ||
) | ||
|
||
forward_loop = quantizer.create_megatron_forward_loop( | ||
get_dataloader, | ||
num_batches=params.calibration_dataset_size // params.batch_size, | ||
seq_length=params.seq_len, | ||
micro_batch_size=params.batch_size, | ||
) | ||
|
||
model = quantizer.quantize(model, forward_loop) | ||
quantizer.export(model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .quantizer import Quantizer, get_calib_data_iter |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a friendly reminder to make sure it doesn't get lost once this PR is ready: this job name also needs to be added to
CICD_Nemo_Test
like the othersThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added