From 135023c5ca47e6f38e57c883ef41cb0895811dec Mon Sep 17 00:00:00 2001 From: eharper Date: Sat, 30 Sep 2023 14:07:57 -0600 Subject: [PATCH] propagate config Signed-off-by: eharper --- .../nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py index 2c9293b2600e..48c1727b4af9 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py @@ -105,6 +105,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): gpt_cfg.peft = cfg.model.peft peft_cls = _get_peft_scheme(cfg.model) gpt_cfg.target = f"{peft_cls.__module__}.{peft_cls.__name__}" + gpt_cfg.tensor_model_parallel_size = cfg.model.get('tensor_model_parallel_size', 1) + gpt_cfg.pipeline_model_parallel_size = cfg.model.get('pipeline_model_parallel_size', 1) + gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get('pipeline_model_parallel_split_rank', 0) # This is needed when modifying a hparam file directly to load `.ckpt` files. # This is not needed to modify the cfg in `.nemo` files.