Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gen #775

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

gen #775

Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/autotrain/app/static/scripts/listeners.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ document.addEventListener('DOMContentLoaded', function () {
const uploadDataTabContent = document.getElementById("upload-data-tab-content");
const hubDataTabContent = document.getElementById("hub-data-tab-content");
const uploadDataTabs = document.getElementById("upload-data-tabs");
const genDataTabContent = document.getElementById("gen-data-tab-content");

const jsonCheckbox = document.getElementById('show-json-parameters');
const jsonParametersDiv = document.getElementById('json-parameters');
Expand Down Expand Up @@ -54,12 +55,20 @@ document.addEventListener('DOMContentLoaded', function () {
if (dataSource.value === "hub") {
uploadDataTabContent.style.display = "none";
uploadDataTabs.style.display = "none";
genDataTabContent.style.display = "none";
hubDataTabContent.style.display = "block";
} else if (dataSource.value === "local") {
uploadDataTabContent.style.display = "block";
uploadDataTabs.style.display = "block";
genDataTabContent.style.display = "none";
hubDataTabContent.style.display = "none";
}
else if (dataSource.value === "gen") {
uploadDataTabContent.style.display = "none";
uploadDataTabs.style.display = "none";
hubDataTabContent.style.display = "none";
genDataTabContent.style.display = "block";
}
}

async function fetchParams() {
Expand Down
27 changes: 26 additions & 1 deletion src/autotrain/app/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -420,9 +420,11 @@
Source
</p>
<select id="dataset_source" name="dataset_source"
class="mt-1 block w-full border border-gray-300 dark:border-gray-600 px-3 py-2 bg-white dark:bg-gray-700 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
class="mt-1 block w-full border border-gray-300 dark:border-gray-600 px-3 py-2 bg-white dark:bg-gray-700 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500"
style="position: relative; z-index: 10;">
<option value="local">Local</option>
<option value="hub">Hugging Face Hub</option>
<option value="gen">Generate</option>
</select>
</div>
</div>
Expand Down Expand Up @@ -524,7 +526,30 @@
</div>
</div>
</div>
<div id="gen-data-tab-content" class="w-full px-4">
<div class="columns-1 mb-2 mt-2">
<label for="gen_prompt"
class="text-sm font-medium text-gray-700 dark:text-gray-300">Prompt Text
</label>
<textarea name="gen_prompt" id="gen_prompt"
class="mt-1 block w-full border border-gray-300 dark:border-gray-600 px-3 py-2 bg-white dark:bg-gray-700 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500"></textarea>
</div>
<div class="columns-2 mb-2 mt-2">
<label for="gen_model"
class="text-sm font-medium text-gray-700 dark:text-gray-300">Model Name
</label>
<input type="text" name="gen_model" id="gen_model"
class="mt-1 block w-full border border-gray-300 dark:border-gray-600 px-3 py-2 bg-white dark:bg-gray-700 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
<label for="gen_samples"
class="text-sm font-medium text-gray-700 dark:text-gray-300">Number of Samples
</label>
<input type="number" name="gen_samples" id="gen_samples" value="100"
class="mt-1 block w-full border border-gray-300 dark:border-gray-600 px-3 py-2 bg-white dark:bg-gray-700 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
</div>
</div>
</div>


<div class="items-center justify-center h-24">
<div class="w-full px-4">
<p class="text-xl text-gray-800 dark:text-gray-200 mb-2 mt-2">
Expand Down
2 changes: 2 additions & 0 deletions src/autotrain/cli/autotrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from autotrain.cli.run_app import RunAutoTrainAppCommand
from autotrain.cli.run_dreambooth import RunAutoTrainDreamboothCommand
from autotrain.cli.run_extractive_qa import RunAutoTrainExtractiveQACommand
from autotrain.cli.run_gen import RunAutoTrainGenCommand
from autotrain.cli.run_image_classification import RunAutoTrainImageClassificationCommand
from autotrain.cli.run_image_regression import RunAutoTrainImageRegressionCommand
from autotrain.cli.run_llm import RunAutoTrainLLMCommand
Expand Down Expand Up @@ -49,6 +50,7 @@ def main():
RunAutoTrainSentenceTransformersCommand.register_subcommand(commands_parser)
RunAutoTrainImageRegressionCommand.register_subcommand(commands_parser)
RunAutoTrainExtractiveQACommand.register_subcommand(commands_parser)
RunAutoTrainGenCommand.register_subcommand(commands_parser)

args = parser.parse_args()

Expand Down
56 changes: 56 additions & 0 deletions src/autotrain/cli/run_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from argparse import ArgumentParser

from autotrain import logger
from autotrain.cli.utils import get_field_info
from autotrain.datagen.gen import AutoTrainGen
from autotrain.datagen.params import AutoTrainGenParams

from . import BaseAutoTrainCommand


def run_autotrain_gen_command(args):
return RunAutoTrainGenCommand(args)


class RunAutoTrainGenCommand(BaseAutoTrainCommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
arg_list = get_field_info(AutoTrainGenParams)
run_autotrain_gen_parser = parser.add_parser("gen", description="✨ AutoTrain Gen")
for arg in arg_list:
names = [arg["arg"]] + arg.get("alias", [])
if "action" in arg:
run_autotrain_gen_parser.add_argument(
*names,
dest=arg["arg"].replace("--", "").replace("-", "_"),
help=arg["help"],
required=arg.get("required", False),
action=arg.get("action"),
default=arg.get("default"),
)
else:
run_autotrain_gen_parser.add_argument(
*names,
dest=arg["arg"].replace("--", "").replace("-", "_"),
help=arg["help"],
required=arg.get("required", False),
type=arg.get("type"),
default=arg.get("default"),
choices=arg.get("choices"),
)
run_autotrain_gen_parser.set_defaults(func=run_autotrain_gen_command)

def __init__(self, args):
self.args = args

store_true_arg_names = [
"push_to_hub",
]
for arg_name in store_true_arg_names:
if getattr(self.args, arg_name) is None:
setattr(self.args, arg_name, False)

def run(self):
logger.info("Running AutoTrain Gen 🚀")
params = AutoTrainGenParams(**vars(self.args))
AutoTrainGen(params).run()
Empty file.
115 changes: 115 additions & 0 deletions src/autotrain/datagen/clients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import json
import time
from dataclasses import dataclass
from typing import Optional

import outlines
import torch
import transformers
from huggingface_hub import InferenceClient

from autotrain import logger


@dataclass
class _TransformersClient:
model_name: str

def __post_init__(self):
self.pipeline = transformers.pipeline(
"text-generation",
model=self.model_name,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)

def chat_completion(self, messages, max_tokens, stream, seed, response_format):
outputs = self.pipeline(
messages,
max_new_tokens=max_tokens,
seed=seed,
response_format=response_format,
stream=stream,
)
return outputs[0]["generated_text"][-1]["content"]


@dataclass
class TransformersClient:
model_name: str

def __post_init__(self):
self.pipeline = outlines.models.transformers(
self.model_name,
# device_map="auto",
model_kwargs={"torch_dtype": torch.bfloat16},
)

def chat_completion(self, messages, max_tokens, stream, seed, response_format):
# dump response_format dict to json
response_format = json.dumps(response_format)
generator = outlines.generate.json(self.pipeline, response_format)
outputs = generator(
messages,
max_tokens=max_tokens,
seed=seed,
)
print(outputs)
return outputs[0]["generated_text"][-1]["content"]


@dataclass
class Client:
name: str
model_name: Optional[str] = None
api_key: Optional[str] = None

def __post_init__(self):
if self.name == "hf-inference-api":
if self.model_name is None:
raise ValueError("Model name is required for Huggingface")
self.client = InferenceClient
elif self.name == "transformers":
if self.model_name is None:
raise ValueError("Model name is required for Transformers")
self.client = TransformersClient
else:
raise ValueError("Client not supported")

def __str__(self):
return f"Client: {self.name}"

def __repr__(self):
return f"Client: {self.name}"

def _huggingface(self):
if self.api_key:
return self.client(self.model_name, token=self.api_key)
return self.client(self.model_name)

def _transformers(self):
return self.client(self.model_name)

def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None, retries=3, delay=5):
if self.name == "hf-inference-api":
_client = self._huggingface()
elif self.name == "transformers":
_client = self._transformers()
else:
raise ValueError("Client not supported")
for attempt in range(retries):
try:
message = _client.chat_completion(
messages=messages,
max_tokens=max_tokens,
stream=False,
seed=seed,
response_format=response_format,
)
return message
except Exception as e:
logger.error(f"Attempt {attempt + 1} failed: {e}")
if attempt < retries - 1:
time.sleep(delay)
else:
return None
21 changes: 21 additions & 0 deletions src/autotrain/datagen/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

from autotrain import logger
from autotrain.datagen.params import AutoTrainGenParams


@dataclass
class AutoTrainGen:
params: AutoTrainGenParams

def __post_init__(self):
logger.info(self.params)
if self.params.task in ("text-classification", "seq2seq"):
from autotrain.datagen.text import TextDataGenerator

self.gen = TextDataGenerator(self.params)
else:
raise NotImplementedError

def run(self):
self.gen.run()
74 changes: 74 additions & 0 deletions src/autotrain/datagen/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os
from typing import Optional

from pydantic import BaseModel, Field

from autotrain import logger


class BaseGenParams(BaseModel):
"""
Base class for all AutoTrain gen parameters.
"""

class Config:
protected_namespaces = ()

def save(self, output_dir):
"""
Save parameters to a json file.
"""
os.makedirs(output_dir, exist_ok=True)
path = os.path.join(output_dir, "gen_params.json")
# save formatted json
with open(path, "w", encoding="utf-8") as f:
f.write(self.model_dump_json(indent=4))
Copy link
Preview

Copilot AI Nov 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method 'model_dump_json' does not exist in Pydantic's BaseModel. It should be 'self.json(indent=4)'.

Suggested change
f.write(self.model_dump_json(indent=4))
f.write(self.json(indent=4))

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

Positive Feedback
Negative Feedback

Provide additional feedback

Please help us improve GitHub Copilot by sharing more details about this comment.

Please select one or more of the options

def __str__(self):
"""
String representation of the parameters.
"""
data = self.model_dump()
data["token"] = "*****" if data.get("token") else None
return str(data)

def __init__(self, **data):
"""
Initialize the parameters, check for unused/extra parameters and warn the user.
"""
super().__init__(**data)

if len(self.project_name) > 0:
if not self.project_name.replace("-", "").isalnum():
raise ValueError("project_name must be alphanumeric but can contain hyphens")

if len(self.project_name) > 50:
raise ValueError("project_name cannot be more than 50 characters")

defaults = set(self.model_fields.keys())
supplied = set(data.keys())
not_supplied = defaults - supplied
if not_supplied:
logger.warning(f"Parameters not supplied by user and set to default: {', '.join(not_supplied)}")
unused = supplied - set(self.model_fields)
if unused:
logger.warning(f"Parameters supplied but not used: {', '.join(unused)}")


class AutoTrainGenParams(BaseGenParams):
gen_model: str = Field("meta-llama/Meta-Llama-3.1-8B-Instruct", title="The model to be used for generation.")
project_name: str = Field("autotrain-datagen", title="Name of the project.")
prompt: str = Field(None, title="Prompt to be used for text generation.")
task: str = Field(None, title="Task type, e.g., text-classification, summarization.")
token: Optional[str] = Field(None, title="Authentication token for accessing the model.")
training_config: Optional[str] = Field(None, title="Path to the training configuration file.")
valid_size: Optional[float] = Field(0.2, title="Validation set size as a fraction of the total dataset.")
username: Optional[str] = Field(None, title="Username of the person running the training.")
push_to_hub: Optional[bool] = Field(True, title="Whether to push the model to Hugging Face Hub.")
backend: Optional[str] = Field("huggingface", title="Backend to be used, e.g., huggingface, local.")
api: Optional[str] = Field(None, title="API endpoint to be used.")
api_key: Optional[str] = Field(None, title="API key for authentication.")
sample: Optional[str] = Field(None, title="Sample dataset for generation.")
min_samples: Optional[int] = Field(200, title="Minimum number of samples required for training.")
# text specific
min_words: Optional[int] = Field(25, title="Minimum number of words in the generated text.")
Loading
Loading