Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Dec 8, 2023
1 parent 95dcb3d commit 0cfefc4
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 101 deletions.
18 changes: 2 additions & 16 deletions src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from autotrain.trainers.clm import utils
from autotrain.trainers.clm.callbacks import LoadBestPeftModelCallback, SavePeftModelCallback
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.common import pause_space
from autotrain.utils import monitor


Expand Down Expand Up @@ -494,22 +495,7 @@ def train(config):
)

if PartialState().process_index == 0:
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])
success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
utils.pause_endpoint(config)
pause_space(config)


if __name__ == "__main__":
Expand Down
31 changes: 31 additions & 0 deletions src/autotrain/trainers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,42 @@
"""
import os

import requests
from huggingface_hub import HfApi
from pydantic import BaseModel

from autotrain import logger


def pause_endpoint(params):
endpoint_id = os.environ["ENDPOINT_ID"]
username = endpoint_id.split("/")[0]
project_name = endpoint_id.split("/")[1]
api_url = f"https://api.endpoints.huggingface.cloud/v2/endpoint/{username}/{project_name}/pause"
headers = {"Authorization": f"Bearer {params.token}"}
r = requests.post(api_url, headers=headers, timeout=120)
return r.json()


def pause_space(params):
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=params.token)
success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{params.repo_id})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
api.pause_space(repo_id=os.environ["SPACE_ID"])
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
pause_endpoint(params)


class AutoTrainParams(BaseModel):
"""
Base class for all AutoTrain parameters.
Expand Down
20 changes: 3 additions & 17 deletions src/autotrain/trainers/dreambooth/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub import snapshot_download

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.dreambooth import utils
from autotrain.trainers.dreambooth.datasets import DreamBoothDataset, collate_fn
from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
Expand Down Expand Up @@ -328,21 +328,7 @@ def load_model_hook(models, input_dir):
if config.push_to_hub:
trainer.push_to_hub()

if "SPACE_ID" in os.environ:
# remove config.image_path directory if it exists
if os.path.exists(config.image_path):
os.system(f"rm -rf {config.image_path}")
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])
success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
pause_space(config)


if __name__ == "__main__":
Expand Down
21 changes: 2 additions & 19 deletions src/autotrain/trainers/generic/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import argparse
import json
import os

from huggingface_hub import HfApi

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.generic import utils
from autotrain.trainers.generic.params import GenericParams
from autotrain.utils import monitor
Expand Down Expand Up @@ -37,22 +35,7 @@ def run(config):
logger.info("Running command...")
utils.run_command(config)

if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])
success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
utils.pause_endpoint(config)
pause_space(config)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions src/autotrain/trainers/image_classification/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.image_classification import utils
from autotrain.trainers.image_classification.params import ImageClassificationParams

Expand Down Expand Up @@ -163,6 +164,9 @@ def train(config):
api.create_repo(repo_id=config.repo_id, repo_type="model")
api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model")

if PartialState().process_index == 0:
pause_space(config)


if __name__ == "__main__":
args = parse_args()
Expand Down
18 changes: 2 additions & 16 deletions src/autotrain/trainers/seq2seq/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.seq2seq import utils
from autotrain.trainers.seq2seq.dataset import Seq2SeqDataset
from autotrain.trainers.seq2seq.params import Seq2SeqParams
Expand Down Expand Up @@ -238,22 +239,7 @@ def train(config):
)

if PartialState().process_index == 0:
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])
success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
utils.pause_endpoint(config)
pause_space(config)


if __name__ == "__main__":
Expand Down
18 changes: 2 additions & 16 deletions src/autotrain/trainers/tabular/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.compose import ColumnTransformer

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.tabular import utils
from autotrain.trainers.tabular.params import TabularParams
from autotrain.utils import monitor
Expand Down Expand Up @@ -329,22 +330,7 @@ def train(config):
api.create_repo(repo_id=config.repo_id, repo_type="model", private=True)
api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model")

if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])
success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
utils.pause_endpoint(config)
pause_space(config)


if __name__ == "__main__":
Expand Down
18 changes: 2 additions & 16 deletions src/autotrain/trainers/text_classification/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)

from autotrain import logger
from autotrain.trainers.common import pause_space
from autotrain.trainers.text_classification import utils
from autotrain.trainers.text_classification.dataset import TextClassificationDataset
from autotrain.trainers.text_classification.params import TextClassificationParams
Expand Down Expand Up @@ -199,22 +200,7 @@ def train(config):
)

if PartialState().process_index == 0:
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])
success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project_name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
utils.pause_endpoint(config)
pause_space(config)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion src/autotrain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,12 @@ def wrapper(*args, **kwargs):
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=os.environ["HF_TOKEN"])
api.pause_space(repo_id=os.environ["SPACE_ID"])
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has failed ❌",
description=error_message,
repo_type="space",
)
api.pause_space(repo_id=os.environ["SPACE_ID"])

return wrapper

0 comments on commit 0cfefc4

Please sign in to comment.