Skip to content
This repository has been archived by the owner on Oct 12, 2023. It is now read-only.

Commit

Permalink
release v0.1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 17, 2023
1 parent 7e6a5eb commit 2578d15
Show file tree
Hide file tree
Showing 14 changed files with 214 additions and 144 deletions.
2 changes: 1 addition & 1 deletion src/glmtuner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from glmtuner.webui import create_ui


__version__ = "0.1.1"
__version__ = "0.1.2"
10 changes: 4 additions & 6 deletions src/glmtuner/extras/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@

LAYERNORM_NAMES = ["layernorm"]

METHODS = ["full", "freeze", "p_tuning", "lora"]

SUPPORTED_MODELS = {
"ChatGLM-6B": {
"hf_path": "THUDM/chatglm-6b"
},
"ChatGLM2-6B": {
"hf_path": "THUDM/chatglm2-6b"
}
"ChatGLM-6B": "THUDM/chatglm-6b",
"ChatGLM2-6B": "THUDM/chatglm2-6b"
}
4 changes: 3 additions & 1 deletion src/glmtuner/extras/ploting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import math
import json
import matplotlib.pyplot as plt
from typing import List, Optional
Expand All @@ -10,12 +11,13 @@
logger = get_logger(__name__)


def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
def smooth(scalars: List[float]) -> List[float]:
r"""
EMA implementation according to TensorBoard.
"""
last = scalars[0]
smoothed = list()
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val)
Expand Down
36 changes: 17 additions & 19 deletions src/glmtuner/webui/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import List, Tuple

from glmtuner.chat.stream_chat import ChatModel
from glmtuner.extras.constants import SUPPORTED_MODELS
from glmtuner.extras.misc import torch_gc
from glmtuner.hparams import GeneratingArguments
from glmtuner.tuner import get_infer_args
from glmtuner.webui.common import get_save_dir
from glmtuner.webui.common import get_model_path, get_save_dir
from glmtuner.webui.locales import ALERTS


class WebChatModel(ChatModel):
Expand All @@ -17,32 +17,30 @@ def __init__(self):
self.generating_args = GeneratingArguments()

def load_model(
self, lang: str, model_name: str, model_path: str, checkpoints: list,
self, lang: str, model_name: str, checkpoints: list,
finetuning_type: str, quantization_bit: str
):
if self.model is not None:
yield "You have loaded a model, please unload it first."
yield ALERTS["err_exists"][lang]
return

if not model_name:
yield "Please select a model."
yield ALERTS["err_no_model"][lang]
return

if model_path:
if not os.path.isdir(model_path):
return None, "Cannot find model directory in local disk.", None, None
model_name_or_path = model_path
elif model_name in SUPPORTED_MODELS: # TODO: use list in gr.State
model_name_or_path = SUPPORTED_MODELS[model_name]["hf_path"]
else:
return None, "Invalid model.", None, None
model_name_or_path = get_model_path(model_name)
if not model_name_or_path:
yield ALERTS["err_no_path"][lang]
return

if checkpoints:
checkpoint_dir = ",".join([os.path.join(get_save_dir(model_name), checkpoint) for checkpoint in checkpoints])
checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
)
else:
checkpoint_dir = None

yield "Loading model..."
yield ALERTS["info_loading"][lang]
args = dict(
model_name_or_path=model_name_or_path,
finetuning_type=finetuning_type,
Expand All @@ -51,14 +49,14 @@ def load_model(
)
super().__init__(*get_infer_args(args))

yield "Model loaded, now you can chat with your model."
yield ALERTS["info_loaded"][lang]

def unload_model(self):
yield "Unloading model..."
def unload_model(self, lang: str):
yield ALERTS["info_unloading"][lang]
self.model = None
self.tokenizer = None
torch_gc()
yield "Model unloaded, please load a model first."
yield ALERTS["info_unloaded"][lang]

def predict(
self,
Expand Down
50 changes: 25 additions & 25 deletions src/glmtuner/webui/common.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,54 @@
import json
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Optional

import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME

from glmtuner.extras.constants import SUPPORTED_MODELS


DEFAULT_CACHE_DIR = "cache"
DEFAULT_DATA_DIR = "data"
DEFAULT_SAVE_DIR = "saves"
USER_CONFIG = "user.config"
DATA_CONFIG = "dataset_info.json"
METHODS = ["full", "freeze", "p_tuning", "lora"]


def get_config_path():
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
def get_save_dir(model_name: str) -> str:
return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])


def load_config() -> Dict[str, str]:
if not os.path.exists(get_config_path()):
return {}
def get_config_path() -> os.PathLike:
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)


with open(get_config_path(), "r", encoding="utf-8") as f:
try:
user_config = json.load(f)
return user_config
except:
return {}
def load_config() -> Dict[str, Any]:
try:
with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f)
except:
return {"last_model": "", "path_dict": {}}


def save_config(model_name: str, model_path: str) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = dict(model_name=model_name, model_path=model_path)
user_config = load_config()
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, ensure_ascii=False)


def get_save_dir(model_name: str) -> str:
return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])
json.dump(user_config, f, indent=2, ensure_ascii=False)


def list_models(model_list: list) -> dict: # TODO: unused
return gr.update(value="", choices=[name for name, _ in model_list])
def get_model_path(model_name: str) -> str:
user_config = load_config()
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))


def list_checkpoints(model_name: str) -> dict:
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = get_save_dir(model_name)
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
if save_dir and os.path.isdir(save_dir):
for checkpoint in os.listdir(save_dir):
if (
Expand All @@ -70,6 +70,6 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
return {}


def list_datasets(dataset_dir: Optional[str] = None) -> Union[dict, List[str]]: # TODO: remove union
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]:
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
return gr.update(value=[], choices=list(dataset_info.keys())) if dataset_dir is not None else list(dataset_info.keys())
return gr.update(value=[], choices=list(dataset_info.keys()))
6 changes: 3 additions & 3 deletions src/glmtuner/webui/components/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from typing import Tuple


def create_preview_box() -> Tuple[Block, Component, Component]:
def create_preview_box() -> Tuple[Block, Component, Component, Component]:
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
with gr.Row():
preview_count = gr.Number(interactive=False)

with gr.Row():
preview_samples = gr.JSON(interactive=False)

close_btn = gr.Button("Close")
close_btn = gr.Button()

close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])

return preview_box, preview_count, preview_samples
return preview_box, preview_count, preview_samples, close_btn
17 changes: 8 additions & 9 deletions src/glmtuner/webui/components/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@
import gradio as gr
from gradio.components import Component

from glmtuner.webui.common import list_datasets, DEFAULT_DATA_DIR, METHODS
from glmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from glmtuner.webui.components.data import create_preview_box
from glmtuner.webui.runner import Runner
from glmtuner.webui.utils import can_preview, get_preview


def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=1)
dataset = gr.Dropdown(choices=list_datasets(), multiselect=True, interactive=True, scale=4)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, interactive=True, scale=2)
dataset = gr.Dropdown(multiselect=True, interactive=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1)

preview_box, preview_count, preview_samples = create_preview_box()
preview_box, preview_count, preview_samples, close_btn = create_preview_box()

dataset_dir.change(list_datasets, [dataset_dir], [dataset])
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])

Expand All @@ -36,20 +35,20 @@ def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str
start_btn.click(
runner.run_eval,
[
top_elems["lang"], top_elems["model_name"], top_elems["model_path"], top_elems["checkpoints"],
finetuning_type, dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["finetuning_type"],
dataset, dataset_dir, max_samples, batch_size, quantization_bit, predict
],
[output_box]
)
stop_btn.click(runner.set_abort, queue=False)

return dict(
finetuning_type=finetuning_type,
dataset_dir=dataset_dir,
dataset=dataset,
preview_btn=preview_btn,
preview_count=preview_count,
preview_samples=preview_samples,
close_btn=close_btn,
max_samples=max_samples,
batch_size=batch_size,
quantization_bit=quantization_bit,
Expand Down
21 changes: 8 additions & 13 deletions src/glmtuner/webui/components/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,40 @@
from gradio.components import Component

from glmtuner.webui.chat import WebChatModel
from glmtuner.webui.common import METHODS
from glmtuner.webui.components.chatbot import create_chat_box


def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button()
quantization_bit = gr.Dropdown([8, 4])

info_box = gr.Markdown()

chat_model = WebChatModel()
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)

with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, interactive=True, scale=1)
quantization_bit = gr.Dropdown([8, 4], label="Quantization bit", info="Quantize model to 4/8-bit mode.")

with gr.Row():
load_btn = gr.Button()
unload_btn = gr.Button("Unload model")

load_btn.click(
chat_model.load_model,
[
top_elems["lang"], top_elems["model_name"], top_elems["model_path"], top_elems["checkpoints"],
finetuning_type, quantization_bit
top_elems["lang"], top_elems["model_name"], top_elems["checkpoints"], top_elems["finetuning_type"],
quantization_bit
],
[info_box]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
)

unload_btn.click(
chat_model.unload_model, outputs=[info_box]
chat_model.unload_model, [top_elems["lang"]], [info_box]
).then(
lambda: ([], []), outputs=[chatbot, history]
).then(
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
)

return dict(
finetuning_type=finetuning_type,
quantization_bit=quantization_bit,
info_box=info_box,
load_btn=load_btn,
Expand Down
Loading

0 comments on commit 2578d15

Please sign in to comment.