From 7c269406ff3960729030dd88a51479f9f225b12f Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Thu, 14 Sep 2023 17:11:13 +0200 Subject: [PATCH] lazy loading, fixes, new release --- src/autotrain/__init__.py | 2 +- src/autotrain/cli/run_image_classification.py | 5 +++-- src/autotrain/cli/run_llm.py | 10 ++++++---- src/autotrain/cli/run_tabular.py | 5 +++-- src/autotrain/cli/run_text_classification.py | 5 +++-- 5 files changed, 16 insertions(+), 11 deletions(-) diff --git a/src/autotrain/__init__.py b/src/autotrain/__init__.py index a5735166e6..dcb85745b8 100644 --- a/src/autotrain/__init__.py +++ b/src/autotrain/__init__.py @@ -30,4 +30,4 @@ warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow") -__version__ = "0.6.32.dev0" +__version__ = "0.6.33.dev0" diff --git a/src/autotrain/cli/run_image_classification.py b/src/autotrain/cli/run_image_classification.py index d6ade0b223..199bd4423f 100644 --- a/src/autotrain/cli/run_image_classification.py +++ b/src/autotrain/cli/run_image_classification.py @@ -5,8 +5,6 @@ import torch from autotrain import logger -from autotrain.trainers.image_classification.__main__ import train as train_image_classification -from autotrain.trainers.image_classification.params import ImageClassificationParams from . import BaseAutoTrainCommand @@ -262,6 +260,9 @@ def __init__(self, args): self.num_gpus = torch.cuda.device_count() def run(self): + from autotrain.trainers.image_classification.__main__ import train as train_image_classification + from autotrain.trainers.image_classification.params import ImageClassificationParams + logger.info("Running Text Classification") if self.args.train: params = ImageClassificationParams( diff --git a/src/autotrain/cli/run_llm.py b/src/autotrain/cli/run_llm.py index 94ae8639bd..71cc89c796 100644 --- a/src/autotrain/cli/run_llm.py +++ b/src/autotrain/cli/run_llm.py @@ -6,10 +6,6 @@ import torch from autotrain import logger -from autotrain.backend import EndpointsRunner, SpaceRunner -from autotrain.infer.text_generation import TextGenerationInference -from autotrain.trainers.clm.__main__ import train as train_llm -from autotrain.trainers.clm.params import LLMTrainingParams from . import BaseAutoTrainCommand @@ -392,6 +388,8 @@ def __init__(self, args): raise ValueError("Token must be specified for spaces backend") if self.args.inference: + from autotrain.infer.text_generation import TextGenerationInference + tgi = TextGenerationInference( self.args.project_name, use_int4=self.args.use_int4, use_int8=self.args.use_int8 ) @@ -407,6 +405,10 @@ def __init__(self, args): self.num_gpus = torch.cuda.device_count() def run(self): + from autotrain.backend import EndpointsRunner, SpaceRunner + from autotrain.trainers.clm.__main__ import train as train_llm + from autotrain.trainers.clm.params import LLMTrainingParams + logger.info("Running LLM") logger.info(f"Params: {self.args}") if self.args.train: diff --git a/src/autotrain/cli/run_tabular.py b/src/autotrain/cli/run_tabular.py index 11cfb89b18..e724d29495 100644 --- a/src/autotrain/cli/run_tabular.py +++ b/src/autotrain/cli/run_tabular.py @@ -6,8 +6,6 @@ from autotrain import logger from autotrain.backend import EndpointsRunner, SpaceRunner -from autotrain.trainers.tabular.__main__ import train as train_tabular -from autotrain.trainers.tabular.params import TabularParams from . import BaseAutoTrainCommand @@ -229,6 +227,9 @@ def __init__(self, args): self.args.target_columns = [k.strip() for k in self.args.target_columns.split(",")] def run(self): + from autotrain.trainers.tabular.__main__ import train as train_tabular + from autotrain.trainers.tabular.params import TabularParams + logger.info("Running Tabular Training...") if self.args.train: params = TabularParams( diff --git a/src/autotrain/cli/run_text_classification.py b/src/autotrain/cli/run_text_classification.py index c437d5f90e..cdb92d6dbb 100644 --- a/src/autotrain/cli/run_text_classification.py +++ b/src/autotrain/cli/run_text_classification.py @@ -7,8 +7,6 @@ from autotrain import logger from autotrain.backend import EndpointsRunner, SpaceRunner -from autotrain.trainers.text_classification.__main__ import train as train_text_classification -from autotrain.trainers.text_classification.params import TextClassificationParams from . import BaseAutoTrainCommand @@ -294,6 +292,9 @@ def __init__(self, args): self.args.token = os.environ.get("HF_TOKEN", None) def run(self): + from autotrain.trainers.text_classification.__main__ import train as train_text_classification + from autotrain.trainers.text_classification.params import TextClassificationParams + logger.info("Running Text Classification") if self.args.train: params = TextClassificationParams(