From ad9b2b47f5442173bb6d82619fcb019142f56c17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Fri, 8 Dec 2023 17:24:29 +0000 Subject: [PATCH] Notify users after training (either success or failure) (#379) Co-authored-by: multimodalart --- src/autotrain/trainers/clm/__main__.py | 12 ++----- src/autotrain/trainers/common.py | 31 +++++++++++++++++++ src/autotrain/trainers/dreambooth/__main__.py | 13 ++------ src/autotrain/trainers/generic/__main__.py | 15 ++------- .../trainers/image_classification/__main__.py | 4 +++ src/autotrain/trainers/seq2seq/__main__.py | 12 ++----- src/autotrain/trainers/tabular/__main__.py | 12 ++----- .../trainers/text_classification/__main__.py | 18 +++++------ src/autotrain/utils.py | 10 ++++-- 9 files changed, 61 insertions(+), 66 deletions(-) diff --git a/src/autotrain/trainers/clm/__main__.py b/src/autotrain/trainers/clm/__main__.py index 850c5ce6fa..d553b28114 100644 --- a/src/autotrain/trainers/clm/__main__.py +++ b/src/autotrain/trainers/clm/__main__.py @@ -27,6 +27,7 @@ from autotrain.trainers.clm import utils from autotrain.trainers.clm.callbacks import LoadBestPeftModelCallback, SavePeftModelCallback from autotrain.trainers.clm.params import LLMTrainingParams +from autotrain.trainers.common import pause_space from autotrain.utils import monitor @@ -494,16 +495,7 @@ def train(config): ) if PartialState().process_index == 0: - if "SPACE_ID" in os.environ: - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - - if "ENDPOINT_ID" in os.environ: - # shut down the endpoint - logger.info("Pausing endpoint...") - utils.pause_endpoint(config) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/trainers/common.py b/src/autotrain/trainers/common.py index 251608bc3d..00ef0090c6 100644 --- a/src/autotrain/trainers/common.py +++ b/src/autotrain/trainers/common.py @@ -3,11 +3,42 @@ """ import os +import requests +from huggingface_hub import HfApi from pydantic import BaseModel from autotrain import logger +def pause_endpoint(params): + endpoint_id = os.environ["ENDPOINT_ID"] + username = endpoint_id.split("/")[0] + project_name = endpoint_id.split("/")[1] + api_url = f"https://api.endpoints.huggingface.cloud/v2/endpoint/{username}/{project_name}/pause" + headers = {"Authorization": f"Bearer {params.token}"} + r = requests.post(api_url, headers=headers, timeout=120) + return r.json() + + +def pause_space(params): + if "SPACE_ID" in os.environ: + # shut down the space + logger.info("Pausing space...") + api = HfApi(token=params.token) + success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{params.repo_id})" + api.create_discussion( + repo_id=os.environ["SPACE_ID"], + title="Your training has finished successfully ✅", + description=success_message, + repo_type="space", + ) + api.pause_space(repo_id=os.environ["SPACE_ID"]) + if "ENDPOINT_ID" in os.environ: + # shut down the endpoint + logger.info("Pausing endpoint...") + pause_endpoint(params) + + class AutoTrainParams(BaseModel): """ Base class for all AutoTrain parameters. diff --git a/src/autotrain/trainers/dreambooth/__main__.py b/src/autotrain/trainers/dreambooth/__main__.py index e81baa91df..6f700c619e 100644 --- a/src/autotrain/trainers/dreambooth/__main__.py +++ b/src/autotrain/trainers/dreambooth/__main__.py @@ -15,9 +15,9 @@ SlicedAttnAddedKVProcessor, ) from diffusers.models.lora import LoRALinearLayer -from huggingface_hub import HfApi, snapshot_download +from huggingface_hub import snapshot_download -from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.dreambooth import utils from autotrain.trainers.dreambooth.datasets import DreamBoothDataset, collate_fn from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams @@ -328,14 +328,7 @@ def load_model_hook(models, input_dir): if config.push_to_hub: trainer.push_to_hub() - if "SPACE_ID" in os.environ: - # remove config.image_path directory if it exists - if os.path.exists(config.image_path): - os.system(f"rm -rf {config.image_path}") - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/trainers/generic/__main__.py b/src/autotrain/trainers/generic/__main__.py index 58f530195f..29cd726527 100644 --- a/src/autotrain/trainers/generic/__main__.py +++ b/src/autotrain/trainers/generic/__main__.py @@ -1,10 +1,8 @@ import argparse import json -import os - -from huggingface_hub import HfApi from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.generic import utils from autotrain.trainers.generic.params import GenericParams from autotrain.utils import monitor @@ -37,16 +35,7 @@ def run(config): logger.info("Running command...") utils.run_command(config) - if "SPACE_ID" in os.environ: - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - - if "ENDPOINT_ID" in os.environ: - # shut down the endpoint - logger.info("Pausing endpoint...") - utils.pause_endpoint(config) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/trainers/image_classification/__main__.py b/src/autotrain/trainers/image_classification/__main__.py index 48b66123df..302d8bc8ca 100644 --- a/src/autotrain/trainers/image_classification/__main__.py +++ b/src/autotrain/trainers/image_classification/__main__.py @@ -14,6 +14,7 @@ ) from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.image_classification import utils from autotrain.trainers.image_classification.params import ImageClassificationParams @@ -163,6 +164,9 @@ def train(config): api.create_repo(repo_id=config.repo_id, repo_type="model") api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model") + if PartialState().process_index == 0: + pause_space(config) + if __name__ == "__main__": args = parse_args() diff --git a/src/autotrain/trainers/seq2seq/__main__.py b/src/autotrain/trainers/seq2seq/__main__.py index 4cc02564af..bd28b8cf98 100644 --- a/src/autotrain/trainers/seq2seq/__main__.py +++ b/src/autotrain/trainers/seq2seq/__main__.py @@ -23,6 +23,7 @@ ) from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.seq2seq import utils from autotrain.trainers.seq2seq.dataset import Seq2SeqDataset from autotrain.trainers.seq2seq.params import Seq2SeqParams @@ -238,16 +239,7 @@ def train(config): ) if PartialState().process_index == 0: - if "SPACE_ID" in os.environ: - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - - if "ENDPOINT_ID" in os.environ: - # shut down the endpoint - logger.info("Pausing endpoint...") - utils.pause_endpoint(config) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/trainers/tabular/__main__.py b/src/autotrain/trainers/tabular/__main__.py index 08b1c7bcf3..85afb051be 100644 --- a/src/autotrain/trainers/tabular/__main__.py +++ b/src/autotrain/trainers/tabular/__main__.py @@ -13,6 +13,7 @@ from sklearn.compose import ColumnTransformer from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.tabular import utils from autotrain.trainers.tabular.params import TabularParams from autotrain.utils import monitor @@ -329,16 +330,7 @@ def train(config): api.create_repo(repo_id=config.repo_id, repo_type="model", private=True) api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model") - if "SPACE_ID" in os.environ: - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - - if "ENDPOINT_ID" in os.environ: - # shut down the endpoint - logger.info("Pausing endpoint...") - utils.pause_endpoint(config) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/trainers/text_classification/__main__.py b/src/autotrain/trainers/text_classification/__main__.py index 886ac57046..f9adf7ea1f 100644 --- a/src/autotrain/trainers/text_classification/__main__.py +++ b/src/autotrain/trainers/text_classification/__main__.py @@ -16,6 +16,7 @@ ) from autotrain import logger +from autotrain.trainers.common import pause_space from autotrain.trainers.text_classification import utils from autotrain.trainers.text_classification.dataset import TextClassificationDataset from autotrain.trainers.text_classification.params import TextClassificationParams @@ -192,19 +193,14 @@ def train(config): logger.info("Pushing model to hub...") api = HfApi(token=config.token) api.create_repo(repo_id=config.repo_id, repo_type="model", private=True) - api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model") + api.upload_folder( + folder_path=config.project_name, + repo_id=config.repo_id, + repo_type="model", + ) if PartialState().process_index == 0: - if "SPACE_ID" in os.environ: - # shut down the space - logger.info("Pausing space...") - api = HfApi(token=config.token) - api.pause_space(repo_id=os.environ["SPACE_ID"]) - - if "ENDPOINT_ID" in os.environ: - # shut down the endpoint - logger.info("Pausing endpoint...") - utils.pause_endpoint(config) + pause_space(config) if __name__ == "__main__": diff --git a/src/autotrain/utils.py b/src/autotrain/utils.py index a3c267acff..1cc6870d0f 100644 --- a/src/autotrain/utils.py +++ b/src/autotrain/utils.py @@ -280,12 +280,18 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) except Exception: if PartialState().process_index == 0: - logger.error(f"{func.__name__} has failed due to an exception:") - logger.error(traceback.format_exc()) + error_message = f"""{func.__name__} has failed due to an exception: {traceback.format_exc()}""" + logger.error(error_message) if "SPACE_ID" in os.environ: # shut down the space logger.info("Pausing space...") api = HfApi(token=os.environ["HF_TOKEN"]) + api.create_discussion( + repo_id=os.environ["SPACE_ID"], + title="Your training has failed ❌", + description=error_message, + repo_type="space", + ) api.pause_space(repo_id=os.environ["SPACE_ID"]) return wrapper