From 0fccf308be6814c667b13f343796921a63f2ecb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 14:25:02 +0800 Subject: [PATCH 01/10] Update app.py --- app.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/app.py b/app.py index f03c3b6..877097a 100644 --- a/app.py +++ b/app.py @@ -28,14 +28,17 @@ predictor = ChatGLM(model_name) -def revise(history, latest_message): +def revise(history, latest_message, last_state): history[-1] = (history[-1][0], latest_message) + last_state[0] = history + last_state[2] = latest_message return history, '' def revoke(history): if len(history) >= 1: history.pop() + last_state[0] = history return history @@ -43,6 +46,13 @@ def interrupt(allow_generate): allow_generate[0] = False +def regenerate(last_state, max_length, top_p, temperature, allow_generate): + history, query, continue_message = last_state + for x in predictor.predict_continue(query, continue_message, max_length, top_p, + temperature, allow_generate, history, last_state): + yield x + + # 搭建 UI 界面 with gr.Blocks(css=""".message { width: inherit !important; @@ -79,20 +89,24 @@ def interrupt(allow_generate): revise_btn = gr.Button("修订") revoke_btn = gr.Button("撤回") interrupt_btn = gr.Button("终止生成") + regenerate_btn = gr.Button("重新生成") history = gr.State([]) allow_generate = gr.State([True]) blank_input = gr.State("") + last_state = gr.State([[], '', '']) # history, query, continue_message generate_button.click( predictor.predict_continue, - inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history], + inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history, last_state], outputs=[chatbot, query]) - revise_btn.click(revise, inputs=[history, revise_message], outputs=[chatbot, revise_message]) - revoke_btn.click(revoke, inputs=[history], outputs=[chatbot]) + revise_btn.click(revise, inputs=[history, revise_message, last_state], outputs=[chatbot, revise_message]) + revoke_btn.click(revoke, inputs=[history, last_state], outputs=[chatbot]) continue_btn.click( predictor.predict_continue, - inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history], + inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history, last_state], outputs=[chatbot, query, continue_message]) interrupt_btn.click(interrupt, inputs=[allow_generate]) + regenerate_btn.click(regenerate, inputs=[last_state, max_length, top_p, temperature, allow_generate], + outputs=[chatbot, query, continue_message]) demo.queue(concurrency_count=4).launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False) demo.close() From b3e4414c4c4402fe15f3d0ab690484ae5d4e5204 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 14:25:06 +0800 Subject: [PATCH 02/10] Update base.py --- predictors/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/predictors/base.py b/predictors/base.py index 9ce8c45..56948cf 100644 --- a/predictors/base.py +++ b/predictors/base.py @@ -1,3 +1,4 @@ +import copy from abc import ABC, abstractmethod @@ -27,8 +28,11 @@ def stream_chat_continue(self, *args, **kwargs): raise NotImplementedError def predict_continue(self, query, latest_message, max_length, top_p, - temperature, allow_generate, history, *args, + temperature, allow_generate, history, last_state, *args, **kwargs): + last_state[0] = copy.deepcopy(history) + last_state[1] = query + last_state[2] = latest_message if history is None: history = [] allow_generate[0] = True From 5ef44668750c316e3acc3da552e6625f1709c36d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 14:28:19 +0800 Subject: [PATCH 03/10] Update app.py --- app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.py b/app.py index 877097a..715061b 100644 --- a/app.py +++ b/app.py @@ -35,7 +35,7 @@ def revise(history, latest_message, last_state): return history, '' -def revoke(history): +def revoke(history, last_state): if len(history) >= 1: history.pop() last_state[0] = history From 2d0c373f4539f64880e5cc6e878b6196308aa861 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 14:29:15 +0800 Subject: [PATCH 04/10] Update app.py --- app.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index 715061b..4003e16 100644 --- a/app.py +++ b/app.py @@ -71,7 +71,7 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate): """) with gr.Row(): with gr.Column(scale=4): - chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800) + chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=900) with gr.Column(scale=1): with gr.Row(): max_length = gr.Slider(32, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) @@ -105,8 +105,9 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate): predictor.predict_continue, inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history, last_state], outputs=[chatbot, query, continue_message]) - interrupt_btn.click(interrupt, inputs=[allow_generate]) regenerate_btn.click(regenerate, inputs=[last_state, max_length, top_p, temperature, allow_generate], outputs=[chatbot, query, continue_message]) + interrupt_btn.click(interrupt, inputs=[allow_generate]) + demo.queue(concurrency_count=4).launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False) demo.close() From 9a87e392909aeb0932df715189e7bd5a57f1e91e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 15:09:08 +0800 Subject: [PATCH 05/10] Update app.py --- app.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/app.py b/app.py index 4003e16..c5deeb1 100644 --- a/app.py +++ b/app.py @@ -31,7 +31,8 @@ def revise(history, latest_message, last_state): history[-1] = (history[-1][0], latest_message) last_state[0] = history - last_state[2] = latest_message + last_state[1] = '' + last_state[2] = '' return history, '' @@ -39,6 +40,8 @@ def revoke(history, last_state): if len(history) >= 1: history.pop() last_state[0] = history + last_state[1] = '' + last_state[2] = '' return history From 1888a2200c4de41f632f1e94db1dac9ce2809ee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 15:13:55 +0800 Subject: [PATCH 06/10] Update app.py --- app.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/app.py b/app.py index c5deeb1..4f8038d 100644 --- a/app.py +++ b/app.py @@ -28,11 +28,8 @@ predictor = ChatGLM(model_name) -def revise(history, latest_message, last_state): +def revise(history, latest_message): history[-1] = (history[-1][0], latest_message) - last_state[0] = history - last_state[1] = '' - last_state[2] = '' return history, '' From 7b8426a56e23ff7521adfdb9caad304c69191526 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 15:23:54 +0800 Subject: [PATCH 07/10] Update app.py --- app.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app.py b/app.py index 4f8038d..134fccf 100644 --- a/app.py +++ b/app.py @@ -48,6 +48,8 @@ def interrupt(allow_generate): def regenerate(last_state, max_length, top_p, temperature, allow_generate): history, query, continue_message = last_state + if len(query) == 0: + raise gr.Error("Please input a query first.") for x in predictor.predict_continue(query, continue_message, max_length, top_p, temperature, allow_generate, history, last_state): yield x From 288c39d8aa80475d151693845e7a6c2459e2e13b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 15:28:12 +0800 Subject: [PATCH 08/10] Update app.py --- app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.py b/app.py index 134fccf..5d3e9cd 100644 --- a/app.py +++ b/app.py @@ -49,7 +49,7 @@ def interrupt(allow_generate): def regenerate(last_state, max_length, top_p, temperature, allow_generate): history, query, continue_message = last_state if len(query) == 0: - raise gr.Error("Please input a query first.") + print("Please input a query first.") for x in predictor.predict_continue(query, continue_message, max_length, top_p, temperature, allow_generate, history, last_state): yield x From 9b0b99fa332719cc62b2e69ea40205fc0c5fc0c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 15:29:39 +0800 Subject: [PATCH 09/10] Update app.py --- app.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/app.py b/app.py index 5d3e9cd..f945a08 100644 --- a/app.py +++ b/app.py @@ -50,6 +50,7 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate): history, query, continue_message = last_state if len(query) == 0: print("Please input a query first.") + return for x in predictor.predict_continue(query, continue_message, max_length, top_p, temperature, allow_generate, history, last_state): yield x @@ -90,8 +91,8 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate): show_label=False, placeholder="Revise message", lines=2).style(container=False) revise_btn = gr.Button("修订") revoke_btn = gr.Button("撤回") - interrupt_btn = gr.Button("终止生成") regenerate_btn = gr.Button("重新生成") + interrupt_btn = gr.Button("终止生成") history = gr.State([]) allow_generate = gr.State([True]) @@ -101,7 +102,7 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate): predictor.predict_continue, inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history, last_state], outputs=[chatbot, query]) - revise_btn.click(revise, inputs=[history, revise_message, last_state], outputs=[chatbot, revise_message]) + revise_btn.click(revise, inputs=[history, revise_message], outputs=[chatbot, revise_message]) revoke_btn.click(revoke, inputs=[history, last_state], outputs=[chatbot]) continue_btn.click( predictor.predict_continue, From c398eeaad0226a03f9249f0226a45d5f6c31cf37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 15:32:49 +0800 Subject: [PATCH 10/10] Update app.py --- app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app.py b/app.py index f945a08..7a4c788 100644 --- a/app.py +++ b/app.py @@ -74,7 +74,7 @@ def regenerate(last_state, max_length, top_p, temperature, allow_generate): """) with gr.Row(): with gr.Column(scale=4): - chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=900) + chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=850) with gr.Column(scale=1): with gr.Row(): max_length = gr.Slider(32, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)