From c65bcf7fca6d4b82f64244f6cfc499e31b226fb3 Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Tue, 10 Sep 2024 15:20:50 +0200 Subject: [PATCH 1/7] gen1 --- src/autotrain/datagen/__init__.py | 0 src/autotrain/datagen/clients.py | 56 +++++++++ src/autotrain/datagen/generator.py | 15 +++ src/autotrain/datagen/text_classification.py | 125 +++++++++++++++++++ src/autotrain/datagen/utils.py | 0 5 files changed, 196 insertions(+) create mode 100644 src/autotrain/datagen/__init__.py create mode 100644 src/autotrain/datagen/clients.py create mode 100644 src/autotrain/datagen/generator.py create mode 100644 src/autotrain/datagen/text_classification.py create mode 100644 src/autotrain/datagen/utils.py diff --git a/src/autotrain/datagen/__init__.py b/src/autotrain/datagen/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py new file mode 100644 index 0000000000..2bfd23774d --- /dev/null +++ b/src/autotrain/datagen/clients.py @@ -0,0 +1,56 @@ +from huggingface_hub import InferenceClient +from dataclasses import dataclass +from typing import Optional + +""" +from huggingface_hub import InferenceClient + +client = InferenceClient( + "meta-llama/Meta-Llama-3.1-8B-Instruct", + token="hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", +) + +for message in client.chat_completion( + messages=[{"role": "user", "content": "What is the capital of France?"}], + max_tokens=500, + stream=True, +): + print(message.choices[0].delta.content, end="") +""" + + +@dataclass +class Client: + name: str + model_name: Optional[str] = None + token: Optional[str] = None + + def __post_init__(self): + if self.name == "huggingface": + if self.model_name is None: + raise ValueError("Model name is required for Huggingface") + self.client = InferenceClient + else: + raise ValueError("Client not supported") + + def __str__(self): + return f"Client: {self.name}" + + def __repr__(self): + return f"Client: {self.name}" + + def _huggingface(self): + if self.token: + return self.client(self.model_name, token=self.token) + return self.client(self.model_name) + + def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None): + _client = self._huggingface() + message = _client.chat_completion( + messages=messages, + max_tokens=max_tokens, + stream=False, + seed=seed, + response_format=response_format, + ) + return message diff --git a/src/autotrain/datagen/generator.py b/src/autotrain/datagen/generator.py new file mode 100644 index 0000000000..26515a096c --- /dev/null +++ b/src/autotrain/datagen/generator.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + + +class BaseDataGenerator(BaseModel): + def load_data(self): + raise NotImplementedError + + def preprocess_data(self): + raise NotImplementedError + + def pre_generate_data(self): + raise NotImplementedError + + def generate_data(self): + raise NotImplementedError diff --git a/src/autotrain/datagen/text_classification.py b/src/autotrain/datagen/text_classification.py new file mode 100644 index 0000000000..d7f16f13dc --- /dev/null +++ b/src/autotrain/datagen/text_classification.py @@ -0,0 +1,125 @@ +from autotrain.datagen.generator import BaseDataGenerator +from autotrain.datagen.clients import Client +import json +from autotrain import logger +import re +import ijson +import random + +SYSTEM_PROMPT = """ +You are an AI bot that generates data for text classification tasks. +You do not repeat the question asked by user. You do not generate code. +Only thing you generate is text data in the specified format. +The user provides a problem statement and you generate the data. +For text classification task, the user provides different classes. +If the user has not provided the classes, generate the classes as well but limit the number of classes to 10. +""" + +DATA_PROMPT = """ +The dataset for text classification is in JSON format. +Each line should be a JSON object with the following keys: text and target. +Make sure each text sample has atleast 25 words. +The target must always be a string. +Don't write what you are doing. Just generate the data. +Each line of the output consists of a dictionary with two keys: text and target and nothing else. +""" + + +def fix_invalid_json(json_string): + # Escape backslashes that are not already escaped + json_string = re.sub(r'(? Date: Wed, 11 Sep 2024 16:06:01 +0200 Subject: [PATCH 2/7] gen --- src/autotrain/cli/autotrain.py | 2 + src/autotrain/cli/run_gen.py | 56 ++++++ src/autotrain/datagen/clients.py | 51 +++--- src/autotrain/datagen/gen.py | 21 +++ src/autotrain/datagen/generator.py | 15 -- src/autotrain/datagen/params.py | 23 +++ src/autotrain/datagen/text.py | 169 +++++++++++++++++++ src/autotrain/datagen/text_classification.py | 125 -------------- src/autotrain/datagen/utils.py | 32 ++++ 9 files changed, 326 insertions(+), 168 deletions(-) create mode 100644 src/autotrain/cli/run_gen.py create mode 100644 src/autotrain/datagen/gen.py create mode 100644 src/autotrain/datagen/params.py create mode 100644 src/autotrain/datagen/text.py delete mode 100644 src/autotrain/datagen/text_classification.py diff --git a/src/autotrain/cli/autotrain.py b/src/autotrain/cli/autotrain.py index ce8c869e46..0ed6d76093 100644 --- a/src/autotrain/cli/autotrain.py +++ b/src/autotrain/cli/autotrain.py @@ -5,6 +5,7 @@ from autotrain.cli.run_app import RunAutoTrainAppCommand from autotrain.cli.run_dreambooth import RunAutoTrainDreamboothCommand from autotrain.cli.run_extractive_qa import RunAutoTrainExtractiveQACommand +from autotrain.cli.run_gen import RunAutoTrainGenCommand from autotrain.cli.run_image_classification import RunAutoTrainImageClassificationCommand from autotrain.cli.run_image_regression import RunAutoTrainImageRegressionCommand from autotrain.cli.run_llm import RunAutoTrainLLMCommand @@ -49,6 +50,7 @@ def main(): RunAutoTrainSentenceTransformersCommand.register_subcommand(commands_parser) RunAutoTrainImageRegressionCommand.register_subcommand(commands_parser) RunAutoTrainExtractiveQACommand.register_subcommand(commands_parser) + RunAutoTrainGenCommand.register_subcommand(commands_parser) args = parser.parse_args() diff --git a/src/autotrain/cli/run_gen.py b/src/autotrain/cli/run_gen.py new file mode 100644 index 0000000000..388a4927e9 --- /dev/null +++ b/src/autotrain/cli/run_gen.py @@ -0,0 +1,56 @@ +from argparse import ArgumentParser + +from autotrain import logger +from autotrain.cli.utils import get_field_info +from autotrain.datagen.gen import AutoTrainGen +from autotrain.datagen.params import AutoTrainGenParams + +from . import BaseAutoTrainCommand + + +def run_autotrain_gen_command(args): + return RunAutoTrainGenCommand(args) + + +class RunAutoTrainGenCommand(BaseAutoTrainCommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + arg_list = get_field_info(AutoTrainGenParams) + run_autotrain_gen_parser = parser.add_parser("gen", description="✨ AutoTrain Gen") + for arg in arg_list: + names = [arg["arg"]] + arg.get("alias", []) + if "action" in arg: + run_autotrain_gen_parser.add_argument( + *names, + dest=arg["arg"].replace("--", "").replace("-", "_"), + help=arg["help"], + required=arg.get("required", False), + action=arg.get("action"), + default=arg.get("default"), + ) + else: + run_autotrain_gen_parser.add_argument( + *names, + dest=arg["arg"].replace("--", "").replace("-", "_"), + help=arg["help"], + required=arg.get("required", False), + type=arg.get("type"), + default=arg.get("default"), + choices=arg.get("choices"), + ) + run_autotrain_gen_parser.set_defaults(func=run_autotrain_gen_command) + + def __init__(self, args): + self.args = args + + store_true_arg_names = [ + "push_to_hub", + ] + for arg_name in store_true_arg_names: + if getattr(self.args, arg_name) is None: + setattr(self.args, arg_name, False) + + def run(self): + logger.info("Running AutoTrain Gen 🚀") + params = AutoTrainGenParams(**vars(self.args)) + AutoTrainGen(params).run() diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py index 2bfd23774d..326cf9ce2a 100644 --- a/src/autotrain/datagen/clients.py +++ b/src/autotrain/datagen/clients.py @@ -1,29 +1,16 @@ -from huggingface_hub import InferenceClient from dataclasses import dataclass from typing import Optional -""" from huggingface_hub import InferenceClient - -client = InferenceClient( - "meta-llama/Meta-Llama-3.1-8B-Instruct", - token="hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", -) - -for message in client.chat_completion( - messages=[{"role": "user", "content": "What is the capital of France?"}], - max_tokens=500, - stream=True, -): - print(message.choices[0].delta.content, end="") -""" +from autotrain import logger +import time @dataclass class Client: name: str model_name: Optional[str] = None - token: Optional[str] = None + api_key: Optional[str] = None def __post_init__(self): if self.name == "huggingface": @@ -40,17 +27,25 @@ def __repr__(self): return f"Client: {self.name}" def _huggingface(self): - if self.token: - return self.client(self.model_name, token=self.token) + if self.api_key: + return self.client(self.model_name, token=self.api_key) return self.client(self.model_name) - def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None): - _client = self._huggingface() - message = _client.chat_completion( - messages=messages, - max_tokens=max_tokens, - stream=False, - seed=seed, - response_format=response_format, - ) - return message + def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None, retries=3, delay=5): + for attempt in range(retries): + try: + _client = self._huggingface() + message = _client.chat_completion( + messages=messages, + max_tokens=max_tokens, + stream=False, + seed=seed, + response_format=response_format, + ) + return message + except Exception as e: + logger.error(f"Attempt {attempt + 1} failed: {e}") + if attempt < retries - 1: + time.sleep(delay) + else: + return None diff --git a/src/autotrain/datagen/gen.py b/src/autotrain/datagen/gen.py new file mode 100644 index 0000000000..586f196683 --- /dev/null +++ b/src/autotrain/datagen/gen.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from autotrain import logger +from autotrain.datagen.params import AutoTrainGenParams + + +@dataclass +class AutoTrainGen: + params: AutoTrainGenParams + + def __post_init__(self): + logger.info(self.params) + if self.params.task in ("text-classification", "seq2seq"): + from autotrain.datagen.text import TextDataGenerator + + self.gen = TextDataGenerator(self.params) + else: + raise NotImplementedError + + def run(self): + self.gen.run() diff --git a/src/autotrain/datagen/generator.py b/src/autotrain/datagen/generator.py index 26515a096c..e69de29bb2 100644 --- a/src/autotrain/datagen/generator.py +++ b/src/autotrain/datagen/generator.py @@ -1,15 +0,0 @@ -from pydantic import BaseModel - - -class BaseDataGenerator(BaseModel): - def load_data(self): - raise NotImplementedError - - def preprocess_data(self): - raise NotImplementedError - - def pre_generate_data(self): - raise NotImplementedError - - def generate_data(self): - raise NotImplementedError diff --git a/src/autotrain/datagen/params.py b/src/autotrain/datagen/params.py new file mode 100644 index 0000000000..a43b879413 --- /dev/null +++ b/src/autotrain/datagen/params.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import Field + +from autotrain.trainers.common import AutoTrainParams + + +class AutoTrainGenParams(AutoTrainParams): + gen_model: str = Field("meta-llama/Meta-Llama-3.1-8B-Instruct", title="The model to be used for generation.") + project_name: str = Field("autotrain-datagen", title="Name of the project.") + prompt: str = Field(None, title="Prompt to be used for text generation.") + task: str = Field(None, title="Task type, e.g., text-classification, summarization.") + token: Optional[str] = Field(None, title="Authentication token for accessing the model.") + training_config: Optional[str] = Field(None, title="Path to the training configuration file.") + valid_size: Optional[float] = Field(0.2, title="Validation set size as a fraction of the total dataset.") + username: Optional[str] = Field(None, title="Username of the person running the training.") + push_to_hub: Optional[bool] = Field(True, title="Whether to push the model to Hugging Face Hub.") + backend: Optional[str] = Field("huggingface", title="Backend to be used, e.g., huggingface, local.") + api: Optional[str] = Field(None, title="API endpoint to be used.") + api_key: Optional[str] = Field(None, title="API key for authentication.") + min_samples: Optional[int] = Field(200, title="Minimum number of samples required for training.") + # text specific + min_words: Optional[int] = Field(25, title="Minimum number of words in the generated text.") diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py new file mode 100644 index 0000000000..bda2c4e941 --- /dev/null +++ b/src/autotrain/datagen/text.py @@ -0,0 +1,169 @@ +import hashlib +import random +import time + +import ijson + +from autotrain import logger +from autotrain.datagen import utils +from autotrain.datagen.clients import Client +from autotrain.datagen.params import AutoTrainGenParams + + +TEXT_CLASSIFICATION_SYSTEM_PROMPT = """ +You are an AI bot that generates data for text classification tasks. +You do not repeat the question asked by user. You do not generate code. +Only thing you generate is text data in the specified format. +The user provides a problem statement and you generate the data. +For text classification task, the user provides different classes. +If the user has not provided the classes, generate the classes as well but limit the number of classes to 10. +""" + +TEXT_CLASSIFICATION_DATA_PROMPT = """ +The dataset for text classification is in JSON format. +Each line should be a JSON object with the following keys: text and target. +Make sure each text sample has atleast {min_words} words. +The target must always be a string. +Don't write what you are doing. Just generate the data. +Each line of the output consists of a dictionary with two keys: text and target and nothing else. +""" + +SEQ2SEQ_SYSTEM_PROMPT = """ +You are an AI bot that generates data for sequence-to-sequence tasks. +You do not repeat the question asked by user. You do not generate code. +Only thing you generate is text data in the specified format. +The user provides a problem statement and you generate the data. +For sequence-to-sequence task, the user provides the input and output format. +If the user has not provided the input and output format, generate the format as well. +""" + +SEQ2SEQ_DATA_PROMPT = """ +The dataset for sequence-to-sequence is in JSON format. +Each line should be a JSON object with the following keys: text and target. +Make sure each text sample has atleast {min_words} words. +Both text and target sentences must always be a string. +Don't write what you are doing. Just generate the data. +Each line of the output consists of a dictionary with two keys: text and target and nothing else. +""" + + +class TextDataGenerator: + def __init__(self, params: AutoTrainGenParams): + self.params = params + if self.params.task == "text-classification": + self.system_prompt = TEXT_CLASSIFICATION_SYSTEM_PROMPT + self.data_prompt = TEXT_CLASSIFICATION_DATA_PROMPT + elif self.params.task == "seq2seq": + self.system_prompt = SEQ2SEQ_SYSTEM_PROMPT + self.data_prompt = SEQ2SEQ_DATA_PROMPT + else: + raise NotImplementedError + + self.data_prompt = self.data_prompt.format(min_words=self.params.min_words) + + def run(self): + ask = self.system_prompt + self.data_prompt + formatted_message = [{"role": "system", "content": ask}] + formatted_message.append({"role": "user", "content": self.params.prompt}) + logger.info("Prompting the model. Using prompt:") + logger.info(formatted_message) + + client = Client(self.params.backend, model_name=self.params.gen_model, api_key=self.params.api_key) + clean_result = [] + + if self.params.task in ["text-classification", "seq2seq"]: + response_format = { + "type": "json", + "value": { + "properties": { + "data": { + "type": "array", + "maxItems": 10, + "minItems": 1, + "items": { + "type": "array", + "properties": { + "text": {"type": "string"}, + "target": {"type": "string"}, + }, + "required": ["text", "target"], + }, + } + }, + "required": ["data"], + }, + } + else: + raise NotImplementedError + + counter = 0 + while True: + current_time = time.time() + random_number = random.randint(0, 1000000) + seed_input = f"{current_time}-{counter}-{random_number}" + random_seed = int(hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16) % (10**8) + message = client.chat_completion( + messages=formatted_message, + max_tokens=4096, + seed=random_seed, + response_format=response_format, + ) + + if message is None: + logger.warning("Failed to generate data. Retrying...") + continue + + result = message.choices[0].message.content + + items = [] + parser = ijson.parse(result) + + current_item = None + current_key = None + + try: + for prefix, event, value in parser: + if event == "start_map": + current_item = {} + elif event == "map_key": + current_key = value + elif event == "string" and current_key: + current_item[current_key] = value + elif event == "end_map" and current_item: + items.append(current_item) + current_item = None + except ijson.common.IncompleteJSONError: + # Handle incomplete JSON data + logger.warning("Incomplete JSON encountered. Returning parsed data.") + + clean_result.append(items) + counter += 1 + num_items_collected = len([item for sublist in clean_result for item in sublist]) + logger.info(f"Collected {num_items_collected} items.") + if num_items_collected >= self.params.min_samples: + break + + # flatten the list + clean_result = [item for sublist in clean_result for item in sublist] + + valid_data = None + if self.params.valid_size != 0: + valid_size = int(self.params.valid_size * len(clean_result)) + random.shuffle(clean_result) + valid_data = clean_result[:valid_size] + train_data = clean_result[valid_size:] + + logger.info(f"Train data size: {len(train_data)}") + logger.info(f"Valid data size: {len(valid_data)}") + + hf_dataset = utils.convert_text_dataset_to_hf(train_data, valid_data) + hf_dataset.save_to_disk(self.params.project_name) + + if self.params.push_to_hub: + logger.info("Pushing the data to Hugging Face Hub.") + utils.push_data_to_hub( + dataset=hf_dataset, + dataset_name=self.params.project_name, + username=self.params.username, + token=self.params.token, + ) diff --git a/src/autotrain/datagen/text_classification.py b/src/autotrain/datagen/text_classification.py deleted file mode 100644 index d7f16f13dc..0000000000 --- a/src/autotrain/datagen/text_classification.py +++ /dev/null @@ -1,125 +0,0 @@ -from autotrain.datagen.generator import BaseDataGenerator -from autotrain.datagen.clients import Client -import json -from autotrain import logger -import re -import ijson -import random - -SYSTEM_PROMPT = """ -You are an AI bot that generates data for text classification tasks. -You do not repeat the question asked by user. You do not generate code. -Only thing you generate is text data in the specified format. -The user provides a problem statement and you generate the data. -For text classification task, the user provides different classes. -If the user has not provided the classes, generate the classes as well but limit the number of classes to 10. -""" - -DATA_PROMPT = """ -The dataset for text classification is in JSON format. -Each line should be a JSON object with the following keys: text and target. -Make sure each text sample has atleast 25 words. -The target must always be a string. -Don't write what you are doing. Just generate the data. -Each line of the output consists of a dictionary with two keys: text and target and nothing else. -""" - - -def fix_invalid_json(json_string): - # Escape backslashes that are not already escaped - json_string = re.sub(r'(? Dataset: + dataset = Dataset.from_list(train) + ddict = {"train": dataset} + if valid is not None: + valid_dataset = Dataset.from_list(valid) + ddict["validation"] = valid_dataset + dataset = DatasetDict(ddict) + return dataset + + +def push_data_to_hub(dataset: Dataset, dataset_name: str, username: str, token: Optional[str] = None) -> str: + if username is None: + raise ValueError("Username is required for pushing data to Hugging Face Hub.") + if token is None: + raise ValueError("Token is required for pushing data to Hugging Face Hub.") + repo_id = f"{username}/{dataset_name}" + dataset.push_to_hub(repo_id, token=token, private=True) + metadata = { + "tags": [ + "autotrain", + "gen", + "synthetic", + ] + } + metadata_update(repo_id, metadata, token=token, repo_type="dataset", overwrite=True) + return repo_id From 06190a73b8cefff2551575a2950cfb4be6ff002c Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Thu, 12 Sep 2024 10:11:20 +0200 Subject: [PATCH 3/7] gen --- src/autotrain/datagen/clients.py | 3 +- src/autotrain/datagen/params.py | 57 ++++++++++++++++++++++++++++-- src/autotrain/datagen/text.py | 11 +++--- src/autotrain/datagen/utils.py | 59 ++++++++++++++++++++++++++++---- 4 files changed, 113 insertions(+), 17 deletions(-) diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py index 326cf9ce2a..a7c2c066ff 100644 --- a/src/autotrain/datagen/clients.py +++ b/src/autotrain/datagen/clients.py @@ -1,9 +1,10 @@ +import time from dataclasses import dataclass from typing import Optional from huggingface_hub import InferenceClient + from autotrain import logger -import time @dataclass diff --git a/src/autotrain/datagen/params.py b/src/autotrain/datagen/params.py index a43b879413..96a669a967 100644 --- a/src/autotrain/datagen/params.py +++ b/src/autotrain/datagen/params.py @@ -1,11 +1,61 @@ +import os from typing import Optional -from pydantic import Field +from pydantic import BaseModel, Field -from autotrain.trainers.common import AutoTrainParams +from autotrain import logger -class AutoTrainGenParams(AutoTrainParams): +class BaseGenParams(BaseModel): + """ + Base class for all AutoTrain gen parameters. + """ + + class Config: + protected_namespaces = () + + def save(self, output_dir): + """ + Save parameters to a json file. + """ + os.makedirs(output_dir, exist_ok=True) + path = os.path.join(output_dir, "gen_params.json") + # save formatted json + with open(path, "w", encoding="utf-8") as f: + f.write(self.model_dump_json(indent=4)) + + def __str__(self): + """ + String representation of the parameters. + """ + data = self.model_dump() + data["token"] = "*****" if data.get("token") else None + return str(data) + + def __init__(self, **data): + """ + Initialize the parameters, check for unused/extra parameters and warn the user. + """ + super().__init__(**data) + + if len(self.project_name) > 0: + if not self.project_name.replace("-", "").isalnum(): + raise ValueError("project_name must be alphanumeric but can contain hyphens") + + if len(self.project_name) > 50: + raise ValueError("project_name cannot be more than 50 characters") + + defaults = set(self.model_fields.keys()) + supplied = set(data.keys()) + not_supplied = defaults - supplied + if not_supplied: + logger.warning(f"Parameters not supplied by user and set to default: {', '.join(not_supplied)}") + unused = supplied - set(self.model_fields) + if unused: + logger.warning(f"Parameters supplied but not used: {', '.join(unused)}") + + +class AutoTrainGenParams(BaseGenParams): gen_model: str = Field("meta-llama/Meta-Llama-3.1-8B-Instruct", title="The model to be used for generation.") project_name: str = Field("autotrain-datagen", title="Name of the project.") prompt: str = Field(None, title="Prompt to be used for text generation.") @@ -18,6 +68,7 @@ class AutoTrainGenParams(AutoTrainParams): backend: Optional[str] = Field("huggingface", title="Backend to be used, e.g., huggingface, local.") api: Optional[str] = Field(None, title="API endpoint to be used.") api_key: Optional[str] = Field(None, title="API key for authentication.") + sample: Optional[str] = Field(None, title="Sample dataset for generation.") min_samples: Optional[int] = Field(200, title="Minimum number of samples required for training.") # text specific min_words: Optional[int] = Field(25, title="Minimum number of words in the generated text.") diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py index bda2c4e941..a50c75c381 100644 --- a/src/autotrain/datagen/text.py +++ b/src/autotrain/datagen/text.py @@ -59,6 +59,8 @@ def __init__(self, params: AutoTrainGenParams): else: raise NotImplementedError + self.params.save(output_dir=self.params.project_name) + self.data_prompt = self.data_prompt.format(min_words=self.params.min_words) def run(self): @@ -161,9 +163,6 @@ def run(self): if self.params.push_to_hub: logger.info("Pushing the data to Hugging Face Hub.") - utils.push_data_to_hub( - dataset=hf_dataset, - dataset_name=self.params.project_name, - username=self.params.username, - token=self.params.token, - ) + utils.push_data_to_hub(params=self.params, dataset=hf_dataset) + + utils.train(params=self.params) diff --git a/src/autotrain/datagen/utils.py b/src/autotrain/datagen/utils.py index 353999cfd7..fb5e703c3d 100644 --- a/src/autotrain/datagen/utils.py +++ b/src/autotrain/datagen/utils.py @@ -1,7 +1,12 @@ +import json +import os +import subprocess from typing import Dict, List, Optional from datasets import Dataset, DatasetDict -from huggingface_hub import metadata_update +from huggingface_hub import HfApi, metadata_update + +from autotrain import logger def convert_text_dataset_to_hf(train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None) -> Dataset: @@ -14,13 +19,40 @@ def convert_text_dataset_to_hf(train: List[Dict[str, str]], valid: Optional[List return dataset -def push_data_to_hub(dataset: Dataset, dataset_name: str, username: str, token: Optional[str] = None) -> str: - if username is None: +def push_data_to_hub(params, dataset) -> str: + if params.username is None: raise ValueError("Username is required for pushing data to Hugging Face Hub.") - if token is None: + if params.token is None: raise ValueError("Token is required for pushing data to Hugging Face Hub.") - repo_id = f"{username}/{dataset_name}" - dataset.push_to_hub(repo_id, token=token, private=True) + + repo_id = f"{params.username}/{params.project_name}" + dataset.push_to_hub(repo_id, token=params.token, private=True) + + if os.path.exists(f"{params.project_name}/gen_params.json"): + gen_params = json.load(open(f"{params.project_name}/gen_params.json")) + if "token" in gen_params: + gen_params.pop("token") + + if "api" in gen_params: + gen_params.pop("api") + + if "api_key" in gen_params: + gen_params.pop("api_key") + + json.dump( + gen_params, + open(f"{params.project_name}/gen_params.json", "w"), + indent=4, + ) + + api = HfApi(token=params.token) + if os.path.exists(f"{params.project_name}/gen_params.json"): + api.upload_file( + path_or_fileobj=f"{params.project_name}/gen_params.json", + repo_id=f"{params.username}/{params.project_name}", + repo_type="dataset", + path_in_repo="gen_params.json", + ) metadata = { "tags": [ "autotrain", @@ -28,5 +60,18 @@ def push_data_to_hub(dataset: Dataset, dataset_name: str, username: str, token: "synthetic", ] } - metadata_update(repo_id, metadata, token=token, repo_type="dataset", overwrite=True) + metadata_update(repo_id, metadata, token=params.token, repo_type="dataset", overwrite=True) return repo_id + + +def train(params): + if params.training_config is None: + logger.info("No training configuration provided. Skipping training...") + return + cmd = f"autotrain --config {params.training_config}" + logger.info(f"Running AutoTrain: {cmd}") + cmd = [str(c) for c in cmd] + env = os.environ.copy() + process = subprocess.Popen(cmd, env=env) + process.wait() + return process.pid From 7f553dd4139d0fa5355ed9becd3d4fdc4cffbeaf Mon Sep 17 00:00:00 2001 From: Abhishek Thakur Date: Sat, 14 Sep 2024 08:34:04 +0200 Subject: [PATCH 4/7] update --- src/autotrain/app/static/scripts/listeners.js | 9 +++ src/autotrain/app/templates/index.html | 27 +++++++- src/autotrain/datagen/clients.py | 67 ++++++++++++++++++- src/autotrain/datagen/text.py | 5 +- src/autotrain/datagen/utils.py | 20 +++++- .../trainers/text_classification/__main__.py | 35 ++++++++++ .../trainers/text_classification/params.py | 5 ++ 7 files changed, 161 insertions(+), 7 deletions(-) diff --git a/src/autotrain/app/static/scripts/listeners.js b/src/autotrain/app/static/scripts/listeners.js index 9638342aea..59b4cfa0fe 100644 --- a/src/autotrain/app/static/scripts/listeners.js +++ b/src/autotrain/app/static/scripts/listeners.js @@ -3,6 +3,7 @@ document.addEventListener('DOMContentLoaded', function () { const uploadDataTabContent = document.getElementById("upload-data-tab-content"); const hubDataTabContent = document.getElementById("hub-data-tab-content"); const uploadDataTabs = document.getElementById("upload-data-tabs"); + const genDataTabContent = document.getElementById("gen-data-tab-content"); const jsonCheckbox = document.getElementById('show-json-parameters'); const jsonParametersDiv = document.getElementById('json-parameters'); @@ -54,12 +55,20 @@ document.addEventListener('DOMContentLoaded', function () { if (dataSource.value === "hub") { uploadDataTabContent.style.display = "none"; uploadDataTabs.style.display = "none"; + genDataTabContent.style.display = "none"; hubDataTabContent.style.display = "block"; } else if (dataSource.value === "local") { uploadDataTabContent.style.display = "block"; uploadDataTabs.style.display = "block"; + genDataTabContent.style.display = "none"; hubDataTabContent.style.display = "none"; } + else if (dataSource.value === "gen") { + uploadDataTabContent.style.display = "none"; + uploadDataTabs.style.display = "none"; + hubDataTabContent.style.display = "none"; + genDataTabContent.style.display = "block"; + } } async function fetchParams() { diff --git a/src/autotrain/app/templates/index.html b/src/autotrain/app/templates/index.html index 00501202cf..76fccd6961 100644 --- a/src/autotrain/app/templates/index.html +++ b/src/autotrain/app/templates/index.html @@ -420,9 +420,11 @@ Source

@@ -524,7 +526,30 @@ +
+
+ + +
+
+ + + + +
+
+ +

diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py index a7c2c066ff..5583f46bee 100644 --- a/src/autotrain/datagen/clients.py +++ b/src/autotrain/datagen/clients.py @@ -5,6 +5,57 @@ from huggingface_hub import InferenceClient from autotrain import logger +import transformers +import torch +import outlines +import json + + +@dataclass +class _TransformersClient: + model_name: str + + def __post_init__(self): + self.pipeline = transformers.pipeline( + "text-generation", + model=self.model_name, + model_kwargs={"torch_dtype": torch.bfloat16}, + device_map="auto", + ) + + def chat_completion(self, messages, max_tokens, stream, seed, response_format): + outputs = self.pipeline( + messages, + max_new_tokens=max_tokens, + seed=seed, + response_format=response_format, + stream=stream, + ) + return outputs[0]["generated_text"][-1]["content"] + + +@dataclass +class TransformersClient: + model_name: str + + def __post_init__(self): + self.pipeline = outlines.models.transformers( + self.model_name, + # device_map="auto", + model_kwargs={"torch_dtype": torch.bfloat16}, + ) + + def chat_completion(self, messages, max_tokens, stream, seed, response_format): + # dump response_format dict to json + response_format = json.dumps(response_format) + generator = outlines.generate.json(self.pipeline, response_format) + outputs = generator( + messages, + max_tokens=max_tokens, + seed=seed, + ) + print(outputs) + return outputs[0]["generated_text"][-1]["content"] @dataclass @@ -14,10 +65,14 @@ class Client: api_key: Optional[str] = None def __post_init__(self): - if self.name == "huggingface": + if self.name == "hf-inference-api": if self.model_name is None: raise ValueError("Model name is required for Huggingface") self.client = InferenceClient + elif self.name == "transformers": + if self.model_name is None: + raise ValueError("Model name is required for Transformers") + self.client = TransformersClient else: raise ValueError("Client not supported") @@ -32,10 +87,18 @@ def _huggingface(self): return self.client(self.model_name, token=self.api_key) return self.client(self.model_name) + def _transformers(self): + return self.client(self.model_name) + def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None, retries=3, delay=5): + if self.name == "hf-inference-api": + _client = self._huggingface() + elif self.name == "transformers": + _client = self._transformers() + else: + raise ValueError("Client not supported") for attempt in range(retries): try: - _client = self._huggingface() message = _client.chat_completion( messages=messages, max_tokens=max_tokens, diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py index a50c75c381..77277e5502 100644 --- a/src/autotrain/datagen/text.py +++ b/src/autotrain/datagen/text.py @@ -9,6 +9,7 @@ from autotrain.datagen.clients import Client from autotrain.datagen.params import AutoTrainGenParams +import os TEXT_CLASSIFICATION_SYSTEM_PROMPT = """ You are an AI bot that generates data for text classification tasks. @@ -158,8 +159,8 @@ def run(self): logger.info(f"Train data size: {len(train_data)}") logger.info(f"Valid data size: {len(valid_data)}") - hf_dataset = utils.convert_text_dataset_to_hf(train_data, valid_data) - hf_dataset.save_to_disk(self.params.project_name) + hf_dataset = utils.convert_text_dataset_to_hf(self.params.task, train_data, valid_data) + hf_dataset.save_to_disk(os.path.join(self.params.project_name, "autotrain-data")) if self.params.push_to_hub: logger.info("Pushing the data to Hugging Face Hub.") diff --git a/src/autotrain/datagen/utils.py b/src/autotrain/datagen/utils.py index fb5e703c3d..3a43c0d840 100644 --- a/src/autotrain/datagen/utils.py +++ b/src/autotrain/datagen/utils.py @@ -3,17 +3,33 @@ import subprocess from typing import Dict, List, Optional -from datasets import Dataset, DatasetDict +from datasets import Dataset, DatasetDict, ClassLabel from huggingface_hub import HfApi, metadata_update from autotrain import logger -def convert_text_dataset_to_hf(train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None) -> Dataset: +def convert_text_dataset_to_hf( + task, train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None +) -> Dataset: + if task == "text-classification": + for item in train: + item["target"] = item["target"].lower().strip() + label_names = list(set([item["target"] for item in train])) + logger.info(f"Label names: {label_names}") + dataset = Dataset.from_list(train) + + if task == "text-classification": + dataset = dataset.cast_column("target", ClassLabel(names=label_names)) + ddict = {"train": dataset} if valid is not None: valid_dataset = Dataset.from_list(valid) + if task == "text-classification": + for item in valid: + item["target"] = item["target"].lower().strip() + valid_dataset = valid_dataset.cast_column("target", ClassLabel(names=label_names)) ddict["validation"] = valid_dataset dataset = DatasetDict(ddict) return dataset diff --git a/src/autotrain/trainers/text_classification/__main__.py b/src/autotrain/trainers/text_classification/__main__.py index 5ce3918e87..f5a5aa1637 100644 --- a/src/autotrain/trainers/text_classification/__main__.py +++ b/src/autotrain/trainers/text_classification/__main__.py @@ -28,6 +28,8 @@ from autotrain.trainers.text_classification import utils from autotrain.trainers.text_classification.dataset import TextClassificationDataset from autotrain.trainers.text_classification.params import TextClassificationParams +from autotrain.datagen.params import AutoTrainGenParams +from autotrain.datagen.gen import AutoTrainGen def parse_args(): @@ -42,6 +44,39 @@ def train(config): if isinstance(config, dict): config = TextClassificationParams(**config) + if config.gen_prompt is not None: + if config.gen_model is None: + raise ValueError("Generation model must be provided when using generation prompt") + gen_params = config.gen_params.split(",") if config.gen_params is not None else [] + gen_params = {p.split("=")[0]: p.split("=")[1] for p in gen_params} + gen_config = AutoTrainGenParams( + gen_model=config.gen_model, + project_name=config.project_name, + prompt=config.gen_prompt, + task="text-classification", + token=config.token, + username=config.username, + push_to_hub=config.push_to_hub, + backend=gen_params.get("backend", "transformers"), + api=gen_params.get("api", None), + api_key=gen_params.get("api_key", None), + min_samples=config.gen_samples, + ) + try: + AutoTrainGen(gen_config).run() + except Exception as e: + logger.error(f"Error generating data: {e}") + return + if config.push_to_hub: + config.data_path = f"{config.username}/{config.project_name}" + else: + config.data_path = f"{config.project_name}/autotrain-data" + + config.train_split = "train" + config.valid_split = "validation" + config.target_column = "target" + config.text_column = "text" + train_data = None valid_data = None # check if config.train_split.csv exists in config.data_path diff --git a/src/autotrain/trainers/text_classification/params.py b/src/autotrain/trainers/text_classification/params.py index c38f1d6c72..79fd4b8e94 100644 --- a/src/autotrain/trainers/text_classification/params.py +++ b/src/autotrain/trainers/text_classification/params.py @@ -35,3 +35,8 @@ class TextClassificationParams(AutoTrainParams): log: str = Field("none", title="Logging using experiment tracking") early_stopping_patience: int = Field(5, title="Early stopping patience") early_stopping_threshold: float = Field(0.01, title="Early stopping threshold") + # generation parameters + gen_prompt: Optional[str] = Field(None, title="Generation prompt") + gen_model: Optional[str] = Field(None, title="Generation model") + gen_samples: Optional[int] = Field(100, title="Generation samples") + gen_params: Optional[str] = Field(None, title="Generation parameters, format: key1=value1,key2=value2") From 6a0cedc8d302aea5e81724773b0f1559ea669238 Mon Sep 17 00:00:00 2001 From: abhishekkrthakur Date: Wed, 25 Sep 2024 11:54:13 +0200 Subject: [PATCH 5/7] gen --- src/autotrain/datagen/clients.py | 8 ++++---- src/autotrain/datagen/generator.py | 0 src/autotrain/datagen/text.py | 2 +- src/autotrain/datagen/utils.py | 2 +- src/autotrain/trainers/text_classification/__main__.py | 4 ++-- 5 files changed, 8 insertions(+), 8 deletions(-) delete mode 100644 src/autotrain/datagen/generator.py diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py index 5583f46bee..a9d86792e7 100644 --- a/src/autotrain/datagen/clients.py +++ b/src/autotrain/datagen/clients.py @@ -1,14 +1,14 @@ +import json import time from dataclasses import dataclass from typing import Optional +import outlines +import torch +import transformers from huggingface_hub import InferenceClient from autotrain import logger -import transformers -import torch -import outlines -import json @dataclass diff --git a/src/autotrain/datagen/generator.py b/src/autotrain/datagen/generator.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py index 77277e5502..bf2a0404e9 100644 --- a/src/autotrain/datagen/text.py +++ b/src/autotrain/datagen/text.py @@ -1,4 +1,5 @@ import hashlib +import os import random import time @@ -9,7 +10,6 @@ from autotrain.datagen.clients import Client from autotrain.datagen.params import AutoTrainGenParams -import os TEXT_CLASSIFICATION_SYSTEM_PROMPT = """ You are an AI bot that generates data for text classification tasks. diff --git a/src/autotrain/datagen/utils.py b/src/autotrain/datagen/utils.py index 3a43c0d840..34499aedd2 100644 --- a/src/autotrain/datagen/utils.py +++ b/src/autotrain/datagen/utils.py @@ -3,7 +3,7 @@ import subprocess from typing import Dict, List, Optional -from datasets import Dataset, DatasetDict, ClassLabel +from datasets import ClassLabel, Dataset, DatasetDict from huggingface_hub import HfApi, metadata_update from autotrain import logger diff --git a/src/autotrain/trainers/text_classification/__main__.py b/src/autotrain/trainers/text_classification/__main__.py index f5a5aa1637..bfc51b78ab 100644 --- a/src/autotrain/trainers/text_classification/__main__.py +++ b/src/autotrain/trainers/text_classification/__main__.py @@ -15,6 +15,8 @@ from transformers.trainer_callback import PrinterCallback from autotrain import logger +from autotrain.datagen.gen import AutoTrainGen +from autotrain.datagen.params import AutoTrainGenParams from autotrain.trainers.common import ( ALLOW_REMOTE_CODE, LossLoggingCallback, @@ -28,8 +30,6 @@ from autotrain.trainers.text_classification import utils from autotrain.trainers.text_classification.dataset import TextClassificationDataset from autotrain.trainers.text_classification.params import TextClassificationParams -from autotrain.datagen.params import AutoTrainGenParams -from autotrain.datagen.gen import AutoTrainGen def parse_args(): From a7dade5109fc9f966e9edd5b074b226afee9cd3b Mon Sep 17 00:00:00 2001 From: Sara Han <127759186+sdiazlor@users.noreply.github.com> Date: Mon, 9 Dec 2024 14:23:09 +0100 Subject: [PATCH 6/7] Add distilabel gen (#819) --- requirements.txt | 3 + src/autotrain/datagen/text.py | 246 +++++++++++++++++++++++----------- 2 files changed, 173 insertions(+), 76 deletions(-) diff --git a/requirements.txt b/requirements.txt index 915c384ad7..0fcf0c98c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,12 +24,15 @@ invisible-watermark==0.2.0 packaging==24.0 cryptography==42.0.5 nvitop==1.3.2 +llama-cpp-python==0.2.90 +outlines==0.0.46 # latest versions tensorboard==2.16.2 peft==0.12.0 trl==0.10.1 tiktoken==0.6.0 transformers==4.44.2 +distilabel==1.4.1 accelerate==0.34.1 diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py index bf2a0404e9..45285f1025 100644 --- a/src/autotrain/datagen/text.py +++ b/src/autotrain/datagen/text.py @@ -4,6 +4,9 @@ import time import ijson +from pydantic import BaseModel +from distilabel.llms import LlamaCppLLM +from distilabel.steps.tasks import TextGeneration from autotrain import logger from autotrain.datagen import utils @@ -23,7 +26,7 @@ TEXT_CLASSIFICATION_DATA_PROMPT = """ The dataset for text classification is in JSON format. Each line should be a JSON object with the following keys: text and target. -Make sure each text sample has atleast {min_words} words. +Make sure each text sample has at least {min_words} words. The target must always be a string. Don't write what you are doing. Just generate the data. Each line of the output consists of a dictionary with two keys: text and target and nothing else. @@ -41,7 +44,7 @@ SEQ2SEQ_DATA_PROMPT = """ The dataset for sequence-to-sequence is in JSON format. Each line should be a JSON object with the following keys: text and target. -Make sure each text sample has atleast {min_words} words. +Make sure each text sample has at least {min_words} words. Both text and target sentences must always be a string. Don't write what you are doing. Just generate the data. Each line of the output consists of a dictionary with two keys: text and target and nothing else. @@ -71,83 +74,170 @@ def run(self): logger.info("Prompting the model. Using prompt:") logger.info(formatted_message) - client = Client(self.params.backend, model_name=self.params.gen_model, api_key=self.params.api_key) - clean_result = [] - - if self.params.task in ["text-classification", "seq2seq"]: - response_format = { - "type": "json", - "value": { - "properties": { - "data": { - "type": "array", - "maxItems": 10, - "minItems": 1, - "items": { + if self.params.backend == "local": + + if self.params.task in ["text-classification", "seq2seq"]: + + class Data(BaseModel): + text: str + target: str + + else: + raise NotImplementedError + + def get_generator(seed: int): + generator = TextGeneration( + llm=LlamaCppLLM( + model_path=self.params.gen_model, + n_gpu_layers=-1, + n_ctx=1024, + seed=seed, + structured_output={"format": "json", "schema": Data}, + ), + num_generations=10, + system_prompt=ask, + ) + generator.load() + return generator + + counter = 0 + clean_result = [] + while True: + current_time = time.time() + random_number = random.randint(0, 1000000) + seed_input = f"{current_time}-{counter}-{random_number}" + random_seed = int( + hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16 + ) % (10**8) + generator = get_generator(seed=random_seed) + results = list( + generator.process( + [ + { + "instruction": self.params.prompt, + } + ] + ) + ) + + for result in results[0]: + items = [] + parser = ijson.parse(result.get("generation", "")) + + current_item = None + current_key = None + + try: + for prefix, event, value in parser: + if event == "start_map": + current_item = {} + elif event == "map_key": + current_key = value + elif event == "string" and current_key: + current_item[current_key] = value + elif event == "end_map" and current_item: + items.append(current_item) + current_item = None + except ijson.common.IncompleteJSONError: + logger.warning( + "Incomplete JSON encountered. Returning parsed data." + ) + + clean_result.append(items[0]) + counter += 1 + num_items_collected = len(clean_result) + logger.info(f"Collected {num_items_collected} items.") + if num_items_collected >= 10: + break + + else: + client = Client( + self.params.backend, + model_name=self.params.gen_model, + api_key=self.params.api_key, + ) + clean_result = [] + + if self.params.task in ["text-classification", "seq2seq"]: + response_format = { + "type": "json", + "value": { + "properties": { + "data": { "type": "array", - "properties": { - "text": {"type": "string"}, - "target": {"type": "string"}, + "maxItems": 10, + "minItems": 1, + "items": { + "type": "array", + "properties": { + "text": {"type": "string"}, + "target": {"type": "string"}, + }, + "required": ["text", "target"], }, - "required": ["text", "target"], - }, - } + } + }, + "required": ["data"], }, - "required": ["data"], - }, - } - else: - raise NotImplementedError + } + else: + raise NotImplementedError - counter = 0 - while True: - current_time = time.time() - random_number = random.randint(0, 1000000) - seed_input = f"{current_time}-{counter}-{random_number}" - random_seed = int(hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16) % (10**8) - message = client.chat_completion( - messages=formatted_message, - max_tokens=4096, - seed=random_seed, - response_format=response_format, - ) + counter = 0 + while True: + current_time = time.time() + random_number = random.randint(0, 1000000) + seed_input = f"{current_time}-{counter}-{random_number}" + random_seed = int( + hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16 + ) % (10**8) + message = client.chat_completion( + messages=formatted_message, + max_tokens=4096, + seed=random_seed, + response_format=response_format, + ) + + if message is None: + logger.warning("Failed to generate data. Retrying...") + continue + + result = message.choices[0].message.content + + items = [] + parser = ijson.parse(result) + + current_item = None + current_key = None + + try: + for prefix, event, value in parser: + if event == "start_map": + current_item = {} + elif event == "map_key": + current_key = value + elif event == "string" and current_key: + current_item[current_key] = value + elif event == "end_map" and current_item: + items.append(current_item) + current_item = None + except ijson.common.IncompleteJSONError: + # Handle incomplete JSON data + logger.warning( + "Incomplete JSON encountered. Returning parsed data." + ) + + clean_result.append(items) + counter += 1 + num_items_collected = len( + [item for sublist in clean_result for item in sublist] + ) + logger.info(f"Collected {num_items_collected} items.") + if num_items_collected >= self.params.min_samples: + break - if message is None: - logger.warning("Failed to generate data. Retrying...") - continue - - result = message.choices[0].message.content - - items = [] - parser = ijson.parse(result) - - current_item = None - current_key = None - - try: - for prefix, event, value in parser: - if event == "start_map": - current_item = {} - elif event == "map_key": - current_key = value - elif event == "string" and current_key: - current_item[current_key] = value - elif event == "end_map" and current_item: - items.append(current_item) - current_item = None - except ijson.common.IncompleteJSONError: - # Handle incomplete JSON data - logger.warning("Incomplete JSON encountered. Returning parsed data.") - - clean_result.append(items) - counter += 1 - num_items_collected = len([item for sublist in clean_result for item in sublist]) - logger.info(f"Collected {num_items_collected} items.") - if num_items_collected >= self.params.min_samples: - break - - # flatten the list - clean_result = [item for sublist in clean_result for item in sublist] + # flatten the list + clean_result = [item for sublist in clean_result for item in sublist] valid_data = None if self.params.valid_size != 0: @@ -159,8 +249,12 @@ def run(self): logger.info(f"Train data size: {len(train_data)}") logger.info(f"Valid data size: {len(valid_data)}") - hf_dataset = utils.convert_text_dataset_to_hf(self.params.task, train_data, valid_data) - hf_dataset.save_to_disk(os.path.join(self.params.project_name, "autotrain-data")) + hf_dataset = utils.convert_text_dataset_to_hf( + self.params.task, train_data, valid_data + ) + hf_dataset.save_to_disk( + os.path.join(self.params.project_name, "autotrain-data") + ) if self.params.push_to_hub: logger.info("Pushing the data to Hugging Face Hub.") From 515e4e52b81a2cf33529bee8bb4e418548593a8e Mon Sep 17 00:00:00 2001 From: abhishekkrthakur Date: Mon, 9 Dec 2024 14:27:16 +0100 Subject: [PATCH 7/7] fix style --- src/autotrain/datagen/text.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py index 45285f1025..f94d0f5d34 100644 --- a/src/autotrain/datagen/text.py +++ b/src/autotrain/datagen/text.py @@ -4,9 +4,9 @@ import time import ijson -from pydantic import BaseModel from distilabel.llms import LlamaCppLLM from distilabel.steps.tasks import TextGeneration +from pydantic import BaseModel from autotrain import logger from autotrain.datagen import utils @@ -106,9 +106,7 @@ def get_generator(seed: int): current_time = time.time() random_number = random.randint(0, 1000000) seed_input = f"{current_time}-{counter}-{random_number}" - random_seed = int( - hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16 - ) % (10**8) + random_seed = int(hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16) % (10**8) generator = get_generator(seed=random_seed) results = list( generator.process( @@ -139,9 +137,7 @@ def get_generator(seed: int): items.append(current_item) current_item = None except ijson.common.IncompleteJSONError: - logger.warning( - "Incomplete JSON encountered. Returning parsed data." - ) + logger.warning("Incomplete JSON encountered. Returning parsed data.") clean_result.append(items[0]) counter += 1 @@ -188,9 +184,7 @@ def get_generator(seed: int): current_time = time.time() random_number = random.randint(0, 1000000) seed_input = f"{current_time}-{counter}-{random_number}" - random_seed = int( - hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16 - ) % (10**8) + random_seed = int(hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16) % (10**8) message = client.chat_completion( messages=formatted_message, max_tokens=4096, @@ -223,15 +217,11 @@ def get_generator(seed: int): current_item = None except ijson.common.IncompleteJSONError: # Handle incomplete JSON data - logger.warning( - "Incomplete JSON encountered. Returning parsed data." - ) + logger.warning("Incomplete JSON encountered. Returning parsed data.") clean_result.append(items) counter += 1 - num_items_collected = len( - [item for sublist in clean_result for item in sublist] - ) + num_items_collected = len([item for sublist in clean_result for item in sublist]) logger.info(f"Collected {num_items_collected} items.") if num_items_collected >= self.params.min_samples: break @@ -249,12 +239,8 @@ def get_generator(seed: int): logger.info(f"Train data size: {len(train_data)}") logger.info(f"Valid data size: {len(valid_data)}") - hf_dataset = utils.convert_text_dataset_to_hf( - self.params.task, train_data, valid_data - ) - hf_dataset.save_to_disk( - os.path.join(self.params.project_name, "autotrain-data") - ) + hf_dataset = utils.convert_text_dataset_to_hf(self.params.task, train_data, valid_data) + hf_dataset.save_to_disk(os.path.join(self.params.project_name, "autotrain-data")) if self.params.push_to_hub: logger.info("Pushing the data to Hugging Face Hub.")