Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Sep 14, 2024
1 parent 0cf7b52 commit 7f553dd
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 7 deletions.
9 changes: 9 additions & 0 deletions src/autotrain/app/static/scripts/listeners.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ document.addEventListener('DOMContentLoaded', function () {
const uploadDataTabContent = document.getElementById("upload-data-tab-content");
const hubDataTabContent = document.getElementById("hub-data-tab-content");
const uploadDataTabs = document.getElementById("upload-data-tabs");
const genDataTabContent = document.getElementById("gen-data-tab-content");

const jsonCheckbox = document.getElementById('show-json-parameters');
const jsonParametersDiv = document.getElementById('json-parameters');
Expand Down Expand Up @@ -54,12 +55,20 @@ document.addEventListener('DOMContentLoaded', function () {
if (dataSource.value === "hub") {
uploadDataTabContent.style.display = "none";
uploadDataTabs.style.display = "none";
genDataTabContent.style.display = "none";
hubDataTabContent.style.display = "block";
} else if (dataSource.value === "local") {
uploadDataTabContent.style.display = "block";
uploadDataTabs.style.display = "block";
genDataTabContent.style.display = "none";
hubDataTabContent.style.display = "none";
}
else if (dataSource.value === "gen") {
uploadDataTabContent.style.display = "none";
uploadDataTabs.style.display = "none";
hubDataTabContent.style.display = "none";
genDataTabContent.style.display = "block";
}
}

async function fetchParams() {
Expand Down
27 changes: 26 additions & 1 deletion src/autotrain/app/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,11 @@
Source
</p>
<select id="dataset_source" name="dataset_source"
class="mt-1 block w-full border border-gray-300 dark:border-gray-600 px-3 py-2 bg-white dark:bg-gray-700 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
class="mt-1 block w-full border border-gray-300 dark:border-gray-600 px-3 py-2 bg-white dark:bg-gray-700 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500"
style="position: relative; z-index: 10;">
<option value="local">Local</option>
<option value="hub">Hugging Face Hub</option>
<option value="gen">Generate</option>
</select>
</div>
</div>
Expand Down Expand Up @@ -524,7 +526,30 @@
</div>
</div>
</div>
<div id="gen-data-tab-content" class="w-full px-4">
<div class="columns-1 mb-2 mt-2">
<label for="gen_prompt"
class="text-sm font-medium text-gray-700 dark:text-gray-300">Prompt Text
</label>
<textarea name="gen_prompt" id="gen_prompt"
class="mt-1 block w-full border border-gray-300 dark:border-gray-600 px-3 py-2 bg-white dark:bg-gray-700 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500"></textarea>
</div>
<div class="columns-2 mb-2 mt-2">
<label for="gen_model"
class="text-sm font-medium text-gray-700 dark:text-gray-300">Model Name
</label>
<input type="text" name="gen_model" id="gen_model"
class="mt-1 block w-full border border-gray-300 dark:border-gray-600 px-3 py-2 bg-white dark:bg-gray-700 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
<label for="gen_samples"
class="text-sm font-medium text-gray-700 dark:text-gray-300">Number of Samples
</label>
<input type="number" name="gen_samples" id="gen_samples" value="100"
class="mt-1 block w-full border border-gray-300 dark:border-gray-600 px-3 py-2 bg-white dark:bg-gray-700 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
</div>
</div>
</div>


<div class="items-center justify-center h-24">
<div class="w-full px-4">
<p class="text-xl text-gray-800 dark:text-gray-200 mb-2 mt-2">
Expand Down
67 changes: 65 additions & 2 deletions src/autotrain/datagen/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,57 @@
from huggingface_hub import InferenceClient

from autotrain import logger
import transformers
import torch
import outlines
import json


@dataclass
class _TransformersClient:
model_name: str

def __post_init__(self):
self.pipeline = transformers.pipeline(
"text-generation",
model=self.model_name,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)

def chat_completion(self, messages, max_tokens, stream, seed, response_format):
outputs = self.pipeline(
messages,
max_new_tokens=max_tokens,
seed=seed,
response_format=response_format,
stream=stream,
)
return outputs[0]["generated_text"][-1]["content"]


@dataclass
class TransformersClient:
model_name: str

def __post_init__(self):
self.pipeline = outlines.models.transformers(
self.model_name,
# device_map="auto",
model_kwargs={"torch_dtype": torch.bfloat16},
)

def chat_completion(self, messages, max_tokens, stream, seed, response_format):
# dump response_format dict to json
response_format = json.dumps(response_format)
generator = outlines.generate.json(self.pipeline, response_format)
outputs = generator(
messages,
max_tokens=max_tokens,
seed=seed,
)
print(outputs)
return outputs[0]["generated_text"][-1]["content"]


@dataclass
Expand All @@ -14,10 +65,14 @@ class Client:
api_key: Optional[str] = None

def __post_init__(self):
if self.name == "huggingface":
if self.name == "hf-inference-api":
if self.model_name is None:
raise ValueError("Model name is required for Huggingface")
self.client = InferenceClient
elif self.name == "transformers":
if self.model_name is None:
raise ValueError("Model name is required for Transformers")
self.client = TransformersClient
else:
raise ValueError("Client not supported")

Expand All @@ -32,10 +87,18 @@ def _huggingface(self):
return self.client(self.model_name, token=self.api_key)
return self.client(self.model_name)

def _transformers(self):
return self.client(self.model_name)

def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None, retries=3, delay=5):
if self.name == "hf-inference-api":
_client = self._huggingface()
elif self.name == "transformers":
_client = self._transformers()
else:
raise ValueError("Client not supported")
for attempt in range(retries):
try:
_client = self._huggingface()
message = _client.chat_completion(
messages=messages,
max_tokens=max_tokens,
Expand Down
5 changes: 3 additions & 2 deletions src/autotrain/datagen/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from autotrain.datagen.clients import Client
from autotrain.datagen.params import AutoTrainGenParams

import os

TEXT_CLASSIFICATION_SYSTEM_PROMPT = """
You are an AI bot that generates data for text classification tasks.
Expand Down Expand Up @@ -158,8 +159,8 @@ def run(self):
logger.info(f"Train data size: {len(train_data)}")
logger.info(f"Valid data size: {len(valid_data)}")

hf_dataset = utils.convert_text_dataset_to_hf(train_data, valid_data)
hf_dataset.save_to_disk(self.params.project_name)
hf_dataset = utils.convert_text_dataset_to_hf(self.params.task, train_data, valid_data)
hf_dataset.save_to_disk(os.path.join(self.params.project_name, "autotrain-data"))

if self.params.push_to_hub:
logger.info("Pushing the data to Hugging Face Hub.")
Expand Down
20 changes: 18 additions & 2 deletions src/autotrain/datagen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,33 @@
import subprocess
from typing import Dict, List, Optional

from datasets import Dataset, DatasetDict
from datasets import Dataset, DatasetDict, ClassLabel
from huggingface_hub import HfApi, metadata_update

from autotrain import logger


def convert_text_dataset_to_hf(train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None) -> Dataset:
def convert_text_dataset_to_hf(
task, train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None
) -> Dataset:
if task == "text-classification":
for item in train:
item["target"] = item["target"].lower().strip()
label_names = list(set([item["target"] for item in train]))
logger.info(f"Label names: {label_names}")

dataset = Dataset.from_list(train)

if task == "text-classification":
dataset = dataset.cast_column("target", ClassLabel(names=label_names))

ddict = {"train": dataset}
if valid is not None:
valid_dataset = Dataset.from_list(valid)
if task == "text-classification":
for item in valid:
item["target"] = item["target"].lower().strip()
valid_dataset = valid_dataset.cast_column("target", ClassLabel(names=label_names))
ddict["validation"] = valid_dataset
dataset = DatasetDict(ddict)
return dataset
Expand Down
35 changes: 35 additions & 0 deletions src/autotrain/trainers/text_classification/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from autotrain.trainers.text_classification import utils
from autotrain.trainers.text_classification.dataset import TextClassificationDataset
from autotrain.trainers.text_classification.params import TextClassificationParams
from autotrain.datagen.params import AutoTrainGenParams
from autotrain.datagen.gen import AutoTrainGen


def parse_args():
Expand All @@ -42,6 +44,39 @@ def train(config):
if isinstance(config, dict):
config = TextClassificationParams(**config)

if config.gen_prompt is not None:
if config.gen_model is None:
raise ValueError("Generation model must be provided when using generation prompt")
gen_params = config.gen_params.split(",") if config.gen_params is not None else []
gen_params = {p.split("=")[0]: p.split("=")[1] for p in gen_params}
gen_config = AutoTrainGenParams(
gen_model=config.gen_model,
project_name=config.project_name,
prompt=config.gen_prompt,
task="text-classification",
token=config.token,
username=config.username,
push_to_hub=config.push_to_hub,
backend=gen_params.get("backend", "transformers"),
api=gen_params.get("api", None),
api_key=gen_params.get("api_key", None),
min_samples=config.gen_samples,
)
try:
AutoTrainGen(gen_config).run()
except Exception as e:
logger.error(f"Error generating data: {e}")
return
if config.push_to_hub:
config.data_path = f"{config.username}/{config.project_name}"
else:
config.data_path = f"{config.project_name}/autotrain-data"

config.train_split = "train"
config.valid_split = "validation"
config.target_column = "target"
config.text_column = "text"

train_data = None
valid_data = None
# check if config.train_split.csv exists in config.data_path
Expand Down
5 changes: 5 additions & 0 deletions src/autotrain/trainers/text_classification/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@ class TextClassificationParams(AutoTrainParams):
log: str = Field("none", title="Logging using experiment tracking")
early_stopping_patience: int = Field(5, title="Early stopping patience")
early_stopping_threshold: float = Field(0.01, title="Early stopping threshold")
# generation parameters
gen_prompt: Optional[str] = Field(None, title="Generation prompt")
gen_model: Optional[str] = Field(None, title="Generation model")
gen_samples: Optional[int] = Field(100, title="Generation samples")
gen_params: Optional[str] = Field(None, title="Generation parameters, format: key1=value1,key2=value2")

0 comments on commit 7f553dd

Please sign in to comment.