diff --git a/src/autotrain/app/static/scripts/listeners.js b/src/autotrain/app/static/scripts/listeners.js index 9638342aea..59b4cfa0fe 100644 --- a/src/autotrain/app/static/scripts/listeners.js +++ b/src/autotrain/app/static/scripts/listeners.js @@ -3,6 +3,7 @@ document.addEventListener('DOMContentLoaded', function () { const uploadDataTabContent = document.getElementById("upload-data-tab-content"); const hubDataTabContent = document.getElementById("hub-data-tab-content"); const uploadDataTabs = document.getElementById("upload-data-tabs"); + const genDataTabContent = document.getElementById("gen-data-tab-content"); const jsonCheckbox = document.getElementById('show-json-parameters'); const jsonParametersDiv = document.getElementById('json-parameters'); @@ -54,12 +55,20 @@ document.addEventListener('DOMContentLoaded', function () { if (dataSource.value === "hub") { uploadDataTabContent.style.display = "none"; uploadDataTabs.style.display = "none"; + genDataTabContent.style.display = "none"; hubDataTabContent.style.display = "block"; } else if (dataSource.value === "local") { uploadDataTabContent.style.display = "block"; uploadDataTabs.style.display = "block"; + genDataTabContent.style.display = "none"; hubDataTabContent.style.display = "none"; } + else if (dataSource.value === "gen") { + uploadDataTabContent.style.display = "none"; + uploadDataTabs.style.display = "none"; + hubDataTabContent.style.display = "none"; + genDataTabContent.style.display = "block"; + } } async function fetchParams() { diff --git a/src/autotrain/app/templates/index.html b/src/autotrain/app/templates/index.html index 00501202cf..76fccd6961 100644 --- a/src/autotrain/app/templates/index.html +++ b/src/autotrain/app/templates/index.html @@ -420,9 +420,11 @@ Source

@@ -524,7 +526,30 @@ +
+
+ + +
+
+ + + + +
+
+ +

diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py index a7c2c066ff..5583f46bee 100644 --- a/src/autotrain/datagen/clients.py +++ b/src/autotrain/datagen/clients.py @@ -5,6 +5,57 @@ from huggingface_hub import InferenceClient from autotrain import logger +import transformers +import torch +import outlines +import json + + +@dataclass +class _TransformersClient: + model_name: str + + def __post_init__(self): + self.pipeline = transformers.pipeline( + "text-generation", + model=self.model_name, + model_kwargs={"torch_dtype": torch.bfloat16}, + device_map="auto", + ) + + def chat_completion(self, messages, max_tokens, stream, seed, response_format): + outputs = self.pipeline( + messages, + max_new_tokens=max_tokens, + seed=seed, + response_format=response_format, + stream=stream, + ) + return outputs[0]["generated_text"][-1]["content"] + + +@dataclass +class TransformersClient: + model_name: str + + def __post_init__(self): + self.pipeline = outlines.models.transformers( + self.model_name, + # device_map="auto", + model_kwargs={"torch_dtype": torch.bfloat16}, + ) + + def chat_completion(self, messages, max_tokens, stream, seed, response_format): + # dump response_format dict to json + response_format = json.dumps(response_format) + generator = outlines.generate.json(self.pipeline, response_format) + outputs = generator( + messages, + max_tokens=max_tokens, + seed=seed, + ) + print(outputs) + return outputs[0]["generated_text"][-1]["content"] @dataclass @@ -14,10 +65,14 @@ class Client: api_key: Optional[str] = None def __post_init__(self): - if self.name == "huggingface": + if self.name == "hf-inference-api": if self.model_name is None: raise ValueError("Model name is required for Huggingface") self.client = InferenceClient + elif self.name == "transformers": + if self.model_name is None: + raise ValueError("Model name is required for Transformers") + self.client = TransformersClient else: raise ValueError("Client not supported") @@ -32,10 +87,18 @@ def _huggingface(self): return self.client(self.model_name, token=self.api_key) return self.client(self.model_name) + def _transformers(self): + return self.client(self.model_name) + def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None, retries=3, delay=5): + if self.name == "hf-inference-api": + _client = self._huggingface() + elif self.name == "transformers": + _client = self._transformers() + else: + raise ValueError("Client not supported") for attempt in range(retries): try: - _client = self._huggingface() message = _client.chat_completion( messages=messages, max_tokens=max_tokens, diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py index a50c75c381..77277e5502 100644 --- a/src/autotrain/datagen/text.py +++ b/src/autotrain/datagen/text.py @@ -9,6 +9,7 @@ from autotrain.datagen.clients import Client from autotrain.datagen.params import AutoTrainGenParams +import os TEXT_CLASSIFICATION_SYSTEM_PROMPT = """ You are an AI bot that generates data for text classification tasks. @@ -158,8 +159,8 @@ def run(self): logger.info(f"Train data size: {len(train_data)}") logger.info(f"Valid data size: {len(valid_data)}") - hf_dataset = utils.convert_text_dataset_to_hf(train_data, valid_data) - hf_dataset.save_to_disk(self.params.project_name) + hf_dataset = utils.convert_text_dataset_to_hf(self.params.task, train_data, valid_data) + hf_dataset.save_to_disk(os.path.join(self.params.project_name, "autotrain-data")) if self.params.push_to_hub: logger.info("Pushing the data to Hugging Face Hub.") diff --git a/src/autotrain/datagen/utils.py b/src/autotrain/datagen/utils.py index fb5e703c3d..3a43c0d840 100644 --- a/src/autotrain/datagen/utils.py +++ b/src/autotrain/datagen/utils.py @@ -3,17 +3,33 @@ import subprocess from typing import Dict, List, Optional -from datasets import Dataset, DatasetDict +from datasets import Dataset, DatasetDict, ClassLabel from huggingface_hub import HfApi, metadata_update from autotrain import logger -def convert_text_dataset_to_hf(train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None) -> Dataset: +def convert_text_dataset_to_hf( + task, train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None +) -> Dataset: + if task == "text-classification": + for item in train: + item["target"] = item["target"].lower().strip() + label_names = list(set([item["target"] for item in train])) + logger.info(f"Label names: {label_names}") + dataset = Dataset.from_list(train) + + if task == "text-classification": + dataset = dataset.cast_column("target", ClassLabel(names=label_names)) + ddict = {"train": dataset} if valid is not None: valid_dataset = Dataset.from_list(valid) + if task == "text-classification": + for item in valid: + item["target"] = item["target"].lower().strip() + valid_dataset = valid_dataset.cast_column("target", ClassLabel(names=label_names)) ddict["validation"] = valid_dataset dataset = DatasetDict(ddict) return dataset diff --git a/src/autotrain/trainers/text_classification/__main__.py b/src/autotrain/trainers/text_classification/__main__.py index 5ce3918e87..f5a5aa1637 100644 --- a/src/autotrain/trainers/text_classification/__main__.py +++ b/src/autotrain/trainers/text_classification/__main__.py @@ -28,6 +28,8 @@ from autotrain.trainers.text_classification import utils from autotrain.trainers.text_classification.dataset import TextClassificationDataset from autotrain.trainers.text_classification.params import TextClassificationParams +from autotrain.datagen.params import AutoTrainGenParams +from autotrain.datagen.gen import AutoTrainGen def parse_args(): @@ -42,6 +44,39 @@ def train(config): if isinstance(config, dict): config = TextClassificationParams(**config) + if config.gen_prompt is not None: + if config.gen_model is None: + raise ValueError("Generation model must be provided when using generation prompt") + gen_params = config.gen_params.split(",") if config.gen_params is not None else [] + gen_params = {p.split("=")[0]: p.split("=")[1] for p in gen_params} + gen_config = AutoTrainGenParams( + gen_model=config.gen_model, + project_name=config.project_name, + prompt=config.gen_prompt, + task="text-classification", + token=config.token, + username=config.username, + push_to_hub=config.push_to_hub, + backend=gen_params.get("backend", "transformers"), + api=gen_params.get("api", None), + api_key=gen_params.get("api_key", None), + min_samples=config.gen_samples, + ) + try: + AutoTrainGen(gen_config).run() + except Exception as e: + logger.error(f"Error generating data: {e}") + return + if config.push_to_hub: + config.data_path = f"{config.username}/{config.project_name}" + else: + config.data_path = f"{config.project_name}/autotrain-data" + + config.train_split = "train" + config.valid_split = "validation" + config.target_column = "target" + config.text_column = "text" + train_data = None valid_data = None # check if config.train_split.csv exists in config.data_path diff --git a/src/autotrain/trainers/text_classification/params.py b/src/autotrain/trainers/text_classification/params.py index c38f1d6c72..79fd4b8e94 100644 --- a/src/autotrain/trainers/text_classification/params.py +++ b/src/autotrain/trainers/text_classification/params.py @@ -35,3 +35,8 @@ class TextClassificationParams(AutoTrainParams): log: str = Field("none", title="Logging using experiment tracking") early_stopping_patience: int = Field(5, title="Early stopping patience") early_stopping_threshold: float = Field(0.01, title="Early stopping threshold") + # generation parameters + gen_prompt: Optional[str] = Field(None, title="Generation prompt") + gen_model: Optional[str] = Field(None, title="Generation model") + gen_samples: Optional[int] = Field(100, title="Generation samples") + gen_params: Optional[str] = Field(None, title="Generation parameters, format: key1=value1,key2=value2")