Skip to content

Commit

Permalink
consolidate peft and sft scripts
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Cui <[email protected]>
  • Loading branch information
cuichenx committed Sep 26, 2023
1 parent 8e238fa commit dc0fe10
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,15 @@ def main(cfg) -> None:
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")
trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer()

model_cfg = MegatronGPTSFTModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg)
if cfg.model.peft.restore_from_path:
model_cfg = MegatronGPTSFTModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg)
else:
model_cfg = MegatronGPTSFTModel.merge_inference_cfg(cfg.model.restore_from_path, cfg)

model = MegatronGPTSFTModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer)

model.load_adapters(cfg.model.peft.restore_from_path)
if cfg.model.peft.restore_from_path:
model.load_adapters(cfg.model.peft.restore_from_path)

model.freeze()
logging.info(f"Freezing parameters for PEFT eval:\n{model.summarize()}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@
"model.data.validation_ds.file_names=[PATH TO VALIDATION JSONL FILE]",
"model.data.validation_ds.names=[NAME FOR METRIC LOGGING]",
model.restore_from_path="PATH TO BASE GPT MODEL .nemo FILE"
model.peft.peft_scheme='lora' # lora, ptuning, adapter, ia3, or none for full fineutning
name="NAME OF TRAINING RUN"
exp_manager.exp_dir="DIR TO SAVE CHECKPOINTS and .nemo FILE",
Please see lora_tutorial.md for a step-by-step guide.
Please see lora.ipynb for a step-by-step guide.
"""


Expand All @@ -67,8 +68,11 @@ def main(cfg) -> None:
# This is not the same as resume training because optimizer states are not restored.
logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path)
model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg))
else:
elif peft_cfg_cls is not None:
logging.info("Adding adapter weights to the model for PEFT")
model.add_adapter(peft_cfg_cls(model_cfg))
else:
logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}")

trainer.fit(model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@ def main(cfg) -> None:
# This is not the same as resume training because optimizer states are not restored.
logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path)
model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls(model_cfg))
else:
elif peft_cfg_cls is not None:
logging.info("Adding adapter weights to the model for PEFT")
model.add_adapter(peft_cfg_cls(model_cfg))
else:
logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}")

trainer.fit(model)

Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,6 @@ def __init__(self, cfg):
"ia3": IA3PEFTConfig,
"ptuning": PtuningPEFTConfig,
"lora": LoraPEFTConfig,
'none': None,
None: None,
}

0 comments on commit dc0fe10

Please sign in to comment.