diff --git a/litgpt/finetune/lora.py b/litgpt/finetune/lora.py index d142385011..aa9cad85e3 100644 --- a/litgpt/finetune/lora.py +++ b/litgpt/finetune/lora.py @@ -20,6 +20,7 @@ from litgpt.generate.base import generate from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable from litgpt.prompts import save_prompt_style +from litgpt.scripts.merge_lora import merge_lora from litgpt.tokenizer import Tokenizer from litgpt.utils import ( CLI, @@ -169,6 +170,8 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: LitDat copy_config_files(checkpoint_dir, save_path.parent) save_hyperparameters(setup, save_path.parent) save_prompt_style(data.prompt_style, save_path.parent) + fabric.barrier() + merge_lora(checkpoint_dir=save_path.parent) def fit(