Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notify users after training (either success or failure) #379

Merged
merged 5 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from autotrain.trainers.clm import utils
from autotrain.trainers.clm.callbacks import LoadBestPeftModelCallback, SavePeftModelCallback
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.common import pause_space
from autotrain.utils import monitor


Expand Down Expand Up @@ -494,16 +495,7 @@ def train(config):
)

if PartialState().process_index == 0:
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])

if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
utils.pause_endpoint(config)
pause_space(config)


if __name__ == "__main__":
Expand Down
31 changes: 31 additions & 0 deletions src/autotrain/trainers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,42 @@
"""
import os

import requests
from huggingface_hub import HfApi
from pydantic import BaseModel

from autotrain import logger


def pause_endpoint(params):
endpoint_id = os.environ["ENDPOINT_ID"]
username = endpoint_id.split("/")[0]
project_name = endpoint_id.split("/")[1]
api_url = f"https://api.endpoints.huggingface.cloud/v2/endpoint/{username}/{project_name}/pause"
headers = {"Authorization": f"Bearer {params.token}"}
r = requests.post(api_url, headers=headers, timeout=120)
return r.json()


def pause_space(params):
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=params.token)
success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{params.repo_id})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
api.pause_space(repo_id=os.environ["SPACE_ID"])
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
pause_endpoint(params)


class AutoTrainParams(BaseModel):
"""
Base class for all AutoTrain parameters.
Expand Down
13 changes: 3 additions & 10 deletions src/autotrain/trainers/dreambooth/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub import snapshot_download

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.dreambooth import utils
from autotrain.trainers.dreambooth.datasets import DreamBoothDataset, collate_fn
from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
Expand Down Expand Up @@ -328,14 +328,7 @@ def load_model_hook(models, input_dir):
if config.push_to_hub:
trainer.push_to_hub()

if "SPACE_ID" in os.environ:
# remove config.image_path directory if it exists
if os.path.exists(config.image_path):
os.system(f"rm -rf {config.image_path}")
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])
pause_space(config)


if __name__ == "__main__":
Expand Down
15 changes: 2 additions & 13 deletions src/autotrain/trainers/generic/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import argparse
import json
import os

from huggingface_hub import HfApi

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.generic import utils
from autotrain.trainers.generic.params import GenericParams
from autotrain.utils import monitor
Expand Down Expand Up @@ -37,16 +35,7 @@ def run(config):
logger.info("Running command...")
utils.run_command(config)

if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])

if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
utils.pause_endpoint(config)
pause_space(config)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions src/autotrain/trainers/image_classification/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.image_classification import utils
from autotrain.trainers.image_classification.params import ImageClassificationParams

Expand Down Expand Up @@ -163,6 +164,9 @@ def train(config):
api.create_repo(repo_id=config.repo_id, repo_type="model")
api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model")

if PartialState().process_index == 0:
pause_space(config)


if __name__ == "__main__":
args = parse_args()
Expand Down
12 changes: 2 additions & 10 deletions src/autotrain/trainers/seq2seq/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.seq2seq import utils
from autotrain.trainers.seq2seq.dataset import Seq2SeqDataset
from autotrain.trainers.seq2seq.params import Seq2SeqParams
Expand Down Expand Up @@ -238,16 +239,7 @@ def train(config):
)

if PartialState().process_index == 0:
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])

if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
utils.pause_endpoint(config)
pause_space(config)


if __name__ == "__main__":
Expand Down
12 changes: 2 additions & 10 deletions src/autotrain/trainers/tabular/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.compose import ColumnTransformer

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.tabular import utils
from autotrain.trainers.tabular.params import TabularParams
from autotrain.utils import monitor
Expand Down Expand Up @@ -329,16 +330,7 @@ def train(config):
api.create_repo(repo_id=config.repo_id, repo_type="model", private=True)
api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model")

if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])

if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
utils.pause_endpoint(config)
pause_space(config)


if __name__ == "__main__":
Expand Down
18 changes: 7 additions & 11 deletions src/autotrain/trainers/text_classification/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.text_classification import utils
from autotrain.trainers.text_classification.dataset import TextClassificationDataset
from autotrain.trainers.text_classification.params import TextClassificationParams
Expand Down Expand Up @@ -192,19 +193,14 @@ def train(config):
logger.info("Pushing model to hub...")
api = HfApi(token=config.token)
api.create_repo(repo_id=config.repo_id, repo_type="model", private=True)
api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model")
api.upload_folder(
folder_path=config.project_name,
repo_id=config.repo_id,
repo_type="model",
)

if PartialState().process_index == 0:
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])

if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
utils.pause_endpoint(config)
pause_space(config)


if __name__ == "__main__":
Expand Down
10 changes: 8 additions & 2 deletions src/autotrain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,18 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)
except Exception:
if PartialState().process_index == 0:
logger.error(f"{func.__name__} has failed due to an exception:")
logger.error(traceback.format_exc())
error_message = f"""{func.__name__} has failed due to an exception: {traceback.format_exc()}"""
logger.error(error_message)
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=os.environ["HF_TOKEN"])
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has failed ❌",
description=error_message,
repo_type="space",
)
api.pause_space(repo_id=os.environ["SPACE_ID"])

return wrapper
Loading