diff --git a/pyproject.toml b/pyproject.toml index 1a5d5b4..64567d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,6 @@ authors = [ { name = "speed1313", email = "speedtry13@icloud.com" } ] dependencies = [ - "click>=8.1.7", "vllm>=0.6.1", "datasets>=3.0.0", "wandb>=0.18.0", diff --git a/src/text2dataset/main.py b/src/text2dataset/main.py index 53e9b12..423aeaa 100644 --- a/src/text2dataset/main.py +++ b/src/text2dataset/main.py @@ -1,6 +1,5 @@ import datasets from datasets import load_dataset -import click import os from datasets import Dataset import wandb @@ -13,6 +12,115 @@ from text2dataset.utils import State from text2dataset.reader import create_dataset import yaml +import argparse +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Args: + model_id: str + batch_size: int + tensor_parallel_size: int + pipeline_parallel_size: int + gpu_id: int + input_path: str + source_column: str + target_column: str + push_to_hub: bool + push_to_hub_path: str + output_dir: str + output_format: str + number_sample_per_shard: int + resume_from_checkpoint: bool + use_wandb: bool + wandb_project: str + wandb_run_name: str + prompt_template_path: str + temperature: float + top_p: float + max_tokens: int + target_lang: str + keep_columns: str | None + split: str + + +def parse_args(): + parser = argparse.ArgumentParser(description="Argument parser for model inference") + parser.add_argument( + "--model_id", + type=str, + default="llm-jp/llm-jp-3-3.7b-instruct", + help="Model name. e.g., llm-jp/llm-jp-3-3.7b-instruct. Specify 'gpt-4o-mini-2024-07-18' for OpenAI API or 'deepl' for DeepL API.", + ) + parser.add_argument( + "--batch_size", type=int, default=1024, help="Batch size for vLLM inference." + ) + parser.add_argument("--tensor_parallel_size", type=int, default=1) + parser.add_argument("--pipeline_parallel_size", type=int, default=1) + parser.add_argument("--gpu_id", type=int, default=0) + parser.add_argument( + "--input_path", + type=str, + default="data/english_quotes.json", + help="Local file path or Hugging Face dataset name.", + ) + parser.add_argument( + "--source_column", + type=str, + default="txt", + help="Column name in the dataset to be prompted.", + ) + parser.add_argument( + "--target_column", + type=str, + default="txt_ja", + help="Column name in the dataset to store generated text.", + ) + parser.add_argument("--push_to_hub", type=bool, default=False) + parser.add_argument("--push_to_hub_path", type=str, default="speed/english_quotes") + parser.add_argument("--output_dir", type=str, default="data/english_quotes_ja") + parser.add_argument("--output_format", type=str, default="json") + parser.add_argument("--number_sample_per_shard", type=int, default=1000) + parser.add_argument( + "--resume_from_checkpoint", + type=bool, + default=False, + help="Resume from the last checkpoint.", + ) + parser.add_argument("--use_wandb", type=bool, default=False) + parser.add_argument("--wandb_project", type=str, default="text2dataset") + parser.add_argument("--wandb_run_name", type=str, default="") + parser.add_argument( + "--prompt_template_path", + type=str, + default="config/prompt.yaml", + help="Path to the prompt template.", + ) + parser.add_argument("--temperature", type=float, default=0.8) + parser.add_argument("--top_p", type=float, default=0.95) + parser.add_argument("--max_tokens", type=int, default=200) + parser.add_argument( + "--target_lang", + type=str, + default="ja", + help="Target language for translation; used for DeepL API.", + ) + parser.add_argument( + "--keep_columns", + type=str, + default=None, + help="Columns to keep in the output dataset, separated by comma. If None, all columns are kept. e.g., 'txt'. target_column is always kept.", + ) + parser.add_argument( + "--split", + type=str, + default="train", + help="Split of the dataset to use. e.g., 'train', 'validation', 'test'.", + ) + + args = parser.parse_args() + return Args(**vars(args)) + logger = logging.getLogger(__name__) logging.basicConfig( @@ -22,101 +130,13 @@ ) -@click.command() -@click.option( - "--model_id", - type=str, - default="llm-jp/llm-jp-3-3.7b-instruct", - help="Model name. e.g. llm-jp/llm-jp-3-3.7b-instruct. If you want to use OpenAI API, specify the model name like 'gpt-4o-mini-2024-07-18'. If you want to use DeepL API, specify 'deepl'.", -) -@click.option( - "--batch_size", type=int, default=1024, help="Batch size for vLLM inference." -) -@click.option("--tensor_parallel_size", type=int, default=1) -@click.option("--pipeline_parallel_size", type=int, default=1) -@click.option("--gpu_id", type=int, default=0) -@click.option( - "--input_path", - type=str, - default="data/english_quotes.json", - help="Local file path or Hugging Face dataset name.", -) -@click.option( - "--source_column", - type=str, - default="txt", - help="Existing column name in the dataset to be prompted.", -) -@click.option( - "--target_column", - type=str, - default="txt_ja", - help="New column name in the dataset to store the generated text.", -) -@click.option("--push_to_hub", type=bool, default=False) -@click.option("--push_to_hub_path", type=str, default="speed/english_quotes") -@click.option("--output_dir", type=str, default="data/english_quotes_ja") -@click.option("--output_format", type=str, default="json") -@click.option("--number_sample_per_shard", type=int, default=1000) -@click.option( - "--resume_from_checkpoint", - type=bool, - default=False, - help="Resume from the last checkpoint.", -) -@click.option("--use_wandb", type=bool, default=False) -@click.option("--wandb_project", type=str, default="text2dataset") -@click.option("--wandb_run_name", type=str, default="") -@click.option( - "--prompt_template_path", - type=str, - default="config/prompt.yaml", - help="Path to the prompt template.", -) -@click.option("--temperature", type=float, default=0.8) -@click.option("--top_p", type=float, default=0.95) -@click.option("--max_tokens", type=int, default=200) -@click.option( - "--target_lang", - type=str, - default="ja", - help="Target language for translation. This is used only for DeepL API.", -) -@click.option( - "--keep_columns", - type=str, - default="txt", - help="Columns to keep in the output dataset. Specify the column names separated by comma.", -) -def main( - model_id: str, - batch_size: int, - output_dir: str, - tensor_parallel_size: int, - pipeline_parallel_size: int, - gpu_id: int, - source_column: str, - target_column: str, - input_path: str, - push_to_hub: bool, - push_to_hub_path: str, - output_format: str, - number_sample_per_shard: int, - resume_from_checkpoint: bool, - use_wandb: bool, - wandb_project: str, - wandb_run_name: str, - prompt_template_path: str, - temperature: float, - top_p: float, - max_tokens: int, - target_lang: str, - keep_columns: str, -): +def main(): + args = parse_args() + # Text in source_column of the Dataset will be translated into Japanese. state = State(0, 0, 0) - if resume_from_checkpoint: - state_path = os.path.join(output_dir, "state.jsonl") + if args.resume_from_checkpoint: + state_path = os.path.join(args.output_dir, "state.jsonl") if os.path.exists(state_path): with open(state_path, "r") as f: state = State(**json.load(f), total_processed_examples=0) @@ -129,53 +149,60 @@ def main( logger.info("Start translation") - os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) - os.makedirs(output_dir, exist_ok=True) - state_path = os.path.join(output_dir, "state.jsonl") - ds = create_dataset(input_path, state) + os.makedirs(args.output_dir, exist_ok=True) + state_path = os.path.join(args.output_dir, "state.jsonl") + ds = create_dataset(args.input_path, state, args.split) # keep only the specified columns - ds = ds.select_columns(keep_columns.split(",")) + if args.keep_columns is not None: + ds = ds.select_columns(args.keep_columns.split(",")) # batch dataloader - data_loader = ds.batch(batch_size=batch_size) + data_loader = ds.batch(batch_size=args.batch_size) - if use_wandb: + if args.use_wandb: config_parameters = dict(locals()) config_parameters.pop("use_wandb") - wandb.init(project=wandb_project, name=wandb_run_name, config=config_parameters) + wandb.init( + project=args.wandb_project, + name=args.wandb_run_name, + config=config_parameters, + ) - with open(prompt_template_path) as f: + with open(args.prompt_template_path) as f: data = yaml.safe_load(f) template = data["prompt"] - if model_id == "deepl": - translator = DeeplTranslator(target_lang) - elif model_id in ["gpt-4o-mini-2024-07-18", "gpt-4o-2024-07-18"]: + if args.model_id == "deepl": + translator = DeeplTranslator(args.target_lang) + elif args.model_id in ["gpt-4o-mini-2024-07-18", "gpt-4o-2024-07-18"]: translator = OpenAIAPITranslator( - model_id, template, temperature, top_p, max_tokens + args.model_id, template, args.temperature, args.top_p, args.max_tokens ) else: translator = Translator( - model_id, - tensor_parallel_size, - pipeline_parallel_size, + args.model_id, + args.tensor_parallel_size, + args.pipeline_parallel_size, template, - temperature, - top_p, - max_tokens, + args.temperature, + args.top_p, + args.max_tokens, ) dataset_buffer = Dataset.from_dict({}) for examples in data_loader: start_time = time.time() - text_list = examples[source_column] + text_list = examples[args.source_column] translated = translator.translate(text_list) # store to buffer dataset_buffer = datasets.concatenate_datasets( [ dataset_buffer, - datasets.Dataset.from_dict({**examples, target_column: translated}), + datasets.Dataset.from_dict( + {**examples, args.target_column: translated} + ), ] ) state.total_processed_examples += len(text_list) @@ -184,22 +211,25 @@ def main( # write shards to output_dir if the buffer is full # e.g number_sample_per_shard = 100, len(dataset_buffer) = 1024 # 1024 // 100 = 10 shards will be written to output_dir - if len(dataset_buffer) >= number_sample_per_shard: - for i in range(len(dataset_buffer) // number_sample_per_shard): + if len(dataset_buffer) >= args.number_sample_per_shard: + for i in range(len(dataset_buffer) // args.number_sample_per_shard): shard_dict = dataset_buffer[ - i * number_sample_per_shard : (i + 1) * number_sample_per_shard + i * args.number_sample_per_shard : (i + 1) + * args.number_sample_per_shard ] shard_ds = Dataset.from_dict(shard_dict) - state = write_shard(shard_ds, output_dir, output_format, state) + state = write_shard( + shard_ds, args.output_dir, args.output_format, state + ) state.current_shard_id += 1 state.save_state(state_path) dataset_buffer = Dataset.from_dict( dataset_buffer[ len(dataset_buffer) - // number_sample_per_shard - * number_sample_per_shard : + // args.number_sample_per_shard + * args.number_sample_per_shard : ] ) @@ -214,25 +244,25 @@ def main( # write the remaining examples if len(dataset_buffer) > 0: - state = write_shard(dataset_buffer, output_dir, output_format, state) + state = write_shard(dataset_buffer, args.output_dir, args.output_format, state) state.save_state(state_path) - if push_to_hub: - if output_format == "jsonl" or output_format == "json": + if args.push_to_hub: + if args.output_format == "jsonl" or args.output_format == "json": # jsonl without state.jsonl - files = os.listdir(output_dir) + files = os.listdir(args.output_dir) if "state.jsonl" in files: files.remove("state.jsonl") # Sort files by shard id to keep the order. files.sort(key=lambda x: int(x.split(".")[0])) translated_ds = load_dataset( - "json", data_files=[os.path.join(output_dir, f) for f in files] + "json", data_files=[os.path.join(args.output_dir, f) for f in files] ) - elif output_format == "parquet": + elif args.output_format == "parquet": translated_ds = load_dataset( - "parquet", data_files=os.path.join(output_dir, "*.parquet") + "parquet", data_files=os.path.join(args.output_dir, "*.parquet") ) - translated_ds.push_to_hub(push_to_hub_path, private=True) + translated_ds.push_to_hub(args.push_to_hub_path, private=True) if __name__ == "__main__":