From e5cc8768db623a9e51dcd4cab260e4f8c9f318cb Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 18 Jul 2023 20:40:10 +0800 Subject: [PATCH] add dev set in web UI --- src/glmtuner/webui/components/eval.py | 4 ++-- src/glmtuner/webui/components/sft.py | 7 +++++-- src/glmtuner/webui/locales.py | 10 ++++++++++ src/glmtuner/webui/runner.py | 8 ++++++++ 4 files changed, 25 insertions(+), 4 deletions(-) diff --git a/src/glmtuner/webui/components/eval.py b/src/glmtuner/webui/components/eval.py index 7fe16e5..879c6dd 100644 --- a/src/glmtuner/webui/components/eval.py +++ b/src/glmtuner/webui/components/eval.py @@ -21,8 +21,8 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) with gr.Row(): - max_source_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1) - max_target_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1) + max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) + max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) max_samples = gr.Textbox(value="100000") batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1) predict = gr.Checkbox(value=True) diff --git a/src/glmtuner/webui/components/sft.py b/src/glmtuner/webui/components/sft.py index a5d83bb..2757194 100644 --- a/src/glmtuner/webui/components/sft.py +++ b/src/glmtuner/webui/components/sft.py @@ -23,8 +23,8 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) with gr.Row(): - max_source_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1) - max_target_length = gr.Slider(value=512, minimum=3, maximum=4096, step=1) + max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) + max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) learning_rate = gr.Textbox(value="5e-5") num_train_epochs = gr.Textbox(value="3.0") max_samples = gr.Textbox(value="100000") @@ -35,6 +35,7 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, lr_scheduler_type = gr.Dropdown( value="cosine", choices=[scheduler.value for scheduler in SchedulerType] ) + dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) fp16 = gr.Checkbox(value=True) with gr.Row(): @@ -72,6 +73,7 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, batch_size, gradient_accumulation_steps, lr_scheduler_type, + dev_ratio, fp16, logging_steps, save_steps, @@ -100,6 +102,7 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, lr_scheduler_type=lr_scheduler_type, + dev_ratio=dev_ratio, fp16=fp16, logging_steps=logging_steps, save_steps=save_steps, diff --git a/src/glmtuner/webui/locales.py b/src/glmtuner/webui/locales.py index cf41496..6800f57 100644 --- a/src/glmtuner/webui/locales.py +++ b/src/glmtuner/webui/locales.py @@ -209,6 +209,16 @@ "info": "采用的学习率调节器名称。" } }, + "dev_ratio": { + "en": { + "label": "Dev ratio", + "info": "Proportion of data in the dev set." + }, + "zh": { + "label": "验证集比例", + "info": "验证集占全部样本的百分比。" + } + }, "fp16": { "en": { "label": "fp16", diff --git a/src/glmtuner/webui/runner.py b/src/glmtuner/webui/runner.py index e355f3d..97dd63d 100644 --- a/src/glmtuner/webui/runner.py +++ b/src/glmtuner/webui/runner.py @@ -75,6 +75,7 @@ def run_train( batch_size: int, gradient_accumulation_steps: int, lr_scheduler_type: str, + dev_ratio: float, fp16: bool, logging_steps: int, save_steps: int, @@ -115,6 +116,13 @@ def run_train( save_steps=save_steps, output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir) ) + + if dev_ratio > 1e-6: + args["dev_ratio"] = dev_ratio + args["evaluation_strategy"] = "steps" + args["eval_steps"] = save_steps + args["load_best_model_at_end"] = True + model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) run_args = dict(