-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* rename all scripts Signed-off-by: Chen Cui <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CodeQL error Signed-off-by: Chen Cui <[email protected]> * rename finetune_generate to just generate Signed-off-by: Chen Cui <[email protected]> --------- Signed-off-by: Chen Cui <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
e7e007b
commit d656f22
Showing
16 changed files
with
379 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
81 changes: 81 additions & 0 deletions
81
examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch.multiprocessing as mp | ||
from omegaconf.omegaconf import OmegaConf | ||
|
||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel | ||
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder | ||
from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP | ||
|
||
from nemo.core.config import hydra_runner | ||
from nemo.utils import logging | ||
from nemo.utils.exp_manager import exp_manager | ||
|
||
mp.set_start_method("spawn", force=True) | ||
|
||
""" | ||
This is the script to finetuning a GPT Model with any PEFT method. | ||
A base GPT Model is required as a starting point. This script will then insert | ||
Adapters into each Transformer layer and will train/update only these adapters | ||
during training. The base GPT Model weights will remain frozen. | ||
During training this script will only save the newly trained Adapter weights | ||
in checkpoints. At the end of training a .nemo file of Adapter weights will | ||
be saved. | ||
Usage: | ||
Assuming the base model is a 125m GPT Model, with TP=1, PP=1: | ||
a. run a training run for a base gpt nemo file: | ||
python megatron_gpt_finetuning.py \ | ||
"model.data.train_ds.file_names=[PATH TO TRAINING JSONL FILE]", | ||
"model.data.train_ds.concat_sampling_probabilities=[SAMPLING VAL]", | ||
"model.data.validation_ds.file_names=[PATH TO VALIDATION JSONL FILE]", | ||
"model.data.validation_ds.names=[NAME FOR METRIC LOGGING]", | ||
model.restore_from_path="PATH TO BASE GPT MODEL .nemo FILE" | ||
model.peft.peft_scheme='lora' # lora, ptuning, adapter, ia3, or none for full fineutning | ||
name="NAME OF TRAINING RUN" | ||
exp_manager.exp_dir="DIR TO SAVE CHECKPOINTS and .nemo FILE", | ||
Please see lora.ipynb for a step-by-step guide. | ||
""" | ||
|
||
|
||
@hydra_runner(config_path="conf", config_name="megatron_gpt_finetuning_config") | ||
def main(cfg) -> None: | ||
logging.info("\n\n************** Experiment configuration ***********") | ||
logging.info(f'\n{OmegaConf.to_yaml(cfg)}') | ||
|
||
trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() | ||
exp_manager(trainer, cfg.exp_manager) | ||
|
||
model_cfg = MegatronGPTSFTModel.merge_cfg_with(cfg.model.restore_from_path, cfg) | ||
model = MegatronGPTSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) | ||
peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] | ||
|
||
if cfg.model.peft.restore_from_path is not None: | ||
# initialize peft weights from a checkpoint instead of randomly | ||
# This is not the same as resume training because optimizer states are not restored. | ||
logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path) | ||
model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg)) | ||
elif peft_cfg_cls is not None: | ||
logging.info("Adding adapter weights to the model for PEFT") | ||
model.add_adapter(peft_cfg_cls(model_cfg)) | ||
else: | ||
logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}") | ||
|
||
trainer.fit(model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
143 changes: 143 additions & 0 deletions
143
examples/nlp/language_modeling/tuning/megatron_gpt_generate.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import asyncio | ||
import threading | ||
from functools import partial | ||
|
||
import torch | ||
import torch.multiprocessing as mp | ||
from omegaconf.omegaconf import OmegaConf | ||
|
||
|
||
from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel | ||
from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer | ||
from nemo.collections.nlp.modules.common.text_generation_utils import generate | ||
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder | ||
from nemo.core.config import hydra_runner | ||
from nemo.utils import logging | ||
|
||
try: | ||
from megatron.core import parallel_state | ||
|
||
HAVE_MEGATRON_CORE = True | ||
except (ImportError, ModuleNotFoundError): | ||
|
||
HAVE_MEGATRON_CORE = False | ||
|
||
mp.set_start_method("spawn", force=True) | ||
""" | ||
This is the script to run inference with a PEFT model or an SFT Model. | ||
If you want to evaluate an SFT .nemo file: | ||
python examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \ | ||
model.restore_from_path=<path_to_sft_nemo_file> \ | ||
model.peft.restore_from_path=null \ | ||
trainer.devices=1 model.data.test_ds.file_names=\[<path_to_test_jsonl_file1>, <path_to_test_jsonl_file2>] \ | ||
model.data.test_ds.names=\['name_for_test_file1', 'name_for_test_file2'] \ # this is not the filename just some identifier | ||
model.data.test_ds.global_batch_size=4 \ # or some other value | ||
model.data.test_ds.micro_batch_size=4 \ | ||
model.data.test_ds.tokens_to_generate=30 \ | ||
inference.greedy=True \ | ||
inference.outfile_path=\'<path_to_jsonl_output_file>' | ||
If you want to evaluate a PEFT Model, you should provide a base GPT model and a PEFT model .nemo file | ||
python examples/nlp/language_modeling/tuning/megatron_gpt_generate.py \ | ||
model.restore_from_path=<path_to_sft_nemo_file> \ | ||
model.peft.restore_from_path=<path_to_peft_nemo_file> \ # this will be created if you use `megatron_gpt_finetuning.py` | ||
trainer.devices=1 model.data.test_ds.file_names=\[<path_to_test_jsonl_file1>, <path_to_test_jsonl_file2>] \ | ||
model.data.test_ds.names=\['name_for_test_file1', 'name_for_test_file2'] \ # this is not the filename just some identifier | ||
model.data.test_ds.global_batch_size=4 \ # or some other value | ||
model.data.test_ds.micro_batch_size=4 \ | ||
model.data.test_ds.tokens_to_generate=30 \ | ||
inference.greedy=True \ | ||
inference.outfile_path=\'<path_to_jsonl_output_file>' | ||
""" | ||
|
||
|
||
def use_inference_server(cfg, model, trainer): | ||
if not HAVE_MEGATRON_CORE: | ||
raise ValueError('Megatron-core needs to be installed to use this feature!') | ||
|
||
from nemo.collections.nlp.modules.common.megatron_web_server import get_chatbot_demo, get_demo | ||
|
||
trainer.test(model, dataloaders=None) | ||
|
||
if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0: | ||
if cfg.web_server: | ||
if cfg.chat: | ||
defaults = { | ||
'user': cfg.chatbot_config.user, | ||
'assistant': cfg.chatbot_config.assistant, | ||
'system': cfg.chatbot_config.system, | ||
} | ||
web_ui = partial( | ||
get_chatbot_demo, | ||
defaults=defaults, | ||
value=cfg.chatbot_config.value, | ||
attributes=cfg.chatbot_config.attributes, | ||
) | ||
else: | ||
web_ui = get_demo | ||
loop = asyncio.new_event_loop() | ||
thread = threading.Thread( | ||
target=web_ui, daemon=True, args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), | ||
) | ||
thread.start() | ||
server = MegatronServer(model.cuda()) | ||
server.run("0.0.0.0", port=cfg.port) | ||
|
||
while True: | ||
choice = torch.cuda.LongTensor(1) | ||
torch.distributed.broadcast(choice, 0) | ||
if choice[0].item() == 0: | ||
generate(model.cuda()) | ||
|
||
|
||
@hydra_runner(config_path="conf", config_name="megatron_gpt_generate_config") | ||
def main(cfg) -> None: | ||
logging.info("\n\n************** Experiment configuration ***********") | ||
logging.info(f"\n{OmegaConf.to_yaml(cfg)}") | ||
trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() | ||
|
||
if cfg.model.peft.restore_from_path: | ||
model_cfg = MegatronGPTSFTModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg) | ||
else: | ||
model_cfg = MegatronGPTSFTModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) | ||
|
||
model = MegatronGPTSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) | ||
|
||
if cfg.model.peft.restore_from_path: | ||
model.load_adapters(cfg.model.peft.restore_from_path) | ||
|
||
model.freeze() | ||
logging.info(f"Freezing parameters for PEFT eval:\n{model.summarize()}") | ||
|
||
if not cfg.model.get('use_flash_attention', False): | ||
cfg.inference.compute_attention_mask = True | ||
config = OmegaConf.to_container(cfg.inference, resolve=True) | ||
model.set_inference_config(config) | ||
|
||
if not cfg.server: | ||
trainer.test(model) | ||
else: | ||
use_inference_server(cfg, model, trainer) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.