Skip to content

Commit

Permalink
lazy loading, fixes, new release
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Sep 14, 2023
1 parent e2b0cf1 commit 7c26940
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/autotrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@
warnings.filterwarnings("ignore", category=UserWarning, module="tensorflow")


__version__ = "0.6.32.dev0"
__version__ = "0.6.33.dev0"
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import torch

from autotrain import logger
from autotrain.trainers.image_classification.__main__ import train as train_image_classification
from autotrain.trainers.image_classification.params import ImageClassificationParams

from . import BaseAutoTrainCommand

Expand Down Expand Up @@ -262,6 +260,9 @@ def __init__(self, args):
self.num_gpus = torch.cuda.device_count()

def run(self):
from autotrain.trainers.image_classification.__main__ import train as train_image_classification
from autotrain.trainers.image_classification.params import ImageClassificationParams

logger.info("Running Text Classification")
if self.args.train:
params = ImageClassificationParams(
Expand Down
10 changes: 6 additions & 4 deletions src/autotrain/cli/run_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
import torch

from autotrain import logger
from autotrain.backend import EndpointsRunner, SpaceRunner
from autotrain.infer.text_generation import TextGenerationInference
from autotrain.trainers.clm.__main__ import train as train_llm
from autotrain.trainers.clm.params import LLMTrainingParams

from . import BaseAutoTrainCommand

Expand Down Expand Up @@ -392,6 +388,8 @@ def __init__(self, args):
raise ValueError("Token must be specified for spaces backend")

if self.args.inference:
from autotrain.infer.text_generation import TextGenerationInference

tgi = TextGenerationInference(
self.args.project_name, use_int4=self.args.use_int4, use_int8=self.args.use_int8
)
Expand All @@ -407,6 +405,10 @@ def __init__(self, args):
self.num_gpus = torch.cuda.device_count()

def run(self):
from autotrain.backend import EndpointsRunner, SpaceRunner
from autotrain.trainers.clm.__main__ import train as train_llm
from autotrain.trainers.clm.params import LLMTrainingParams

logger.info("Running LLM")
logger.info(f"Params: {self.args}")
if self.args.train:
Expand Down
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

from autotrain import logger
from autotrain.backend import EndpointsRunner, SpaceRunner
from autotrain.trainers.tabular.__main__ import train as train_tabular
from autotrain.trainers.tabular.params import TabularParams

from . import BaseAutoTrainCommand

Expand Down Expand Up @@ -229,6 +227,9 @@ def __init__(self, args):
self.args.target_columns = [k.strip() for k in self.args.target_columns.split(",")]

def run(self):
from autotrain.trainers.tabular.__main__ import train as train_tabular
from autotrain.trainers.tabular.params import TabularParams

logger.info("Running Tabular Training...")
if self.args.train:
params = TabularParams(
Expand Down
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from autotrain import logger
from autotrain.backend import EndpointsRunner, SpaceRunner
from autotrain.trainers.text_classification.__main__ import train as train_text_classification
from autotrain.trainers.text_classification.params import TextClassificationParams

from . import BaseAutoTrainCommand

Expand Down Expand Up @@ -294,6 +292,9 @@ def __init__(self, args):
self.args.token = os.environ.get("HF_TOKEN", None)

def run(self):
from autotrain.trainers.text_classification.__main__ import train as train_text_classification
from autotrain.trainers.text_classification.params import TextClassificationParams

logger.info("Running Text Classification")
if self.args.train:
params = TextClassificationParams(
Expand Down

0 comments on commit 7c26940

Please sign in to comment.