From 1cb2f48f9c0534ece45041a6cacd668e3b509614 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Mon, 26 Aug 2024 11:25:07 +0200 Subject: [PATCH] clear memory before merging adapters --- configs/llm_finetuning/gpt2_sft.yml | 3 ++- src/autotrain/trainers/clm/utils.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/configs/llm_finetuning/gpt2_sft.yml b/configs/llm_finetuning/gpt2_sft.yml index 44160ce58b..50350948ed 100644 --- a/configs/llm_finetuning/gpt2_sft.yml +++ b/configs/llm_finetuning/gpt2_sft.yml @@ -24,8 +24,9 @@ params: scheduler: linear gradient_accumulation: 4 mixed_precision: fp16 + merge_adapter: true hub: username: ${HF_USERNAME} token: ${HF_TOKEN} - push_to_hub: true \ No newline at end of file + push_to_hub: false \ No newline at end of file diff --git a/src/autotrain/trainers/clm/utils.py b/src/autotrain/trainers/clm/utils.py index d2a09ee6ef..4a49da4da3 100644 --- a/src/autotrain/trainers/clm/utils.py +++ b/src/autotrain/trainers/clm/utils.py @@ -1,4 +1,5 @@ import ast +import gc import os from enum import Enum from itertools import chain @@ -295,6 +296,9 @@ def post_training_steps(config, trainer): f.write(model_card) if config.peft and config.merge_adapter: + del trainer + gc.collect() + torch.cuda.empty_cache() logger.info("Merging adapter weights...") try: merge_adapter(