From 0cfefc496e0c6bb888a34e058ec87894792d827d Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Fri, 8 Dec 2023 18:19:18 +0100 Subject: [PATCH] fix --- src/autotrain/trainers/clm/__main__.py | 18 ++--------- src/autotrain/trainers/common.py | 31 +++++++++++++++++++ src/autotrain/trainers/dreambooth/__main__.py | 20 ++---------- src/autotrain/trainers/generic/__main__.py | 21 ++----------- .../trainers/image_classification/__main__.py | 4 +++ src/autotrain/trainers/seq2seq/__main__.py | 18 ++--------- src/autotrain/trainers/tabular/__main__.py | 18 ++--------- .../trainers/text_classification/__main__.py | 18 ++--------- src/autotrain/utils.py | 2 +- 9 files changed, 49 insertions(+), 101 deletions(-) diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py index 045771e031..d553b28114 100644 --- a/src/autotrain/trainers/clm/__main__.py +++ b/src/autotrain/trainers/clm/__main__.py @@ -27,6 +27,7 @@ from autotrain.trainers.clm import utils from autotrain.trainers.clm.callbacks import LoadBestPeftModelCallback, SavePeftModelCallback from autotrain.trainers.clm.params import LLMTrainingParams +from autotrain.trainers.common import pause_space from autotrain.utils import monitor @@ -494,22 +495,7 @@ def train(config): ) if PartialState().process_index == 0: - if "SPACE_ID" in os.environ: - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})" - api.create_discussion( - repo_id=os.environ["SPACE_ID"], - title="Your training has finished successfully ✅", - description=success_message, - repo_type="space", - ) - if "ENDPOINT_ID" in os.environ: - # shut down the endpoint - logger.info("Pausing endpoint...") - utils.pause_endpoint(config) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/trainers/common.py b/src/autotrain/trainers/common.py index 251608bc3d..00ef0090c6 100644 --- a/src/autotrain/trainers/common.py +++ b/src/autotrain/trainers/common.py @@ -3,11 +3,42 @@ """ import os +import requests +from huggingface_hub import HfApi from pydantic import BaseModel from autotrain import logger +def pause_endpoint(params): + endpoint_id = os.environ["ENDPOINT_ID"] + username = endpoint_id.split("/")[0] + project_name = endpoint_id.split("/")[1] + api_url = f"https://api.endpoints.huggingface.cloud/v2/endpoint/{username}/{project_name}/pause" + headers = {"Authorization": f"Bearer {params.token}"} + r = requests.post(api_url, headers=headers, timeout=120) + return r.json() + + +def pause_space(params): + if "SPACE_ID" in os.environ: + # shut down the space + logger.info("Pausing space...") + api = HfApi(token=params.token) + success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{params.repo_id})" + api.create_discussion( + repo_id=os.environ["SPACE_ID"], + title="Your training has finished successfully ✅", + description=success_message, + repo_type="space", + ) + api.pause_space(repo_id=os.environ["SPACE_ID"]) + if "ENDPOINT_ID" in os.environ: + # shut down the endpoint + logger.info("Pausing endpoint...") + pause_endpoint(params) + + class AutoTrainParams(BaseModel): """ Base class for all AutoTrain parameters. diff --git a/src/autotrain/trainers/dreambooth/__main__.py b/src/autotrain/trainers/dreambooth/__main__.py index 1e1a5c6fc0..6f700c619e 100644 --- a/src/autotrain/trainers/dreambooth/__main__.py +++ b/src/autotrain/trainers/dreambooth/__main__.py @@ -15,9 +15,9 @@ SlicedAttnAddedKVProcessor, ) from diffusers.models.lora import LoRALinearLayer -from huggingface_hub import HfApi, snapshot_download +from huggingface_hub import snapshot_download -from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.dreambooth import utils from autotrain.trainers.dreambooth.datasets import DreamBoothDataset, collate_fn from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams @@ -328,21 +328,7 @@ def load_model_hook(models, input_dir): if config.push_to_hub: trainer.push_to_hub() - if "SPACE_ID" in os.environ: - # remove config.image_path directory if it exists - if os.path.exists(config.image_path): - os.system(f"rm -rf {config.image_path}") - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})" - api.create_discussion( - repo_id=os.environ["SPACE_ID"], - title="Your training has finished successfully ✅", - description=success_message, - repo_type="space", - ) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/trainers/generic/__main__.py b/src/autotrain/trainers/generic/__main__.py index dae8845888..29cd726527 100644 --- a/src/autotrain/trainers/generic/__main__.py +++ b/src/autotrain/trainers/generic/__main__.py @@ -1,10 +1,8 @@ import argparse import json -import os - -from huggingface_hub import HfApi from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.generic import utils from autotrain.trainers.generic.params import GenericParams from autotrain.utils import monitor @@ -37,22 +35,7 @@ def run(config): logger.info("Running command...") utils.run_command(config) - if "SPACE_ID" in os.environ: - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})" - api.create_discussion( - repo_id=os.environ["SPACE_ID"], - title="Your training has finished successfully ✅", - description=success_message, - repo_type="space", - ) - if "ENDPOINT_ID" in os.environ: - # shut down the endpoint - logger.info("Pausing endpoint...") - utils.pause_endpoint(config) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/trainers/image_classification/__main__.py b/src/autotrain/trainers/image_classification/__main__.py index 48b66123df..302d8bc8ca 100644 --- a/src/autotrain/trainers/image_classification/__main__.py +++ b/src/autotrain/trainers/image_classification/__main__.py @@ -14,6 +14,7 @@ ) from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.image_classification import utils from autotrain.trainers.image_classification.params import ImageClassificationParams @@ -163,6 +164,9 @@ def train(config): api.create_repo(repo_id=config.repo_id, repo_type="model") api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model") + if PartialState().process_index == 0: + pause_space(config) + if __name__ == "__main__": args = parse_args() diff --git a/src/autotrain/trainers/seq2seq/__main__.py b/src/autotrain/trainers/seq2seq/__main__.py index af3fddc02e..bd28b8cf98 100644 --- a/src/autotrain/trainers/seq2seq/__main__.py +++ b/src/autotrain/trainers/seq2seq/__main__.py @@ -23,6 +23,7 @@ ) from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.seq2seq import utils from autotrain.trainers.seq2seq.dataset import Seq2SeqDataset from autotrain.trainers.seq2seq.params import Seq2SeqParams @@ -238,22 +239,7 @@ def train(config): ) if PartialState().process_index == 0: - if "SPACE_ID" in os.environ: - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})" - api.create_discussion( - repo_id=os.environ["SPACE_ID"], - title="Your training has finished successfully ✅", - description=success_message, - repo_type="space", - ) - if "ENDPOINT_ID" in os.environ: - # shut down the endpoint - logger.info("Pausing endpoint...") - utils.pause_endpoint(config) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/trainers/tabular/__main__.py b/src/autotrain/trainers/tabular/__main__.py index c9f849d48b..85afb051be 100644 --- a/src/autotrain/trainers/tabular/__main__.py +++ b/src/autotrain/trainers/tabular/__main__.py @@ -13,6 +13,7 @@ from sklearn.compose import ColumnTransformer from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.tabular import utils from autotrain.trainers.tabular.params import TabularParams from autotrain.utils import monitor @@ -329,22 +330,7 @@ def train(config): api.create_repo(repo_id=config.repo_id, repo_type="model", private=True) api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model") - if "SPACE_ID" in os.environ: - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})" - api.create_discussion( - repo_id=os.environ["SPACE_ID"], - title="Your training has finished successfully ✅", - description=success_message, - repo_type="space", - ) - if "ENDPOINT_ID" in os.environ: - # shut down the endpoint - logger.info("Pausing endpoint...") - utils.pause_endpoint(config) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/trainers/text_classification/__main__.py b/src/autotrain/trainers/text_classification/__main__.py index 932bfd5c6e..f9adf7ea1f 100644 --- a/src/autotrain/trainers/text_classification/__main__.py +++ b/src/autotrain/trainers/text_classification/__main__.py @@ -16,6 +16,7 @@ ) from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.text_classification import utils from autotrain.trainers.text_classification.dataset import TextClassificationDataset from autotrain.trainers.text_classification.params import TextClassificationParams @@ -199,22 +200,7 @@ def train(config): ) if PartialState().process_index == 0: - if "SPACE_ID" in os.environ: - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})" - api.create_discussion( - repo_id=os.environ["SPACE_ID"], - title="Your training has finished successfully ✅", - description=success_message, - repo_type="space", - ) - if "ENDPOINT_ID" in os.environ: - # shut down the endpoint - logger.info("Pausing endpoint...") - utils.pause_endpoint(config) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/utils.py b/src/autotrain/utils.py index e3187f4e6e..1cc6870d0f 100644 --- a/src/autotrain/utils.py +++ b/src/autotrain/utils.py @@ -286,12 +286,12 @@ def wrapper(*args, **kwargs): # shut down the space logger.info("Pausing space...") api = HfApi(token=os.environ["HF_TOKEN"]) - api.pause_space(repo_id=os.environ["SPACE_ID"]) api.create_discussion( repo_id=os.environ["SPACE_ID"], title="Your training has failed ❌", description=error_message, repo_type="space", ) + api.pause_space(repo_id=os.environ["SPACE_ID"]) return wrapper