From 2578d15bc65b483ddfaf056560d6f7a8bffd2700 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 17 Jul 2023 23:14:16 +0800 Subject: [PATCH] release v0.1.2 --- src/glmtuner/__init__.py | 2 +- src/glmtuner/extras/constants.py | 10 ++-- src/glmtuner/extras/ploting.py | 4 +- src/glmtuner/webui/chat.py | 36 ++++++------- src/glmtuner/webui/common.py | 50 ++++++++--------- src/glmtuner/webui/components/data.py | 6 +-- src/glmtuner/webui/components/eval.py | 17 +++--- src/glmtuner/webui/components/infer.py | 21 +++----- src/glmtuner/webui/components/sft.py | 23 ++++---- src/glmtuner/webui/components/top.py | 22 ++++---- src/glmtuner/webui/locales.py | 74 +++++++++++++++++++++++--- src/glmtuner/webui/manager.py | 21 +++++++- src/glmtuner/webui/runner.py | 61 +++++++++++---------- src/glmtuner/webui/utils.py | 11 ++-- 14 files changed, 214 insertions(+), 144 deletions(-) diff --git a/src/glmtuner/__init__.py b/src/glmtuner/__init__.py index de0171e..7a3d7d6 100644 --- a/src/glmtuner/__init__.py +++ b/src/glmtuner/__init__.py @@ -4,4 +4,4 @@ from glmtuner.webui import create_ui -__version__ = "0.1.1" +__version__ = "0.1.2" diff --git a/src/glmtuner/extras/constants.py b/src/glmtuner/extras/constants.py index 71d1de1..96f2214 100644 --- a/src/glmtuner/extras/constants.py +++ b/src/glmtuner/extras/constants.py @@ -6,11 +6,9 @@ LAYERNORM_NAMES = ["layernorm"] +METHODS = ["full", "freeze", "p_tuning", "lora"] + SUPPORTED_MODELS = { - "ChatGLM-6B": { - "hf_path": "THUDM/chatglm-6b" - }, - "ChatGLM2-6B": { - "hf_path": "THUDM/chatglm2-6b" - } + "ChatGLM-6B": "THUDM/chatglm-6b", + "ChatGLM2-6B": "THUDM/chatglm2-6b" } diff --git a/src/glmtuner/extras/ploting.py b/src/glmtuner/extras/ploting.py index 78a71ad..880e491 100644 --- a/src/glmtuner/extras/ploting.py +++ b/src/glmtuner/extras/ploting.py @@ -1,4 +1,5 @@ import os +import math import json import matplotlib.pyplot as plt from typing import List, Optional @@ -10,12 +11,13 @@ logger = get_logger(__name__) -def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: +def smooth(scalars: List[float]) -> List[float]: r""" EMA implementation according to TensorBoard. """ last = scalars[0] smoothed = list() + weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function for next_val in scalars: smoothed_val = last * weight + (1 - weight) * next_val smoothed.append(smoothed_val) diff --git a/src/glmtuner/webui/chat.py b/src/glmtuner/webui/chat.py index 3b0605c..570b860 100644 --- a/src/glmtuner/webui/chat.py +++ b/src/glmtuner/webui/chat.py @@ -2,11 +2,11 @@ from typing import List, Tuple from glmtuner.chat.stream_chat import ChatModel -from glmtuner.extras.constants import SUPPORTED_MODELS from glmtuner.extras.misc import torch_gc from glmtuner.hparams import GeneratingArguments from glmtuner.tuner import get_infer_args -from glmtuner.webui.common import get_save_dir +from glmtuner.webui.common import get_model_path, get_save_dir +from glmtuner.webui.locales import ALERTS class WebChatModel(ChatModel): @@ -17,32 +17,30 @@ def __init__(self): self.generating_args = GeneratingArguments() def load_model( - self, lang: str, model_name: str, model_path: str, checkpoints: list, + self, lang: str, model_name: str, checkpoints: list, finetuning_type: str, quantization_bit: str ): if self.model is not None: - yield "You have loaded a model, please unload it first." + yield ALERTS["err_exists"][lang] return if not model_name: - yield "Please select a model." + yield ALERTS["err_no_model"][lang] return - if model_path: - if not os.path.isdir(model_path): - return None, "Cannot find model directory in local disk.", None, None - model_name_or_path = model_path - elif model_name in SUPPORTED_MODELS: # TODO: use list in gr.State - model_name_or_path = SUPPORTED_MODELS[model_name]["hf_path"] - else: - return None, "Invalid model.", None, None + model_name_or_path = get_model_path(model_name) + if not model_name_or_path: + yield ALERTS["err_no_path"][lang] + return if checkpoints: - checkpoint_dir = ",".join([os.path.join(get_save_dir(model_name), checkpoint) for checkpoint in checkpoints]) + checkpoint_dir = ",".join( + [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] + ) else: checkpoint_dir = None - yield "Loading model..." + yield ALERTS["info_loading"][lang] args = dict( model_name_or_path=model_name_or_path, finetuning_type=finetuning_type, @@ -51,14 +49,14 @@ def load_model( ) super().__init__(*get_infer_args(args)) - yield "Model loaded, now you can chat with your model." + yield ALERTS["info_loaded"][lang] - def unload_model(self): - yield "Unloading model..." + def unload_model(self, lang: str): + yield ALERTS["info_unloading"][lang] self.model = None self.tokenizer = None torch_gc() - yield "Model unloaded, please load a model first." + yield ALERTS["info_unloaded"][lang] def predict( self, diff --git a/src/glmtuner/webui/common.py b/src/glmtuner/webui/common.py index 9729c6a..6a9d867 100644 --- a/src/glmtuner/webui/common.py +++ b/src/glmtuner/webui/common.py @@ -1,54 +1,54 @@ import json import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional import gradio as gr from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from glmtuner.extras.constants import SUPPORTED_MODELS + DEFAULT_CACHE_DIR = "cache" DEFAULT_DATA_DIR = "data" DEFAULT_SAVE_DIR = "saves" USER_CONFIG = "user.config" DATA_CONFIG = "dataset_info.json" -METHODS = ["full", "freeze", "p_tuning", "lora"] -def get_config_path(): - return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) +def get_save_dir(model_name: str) -> str: + return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1]) -def load_config() -> Dict[str, str]: - if not os.path.exists(get_config_path()): - return {} +def get_config_path() -> os.PathLike: + return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) + - with open(get_config_path(), "r", encoding="utf-8") as f: - try: - user_config = json.load(f) - return user_config - except: - return {} +def load_config() -> Dict[str, Any]: + try: + with open(get_config_path(), "r", encoding="utf-8") as f: + return json.load(f) + except: + return {"last_model": "", "path_dict": {}} def save_config(model_name: str, model_path: str) -> None: os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) - user_config = dict(model_name=model_name, model_path=model_path) + user_config = load_config() + user_config["last_model"] = model_name + user_config["path_dict"][model_name] = model_path with open(get_config_path(), "w", encoding="utf-8") as f: - json.dump(user_config, f, ensure_ascii=False) - - -def get_save_dir(model_name: str) -> str: - return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1]) + json.dump(user_config, f, indent=2, ensure_ascii=False) -def list_models(model_list: list) -> dict: # TODO: unused - return gr.update(value="", choices=[name for name, _ in model_list]) +def get_model_path(model_name: str) -> str: + user_config = load_config() + return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, "")) -def list_checkpoints(model_name: str) -> dict: +def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: checkpoints = [] - save_dir = get_save_dir(model_name) + save_dir = os.path.join(get_save_dir(model_name), finetuning_type) if save_dir and os.path.isdir(save_dir): for checkpoint in os.listdir(save_dir): if ( @@ -70,6 +70,6 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: return {} -def list_datasets(dataset_dir: Optional[str] = None) -> Union[dict, List[str]]: # TODO: remove union +def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]: dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) - return gr.update(value=[], choices=list(dataset_info.keys())) if dataset_dir is not None else list(dataset_info.keys()) + return gr.update(value=[], choices=list(dataset_info.keys())) diff --git a/src/glmtuner/webui/components/data.py b/src/glmtuner/webui/components/data.py index d338b09..4445f39 100644 --- a/src/glmtuner/webui/components/data.py +++ b/src/glmtuner/webui/components/data.py @@ -4,7 +4,7 @@ from typing import Tuple -def create_preview_box() -> Tuple[Block, Component, Component]: +def create_preview_box() -> Tuple[Block, Component, Component, Component]: with gr.Box(visible=False, elem_classes="modal-box") as preview_box: with gr.Row(): preview_count = gr.Number(interactive=False) @@ -12,8 +12,8 @@ def create_preview_box() -> Tuple[Block, Component, Component]: with gr.Row(): preview_samples = gr.JSON(interactive=False) - close_btn = gr.Button("Close") + close_btn = gr.Button() close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box]) - return preview_box, preview_count, preview_samples + return preview_box, preview_count, preview_samples, close_btn diff --git a/src/glmtuner/webui/components/eval.py b/src/glmtuner/webui/components/eval.py index b78ab41..3067d6d 100644 --- a/src/glmtuner/webui/components/eval.py +++ b/src/glmtuner/webui/components/eval.py @@ -2,7 +2,7 @@ import gradio as gr from gradio.components import Component -from glmtuner.webui.common import list_datasets, DEFAULT_DATA_DIR, METHODS +from glmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from glmtuner.webui.components.data import create_preview_box from glmtuner.webui.runner import Runner from glmtuner.webui.utils import can_preview, get_preview @@ -10,14 +10,13 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: with gr.Row(): - finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1) - dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1) - dataset = gr.Dropdown(choices=list_datasets(), multiselect=True, interactive=True, scale=4) + dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2) + dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4) preview_btn = gr.Button(interactive=False, scale=1) - preview_box, preview_count, preview_samples = create_preview_box() + preview_box, preview_count, preview_samples, close_btn = create_preview_box() - dataset_dir.change(list_datasets, [dataset_dir], [dataset]) + dataset_dir.change(list_dataset, [dataset_dir], [dataset]) dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) @@ -36,20 +35,20 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str start_btn.click( runner.run_eval, [ - top_elems["lang"], top_elems["model_name"], top_elems["model_path"], top_elems["checkpoints"], - finetuning_type, dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict + top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["finetuning_type"], + dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict ], [output_box] ) stop_btn.click(runner.set_abort, queue=False) return dict( - finetuning_type=finetuning_type, dataset_dir=dataset_dir, dataset=dataset, preview_btn=preview_btn, preview_count=preview_count, preview_samples=preview_samples, + close_btn=close_btn, max_samples=max_samples, batch_size=batch_size, quantization_bit=quantization_bit, diff --git a/src/glmtuner/webui/components/infer.py b/src/glmtuner/webui/components/infer.py index 6791350..bcbddf4 100644 --- a/src/glmtuner/webui/components/infer.py +++ b/src/glmtuner/webui/components/infer.py @@ -4,29 +4,25 @@ from gradio.components import Component from glmtuner.webui.chat import WebChatModel -from glmtuner.webui.common import METHODS from glmtuner.webui.components.chatbot import create_chat_box def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: + with gr.Row(): + load_btn = gr.Button() + unload_btn = gr.Button() + quantization_bit = gr.Dropdown([8, 4]) + info_box = gr.Markdown() chat_model = WebChatModel() chat_box, chatbot, history, chat_elems = create_chat_box(chat_model) - with gr.Row(): - finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1) - quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Quantize model to 4/8-bit mode.") - - with gr.Row(): - load_btn = gr.Button() - unload_btn = gr.Button("Unload model") - load_btn.click( chat_model.load_model, [ - top_elems["lang"], top_elems["model_name"], top_elems["model_path"], top_elems["checkpoints"], - finetuning_type, quantization_bit + top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["finetuning_type"], + quantization_bit ], [info_box] ).then( @@ -34,7 +30,7 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: ) unload_btn.click( - chat_model.unload_model, outputs=[info_box] + chat_model.unload_model, [top_elems["lang"]], [info_box] ).then( lambda: ([], []), outputs=[chatbot, history] ).then( @@ -42,7 +38,6 @@ def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: ) return dict( - finetuning_type=finetuning_type, quantization_bit=quantization_bit, info_box=info_box, load_btn=load_btn, diff --git a/src/glmtuner/webui/components/sft.py b/src/glmtuner/webui/components/sft.py index fd9bbb7..fde09a2 100644 --- a/src/glmtuner/webui/components/sft.py +++ b/src/glmtuner/webui/components/sft.py @@ -4,22 +4,21 @@ import gradio as gr from gradio.components import Component -from glmtuner.webui.common import list_datasets, DEFAULT_DATA_DIR, METHODS +from glmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from glmtuner.webui.components.data import create_preview_box from glmtuner.webui.runner import Runner -from glmtuner.webui.utils import can_preview, get_preview, get_time, gen_plot +from glmtuner.webui.utils import can_preview, get_preview, gen_plot def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: with gr.Row(): - finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1) - dataset = gr.Dropdown(choices=list_datasets(), multiselect=True, interactive=True, scale=4) + dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4) preview_btn = gr.Button(interactive=False, scale=1) - preview_box, preview_count, preview_samples = create_preview_box() + preview_box, preview_count, preview_samples, close_btn = create_preview_box() - dataset_dir.change(list_datasets, [dataset_dir], [dataset]) + dataset_dir.change(list_dataset, [dataset_dir], [dataset]) dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) @@ -47,7 +46,7 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, with gr.Row(): with gr.Column(scale=4): - output_dir = gr.Textbox(value=get_time(), interactive=True) + output_dir = gr.Textbox(interactive=True) output_box = gr.Markdown() with gr.Column(scale=1): @@ -56,8 +55,8 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, start_btn.click( runner.run_train, [ - top_elems["lang"], top_elems["model_name"], top_elems["model_path"], top_elems["checkpoints"], - finetuning_type, dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, + top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["finetuning_type"], + dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, fp16, quantization_bit, batch_size, gradient_accumulation_steps, lr_scheduler_type, logging_steps, save_steps, output_dir ], @@ -65,15 +64,17 @@ def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, ) stop_btn.click(runner.set_abort, queue=False) - output_box.change(gen_plot, [top_elems["model_name"], output_dir], loss_viewer, queue=False) + output_box.change( + gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False + ) return dict( - finetuning_type=finetuning_type, dataset_dir=dataset_dir, dataset=dataset, preview_btn=preview_btn, preview_count=preview_count, preview_samples=preview_samples, + close_btn=close_btn, learning_rate=learning_rate, num_train_epochs=num_train_epochs, max_samples=max_samples, diff --git a/src/glmtuner/webui/components/top.py b/src/glmtuner/webui/components/top.py index 66ff22a..ea99fff 100644 --- a/src/glmtuner/webui/components/top.py +++ b/src/glmtuner/webui/components/top.py @@ -3,35 +3,37 @@ import gradio as gr from gradio.components import Component -from glmtuner.extras.constants import SUPPORTED_MODELS -from glmtuner.webui.common import list_checkpoints, load_config, save_config +from glmtuner.extras.constants import METHODS, SUPPORTED_MODELS +from glmtuner.webui.common import list_checkpoint, get_model_path, save_config def create_top() -> Dict[str, Component]: - user_config = load_config() # TODO: use model list available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] with gr.Row(): lang = gr.Dropdown(choices=["en", "zh"], value="en", interactive=True, scale=1) - model_name = gr.Dropdown(choices=available_models, value=user_config.get("model_name", None), scale=2) - model_path = gr.Textbox(value=user_config.get("model_path", None), scale=2) + model_name = gr.Dropdown(choices=available_models, scale=3) + model_path = gr.Textbox(scale=3) with gr.Row(): + finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1) checkpoints = gr.Dropdown(multiselect=True, interactive=True, scale=5) refresh_btn = gr.Button(scale=1) model_name.change( - list_checkpoints, [model_name], [checkpoints] - ).then( # TODO: save model list - lambda: gr.update(value=""), outputs=[model_path] - ) + list_checkpoint, [model_name, finetuning_type], [checkpoints] + ).then( + get_model_path, [model_name], [model_path] + ) # do not save config since the below line will save model_path.change(save_config, [model_name, model_path]) - refresh_btn.click(list_checkpoints, [model_name], [checkpoints]) + finetuning_type.change(list_checkpoint, [model_name, finetuning_type], [checkpoints]) + refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints]) return dict( lang=lang, model_name=model_name, model_path=model_path, + finetuning_type=finetuning_type, checkpoints=checkpoints, refresh_btn=refresh_btn ) diff --git a/src/glmtuner/webui/locales.py b/src/glmtuner/webui/locales.py index 8b6ec1a..9758bab 100644 --- a/src/glmtuner/webui/locales.py +++ b/src/glmtuner/webui/locales.py @@ -17,12 +17,12 @@ }, "model_path": { "en": { - "label": "Local path (Optional)", - "info": "Absolute path of the model directory." + "label": "Model path", + "info": "Path to pretrained model or model identifier from Hugging Face." }, "zh": { - "label": "模型路径(可选项)", - "info": "模型文件夹在磁盘上的绝对路径。" + "label": "模型路径", + "info": "本地模型的文件路径或 Hugging Face 的模型标识符。" } }, "checkpoints": { @@ -83,6 +83,14 @@ "label": "样例" } }, + "close_btn": { + "en": { + "value": "Close" + }, + "zh": { + "value": "关闭" + } + }, "max_samples": { "en": { "label": "Max samples", @@ -131,10 +139,10 @@ }, "output_box": { "en": { - "value": "Ready" + "value": "Ready." }, "zh": { - "value": "准备就绪" + "value": "准备就绪。" } }, "finetuning_type": { @@ -222,7 +230,7 @@ }, "zh": { "label": "断点名称", - "info": "保存模型断点的文件夹名称" + "info": "保存模型断点的文件夹名称。" } }, "loss_viewer": { @@ -314,3 +322,55 @@ } } } + + +ALERTS = { + "err_conflict": { + "en": "A process is in running, please abort it firstly.", + "zh": "任务已存在,请先中断训练。" + }, + "err_exists": { + "en": "You have loaded a model, please unload it first.", + "zh": "模型已存在,请先卸载模型。" + }, + "err_no_model": { + "en": "Please select a model.", + "zh": "请选择模型。" + }, + "err_no_path": { + "en": "Model not found.", + "zh": "模型未找到。" + }, + "err_no_dataset": { + "en": "Please choose a dataset.", + "zh": "请选择数据集。" + }, + "info_aborting": { + "en": "Aborted, wait for terminating...", + "zh": "训练中断,正在等待线程结束……" + }, + "info_aborted": { + "en": "Ready.", + "zh": "准备就绪。" + }, + "info_finished": { + "en": "Finished.", + "zh": "训练完毕。" + }, + "info_loading": { + "en": "Loading model...", + "zh": "加载中……" + }, + "info_unloading": { + "en": "Unloading model...", + "zh": "卸载中……" + }, + "info_loaded": { + "en": "Model loaded, now you can chat with your model!", + "zh": "模型已加载,可以开始聊天了!" + }, + "info_unloaded": { + "en": "Model unloaded.", + "zh": "模型已卸载。" + } +} diff --git a/src/glmtuner/webui/manager.py b/src/glmtuner/webui/manager.py index 38fd6e3..5cda4c8 100644 --- a/src/glmtuner/webui/manager.py +++ b/src/glmtuner/webui/manager.py @@ -1,8 +1,10 @@ import gradio as gr -from typing import Dict, List +from typing import Any, Dict, List from gradio.components import Component +from glmtuner.webui.common import get_model_path, list_dataset, load_config from glmtuner.webui.locales import LOCALES +from glmtuner.webui.utils import get_time class Manager: @@ -10,9 +12,24 @@ class Manager: def __init__(self, elem_list: List[Dict[str, Component]]): self.elem_list = elem_list + def gen_refresh(self) -> Dict[str, Any]: + refresh_dict = { + "dataset": {"choices": list_dataset()["choices"]}, + "output_dir": {"value": get_time()} + } + user_config = load_config() + if user_config["last_model"]: + refresh_dict["model_name"] = {"value": user_config["last_model"]} + refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])} + + return refresh_dict + def gen_label(self, lang: str) -> Dict[Component, dict]: update_dict = {} + refresh_dict = self.gen_refresh() + for elems in self.elem_list: for name, component in elems.items(): - update_dict[component] = gr.update(**LOCALES[name][lang]) + update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {})) + return update_dict diff --git a/src/glmtuner/webui/runner.py b/src/glmtuner/webui/runner.py index 1c5b603..80e3a45 100644 --- a/src/glmtuner/webui/runner.py +++ b/src/glmtuner/webui/runner.py @@ -6,11 +6,11 @@ from typing import Optional, Tuple from glmtuner.extras.callbacks import LogCallback -from glmtuner.extras.constants import SUPPORTED_MODELS from glmtuner.extras.logging import LoggerHandler from glmtuner.extras.misc import torch_gc from glmtuner.tuner import get_train_args, run_sft -from glmtuner.webui.common import get_save_dir +from glmtuner.webui.common import get_model_path, get_save_dir +from glmtuner.webui.locales import ALERTS from glmtuner.webui.utils import format_info, get_eval_results @@ -24,24 +24,19 @@ def set_abort(self): self.aborted = True self.running = False - def initialize(self, model_name: str, model_path: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]: + def initialize(self, lang: str, model_name: str, dataset: list) -> Tuple[str, str, LoggerHandler, LogCallback]: if self.running: - return None, "A process is in running, please abort it firstly.", None, None + return None, ALERTS["err_conflict"][lang], None, None if not model_name: - return None, "Please select a model.", None, None - - if model_path: - if not os.path.isdir(model_path): - return None, "Cannot find model directory in local disk.", None, None - model_name_or_path = model_path - elif model_name in SUPPORTED_MODELS: # TODO: use list in gr.State - model_name_or_path = SUPPORTED_MODELS[model_name]["hf_path"] - else: - return None, "Invalid model.", None, None + return None, ALERTS["err_no_model"][lang], None, None + + model_name_or_path = get_model_path(model_name) + if not model_name_or_path: + return None, ALERTS["err_no_path"][lang], None, None if len(dataset) == 0: - return None, "Please choose datasets.", None, None + return None, ALERTS["err_no_dataset"][lang], None, None self.aborted = False self.running = True @@ -54,27 +49,29 @@ def initialize(self, model_name: str, model_path: str, dataset: list) -> Tuple[s return model_name_or_path, "", logger_handler, trainer_callback - def finalize(self, finish_info: Optional[str] = None) -> str: + def finalize(self, lang: str, finish_info: Optional[str] = None) -> str: self.running = False torch_gc() if self.aborted: - return "Ready" + return ALERTS["info_aborted"][lang] else: - return finish_info if finish_info is not None else "Finished" + return finish_info if finish_info is not None else ALERTS["info_finished"][lang] def run_train( - self, lang, model_name, model_path, checkpoints, finetuning_type, + self, lang, model_name, checkpoints, finetuning_type, dataset, dataset_dir, learning_rate, num_train_epochs, max_samples, fp16, quantization_bit, batch_size, gradient_accumulation_steps, lr_scheduler_type, logging_steps, save_steps, output_dir ): - model_name_or_path, error, logger_handler, trainer_callback = self.initialize(model_name, model_path, dataset) + model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) if error: yield error return if checkpoints: - checkpoint_dir = ",".join([os.path.join(get_save_dir(model_name), checkpoint) for checkpoint in checkpoints]) + checkpoint_dir = ",".join( + [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] + ) else: checkpoint_dir = None @@ -85,7 +82,7 @@ def run_train( dataset=",".join(dataset), dataset_dir=dataset_dir, max_samples=int(max_samples), - output_dir=os.path.join(get_save_dir(model_name), output_dir), + output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir), checkpoint_dir=checkpoint_dir, overwrite_cache=True, per_device_train_batch_size=batch_size, @@ -113,27 +110,29 @@ def run_train( while thread.is_alive(): time.sleep(1) if self.aborted: - yield "Aborted, wait for terminating..." + yield ALERTS["info_aborting"][lang] else: yield format_info(logger_handler.log, trainer_callback.tracker) - yield self.finalize() + yield self.finalize(lang) def run_eval( - self, lang, model_name, model_path, checkpoints, finetuning_type, + self, lang, model_name, checkpoints, finetuning_type, dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict ): - model_name_or_path, error, logger_handler, trainer_callback = self.initialize(model_name, model_path, dataset) + model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) if error: yield error return if checkpoints: - checkpoint_dir = ",".join([os.path.join(get_save_dir(model_name), checkpoint) for checkpoint in checkpoints]) - output_dir = os.path.join(get_save_dir(model_name), "eval_" + "_".join(checkpoints)) + checkpoint_dir = ",".join( + [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] + ) + output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints)) else: checkpoint_dir = None - output_dir = os.path.join(get_save_dir(model_name), "eval_base") + output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base") args = dict( model_name_or_path=model_name_or_path, @@ -169,8 +168,8 @@ def run_eval( while thread.is_alive(): time.sleep(1) if self.aborted: - yield "Aborted, wait for terminating..." + yield ALERTS["info_aborting"][lang] else: yield format_info(logger_handler.log, trainer_callback.tracker) - yield self.finalize(get_eval_results(os.path.join(output_dir, "eval_results.json"))) + yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "eval_results.json"))) diff --git a/src/glmtuner/webui/utils.py b/src/glmtuner/webui/utils.py index 21b643c..7d4042c 100644 --- a/src/glmtuner/webui/utils.py +++ b/src/glmtuner/webui/utils.py @@ -1,11 +1,10 @@ -import json import os -from datetime import datetime -from typing import Tuple - +import json import gradio as gr import matplotlib.figure import matplotlib.pyplot as plt +from typing import Tuple +from datetime import datetime from glmtuner.extras.ploting import smooth from glmtuner.webui.common import get_save_dir, DATA_CONFIG @@ -52,8 +51,8 @@ def get_eval_results(path: os.PathLike) -> str: return "```json\n{}\n```\n".format(result) -def gen_plot(base_model: str, output_dir: str) -> matplotlib.figure.Figure: - log_file = os.path.join(get_save_dir(base_model), output_dir, "trainer_log.jsonl") +def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure: + log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl") if not os.path.isfile(log_file): return None