diff --git a/requirements.txt b/requirements.txt index de727743ec..308c52eb11 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,6 +21,8 @@ einops==0.8.0 packaging==24.1 cryptography==43.0.3 nvitop==1.3.2 +llama-cpp-python==0.2.90 +outlines==0.0.46 # latest versions tensorboard==2.16.2 peft==0.13.2 @@ -30,6 +32,7 @@ transformers==4.46.2 accelerate==1.1.1 bitsandbytes==0.44.1 # extras +distilabel==1.4.1 rouge_score==0.1.2 py7zr==0.22.0 fastapi==0.115.4 diff --git a/src/autotrain/app/static/scripts/listeners.js b/src/autotrain/app/static/scripts/listeners.js index 9638342aea..59b4cfa0fe 100644 --- a/src/autotrain/app/static/scripts/listeners.js +++ b/src/autotrain/app/static/scripts/listeners.js @@ -3,6 +3,7 @@ document.addEventListener('DOMContentLoaded', function () { const uploadDataTabContent = document.getElementById("upload-data-tab-content"); const hubDataTabContent = document.getElementById("hub-data-tab-content"); const uploadDataTabs = document.getElementById("upload-data-tabs"); + const genDataTabContent = document.getElementById("gen-data-tab-content"); const jsonCheckbox = document.getElementById('show-json-parameters'); const jsonParametersDiv = document.getElementById('json-parameters'); @@ -54,12 +55,20 @@ document.addEventListener('DOMContentLoaded', function () { if (dataSource.value === "hub") { uploadDataTabContent.style.display = "none"; uploadDataTabs.style.display = "none"; + genDataTabContent.style.display = "none"; hubDataTabContent.style.display = "block"; } else if (dataSource.value === "local") { uploadDataTabContent.style.display = "block"; uploadDataTabs.style.display = "block"; + genDataTabContent.style.display = "none"; hubDataTabContent.style.display = "none"; } + else if (dataSource.value === "gen") { + uploadDataTabContent.style.display = "none"; + uploadDataTabs.style.display = "none"; + hubDataTabContent.style.display = "none"; + genDataTabContent.style.display = "block"; + } } async function fetchParams() { diff --git a/src/autotrain/app/templates/index.html b/src/autotrain/app/templates/index.html index 0ee5226c9d..7d3a79f2ee 100644 --- a/src/autotrain/app/templates/index.html +++ b/src/autotrain/app/templates/index.html @@ -415,9 +415,11 @@ Source
@@ -519,7 +521,30 @@ +diff --git a/src/autotrain/cli/autotrain.py b/src/autotrain/cli/autotrain.py index fcd85b9828..30699de942 100644 --- a/src/autotrain/cli/autotrain.py +++ b/src/autotrain/cli/autotrain.py @@ -4,6 +4,7 @@ from autotrain.cli.run_api import RunAutoTrainAPICommand from autotrain.cli.run_app import RunAutoTrainAppCommand from autotrain.cli.run_extractive_qa import RunAutoTrainExtractiveQACommand +from autotrain.cli.run_gen import RunAutoTrainGenCommand from autotrain.cli.run_image_classification import RunAutoTrainImageClassificationCommand from autotrain.cli.run_image_regression import RunAutoTrainImageRegressionCommand from autotrain.cli.run_llm import RunAutoTrainLLMCommand @@ -47,6 +48,7 @@ def main(): RunAutoTrainSentenceTransformersCommand.register_subcommand(commands_parser) RunAutoTrainImageRegressionCommand.register_subcommand(commands_parser) RunAutoTrainExtractiveQACommand.register_subcommand(commands_parser) + RunAutoTrainGenCommand.register_subcommand(commands_parser) args = parser.parse_args() diff --git a/src/autotrain/cli/run_gen.py b/src/autotrain/cli/run_gen.py new file mode 100644 index 0000000000..388a4927e9 --- /dev/null +++ b/src/autotrain/cli/run_gen.py @@ -0,0 +1,56 @@ +from argparse import ArgumentParser + +from autotrain import logger +from autotrain.cli.utils import get_field_info +from autotrain.datagen.gen import AutoTrainGen +from autotrain.datagen.params import AutoTrainGenParams + +from . import BaseAutoTrainCommand + + +def run_autotrain_gen_command(args): + return RunAutoTrainGenCommand(args) + + +class RunAutoTrainGenCommand(BaseAutoTrainCommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + arg_list = get_field_info(AutoTrainGenParams) + run_autotrain_gen_parser = parser.add_parser("gen", description="✨ AutoTrain Gen") + for arg in arg_list: + names = [arg["arg"]] + arg.get("alias", []) + if "action" in arg: + run_autotrain_gen_parser.add_argument( + *names, + dest=arg["arg"].replace("--", "").replace("-", "_"), + help=arg["help"], + required=arg.get("required", False), + action=arg.get("action"), + default=arg.get("default"), + ) + else: + run_autotrain_gen_parser.add_argument( + *names, + dest=arg["arg"].replace("--", "").replace("-", "_"), + help=arg["help"], + required=arg.get("required", False), + type=arg.get("type"), + default=arg.get("default"), + choices=arg.get("choices"), + ) + run_autotrain_gen_parser.set_defaults(func=run_autotrain_gen_command) + + def __init__(self, args): + self.args = args + + store_true_arg_names = [ + "push_to_hub", + ] + for arg_name in store_true_arg_names: + if getattr(self.args, arg_name) is None: + setattr(self.args, arg_name, False) + + def run(self): + logger.info("Running AutoTrain Gen 🚀") + params = AutoTrainGenParams(**vars(self.args)) + AutoTrainGen(params).run() diff --git a/src/autotrain/datagen/__init__.py b/src/autotrain/datagen/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/autotrain/datagen/clients.py b/src/autotrain/datagen/clients.py new file mode 100644 index 0000000000..a9d86792e7 --- /dev/null +++ b/src/autotrain/datagen/clients.py @@ -0,0 +1,115 @@ +import json +import time +from dataclasses import dataclass +from typing import Optional + +import outlines +import torch +import transformers +from huggingface_hub import InferenceClient + +from autotrain import logger + + +@dataclass +class _TransformersClient: + model_name: str + + def __post_init__(self): + self.pipeline = transformers.pipeline( + "text-generation", + model=self.model_name, + model_kwargs={"torch_dtype": torch.bfloat16}, + device_map="auto", + ) + + def chat_completion(self, messages, max_tokens, stream, seed, response_format): + outputs = self.pipeline( + messages, + max_new_tokens=max_tokens, + seed=seed, + response_format=response_format, + stream=stream, + ) + return outputs[0]["generated_text"][-1]["content"] + + +@dataclass +class TransformersClient: + model_name: str + + def __post_init__(self): + self.pipeline = outlines.models.transformers( + self.model_name, + # device_map="auto", + model_kwargs={"torch_dtype": torch.bfloat16}, + ) + + def chat_completion(self, messages, max_tokens, stream, seed, response_format): + # dump response_format dict to json + response_format = json.dumps(response_format) + generator = outlines.generate.json(self.pipeline, response_format) + outputs = generator( + messages, + max_tokens=max_tokens, + seed=seed, + ) + print(outputs) + return outputs[0]["generated_text"][-1]["content"] + + +@dataclass +class Client: + name: str + model_name: Optional[str] = None + api_key: Optional[str] = None + + def __post_init__(self): + if self.name == "hf-inference-api": + if self.model_name is None: + raise ValueError("Model name is required for Huggingface") + self.client = InferenceClient + elif self.name == "transformers": + if self.model_name is None: + raise ValueError("Model name is required for Transformers") + self.client = TransformersClient + else: + raise ValueError("Client not supported") + + def __str__(self): + return f"Client: {self.name}" + + def __repr__(self): + return f"Client: {self.name}" + + def _huggingface(self): + if self.api_key: + return self.client(self.model_name, token=self.api_key) + return self.client(self.model_name) + + def _transformers(self): + return self.client(self.model_name) + + def chat_completion(self, messages, max_tokens=500, seed=42, response_format=None, retries=3, delay=5): + if self.name == "hf-inference-api": + _client = self._huggingface() + elif self.name == "transformers": + _client = self._transformers() + else: + raise ValueError("Client not supported") + for attempt in range(retries): + try: + message = _client.chat_completion( + messages=messages, + max_tokens=max_tokens, + stream=False, + seed=seed, + response_format=response_format, + ) + return message + except Exception as e: + logger.error(f"Attempt {attempt + 1} failed: {e}") + if attempt < retries - 1: + time.sleep(delay) + else: + return None diff --git a/src/autotrain/datagen/gen.py b/src/autotrain/datagen/gen.py new file mode 100644 index 0000000000..586f196683 --- /dev/null +++ b/src/autotrain/datagen/gen.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from autotrain import logger +from autotrain.datagen.params import AutoTrainGenParams + + +@dataclass +class AutoTrainGen: + params: AutoTrainGenParams + + def __post_init__(self): + logger.info(self.params) + if self.params.task in ("text-classification", "seq2seq"): + from autotrain.datagen.text import TextDataGenerator + + self.gen = TextDataGenerator(self.params) + else: + raise NotImplementedError + + def run(self): + self.gen.run() diff --git a/src/autotrain/datagen/params.py b/src/autotrain/datagen/params.py new file mode 100644 index 0000000000..96a669a967 --- /dev/null +++ b/src/autotrain/datagen/params.py @@ -0,0 +1,74 @@ +import os +from typing import Optional + +from pydantic import BaseModel, Field + +from autotrain import logger + + +class BaseGenParams(BaseModel): + """ + Base class for all AutoTrain gen parameters. + """ + + class Config: + protected_namespaces = () + + def save(self, output_dir): + """ + Save parameters to a json file. + """ + os.makedirs(output_dir, exist_ok=True) + path = os.path.join(output_dir, "gen_params.json") + # save formatted json + with open(path, "w", encoding="utf-8") as f: + f.write(self.model_dump_json(indent=4)) + + def __str__(self): + """ + String representation of the parameters. + """ + data = self.model_dump() + data["token"] = "*****" if data.get("token") else None + return str(data) + + def __init__(self, **data): + """ + Initialize the parameters, check for unused/extra parameters and warn the user. + """ + super().__init__(**data) + + if len(self.project_name) > 0: + if not self.project_name.replace("-", "").isalnum(): + raise ValueError("project_name must be alphanumeric but can contain hyphens") + + if len(self.project_name) > 50: + raise ValueError("project_name cannot be more than 50 characters") + + defaults = set(self.model_fields.keys()) + supplied = set(data.keys()) + not_supplied = defaults - supplied + if not_supplied: + logger.warning(f"Parameters not supplied by user and set to default: {', '.join(not_supplied)}") + unused = supplied - set(self.model_fields) + if unused: + logger.warning(f"Parameters supplied but not used: {', '.join(unused)}") + + +class AutoTrainGenParams(BaseGenParams): + gen_model: str = Field("meta-llama/Meta-Llama-3.1-8B-Instruct", title="The model to be used for generation.") + project_name: str = Field("autotrain-datagen", title="Name of the project.") + prompt: str = Field(None, title="Prompt to be used for text generation.") + task: str = Field(None, title="Task type, e.g., text-classification, summarization.") + token: Optional[str] = Field(None, title="Authentication token for accessing the model.") + training_config: Optional[str] = Field(None, title="Path to the training configuration file.") + valid_size: Optional[float] = Field(0.2, title="Validation set size as a fraction of the total dataset.") + username: Optional[str] = Field(None, title="Username of the person running the training.") + push_to_hub: Optional[bool] = Field(True, title="Whether to push the model to Hugging Face Hub.") + backend: Optional[str] = Field("huggingface", title="Backend to be used, e.g., huggingface, local.") + api: Optional[str] = Field(None, title="API endpoint to be used.") + api_key: Optional[str] = Field(None, title="API key for authentication.") + sample: Optional[str] = Field(None, title="Sample dataset for generation.") + min_samples: Optional[int] = Field(200, title="Minimum number of samples required for training.") + # text specific + min_words: Optional[int] = Field(25, title="Minimum number of words in the generated text.") diff --git a/src/autotrain/datagen/text.py b/src/autotrain/datagen/text.py new file mode 100644 index 0000000000..f94d0f5d34 --- /dev/null +++ b/src/autotrain/datagen/text.py @@ -0,0 +1,249 @@ +import hashlib +import os +import random +import time + +import ijson +from distilabel.llms import LlamaCppLLM +from distilabel.steps.tasks import TextGeneration +from pydantic import BaseModel + +from autotrain import logger +from autotrain.datagen import utils +from autotrain.datagen.clients import Client +from autotrain.datagen.params import AutoTrainGenParams + + +TEXT_CLASSIFICATION_SYSTEM_PROMPT = """ +You are an AI bot that generates data for text classification tasks. +You do not repeat the question asked by user. You do not generate code. +Only thing you generate is text data in the specified format. +The user provides a problem statement and you generate the data. +For text classification task, the user provides different classes. +If the user has not provided the classes, generate the classes as well but limit the number of classes to 10. +""" + +TEXT_CLASSIFICATION_DATA_PROMPT = """ +The dataset for text classification is in JSON format. +Each line should be a JSON object with the following keys: text and target. +Make sure each text sample has at least {min_words} words. +The target must always be a string. +Don't write what you are doing. Just generate the data. +Each line of the output consists of a dictionary with two keys: text and target and nothing else. +""" + +SEQ2SEQ_SYSTEM_PROMPT = """ +You are an AI bot that generates data for sequence-to-sequence tasks. +You do not repeat the question asked by user. You do not generate code. +Only thing you generate is text data in the specified format. +The user provides a problem statement and you generate the data. +For sequence-to-sequence task, the user provides the input and output format. +If the user has not provided the input and output format, generate the format as well. +""" + +SEQ2SEQ_DATA_PROMPT = """ +The dataset for sequence-to-sequence is in JSON format. +Each line should be a JSON object with the following keys: text and target. +Make sure each text sample has at least {min_words} words. +Both text and target sentences must always be a string. +Don't write what you are doing. Just generate the data. +Each line of the output consists of a dictionary with two keys: text and target and nothing else. +""" + + +class TextDataGenerator: + def __init__(self, params: AutoTrainGenParams): + self.params = params + if self.params.task == "text-classification": + self.system_prompt = TEXT_CLASSIFICATION_SYSTEM_PROMPT + self.data_prompt = TEXT_CLASSIFICATION_DATA_PROMPT + elif self.params.task == "seq2seq": + self.system_prompt = SEQ2SEQ_SYSTEM_PROMPT + self.data_prompt = SEQ2SEQ_DATA_PROMPT + else: + raise NotImplementedError + + self.params.save(output_dir=self.params.project_name) + + self.data_prompt = self.data_prompt.format(min_words=self.params.min_words) + + def run(self): + ask = self.system_prompt + self.data_prompt + formatted_message = [{"role": "system", "content": ask}] + formatted_message.append({"role": "user", "content": self.params.prompt}) + logger.info("Prompting the model. Using prompt:") + logger.info(formatted_message) + + if self.params.backend == "local": + + if self.params.task in ["text-classification", "seq2seq"]: + + class Data(BaseModel): + text: str + target: str + + else: + raise NotImplementedError + + def get_generator(seed: int): + generator = TextGeneration( + llm=LlamaCppLLM( + model_path=self.params.gen_model, + n_gpu_layers=-1, + n_ctx=1024, + seed=seed, + structured_output={"format": "json", "schema": Data}, + ), + num_generations=10, + system_prompt=ask, + ) + generator.load() + return generator + + counter = 0 + clean_result = [] + while True: + current_time = time.time() + random_number = random.randint(0, 1000000) + seed_input = f"{current_time}-{counter}-{random_number}" + random_seed = int(hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16) % (10**8) + generator = get_generator(seed=random_seed) + results = list( + generator.process( + [ + { + "instruction": self.params.prompt, + } + ] + ) + ) + + for result in results[0]: + items = [] + parser = ijson.parse(result.get("generation", "")) + + current_item = None + current_key = None + + try: + for prefix, event, value in parser: + if event == "start_map": + current_item = {} + elif event == "map_key": + current_key = value + elif event == "string" and current_key: + current_item[current_key] = value + elif event == "end_map" and current_item: + items.append(current_item) + current_item = None + except ijson.common.IncompleteJSONError: + logger.warning("Incomplete JSON encountered. Returning parsed data.") + + clean_result.append(items[0]) + counter += 1 + num_items_collected = len(clean_result) + logger.info(f"Collected {num_items_collected} items.") + if num_items_collected >= 10: + break + + else: + client = Client( + self.params.backend, + model_name=self.params.gen_model, + api_key=self.params.api_key, + ) + clean_result = [] + + if self.params.task in ["text-classification", "seq2seq"]: + response_format = { + "type": "json", + "value": { + "properties": { + "data": { + "type": "array", + "maxItems": 10, + "minItems": 1, + "items": { + "type": "array", + "properties": { + "text": {"type": "string"}, + "target": {"type": "string"}, + }, + "required": ["text", "target"], + }, + } + }, + "required": ["data"], + }, + } + else: + raise NotImplementedError + + counter = 0 + while True: + current_time = time.time() + random_number = random.randint(0, 1000000) + seed_input = f"{current_time}-{counter}-{random_number}" + random_seed = int(hashlib.sha256(seed_input.encode("utf-8")).hexdigest(), 16) % (10**8) + message = client.chat_completion( + messages=formatted_message, + max_tokens=4096, + seed=random_seed, + response_format=response_format, + ) + + if message is None: + logger.warning("Failed to generate data. Retrying...") + continue + + result = message.choices[0].message.content + + items = [] + parser = ijson.parse(result) + + current_item = None + current_key = None + + try: + for prefix, event, value in parser: + if event == "start_map": + current_item = {} + elif event == "map_key": + current_key = value + elif event == "string" and current_key: + current_item[current_key] = value + elif event == "end_map" and current_item: + items.append(current_item) + current_item = None + except ijson.common.IncompleteJSONError: + # Handle incomplete JSON data + logger.warning("Incomplete JSON encountered. Returning parsed data.") + + clean_result.append(items) + counter += 1 + num_items_collected = len([item for sublist in clean_result for item in sublist]) + logger.info(f"Collected {num_items_collected} items.") + if num_items_collected >= self.params.min_samples: + break + + # flatten the list + clean_result = [item for sublist in clean_result for item in sublist] + + valid_data = None + if self.params.valid_size != 0: + valid_size = int(self.params.valid_size * len(clean_result)) + random.shuffle(clean_result) + valid_data = clean_result[:valid_size] + train_data = clean_result[valid_size:] + + logger.info(f"Train data size: {len(train_data)}") + logger.info(f"Valid data size: {len(valid_data)}") + + hf_dataset = utils.convert_text_dataset_to_hf(self.params.task, train_data, valid_data) + hf_dataset.save_to_disk(os.path.join(self.params.project_name, "autotrain-data")) + + if self.params.push_to_hub: + logger.info("Pushing the data to Hugging Face Hub.") + utils.push_data_to_hub(params=self.params, dataset=hf_dataset) + + utils.train(params=self.params) diff --git a/src/autotrain/datagen/utils.py b/src/autotrain/datagen/utils.py new file mode 100644 index 0000000000..34499aedd2 --- /dev/null +++ b/src/autotrain/datagen/utils.py @@ -0,0 +1,93 @@ +import json +import os +import subprocess +from typing import Dict, List, Optional + +from datasets import ClassLabel, Dataset, DatasetDict +from huggingface_hub import HfApi, metadata_update + +from autotrain import logger + + +def convert_text_dataset_to_hf( + task, train: List[Dict[str, str]], valid: Optional[List[Dict[str, str]]] = None +) -> Dataset: + if task == "text-classification": + for item in train: + item["target"] = item["target"].lower().strip() + label_names = list(set([item["target"] for item in train])) + logger.info(f"Label names: {label_names}") + + dataset = Dataset.from_list(train) + + if task == "text-classification": + dataset = dataset.cast_column("target", ClassLabel(names=label_names)) + + ddict = {"train": dataset} + if valid is not None: + valid_dataset = Dataset.from_list(valid) + if task == "text-classification": + for item in valid: + item["target"] = item["target"].lower().strip() + valid_dataset = valid_dataset.cast_column("target", ClassLabel(names=label_names)) + ddict["validation"] = valid_dataset + dataset = DatasetDict(ddict) + return dataset + + +def push_data_to_hub(params, dataset) -> str: + if params.username is None: + raise ValueError("Username is required for pushing data to Hugging Face Hub.") + if params.token is None: + raise ValueError("Token is required for pushing data to Hugging Face Hub.") + + repo_id = f"{params.username}/{params.project_name}" + dataset.push_to_hub(repo_id, token=params.token, private=True) + + if os.path.exists(f"{params.project_name}/gen_params.json"): + gen_params = json.load(open(f"{params.project_name}/gen_params.json")) + if "token" in gen_params: + gen_params.pop("token") + + if "api" in gen_params: + gen_params.pop("api") + + if "api_key" in gen_params: + gen_params.pop("api_key") + + json.dump( + gen_params, + open(f"{params.project_name}/gen_params.json", "w"), + indent=4, + ) + + api = HfApi(token=params.token) + if os.path.exists(f"{params.project_name}/gen_params.json"): + api.upload_file( + path_or_fileobj=f"{params.project_name}/gen_params.json", + repo_id=f"{params.username}/{params.project_name}", + repo_type="dataset", + path_in_repo="gen_params.json", + ) + metadata = { + "tags": [ + "autotrain", + "gen", + "synthetic", + ] + } + metadata_update(repo_id, metadata, token=params.token, repo_type="dataset", overwrite=True) + return repo_id + + +def train(params): + if params.training_config is None: + logger.info("No training configuration provided. Skipping training...") + return + cmd = f"autotrain --config {params.training_config}" + logger.info(f"Running AutoTrain: {cmd}") + cmd = [str(c) for c in cmd] + env = os.environ.copy() + process = subprocess.Popen(cmd, env=env) + process.wait() + return process.pid diff --git a/src/autotrain/trainers/text_classification/__main__.py b/src/autotrain/trainers/text_classification/__main__.py index 5b2cf67af5..cd603e9474 100644 --- a/src/autotrain/trainers/text_classification/__main__.py +++ b/src/autotrain/trainers/text_classification/__main__.py @@ -15,6 +15,8 @@ from transformers.trainer_callback import PrinterCallback from autotrain import logger +from autotrain.datagen.gen import AutoTrainGen +from autotrain.datagen.params import AutoTrainGenParams from autotrain.trainers.common import ( ALLOW_REMOTE_CODE, LossLoggingCallback, @@ -42,6 +44,39 @@ def train(config): if isinstance(config, dict): config = TextClassificationParams(**config) + if config.gen_prompt is not None: + if config.gen_model is None: + raise ValueError("Generation model must be provided when using generation prompt") + gen_params = config.gen_params.split(",") if config.gen_params is not None else [] + gen_params = {p.split("=")[0]: p.split("=")[1] for p in gen_params} + gen_config = AutoTrainGenParams( + gen_model=config.gen_model, + project_name=config.project_name, + prompt=config.gen_prompt, + task="text-classification", + token=config.token, + username=config.username, + push_to_hub=config.push_to_hub, + backend=gen_params.get("backend", "transformers"), + api=gen_params.get("api", None), + api_key=gen_params.get("api_key", None), + min_samples=config.gen_samples, + ) + try: + AutoTrainGen(gen_config).run() + except Exception as e: + logger.error(f"Error generating data: {e}") + return + if config.push_to_hub: + config.data_path = f"{config.username}/{config.project_name}" + else: + config.data_path = f"{config.project_name}/autotrain-data" + + config.train_split = "train" + config.valid_split = "validation" + config.target_column = "target" + config.text_column = "text" + train_data = None valid_data = None # check if config.train_split.csv exists in config.data_path diff --git a/src/autotrain/trainers/text_classification/params.py b/src/autotrain/trainers/text_classification/params.py index b03758adad..84204853af 100644 --- a/src/autotrain/trainers/text_classification/params.py +++ b/src/autotrain/trainers/text_classification/params.py @@ -70,3 +70,8 @@ class TextClassificationParams(AutoTrainParams): log: str = Field("none", title="Logging using experiment tracking") early_stopping_patience: int = Field(5, title="Early stopping patience") early_stopping_threshold: float = Field(0.01, title="Early stopping threshold") + # generation parameters + gen_prompt: Optional[str] = Field(None, title="Generation prompt") + gen_model: Optional[str] = Field(None, title="Generation model") + gen_samples: Optional[int] = Field(100, title="Generation samples") + gen_params: Optional[str] = Field(None, title="Generation parameters, format: key1=value1,key2=value2")