From a9dbbe6d68dfb62da878fa735684bdd201f29fec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87=20=28Yang=20Peiwen=29?= <915505626@qq.com> Date: Sat, 1 Apr 2023 23:12:08 +0800 Subject: [PATCH] Fix chatglm (#13) * Update chatglm2.py * Update app.py * parse_codeblock --- app.py | 4 ++-- predictors/base.py | 14 ++++++++++++++ predictors/{chatglm.py => chatglm2.py} | 9 +++++---- 3 files changed, 21 insertions(+), 6 deletions(-) rename predictors/{chatglm.py => chatglm2.py} (92%) diff --git a/app.py b/app.py index 4ab3f21..1800fe9 100644 --- a/app.py +++ b/app.py @@ -14,7 +14,7 @@ # model_name = 'BelleGroup/BELLE-LLAMA-7B-2M-gptq' if 'chatglm' in model_name.lower(): - from predictors.chatglm import ChatGLM + from predictors.chatglm2 import ChatGLM predictor = ChatGLM(model_name) elif 'gptq' in model_name.lower(): from predictors.llama_gptq import LLaMaGPTQ @@ -26,7 +26,7 @@ from predictors.debug import Debug predictor = Debug(model_name) else: - from predictors.chatglm import ChatGLM + from predictors.chatglm2 import ChatGLM predictor = ChatGLM(model_name) diff --git a/predictors/base.py b/predictors/base.py index 2cd4ea0..9ce8c45 100644 --- a/predictors/base.py +++ b/predictors/base.py @@ -1,6 +1,20 @@ from abc import ABC, abstractmethod +def parse_codeblock(text): + lines = text.split("\n") + for i, line in enumerate(lines): + if "```" in line: + if line != "```": + lines[i] = f'
'
+            else:
+                lines[i] = '
' + else: + if i > 0: + lines[i] = "
" + line.replace("<", "<").replace(">", ">") + return "".join(lines) + + class BasePredictor(ABC): @abstractmethod diff --git a/predictors/chatglm.py b/predictors/chatglm2.py similarity index 92% rename from predictors/chatglm.py rename to predictors/chatglm2.py index fa6b90a..a1d6012 100644 --- a/predictors/chatglm.py +++ b/predictors/chatglm2.py @@ -5,7 +5,8 @@ from transformers import AutoModel, AutoTokenizer from transformers import LogitsProcessor, LogitsProcessorList -from predictors.base import BasePredictor +from predictors.base import BasePredictor, parse_codeblock +from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration class InvalidScoreLogitsProcessor(LogitsProcessor): @@ -27,7 +28,7 @@ def __init__(self, model_name): self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, resume_download=True) if 'int4' not in model_name: - model = AutoModel.from_pretrained( + model = ChatGLMForConditionalGeneration.from_pretrained( model_name, trust_remote_code=True, resume_download=True, @@ -36,7 +37,7 @@ def __init__(self, model_name): device_map={'': self.device} ) else: - model = AutoModel.from_pretrained( + model = ChatGLMForConditionalGeneration.from_pretrained( model_name, trust_remote_code=True, resume_download=True @@ -105,4 +106,4 @@ def stream_chat_continue(self, outputs = outputs.tolist()[0][input_length:] response = tokenizer.decode(outputs) response = model.process_response(response) - yield response + yield parse_codeblock(response)