Skip to content

Commit

Permalink
py
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Oct 16, 2024
1 parent c806f01 commit a571185
Show file tree
Hide file tree
Showing 19 changed files with 597 additions and 479 deletions.
36 changes: 36 additions & 0 deletions notebooks/python_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os

from autotrain.params import LLMTrainingParams
from autotrain.project import AutoTrainProject


params = LLMTrainingParams(
model="meta-llama/Llama-3.2-1B-Instruct",
data_path="HuggingFaceH4/no_robots",
chat_template="tokenizer",
text_column="messages",
train_split="train",
trainer="sft",
epochs=3,
batch_size=1,
lr=1e-5,
peft=True,
quantization="int4",
target_modules="all-linear",
padding="right",
optimizer="paged_adamw_8bit",
scheduler="cosine",
gradient_accumulation=8,
mixed_precision="bf16",
merge_adapter=True,
project_name="autotrain-llama32-1b-finetune",
log="tensorboard",
push_to_hub=False,
username=os.environ.get("HF_USERNAME"),
token=os.environ.get("HF_TOKEN"),
)


backend = "local"
project = AutoTrainProject(params=params, backend=backend, process=True)
project.create()
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ max-line-length = 119
per-file-ignores =
# imported but unused
__init__.py: F401, E402
src/autotrain/params.py: F401
exclude =
.git,
.venv,
Expand Down
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from autotrain import logger
from autotrain.cli import BaseAutoTrainCommand
from autotrain.cli.utils import common_args, dreambooth_munge_data
from autotrain.cli.utils import common_args
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
from autotrain.trainers.dreambooth.utils import VALID_IMAGE_EXTENSIONS, XL_MODELS
Expand Down Expand Up @@ -387,7 +388,7 @@ def __init__(self, args):
def run(self):
logger.info("Running DreamBooth Training")
params = DreamBoothTrainingParams(**vars(self.args))
params = dreambooth_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_extractive_qa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import ext_qa_munge_data, get_field_info
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.extractive_question_answering.params import ExtractiveQuestionAnsweringParams

Expand Down Expand Up @@ -100,7 +101,7 @@ def run(self):
logger.info("Running Extractive Question Answering")
if self.args.train:
params = ExtractiveQuestionAnsweringParams(**vars(self.args))
params = ext_qa_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_image_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, img_clf_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.image_classification.params import ImageClassificationParams

Expand Down Expand Up @@ -108,7 +109,7 @@ def run(self):
logger.info("Running Image Classification")
if self.args.train:
params = ImageClassificationParams(**vars(self.args))
params = img_clf_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_image_regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, img_reg_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.image_regression.params import ImageRegressionParams

Expand Down Expand Up @@ -108,7 +109,7 @@ def run(self):
logger.info("Running Image Regression")
if self.args.train:
params = ImageRegressionParams(**vars(self.args))
params = img_reg_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, llm_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.clm.params import LLMTrainingParams

Expand Down Expand Up @@ -136,7 +137,7 @@ def run(self):
logger.info("Running LLM")
if self.args.train:
params = LLMTrainingParams(**vars(self.args))
params = llm_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_object_detection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, img_obj_detect_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.object_detection.params import ObjectDetectionParams

Expand Down Expand Up @@ -108,7 +109,7 @@ def run(self):
logger.info("Running Object Detection")
if self.args.train:
params = ObjectDetectionParams(**vars(self.args))
params = img_obj_detect_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_sent_tranformers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, sent_transformers_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.sent_transformers.params import SentenceTransformersParams

Expand Down Expand Up @@ -108,7 +109,7 @@ def run(self):
logger.info("Running Sentence Transformers...")
if self.args.train:
params = SentenceTransformersParams(**vars(self.args))
params = sent_transformers_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, seq2seq_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.seq2seq.params import Seq2SeqParams

Expand Down Expand Up @@ -92,7 +93,7 @@ def run(self):
logger.info("Running Seq2Seq Classification")
if self.args.train:
params = Seq2SeqParams(**vars(self.args))
params = seq2seq_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_tabular.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, tabular_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.tabular.params import TabularParams

Expand Down Expand Up @@ -101,7 +102,7 @@ def run(self):
logger.info("Running Tabular Training")
if self.args.train:
params = TabularParams(**vars(self.args))
params = tabular_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_text_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, text_clf_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.text_classification.params import TextClassificationParams

Expand Down Expand Up @@ -101,7 +102,7 @@ def run(self):
logger.info("Running Text Classification")
if self.args.train:
params = TextClassificationParams(**vars(self.args))
params = text_clf_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_text_regression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, text_reg_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.text_regression.params import TextRegressionParams

Expand Down Expand Up @@ -101,7 +102,7 @@ def run(self):
logger.info("Running Text Regression")
if self.args.train:
params = TextRegressionParams(**vars(self.args))
params = text_reg_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_token_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, token_clf_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.token_classification.params import TokenClassificationParams

Expand Down Expand Up @@ -101,7 +102,7 @@ def run(self):
logger.info("Running Token Classification")
if self.args.train:
params = TokenClassificationParams(**vars(self.args))
params = token_clf_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
5 changes: 3 additions & 2 deletions src/autotrain/cli/run_vlm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info, vlm_munge_data
from autotrain.cli.utils import get_field_info
from autotrain.process import AutoTrainDataProcessor
from autotrain.project import AutoTrainProject
from autotrain.trainers.vlm.params import VLMTrainingParams

Expand Down Expand Up @@ -106,7 +107,7 @@ def run(self):
logger.info("Running Image Regression")
if self.args.train:
params = VLMTrainingParams(**vars(self.args))
params = vlm_munge_data(params, local=self.args.backend.startswith("local"))
params = AutoTrainDataProcessor(params, local=self.args.backend.startswith("local"))
project = AutoTrainProject(params=params, backend=self.args.backend)
job_id = project.create()
logger.info(f"Job ID: {job_id}")
Loading

0 comments on commit a571185

Please sign in to comment.