Skip to content

Commit

Permalink
propagate config (#7589)
Browse files Browse the repository at this point in the history
Signed-off-by: eharper <[email protected]>
  • Loading branch information
ericharper authored Oct 2, 2023
1 parent 8ce1ac5 commit d64edf9
Showing 1 changed file with 3 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
gpt_cfg.peft = cfg.model.peft
peft_cls = _get_peft_scheme(cfg.model)
gpt_cfg.target = f"{peft_cls.__module__}.{peft_cls.__name__}"
gpt_cfg.tensor_model_parallel_size = cfg.model.get('tensor_model_parallel_size', 1)
gpt_cfg.pipeline_model_parallel_size = cfg.model.get('pipeline_model_parallel_size', 1)
gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get('pipeline_model_parallel_split_rank', 0)

# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
Expand Down

0 comments on commit d64edf9

Please sign in to comment.