From c65bcf7fca6d4b82f64244f6cfc499e31b226fb3 Mon Sep 17 00:00:00 2001
From: Abhishek Thakur
Date: Tue, 10 Sep 2024 15:20:50 +0200
Subject: [PATCH 1/7] gen1
---
src/autotrain/datagen/__init__.py | 0
src/autotrain/datagen/clients.py | 56 +++++++++
src/autotrain/datagen/generator.py | 15 +++
src/autotrain/datagen/text_classification.py | 125 +++++++++++++++++++
src/autotrain/datagen/utils.py | 0
5 files changed, 196 insertions(+)
create mode 100644 src/autotrain/datagen/__init__.py
create mode 100644 src/autotrain/datagen/clients.py
create mode 100644 src/autotrain/datagen/generator.py
create mode 100644 src/autotrain/datagen/text_classification.py
create mode 100644 src/autotrain/datagen/utils.py
diff --git a/src/autotrain/datagen/__init__.py b/src/autotrain/datagen/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py
new file mode 100644
index 0000000000..2bfd23774d
--- /dev/null
+++ b/src/autotrain/datagen/clients.py
@@ -0,0 +1,56 @@
+from huggingface_hub import InferenceClient
+from dataclasses import dataclass
+from typing import Optional
+
+"""
+from huggingface_hub import InferenceClient
+
+client = InferenceClient(
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
+ token="hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
+)
+
+for message in client.chat_completion(
+ messages=[{"role": "user", "content": "What is the capital of France?"}],
+ max_tokens=500,
+ stream=True,
+):
+ print(message.choices[0].delta.content, end="")
+"""
+
+
+@dataclass
+class Client:
+ name: str
+ model_name: Optional[str] = None
+ token: Optional[str] = None
+
+ def __post_init__(self):
+ if self.name == "huggingface":
+ if self.model_name is None:
+ raise ValueError("Model name is required for Huggingface")
+ self.client = InferenceClient
+ else:
+ raise ValueError("Client not supported")
+
+ def __str__(self):
+ return f"Client: {self.name}"
+
+ def __repr__(self):
+ return f"Client: {self.name}"
+
+ def _huggingface(self):
+ if self.token:
+ return self.client(self.model_name, token=self.token)
+ return self.client(self.model_name)
+
+ def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None):
+ _client = self._huggingface()
+ message = _client.chat_completion(
+ messages=messages,
+ max_tokens=max_tokens,
+ stream=False,
+ seed=seed,
+ response_format=response_format,
+ )
+ return message
diff --git a/src/autotrain/datagen/generator.py b/src/autotrain/datagen/generator.py
new file mode 100644
index 0000000000..26515a096c
--- /dev/null
+++ b/src/autotrain/datagen/generator.py
@@ -0,0 +1,15 @@
+from pydantic import BaseModel
+
+
+class BaseDataGenerator(BaseModel):
+ def load_data(self):
+ raise NotImplementedError
+
+ def preprocess_data(self):
+ raise NotImplementedError
+
+ def pre_generate_data(self):
+ raise NotImplementedError
+
+ def generate_data(self):
+ raise NotImplementedError
diff --git a/src/autotrain/datagen/text_classification.py b/src/autotrain/datagen/text_classification.py
new file mode 100644
index 0000000000..d7f16f13dc
--- /dev/null
+++ b/src/autotrain/datagen/text_classification.py
@@ -0,0 +1,125 @@
+from autotrain.datagen.generator import BaseDataGenerator
+from autotrain.datagen.clients import Client
+import json
+from autotrain import logger
+import re
+import ijson
+import random
+
+SYSTEM_PROMPT = """
+You are an AI bot that generates data for text classification tasks.
+You do not repeat the question asked by user. You do not generate code.
+Only thing you generate is text data in the specified format.
+The user provides a problem statement and you generate the data.
+For text classification task, the user provides different classes.
+If the user has not provided the classes, generate the classes as well but limit the number of classes to 10.
+"""
+
+DATA_PROMPT = """
+The dataset for text classification is in JSON format.
+Each line should be a JSON object with the following keys: text and target.
+Make sure each text sample has atleast 25 words.
+The target must always be a string.
+Don't write what you are doing. Just generate the data.
+Each line of the output consists of a dictionary with two keys: text and target and nothing else.
+"""
+
+
+def fix_invalid_json(json_string):
+ # Escape backslashes that are not already escaped
+ json_string = re.sub(r'(?
Date: Wed, 11 Sep 2024 16:06:01 +0200
Subject: [PATCH 2/7] gen
---
src/autotrain/cli/autotrain.py | 2 +
src/autotrain/cli/run_gen.py | 56 ++++++
src/autotrain/datagen/clients.py | 51 +++---
src/autotrain/datagen/gen.py | 21 +++
src/autotrain/datagen/generator.py | 15 --
src/autotrain/datagen/params.py | 23 +++
src/autotrain/datagen/text.py | 169 +++++++++++++++++++
src/autotrain/datagen/text_classification.py | 125 --------------
src/autotrain/datagen/utils.py | 32 ++++
9 files changed, 326 insertions(+), 168 deletions(-)
create mode 100644 src/autotrain/cli/run_gen.py
create mode 100644 src/autotrain/datagen/gen.py
create mode 100644 src/autotrain/datagen/params.py
create mode 100644 src/autotrain/datagen/text.py
delete mode 100644 src/autotrain/datagen/text_classification.py
diff --git a/src/autotrain/cli/autotrain.py b/src/autotrain/cli/autotrain.py
index ce8c869e46..0ed6d76093 100644
--- a/src/autotrain/cli/autotrain.py
+++ b/src/autotrain/cli/autotrain.py
@@ -5,6 +5,7 @@
from autotrain.cli.run_app import RunAutoTrainAppCommand
from autotrain.cli.run_dreambooth import RunAutoTrainDreamboothCommand
from autotrain.cli.run_extractive_qa import RunAutoTrainExtractiveQACommand
+from autotrain.cli.run_gen import RunAutoTrainGenCommand
from autotrain.cli.run_image_classification import RunAutoTrainImageClassificationCommand
from autotrain.cli.run_image_regression import RunAutoTrainImageRegressionCommand
from autotrain.cli.run_llm import RunAutoTrainLLMCommand
@@ -49,6 +50,7 @@ def main():
RunAutoTrainSentenceTransformersCommand.register_subcommand(commands_parser)
RunAutoTrainImageRegressionCommand.register_subcommand(commands_parser)
RunAutoTrainExtractiveQACommand.register_subcommand(commands_parser)
+ RunAutoTrainGenCommand.register_subcommand(commands_parser)
args = parser.parse_args()
diff --git a/src/autotrain/cli/run_gen.py b/src/autotrain/cli/run_gen.py
new file mode 100644
index 0000000000..388a4927e9
--- /dev/null
+++ b/src/autotrain/cli/run_gen.py
@@ -0,0 +1,56 @@
+from argparse import ArgumentParser
+
+from autotrain import logger
+from autotrain.cli.utils import get_field_info
+from autotrain.datagen.gen import AutoTrainGen
+from autotrain.datagen.params import AutoTrainGenParams
+
+from . import BaseAutoTrainCommand
+
+
+def run_autotrain_gen_command(args):
+ return RunAutoTrainGenCommand(args)
+
+
+class RunAutoTrainGenCommand(BaseAutoTrainCommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ arg_list = get_field_info(AutoTrainGenParams)
+ run_autotrain_gen_parser = parser.add_parser("gen", description="✨ AutoTrain Gen")
+ for arg in arg_list:
+ names = [arg["arg"]] + arg.get("alias", [])
+ if "action" in arg:
+ run_autotrain_gen_parser.add_argument(
+ *names,
+ dest=arg["arg"].replace("--", "").replace("-", "_"),
+ help=arg["help"],
+ required=arg.get("required", False),
+ action=arg.get("action"),
+ default=arg.get("default"),
+ )
+ else:
+ run_autotrain_gen_parser.add_argument(
+ *names,
+ dest=arg["arg"].replace("--", "").replace("-", "_"),
+ help=arg["help"],
+ required=arg.get("required", False),
+ type=arg.get("type"),
+ default=arg.get("default"),
+ choices=arg.get("choices"),
+ )
+ run_autotrain_gen_parser.set_defaults(func=run_autotrain_gen_command)
+
+ def __init__(self, args):
+ self.args = args
+
+ store_true_arg_names = [
+ "push_to_hub",
+ ]
+ for arg_name in store_true_arg_names:
+ if getattr(self.args, arg_name) is None:
+ setattr(self.args, arg_name, False)
+
+ def run(self):
+ logger.info("Running AutoTrain Gen 🚀")
+ params = AutoTrainGenParams(**vars(self.args))
+ AutoTrainGen(params).run()
diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py
index 2bfd23774d..326cf9ce2a 100644
--- a/src/autotrain/datagen/clients.py
+++ b/src/autotrain/datagen/clients.py
@@ -1,29 +1,16 @@
-from huggingface_hub import InferenceClient
from dataclasses import dataclass
from typing import Optional
-"""
from huggingface_hub import InferenceClient
-
-client = InferenceClient(
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
- token="hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
-)
-
-for message in client.chat_completion(
- messages=[{"role": "user", "content": "What is the capital of France?"}],
- max_tokens=500,
- stream=True,
-):
- print(message.choices[0].delta.content, end="")
-"""
+from autotrain import logger
+import time
@dataclass
class Client:
name: str
model_name: Optional[str] = None
- token: Optional[str] = None
+ api_key: Optional[str] = None
def __post_init__(self):
if self.name == "huggingface":
@@ -40,17 +27,25 @@ def __repr__(self):
return f"Client: {self.name}"
def _huggingface(self):
- if self.token:
- return self.client(self.model_name, token=self.token)
+ if self.api_key:
+ return self.client(self.model_name, token=self.api_key)
return self.client(self.model_name)
- def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None):
- _client = self._huggingface()
- message = _client.chat_completion(
- messages=messages,
- max_tokens=max_tokens,
- stream=False,
- seed=seed,
- response_format=response_format,
- )
- return message
+ def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None, retries=3, delay=5):
+ for attempt in range(retries):
+ try:
+ _client = self._huggingface()
+ message = _client.chat_completion(
+ messages=messages,
+ max_tokens=max_tokens,
+ stream=False,
+ seed=seed,
+ response_format=response_format,
+ )
+ return message
+ except Exception as e:
+ logger.error(f"Attempt {attempt + 1} failed: {e}")
+ if attempt < retries - 1:
+ time.sleep(delay)
+ else:
+ return None
diff --git a/src/autotrain/datagen/gen.py b/src/autotrain/datagen/gen.py
new file mode 100644
index 0000000000..586f196683
--- /dev/null
+++ b/src/autotrain/datagen/gen.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+
+from autotrain import logger
+from autotrain.datagen.params import AutoTrainGenParams
+
+
+@dataclass
+class AutoTrainGen:
+ params: AutoTrainGenParams
+
+ def __post_init__(self):
+ logger.info(self.params)
+ if self.params.task in ("text-classification", "seq2seq"):
+ from autotrain.datagen.text import TextDataGenerator
+
+ self.gen = TextDataGenerator(self.params)
+ else:
+ raise NotImplementedError
+
+ def run(self):
+ self.gen.run()
diff --git a/src/autotrain/datagen/generator.py b/src/autotrain/datagen/generator.py
index 26515a096c..e69de29bb2 100644
--- a/src/autotrain/datagen/generator.py
+++ b/src/autotrain/datagen/generator.py
@@ -1,15 +0,0 @@
-from pydantic import BaseModel
-
-
-class BaseDataGenerator(BaseModel):
- def load_data(self):
- raise NotImplementedError
-
- def preprocess_data(self):
- raise NotImplementedError
-
- def pre_generate_data(self):
- raise NotImplementedError
-
- def generate_data(self):
- raise NotImplementedError
diff --git a/src/autotrain/datagen/params.py b/src/autotrain/datagen/params.py
new file mode 100644
index 0000000000..a43b879413
--- /dev/null
+++ b/src/autotrain/datagen/params.py
@@ -0,0 +1,23 @@
+from typing import Optional
+
+from pydantic import Field
+
+from autotrain.trainers.common import AutoTrainParams
+
+
+class AutoTrainGenParams(AutoTrainParams):
+ gen_model: str = Field("meta-llama/Meta-Llama-3.1-8B-Instruct", title="The model to be used for generation.")
+ project_name: str = Field("autotrain-datagen", title="Name of the project.")
+ prompt: str = Field(None, title="Prompt to be used for text generation.")
+ task: str = Field(None, title="Task type, e.g., text-classification, summarization.")
+ token: Optional[str] = Field(None, title="Authentication token for accessing the model.")
+ training_config: Optional[str] = Field(None, title="Path to the training configuration file.")
+ valid_size: Optional[float] = Field(0.2, title="Validation set size as a fraction of the total dataset.")
+ username: Optional[str] = Field(None, title="Username of the person running the training.")
+ push_to_hub: Optional[bool] = Field(True, title="Whether to push the model to Hugging Face Hub.")
+ backend: Optional[str] = Field("huggingface", title="Backend to be used, e.g., huggingface, local.")
+ api: Optional[str] = Field(None, title="API endpoint to be used.")
+ api_key: Optional[str] = Field(None, title="API key for authentication.")
+ min_samples: Optional[int] = Field(200, title="Minimum number of samples required for training.")
+ # text specific
+ min_words: Optional[int] = Field(25, title="Minimum number of words in the generated text.")
diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py
new file mode 100644
index 0000000000..bda2c4e941
--- /dev/null
+++ b/src/autotrain/datagen/text.py
@@ -0,0 +1,169 @@
+import hashlib
+import random
+import time
+
+import ijson
+
+from autotrain import logger
+from autotrain.datagen import utils
+from autotrain.datagen.clients import Client
+from autotrain.datagen.params import AutoTrainGenParams
+
+
+TEXT_CLASSIFICATION_SYSTEM_PROMPT = """
+You are an AI bot that generates data for text classification tasks.
+You do not repeat the question asked by user. You do not generate code.
+Only thing you generate is text data in the specified format.
+The user provides a problem statement and you generate the data.
+For text classification task, the user provides different classes.
+If the user has not provided the classes, generate the classes as well but limit the number of classes to 10.
+"""
+
+TEXT_CLASSIFICATION_DATA_PROMPT = """
+The dataset for text classification is in JSON format.
+Each line should be a JSON object with the following keys: text and target.
+Make sure each text sample has atleast {min_words} words.
+The target must always be a string.
+Don't write what you are doing. Just generate the data.
+Each line of the output consists of a dictionary with two keys: text and target and nothing else.
+"""
+
+SEQ2SEQ_SYSTEM_PROMPT = """
+You are an AI bot that generates data for sequence-to-sequence tasks.
+You do not repeat the question asked by user. You do not generate code.
+Only thing you generate is text data in the specified format.
+The user provides a problem statement and you generate the data.
+For sequence-to-sequence task, the user provides the input and output format.
+If the user has not provided the input and output format, generate the format as well.
+"""
+
+SEQ2SEQ_DATA_PROMPT = """
+The dataset for sequence-to-sequence is in JSON format.
+Each line should be a JSON object with the following keys: text and target.
+Make sure each text sample has atleast {min_words} words.
+Both text and target sentences must always be a string.
+Don't write what you are doing. Just generate the data.
+Each line of the output consists of a dictionary with two keys: text and target and nothing else.
+"""
+
+
+class TextDataGenerator:
+ def __init__(self, params: AutoTrainGenParams):
+ self.params = params
+ if self.params.task == "text-classification":
+ self.system_prompt = TEXT_CLASSIFICATION_SYSTEM_PROMPT
+ self.data_prompt = TEXT_CLASSIFICATION_DATA_PROMPT
+ elif self.params.task == "seq2seq":
+ self.system_prompt = SEQ2SEQ_SYSTEM_PROMPT
+ self.data_prompt = SEQ2SEQ_DATA_PROMPT
+ else:
+ raise NotImplementedError
+
+ self.data_prompt = self.data_prompt.format(min_words=self.params.min_words)
+
+ def run(self):
+ ask = self.system_prompt + self.data_prompt
+ formatted_message = [{"role": "system", "content": ask}]
+ formatted_message.append({"role": "user", "content": self.params.prompt})
+ logger.info("Prompting the model. Using prompt:")
+ logger.info(formatted_message)
+
+ client = Client(self.params.backend, model_name=self.params.gen_model, api_key=self.params.api_key)
+ clean_result = []
+
+ if self.params.task in ["text-classification", "seq2seq"]:
+ response_format = {
+ "type": "json",
+ "value": {
+ "properties": {
+ "data": {
+ "type": "array",
+ "maxItems": 10,
+ "minItems": 1,
+ "items": {
+ "type": "array",
+ "properties": {
+ "text": {"type": "string"},
+ "target": {"type": "string"},
+ },
+ "required": ["text", "target"],
+ },
+ }
+ },
+ "required": ["data"],
+ },
+ }
+ else:
+ raise NotImplementedError
+
+ counter = 0
+ while True:
+ current_time = time.time()
+ random_number = random.randint(0, 1000000)
+ seed_input = f"{current_time}-{counter}-{random_number}"
+ random_seed = int(hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16) % (10**8)
+ message = client.chat_completion(
+ messages=formatted_message,
+ max_tokens=4096,
+ seed=random_seed,
+ response_format=response_format,
+ )
+
+ if message is None:
+ logger.warning("Failed to generate data. Retrying...")
+ continue
+
+ result = message.choices[0].message.content
+
+ items = []
+ parser = ijson.parse(result)
+
+ current_item = None
+ current_key = None
+
+ try:
+ for prefix, event, value in parser:
+ if event == "start_map":
+ current_item = {}
+ elif event == "map_key":
+ current_key = value
+ elif event == "string" and current_key:
+ current_item[current_key] = value
+ elif event == "end_map" and current_item:
+ items.append(current_item)
+ current_item = None
+ except ijson.common.IncompleteJSONError:
+ # Handle incomplete JSON data
+ logger.warning("Incomplete JSON encountered. Returning parsed data.")
+
+ clean_result.append(items)
+ counter += 1
+ num_items_collected = len([item for sublist in clean_result for item in sublist])
+ logger.info(f"Collected {num_items_collected} items.")
+ if num_items_collected >= self.params.min_samples:
+ break
+
+ # flatten the list
+ clean_result = [item for sublist in clean_result for item in sublist]
+
+ valid_data = None
+ if self.params.valid_size != 0:
+ valid_size = int(self.params.valid_size * len(clean_result))
+ random.shuffle(clean_result)
+ valid_data = clean_result[:valid_size]
+ train_data = clean_result[valid_size:]
+
+ logger.info(f"Train data size: {len(train_data)}")
+ logger.info(f"Valid data size: {len(valid_data)}")
+
+ hf_dataset = utils.convert_text_dataset_to_hf(train_data, valid_data)
+ hf_dataset.save_to_disk(self.params.project_name)
+
+ if self.params.push_to_hub:
+ logger.info("Pushing the data to Hugging Face Hub.")
+ utils.push_data_to_hub(
+ dataset=hf_dataset,
+ dataset_name=self.params.project_name,
+ username=self.params.username,
+ token=self.params.token,
+ )
diff --git a/src/autotrain/datagen/text_classification.py b/src/autotrain/datagen/text_classification.py
deleted file mode 100644
index d7f16f13dc..0000000000
--- a/src/autotrain/datagen/text_classification.py
+++ /dev/null
@@ -1,125 +0,0 @@
-from autotrain.datagen.generator import BaseDataGenerator
-from autotrain.datagen.clients import Client
-import json
-from autotrain import logger
-import re
-import ijson
-import random
-
-SYSTEM_PROMPT = """
-You are an AI bot that generates data for text classification tasks.
-You do not repeat the question asked by user. You do not generate code.
-Only thing you generate is text data in the specified format.
-The user provides a problem statement and you generate the data.
-For text classification task, the user provides different classes.
-If the user has not provided the classes, generate the classes as well but limit the number of classes to 10.
-"""
-
-DATA_PROMPT = """
-The dataset for text classification is in JSON format.
-Each line should be a JSON object with the following keys: text and target.
-Make sure each text sample has atleast 25 words.
-The target must always be a string.
-Don't write what you are doing. Just generate the data.
-Each line of the output consists of a dictionary with two keys: text and target and nothing else.
-"""
-
-
-def fix_invalid_json(json_string):
- # Escape backslashes that are not already escaped
- json_string = re.sub(r'(? Dataset:
+ dataset = Dataset.from_list(train)
+ ddict = {"train": dataset}
+ if valid is not None:
+ valid_dataset = Dataset.from_list(valid)
+ ddict["validation"] = valid_dataset
+ dataset = DatasetDict(ddict)
+ return dataset
+
+
+def push_data_to_hub(dataset: Dataset, dataset_name: str, username: str, token: Optional[str] = None) -> str:
+ if username is None:
+ raise ValueError("Username is required for pushing data to Hugging Face Hub.")
+ if token is None:
+ raise ValueError("Token is required for pushing data to Hugging Face Hub.")
+ repo_id = f"{username}/{dataset_name}"
+ dataset.push_to_hub(repo_id, token=token, private=True)
+ metadata = {
+ "tags": [
+ "autotrain",
+ "gen",
+ "synthetic",
+ ]
+ }
+ metadata_update(repo_id, metadata, token=token, repo_type="dataset", overwrite=True)
+ return repo_id
From 06190a73b8cefff2551575a2950cfb4be6ff002c Mon Sep 17 00:00:00 2001
From: Abhishek Thakur
Date: Thu, 12 Sep 2024 10:11:20 +0200
Subject: [PATCH 3/7] gen
---
src/autotrain/datagen/clients.py | 3 +-
src/autotrain/datagen/params.py | 57 ++++++++++++++++++++++++++++--
src/autotrain/datagen/text.py | 11 +++---
src/autotrain/datagen/utils.py | 59 ++++++++++++++++++++++++++++----
4 files changed, 113 insertions(+), 17 deletions(-)
diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py
index 326cf9ce2a..a7c2c066ff 100644
--- a/src/autotrain/datagen/clients.py
+++ b/src/autotrain/datagen/clients.py
@@ -1,9 +1,10 @@
+import time
from dataclasses import dataclass
from typing import Optional
from huggingface_hub import InferenceClient
+
from autotrain import logger
-import time
@dataclass
diff --git a/src/autotrain/datagen/params.py b/src/autotrain/datagen/params.py
index a43b879413..96a669a967 100644
--- a/src/autotrain/datagen/params.py
+++ b/src/autotrain/datagen/params.py
@@ -1,11 +1,61 @@
+import os
from typing import Optional
-from pydantic import Field
+from pydantic import BaseModel, Field
-from autotrain.trainers.common import AutoTrainParams
+from autotrain import logger
-class AutoTrainGenParams(AutoTrainParams):
+class BaseGenParams(BaseModel):
+ """
+ Base class for all AutoTrain gen parameters.
+ """
+
+ class Config:
+ protected_namespaces = ()
+
+ def save(self, output_dir):
+ """
+ Save parameters to a json file.
+ """
+ os.makedirs(output_dir, exist_ok=True)
+ path = os.path.join(output_dir, "gen_params.json")
+ # save formatted json
+ with open(path, "w", encoding="utf-8") as f:
+ f.write(self.model_dump_json(indent=4))
+
+ def __str__(self):
+ """
+ String representation of the parameters.
+ """
+ data = self.model_dump()
+ data["token"] = "*****" if data.get("token") else None
+ return str(data)
+
+ def __init__(self, **data):
+ """
+ Initialize the parameters, check for unused/extra parameters and warn the user.
+ """
+ super().__init__(**data)
+
+ if len(self.project_name) > 0:
+ if not self.project_name.replace("-", "").isalnum():
+ raise ValueError("project_name must be alphanumeric but can contain hyphens")
+
+ if len(self.project_name) > 50:
+ raise ValueError("project_name cannot be more than 50 characters")
+
+ defaults = set(self.model_fields.keys())
+ supplied = set(data.keys())
+ not_supplied = defaults - supplied
+ if not_supplied:
+ logger.warning(f"Parameters not supplied by user and set to default: {', '.join(not_supplied)}")
+ unused = supplied - set(self.model_fields)
+ if unused:
+ logger.warning(f"Parameters supplied but not used: {', '.join(unused)}")
+
+
+class AutoTrainGenParams(BaseGenParams):
gen_model: str = Field("meta-llama/Meta-Llama-3.1-8B-Instruct", title="The model to be used for generation.")
project_name: str = Field("autotrain-datagen", title="Name of the project.")
prompt: str = Field(None, title="Prompt to be used for text generation.")
@@ -18,6 +68,7 @@ class AutoTrainGenParams(AutoTrainParams):
backend: Optional[str] = Field("huggingface", title="Backend to be used, e.g., huggingface, local.")
api: Optional[str] = Field(None, title="API endpoint to be used.")
api_key: Optional[str] = Field(None, title="API key for authentication.")
+ sample: Optional[str] = Field(None, title="Sample dataset for generation.")
min_samples: Optional[int] = Field(200, title="Minimum number of samples required for training.")
# text specific
min_words: Optional[int] = Field(25, title="Minimum number of words in the generated text.")
diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py
index bda2c4e941..a50c75c381 100644
--- a/src/autotrain/datagen/text.py
+++ b/src/autotrain/datagen/text.py
@@ -59,6 +59,8 @@ def __init__(self, params: AutoTrainGenParams):
else:
raise NotImplementedError
+ self.params.save(output_dir=self.params.project_name)
+
self.data_prompt = self.data_prompt.format(min_words=self.params.min_words)
def run(self):
@@ -161,9 +163,6 @@ def run(self):
if self.params.push_to_hub:
logger.info("Pushing the data to Hugging Face Hub.")
- utils.push_data_to_hub(
- dataset=hf_dataset,
- dataset_name=self.params.project_name,
- username=self.params.username,
- token=self.params.token,
- )
+ utils.push_data_to_hub(params=self.params, dataset=hf_dataset)
+
+ utils.train(params=self.params)
diff --git a/src/autotrain/datagen/utils.py b/src/autotrain/datagen/utils.py
index 353999cfd7..fb5e703c3d 100644
--- a/src/autotrain/datagen/utils.py
+++ b/src/autotrain/datagen/utils.py
@@ -1,7 +1,12 @@
+import json
+import os
+import subprocess
from typing import Dict, List, Optional
from datasets import Dataset, DatasetDict
-from huggingface_hub import metadata_update
+from huggingface_hub import HfApi, metadata_update
+
+from autotrain import logger
def convert_text_dataset_to_hf(train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None) -> Dataset:
@@ -14,13 +19,40 @@ def convert_text_dataset_to_hf(train: List[Dict[str, str]], valid: Optional[List
return dataset
-def push_data_to_hub(dataset: Dataset, dataset_name: str, username: str, token: Optional[str] = None) -> str:
- if username is None:
+def push_data_to_hub(params, dataset) -> str:
+ if params.username is None:
raise ValueError("Username is required for pushing data to Hugging Face Hub.")
- if token is None:
+ if params.token is None:
raise ValueError("Token is required for pushing data to Hugging Face Hub.")
- repo_id = f"{username}/{dataset_name}"
- dataset.push_to_hub(repo_id, token=token, private=True)
+
+ repo_id = f"{params.username}/{params.project_name}"
+ dataset.push_to_hub(repo_id, token=params.token, private=True)
+
+ if os.path.exists(f"{params.project_name}/gen_params.json"):
+ gen_params = json.load(open(f"{params.project_name}/gen_params.json"))
+ if "token" in gen_params:
+ gen_params.pop("token")
+
+ if "api" in gen_params:
+ gen_params.pop("api")
+
+ if "api_key" in gen_params:
+ gen_params.pop("api_key")
+
+ json.dump(
+ gen_params,
+ open(f"{params.project_name}/gen_params.json", "w"),
+ indent=4,
+ )
+
+ api = HfApi(token=params.token)
+ if os.path.exists(f"{params.project_name}/gen_params.json"):
+ api.upload_file(
+ path_or_fileobj=f"{params.project_name}/gen_params.json",
+ repo_id=f"{params.username}/{params.project_name}",
+ repo_type="dataset",
+ path_in_repo="gen_params.json",
+ )
metadata = {
"tags": [
"autotrain",
@@ -28,5 +60,18 @@ def push_data_to_hub(dataset: Dataset, dataset_name: str, username: str, token:
"synthetic",
]
}
- metadata_update(repo_id, metadata, token=token, repo_type="dataset", overwrite=True)
+ metadata_update(repo_id, metadata, token=params.token, repo_type="dataset", overwrite=True)
return repo_id
+
+
+def train(params):
+ if params.training_config is None:
+ logger.info("No training configuration provided. Skipping training...")
+ return
+ cmd = f"autotrain --config {params.training_config}"
+ logger.info(f"Running AutoTrain: {cmd}")
+ cmd = [str(c) for c in cmd]
+ env = os.environ.copy()
+ process = subprocess.Popen(cmd, env=env)
+ process.wait()
+ return process.pid
From 7f553dd4139d0fa5355ed9becd3d4fdc4cffbeaf Mon Sep 17 00:00:00 2001
From: Abhishek Thakur
Date: Sat, 14 Sep 2024 08:34:04 +0200
Subject: [PATCH 4/7] update
---
src/autotrain/app/static/scripts/listeners.js | 9 +++
src/autotrain/app/templates/index.html | 27 +++++++-
src/autotrain/datagen/clients.py | 67 ++++++++++++++++++-
src/autotrain/datagen/text.py | 5 +-
src/autotrain/datagen/utils.py | 20 +++++-
.../trainers/text_classification/__main__.py | 35 ++++++++++
.../trainers/text_classification/params.py | 5 ++
7 files changed, 161 insertions(+), 7 deletions(-)
diff --git a/src/autotrain/app/static/scripts/listeners.js b/src/autotrain/app/static/scripts/listeners.js
index 9638342aea..59b4cfa0fe 100644
--- a/src/autotrain/app/static/scripts/listeners.js
+++ b/src/autotrain/app/static/scripts/listeners.js
@@ -3,6 +3,7 @@ document.addEventListener('DOMContentLoaded', function () {
const uploadDataTabContent = document.getElementById("upload-data-tab-content");
const hubDataTabContent = document.getElementById("hub-data-tab-content");
const uploadDataTabs = document.getElementById("upload-data-tabs");
+ const genDataTabContent = document.getElementById("gen-data-tab-content");
const jsonCheckbox = document.getElementById('show-json-parameters');
const jsonParametersDiv = document.getElementById('json-parameters');
@@ -54,12 +55,20 @@ document.addEventListener('DOMContentLoaded', function () {
if (dataSource.value === "hub") {
uploadDataTabContent.style.display = "none";
uploadDataTabs.style.display = "none";
+ genDataTabContent.style.display = "none";
hubDataTabContent.style.display = "block";
} else if (dataSource.value === "local") {
uploadDataTabContent.style.display = "block";
uploadDataTabs.style.display = "block";
+ genDataTabContent.style.display = "none";
hubDataTabContent.style.display = "none";
}
+ else if (dataSource.value === "gen") {
+ uploadDataTabContent.style.display = "none";
+ uploadDataTabs.style.display = "none";
+ hubDataTabContent.style.display = "none";
+ genDataTabContent.style.display = "block";
+ }
}
async function fetchParams() {
diff --git a/src/autotrain/app/templates/index.html b/src/autotrain/app/templates/index.html
index 00501202cf..76fccd6961 100644
--- a/src/autotrain/app/templates/index.html
+++ b/src/autotrain/app/templates/index.html
@@ -420,9 +420,11 @@
Source
@@ -524,7 +526,30 @@
+
+
+
diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py
index a7c2c066ff..5583f46bee 100644
--- a/src/autotrain/datagen/clients.py
+++ b/src/autotrain/datagen/clients.py
@@ -5,6 +5,57 @@
from huggingface_hub import InferenceClient
from autotrain import logger
+import transformers
+import torch
+import outlines
+import json
+
+
+@dataclass
+class _TransformersClient:
+ model_name: str
+
+ def __post_init__(self):
+ self.pipeline = transformers.pipeline(
+ "text-generation",
+ model=self.model_name,
+ model_kwargs={"torch_dtype": torch.bfloat16},
+ device_map="auto",
+ )
+
+ def chat_completion(self, messages, max_tokens, stream, seed, response_format):
+ outputs = self.pipeline(
+ messages,
+ max_new_tokens=max_tokens,
+ seed=seed,
+ response_format=response_format,
+ stream=stream,
+ )
+ return outputs[0]["generated_text"][-1]["content"]
+
+
+@dataclass
+class TransformersClient:
+ model_name: str
+
+ def __post_init__(self):
+ self.pipeline = outlines.models.transformers(
+ self.model_name,
+ # device_map="auto",
+ model_kwargs={"torch_dtype": torch.bfloat16},
+ )
+
+ def chat_completion(self, messages, max_tokens, stream, seed, response_format):
+ # dump response_format dict to json
+ response_format = json.dumps(response_format)
+ generator = outlines.generate.json(self.pipeline, response_format)
+ outputs = generator(
+ messages,
+ max_tokens=max_tokens,
+ seed=seed,
+ )
+ print(outputs)
+ return outputs[0]["generated_text"][-1]["content"]
@dataclass
@@ -14,10 +65,14 @@ class Client:
api_key: Optional[str] = None
def __post_init__(self):
- if self.name == "huggingface":
+ if self.name == "hf-inference-api":
if self.model_name is None:
raise ValueError("Model name is required for Huggingface")
self.client = InferenceClient
+ elif self.name == "transformers":
+ if self.model_name is None:
+ raise ValueError("Model name is required for Transformers")
+ self.client = TransformersClient
else:
raise ValueError("Client not supported")
@@ -32,10 +87,18 @@ def _huggingface(self):
return self.client(self.model_name, token=self.api_key)
return self.client(self.model_name)
+ def _transformers(self):
+ return self.client(self.model_name)
+
def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None, retries=3, delay=5):
+ if self.name == "hf-inference-api":
+ _client = self._huggingface()
+ elif self.name == "transformers":
+ _client = self._transformers()
+ else:
+ raise ValueError("Client not supported")
for attempt in range(retries):
try:
- _client = self._huggingface()
message = _client.chat_completion(
messages=messages,
max_tokens=max_tokens,
diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py
index a50c75c381..77277e5502 100644
--- a/src/autotrain/datagen/text.py
+++ b/src/autotrain/datagen/text.py
@@ -9,6 +9,7 @@
from autotrain.datagen.clients import Client
from autotrain.datagen.params import AutoTrainGenParams
+import os
TEXT_CLASSIFICATION_SYSTEM_PROMPT = """
You are an AI bot that generates data for text classification tasks.
@@ -158,8 +159,8 @@ def run(self):
logger.info(f"Train data size: {len(train_data)}")
logger.info(f"Valid data size: {len(valid_data)}")
- hf_dataset = utils.convert_text_dataset_to_hf(train_data, valid_data)
- hf_dataset.save_to_disk(self.params.project_name)
+ hf_dataset = utils.convert_text_dataset_to_hf(self.params.task, train_data, valid_data)
+ hf_dataset.save_to_disk(os.path.join(self.params.project_name, "autotrain-data"))
if self.params.push_to_hub:
logger.info("Pushing the data to Hugging Face Hub.")
diff --git a/src/autotrain/datagen/utils.py b/src/autotrain/datagen/utils.py
index fb5e703c3d..3a43c0d840 100644
--- a/src/autotrain/datagen/utils.py
+++ b/src/autotrain/datagen/utils.py
@@ -3,17 +3,33 @@
import subprocess
from typing import Dict, List, Optional
-from datasets import Dataset, DatasetDict
+from datasets import Dataset, DatasetDict, ClassLabel
from huggingface_hub import HfApi, metadata_update
from autotrain import logger
-def convert_text_dataset_to_hf(train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None) -> Dataset:
+def convert_text_dataset_to_hf(
+ task, train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None
+) -> Dataset:
+ if task == "text-classification":
+ for item in train:
+ item["target"] = item["target"].lower().strip()
+ label_names = list(set([item["target"] for item in train]))
+ logger.info(f"Label names: {label_names}")
+
dataset = Dataset.from_list(train)
+
+ if task == "text-classification":
+ dataset = dataset.cast_column("target", ClassLabel(names=label_names))
+
ddict = {"train": dataset}
if valid is not None:
valid_dataset = Dataset.from_list(valid)
+ if task == "text-classification":
+ for item in valid:
+ item["target"] = item["target"].lower().strip()
+ valid_dataset = valid_dataset.cast_column("target", ClassLabel(names=label_names))
ddict["validation"] = valid_dataset
dataset = DatasetDict(ddict)
return dataset
diff --git a/src/autotrain/trainers/text_classification/__main__.py b/src/autotrain/trainers/text_classification/__main__.py
index 5ce3918e87..f5a5aa1637 100644
--- a/src/autotrain/trainers/text_classification/__main__.py
+++ b/src/autotrain/trainers/text_classification/__main__.py
@@ -28,6 +28,8 @@
from autotrain.trainers.text_classification import utils
from autotrain.trainers.text_classification.dataset import TextClassificationDataset
from autotrain.trainers.text_classification.params import TextClassificationParams
+from autotrain.datagen.params import AutoTrainGenParams
+from autotrain.datagen.gen import AutoTrainGen
def parse_args():
@@ -42,6 +44,39 @@ def train(config):
if isinstance(config, dict):
config = TextClassificationParams(**config)
+ if config.gen_prompt is not None:
+ if config.gen_model is None:
+ raise ValueError("Generation model must be provided when using generation prompt")
+ gen_params = config.gen_params.split(",") if config.gen_params is not None else []
+ gen_params = {p.split("=")[0]: p.split("=")[1] for p in gen_params}
+ gen_config = AutoTrainGenParams(
+ gen_model=config.gen_model,
+ project_name=config.project_name,
+ prompt=config.gen_prompt,
+ task="text-classification",
+ token=config.token,
+ username=config.username,
+ push_to_hub=config.push_to_hub,
+ backend=gen_params.get("backend", "transformers"),
+ api=gen_params.get("api", None),
+ api_key=gen_params.get("api_key", None),
+ min_samples=config.gen_samples,
+ )
+ try:
+ AutoTrainGen(gen_config).run()
+ except Exception as e:
+ logger.error(f"Error generating data: {e}")
+ return
+ if config.push_to_hub:
+ config.data_path = f"{config.username}/{config.project_name}"
+ else:
+ config.data_path = f"{config.project_name}/autotrain-data"
+
+ config.train_split = "train"
+ config.valid_split = "validation"
+ config.target_column = "target"
+ config.text_column = "text"
+
train_data = None
valid_data = None
# check if config.train_split.csv exists in config.data_path
diff --git a/src/autotrain/trainers/text_classification/params.py b/src/autotrain/trainers/text_classification/params.py
index c38f1d6c72..79fd4b8e94 100644
--- a/src/autotrain/trainers/text_classification/params.py
+++ b/src/autotrain/trainers/text_classification/params.py
@@ -35,3 +35,8 @@ class TextClassificationParams(AutoTrainParams):
log: str = Field("none", title="Logging using experiment tracking")
early_stopping_patience: int = Field(5, title="Early stopping patience")
early_stopping_threshold: float = Field(0.01, title="Early stopping threshold")
+ # generation parameters
+ gen_prompt: Optional[str] = Field(None, title="Generation prompt")
+ gen_model: Optional[str] = Field(None, title="Generation model")
+ gen_samples: Optional[int] = Field(100, title="Generation samples")
+ gen_params: Optional[str] = Field(None, title="Generation parameters, format: key1=value1,key2=value2")
From 6a0cedc8d302aea5e81724773b0f1559ea669238 Mon Sep 17 00:00:00 2001
From: abhishekkrthakur
Date: Wed, 25 Sep 2024 11:54:13 +0200
Subject: [PATCH 5/7] gen
---
src/autotrain/datagen/clients.py | 8 ++++----
src/autotrain/datagen/generator.py | 0
src/autotrain/datagen/text.py | 2 +-
src/autotrain/datagen/utils.py | 2 +-
src/autotrain/trainers/text_classification/__main__.py | 4 ++--
5 files changed, 8 insertions(+), 8 deletions(-)
delete mode 100644 src/autotrain/datagen/generator.py
diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py
index 5583f46bee..a9d86792e7 100644
--- a/src/autotrain/datagen/clients.py
+++ b/src/autotrain/datagen/clients.py
@@ -1,14 +1,14 @@
+import json
import time
from dataclasses import dataclass
from typing import Optional
+import outlines
+import torch
+import transformers
from huggingface_hub import InferenceClient
from autotrain import logger
-import transformers
-import torch
-import outlines
-import json
@dataclass
diff --git a/src/autotrain/datagen/generator.py b/src/autotrain/datagen/generator.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py
index 77277e5502..bf2a0404e9 100644
--- a/src/autotrain/datagen/text.py
+++ b/src/autotrain/datagen/text.py
@@ -1,4 +1,5 @@
import hashlib
+import os
import random
import time
@@ -9,7 +10,6 @@
from autotrain.datagen.clients import Client
from autotrain.datagen.params import AutoTrainGenParams
-import os
TEXT_CLASSIFICATION_SYSTEM_PROMPT = """
You are an AI bot that generates data for text classification tasks.
diff --git a/src/autotrain/datagen/utils.py b/src/autotrain/datagen/utils.py
index 3a43c0d840..34499aedd2 100644
--- a/src/autotrain/datagen/utils.py
+++ b/src/autotrain/datagen/utils.py
@@ -3,7 +3,7 @@
import subprocess
from typing import Dict, List, Optional
-from datasets import Dataset, DatasetDict, ClassLabel
+from datasets import ClassLabel, Dataset, DatasetDict
from huggingface_hub import HfApi, metadata_update
from autotrain import logger
diff --git a/src/autotrain/trainers/text_classification/__main__.py b/src/autotrain/trainers/text_classification/__main__.py
index f5a5aa1637..bfc51b78ab 100644
--- a/src/autotrain/trainers/text_classification/__main__.py
+++ b/src/autotrain/trainers/text_classification/__main__.py
@@ -15,6 +15,8 @@
from transformers.trainer_callback import PrinterCallback
from autotrain import logger
+from autotrain.datagen.gen import AutoTrainGen
+from autotrain.datagen.params import AutoTrainGenParams
from autotrain.trainers.common import (
ALLOW_REMOTE_CODE,
LossLoggingCallback,
@@ -28,8 +30,6 @@
from autotrain.trainers.text_classification import utils
from autotrain.trainers.text_classification.dataset import TextClassificationDataset
from autotrain.trainers.text_classification.params import TextClassificationParams
-from autotrain.datagen.params import AutoTrainGenParams
-from autotrain.datagen.gen import AutoTrainGen
def parse_args():
From a7dade5109fc9f966e9edd5b074b226afee9cd3b Mon Sep 17 00:00:00 2001
From: Sara Han <127759186+sdiazlor@users.noreply.github.com>
Date: Mon, 9 Dec 2024 14:23:09 +0100
Subject: [PATCH 6/7] Add distilabel gen (#819)
---
requirements.txt | 3 +
src/autotrain/datagen/text.py | 246 +++++++++++++++++++++++-----------
2 files changed, 173 insertions(+), 76 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index 915c384ad7..0fcf0c98c6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -24,12 +24,15 @@ invisible-watermark==0.2.0
packaging==24.0
cryptography==42.0.5
nvitop==1.3.2
+llama-cpp-python==0.2.90
+outlines==0.0.46
# latest versions
tensorboard==2.16.2
peft==0.12.0
trl==0.10.1
tiktoken==0.6.0
transformers==4.44.2
+distilabel==1.4.1
accelerate==0.34.1
diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py
index bf2a0404e9..45285f1025 100644
--- a/src/autotrain/datagen/text.py
+++ b/src/autotrain/datagen/text.py
@@ -4,6 +4,9 @@
import time
import ijson
+from pydantic import BaseModel
+from distilabel.llms import LlamaCppLLM
+from distilabel.steps.tasks import TextGeneration
from autotrain import logger
from autotrain.datagen import utils
@@ -23,7 +26,7 @@
TEXT_CLASSIFICATION_DATA_PROMPT = """
The dataset for text classification is in JSON format.
Each line should be a JSON object with the following keys: text and target.
-Make sure each text sample has atleast {min_words} words.
+Make sure each text sample has at least {min_words} words.
The target must always be a string.
Don't write what you are doing. Just generate the data.
Each line of the output consists of a dictionary with two keys: text and target and nothing else.
@@ -41,7 +44,7 @@
SEQ2SEQ_DATA_PROMPT = """
The dataset for sequence-to-sequence is in JSON format.
Each line should be a JSON object with the following keys: text and target.
-Make sure each text sample has atleast {min_words} words.
+Make sure each text sample has at least {min_words} words.
Both text and target sentences must always be a string.
Don't write what you are doing. Just generate the data.
Each line of the output consists of a dictionary with two keys: text and target and nothing else.
@@ -71,83 +74,170 @@ def run(self):
logger.info("Prompting the model. Using prompt:")
logger.info(formatted_message)
- client = Client(self.params.backend, model_name=self.params.gen_model, api_key=self.params.api_key)
- clean_result = []
-
- if self.params.task in ["text-classification", "seq2seq"]:
- response_format = {
- "type": "json",
- "value": {
- "properties": {
- "data": {
- "type": "array",
- "maxItems": 10,
- "minItems": 1,
- "items": {
+ if self.params.backend == "local":
+
+ if self.params.task in ["text-classification", "seq2seq"]:
+
+ class Data(BaseModel):
+ text: str
+ target: str
+
+ else:
+ raise NotImplementedError
+
+ def get_generator(seed: int):
+ generator = TextGeneration(
+ llm=LlamaCppLLM(
+ model_path=self.params.gen_model,
+ n_gpu_layers=-1,
+ n_ctx=1024,
+ seed=seed,
+ structured_output={"format": "json", "schema": Data},
+ ),
+ num_generations=10,
+ system_prompt=ask,
+ )
+ generator.load()
+ return generator
+
+ counter = 0
+ clean_result = []
+ while True:
+ current_time = time.time()
+ random_number = random.randint(0, 1000000)
+ seed_input = f"{current_time}-{counter}-{random_number}"
+ random_seed = int(
+ hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16
+ ) % (10**8)
+ generator = get_generator(seed=random_seed)
+ results = list(
+ generator.process(
+ [
+ {
+ "instruction": self.params.prompt,
+ }
+ ]
+ )
+ )
+
+ for result in results[0]:
+ items = []
+ parser = ijson.parse(result.get("generation", ""))
+
+ current_item = None
+ current_key = None
+
+ try:
+ for prefix, event, value in parser:
+ if event == "start_map":
+ current_item = {}
+ elif event == "map_key":
+ current_key = value
+ elif event == "string" and current_key:
+ current_item[current_key] = value
+ elif event == "end_map" and current_item:
+ items.append(current_item)
+ current_item = None
+ except ijson.common.IncompleteJSONError:
+ logger.warning(
+ "Incomplete JSON encountered. Returning parsed data."
+ )
+
+ clean_result.append(items[0])
+ counter += 1
+ num_items_collected = len(clean_result)
+ logger.info(f"Collected {num_items_collected} items.")
+ if num_items_collected >= 10:
+ break
+
+ else:
+ client = Client(
+ self.params.backend,
+ model_name=self.params.gen_model,
+ api_key=self.params.api_key,
+ )
+ clean_result = []
+
+ if self.params.task in ["text-classification", "seq2seq"]:
+ response_format = {
+ "type": "json",
+ "value": {
+ "properties": {
+ "data": {
"type": "array",
- "properties": {
- "text": {"type": "string"},
- "target": {"type": "string"},
+ "maxItems": 10,
+ "minItems": 1,
+ "items": {
+ "type": "array",
+ "properties": {
+ "text": {"type": "string"},
+ "target": {"type": "string"},
+ },
+ "required": ["text", "target"],
},
- "required": ["text", "target"],
- },
- }
+ }
+ },
+ "required": ["data"],
},
- "required": ["data"],
- },
- }
- else:
- raise NotImplementedError
+ }
+ else:
+ raise NotImplementedError
- counter = 0
- while True:
- current_time = time.time()
- random_number = random.randint(0, 1000000)
- seed_input = f"{current_time}-{counter}-{random_number}"
- random_seed = int(hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16) % (10**8)
- message = client.chat_completion(
- messages=formatted_message,
- max_tokens=4096,
- seed=random_seed,
- response_format=response_format,
- )
+ counter = 0
+ while True:
+ current_time = time.time()
+ random_number = random.randint(0, 1000000)
+ seed_input = f"{current_time}-{counter}-{random_number}"
+ random_seed = int(
+ hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16
+ ) % (10**8)
+ message = client.chat_completion(
+ messages=formatted_message,
+ max_tokens=4096,
+ seed=random_seed,
+ response_format=response_format,
+ )
+
+ if message is None:
+ logger.warning("Failed to generate data. Retrying...")
+ continue
+
+ result = message.choices[0].message.content
+
+ items = []
+ parser = ijson.parse(result)
+
+ current_item = None
+ current_key = None
+
+ try:
+ for prefix, event, value in parser:
+ if event == "start_map":
+ current_item = {}
+ elif event == "map_key":
+ current_key = value
+ elif event == "string" and current_key:
+ current_item[current_key] = value
+ elif event == "end_map" and current_item:
+ items.append(current_item)
+ current_item = None
+ except ijson.common.IncompleteJSONError:
+ # Handle incomplete JSON data
+ logger.warning(
+ "Incomplete JSON encountered. Returning parsed data."
+ )
+
+ clean_result.append(items)
+ counter += 1
+ num_items_collected = len(
+ [item for sublist in clean_result for item in sublist]
+ )
+ logger.info(f"Collected {num_items_collected} items.")
+ if num_items_collected >= self.params.min_samples:
+ break
- if message is None:
- logger.warning("Failed to generate data. Retrying...")
- continue
-
- result = message.choices[0].message.content
-
- items = []
- parser = ijson.parse(result)
-
- current_item = None
- current_key = None
-
- try:
- for prefix, event, value in parser:
- if event == "start_map":
- current_item = {}
- elif event == "map_key":
- current_key = value
- elif event == "string" and current_key:
- current_item[current_key] = value
- elif event == "end_map" and current_item:
- items.append(current_item)
- current_item = None
- except ijson.common.IncompleteJSONError:
- # Handle incomplete JSON data
- logger.warning("Incomplete JSON encountered. Returning parsed data.")
-
- clean_result.append(items)
- counter += 1
- num_items_collected = len([item for sublist in clean_result for item in sublist])
- logger.info(f"Collected {num_items_collected} items.")
- if num_items_collected >= self.params.min_samples:
- break
-
- # flatten the list
- clean_result = [item for sublist in clean_result for item in sublist]
+ # flatten the list
+ clean_result = [item for sublist in clean_result for item in sublist]
valid_data = None
if self.params.valid_size != 0:
@@ -159,8 +249,12 @@ def run(self):
logger.info(f"Train data size: {len(train_data)}")
logger.info(f"Valid data size: {len(valid_data)}")
- hf_dataset = utils.convert_text_dataset_to_hf(self.params.task, train_data, valid_data)
- hf_dataset.save_to_disk(os.path.join(self.params.project_name, "autotrain-data"))
+ hf_dataset = utils.convert_text_dataset_to_hf(
+ self.params.task, train_data, valid_data
+ )
+ hf_dataset.save_to_disk(
+ os.path.join(self.params.project_name, "autotrain-data")
+ )
if self.params.push_to_hub:
logger.info("Pushing the data to Hugging Face Hub.")
From 515e4e52b81a2cf33529bee8bb4e418548593a8e Mon Sep 17 00:00:00 2001
From: abhishekkrthakur
Date: Mon, 9 Dec 2024 14:27:16 +0100
Subject: [PATCH 7/7] fix style
---
src/autotrain/datagen/text.py | 30 ++++++++----------------------
1 file changed, 8 insertions(+), 22 deletions(-)
diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py
index 45285f1025..f94d0f5d34 100644
--- a/src/autotrain/datagen/text.py
+++ b/src/autotrain/datagen/text.py
@@ -4,9 +4,9 @@
import time
import ijson
-from pydantic import BaseModel
from distilabel.llms import LlamaCppLLM
from distilabel.steps.tasks import TextGeneration
+from pydantic import BaseModel
from autotrain import logger
from autotrain.datagen import utils
@@ -106,9 +106,7 @@ def get_generator(seed: int):
current_time = time.time()
random_number = random.randint(0, 1000000)
seed_input = f"{current_time}-{counter}-{random_number}"
- random_seed = int(
- hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16
- ) % (10**8)
+ random_seed = int(hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16) % (10**8)
generator = get_generator(seed=random_seed)
results = list(
generator.process(
@@ -139,9 +137,7 @@ def get_generator(seed: int):
items.append(current_item)
current_item = None
except ijson.common.IncompleteJSONError:
- logger.warning(
- "Incomplete JSON encountered. Returning parsed data."
- )
+ logger.warning("Incomplete JSON encountered. Returning parsed data.")
clean_result.append(items[0])
counter += 1
@@ -188,9 +184,7 @@ def get_generator(seed: int):
current_time = time.time()
random_number = random.randint(0, 1000000)
seed_input = f"{current_time}-{counter}-{random_number}"
- random_seed = int(
- hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16
- ) % (10**8)
+ random_seed = int(hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16) % (10**8)
message = client.chat_completion(
messages=formatted_message,
max_tokens=4096,
@@ -223,15 +217,11 @@ def get_generator(seed: int):
current_item = None
except ijson.common.IncompleteJSONError:
# Handle incomplete JSON data
- logger.warning(
- "Incomplete JSON encountered. Returning parsed data."
- )
+ logger.warning("Incomplete JSON encountered. Returning parsed data.")
clean_result.append(items)
counter += 1
- num_items_collected = len(
- [item for sublist in clean_result for item in sublist]
- )
+ num_items_collected = len([item for sublist in clean_result for item in sublist])
logger.info(f"Collected {num_items_collected} items.")
if num_items_collected >= self.params.min_samples:
break
@@ -249,12 +239,8 @@ def get_generator(seed: int):
logger.info(f"Train data size: {len(train_data)}")
logger.info(f"Valid data size: {len(valid_data)}")
- hf_dataset = utils.convert_text_dataset_to_hf(
- self.params.task, train_data, valid_data
- )
- hf_dataset.save_to_disk(
- os.path.join(self.params.project_name, "autotrain-data")
- )
+ hf_dataset = utils.convert_text_dataset_to_hf(self.params.task, train_data, valid_data)
+ hf_dataset.save_to_disk(os.path.join(self.params.project_name, "autotrain-data"))
if self.params.push_to_hub:
logger.info("Pushing the data to Hugging Face Hub.")