Skip to content

Commit

Permalink
make style
Browse files Browse the repository at this point in the history
  • Loading branch information
multimodalart committed Dec 8, 2023
1 parent f06236b commit 43d3565
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 17 deletions.
8 changes: 7 additions & 1 deletion src/autotrain/trainers/clm/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,13 @@ def train(config):
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])

success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project-name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
Expand Down
7 changes: 7 additions & 0 deletions src/autotrain/trainers/dreambooth/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,13 @@ def load_model_hook(models, input_dir):
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])
success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project-name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)


if __name__ == "__main__":
Expand Down
13 changes: 7 additions & 6 deletions src/autotrain/trainers/generic/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def run(config):
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])
success_message = f"Your training run was successfull! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project-name})"
api.create_discussion(repo_id=os.environ['SPACE_ID'],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space"
)
success_message = f"Your training run was succesful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project-name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
Expand Down
8 changes: 7 additions & 1 deletion src/autotrain/trainers/seq2seq/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,13 @@ def train(config):
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])

success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project-name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
Expand Down
8 changes: 7 additions & 1 deletion src/autotrain/trainers/tabular/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,13 @@ def train(config):
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])

success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project-name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
Expand Down
14 changes: 12 additions & 2 deletions src/autotrain/trainers/text_classification/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,25 @@ def train(config):
logger.info("Pushing model to hub...")
api = HfApi(token=config.token)
api.create_repo(repo_id=config.repo_id, repo_type="model", private=True)
api.upload_folder(folder_path=config.project_name, repo_id=config.repo_id, repo_type="model")
api.upload_folder(
folder_path=config.project_name,
repo_id=config.repo_id,
repo_type="model",
)

if PartialState().process_index == 0:
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=config.token)
api.pause_space(repo_id=os.environ["SPACE_ID"])

success_message = f"Your training run was successful! [Check out your trained model here](https://huggingface.co/{config.username}/{config.project-name})"
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has finished successfully ✅",
description=success_message,
repo_type="space",
)
if "ENDPOINT_ID" in os.environ:
# shut down the endpoint
logger.info("Pausing endpoint...")
Expand Down
14 changes: 8 additions & 6 deletions src/autotrain/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,16 +280,18 @@ def wrapper(*args, **kwargs):
return func(*args, **kwargs)
except Exception:
if PartialState().process_index == 0:
error_message = f'''{func.__name__} has failed due to an exception: {traceback.format_exc()}'''
error_message = f"""{func.__name__} has failed due to an exception: {traceback.format_exc()}"""
logger.error(error_message)
if "SPACE_ID" in os.environ:
# shut down the space
logger.info("Pausing space...")
api = HfApi(token=os.environ["HF_TOKEN"])
api.pause_space(repo_id=os.environ["SPACE_ID"])
api.create_discussion(repo_id=os.environ['SPACE_ID'],
title="Your training has failed ❌",
description=error_message,
repo_type="space"
)
api.create_discussion(
repo_id=os.environ["SPACE_ID"],
title="Your training has failed ❌",
description=error_message,
repo_type="space",
)

return wrapper

0 comments on commit 43d3565

Please sign in to comment.