From 5098f8562ca22dd85b800a6f499ab5534f1f037b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87=20=28Yang=20Peiwen=29?= <915505626@qq.com> Date: Sun, 23 Apr 2023 15:58:30 +0800 Subject: [PATCH] Regenerate (#26) * Update app.py * Update base.py * Update app.py * Update app.py * Update app.py * Update app.py * Update app.py * Update app.py * Update app.py * Update app.py --- app.py | 28 +++++++++++++++++++++++----- predictors/base.py | 6 +++++- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/app.py b/app.py index f03c3b6..7a4c788 100644 --- a/app.py +++ b/app.py @@ -33,9 +33,12 @@ def revise(history, latest_message): return history, '' -def revoke(history): +def revoke(history, last_state): if len(history) >= 1: history.pop() + last_state[0] = history + last_state[1] = '' + last_state[2] = '' return history @@ -43,6 +46,16 @@ def interrupt(allow_generate): allow_generate[0] = False +def regenerate(last_state, max_length, top_p, temperature, allow_generate): + history, query, continue_message = last_state + if len(query) == 0: + print("Please input a query first.") + return + for x in predictor.predict_continue(query, continue_message, max_length, top_p, + temperature, allow_generate, history, last_state): + yield x + + # 搭建 UI 界面 with gr.Blocks(css=""".message { width: inherit !important; @@ -61,7 +74,7 @@ def interrupt(allow_generate): """) with gr.Row(): with gr.Column(scale=4): - chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800) + chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=850) with gr.Column(scale=1): with gr.Row(): max_length = gr.Slider(32, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) @@ -78,21 +91,26 @@ def interrupt(allow_generate): show_label=False, placeholder="Revise message", lines=2).style(container=False) revise_btn = gr.Button("修订") revoke_btn = gr.Button("撤回") + regenerate_btn = gr.Button("重新生成") interrupt_btn = gr.Button("终止生成") history = gr.State([]) allow_generate = gr.State([True]) blank_input = gr.State("") + last_state = gr.State([[], '', '']) # history, query, continue_message generate_button.click( predictor.predict_continue, - inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history], + inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history, last_state], outputs=[chatbot, query]) revise_btn.click(revise, inputs=[history, revise_message], outputs=[chatbot, revise_message]) - revoke_btn.click(revoke, inputs=[history], outputs=[chatbot]) + revoke_btn.click(revoke, inputs=[history, last_state], outputs=[chatbot]) continue_btn.click( predictor.predict_continue, - inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history], + inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history, last_state], outputs=[chatbot, query, continue_message]) + regenerate_btn.click(regenerate, inputs=[last_state, max_length, top_p, temperature, allow_generate], + outputs=[chatbot, query, continue_message]) interrupt_btn.click(interrupt, inputs=[allow_generate]) + demo.queue(concurrency_count=4).launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False) demo.close() diff --git a/predictors/base.py b/predictors/base.py index 9ce8c45..56948cf 100644 --- a/predictors/base.py +++ b/predictors/base.py @@ -1,3 +1,4 @@ +import copy from abc import ABC, abstractmethod @@ -27,8 +28,11 @@ def stream_chat_continue(self, *args, **kwargs): raise NotImplementedError def predict_continue(self, query, latest_message, max_length, top_p, - temperature, allow_generate, history, *args, + temperature, allow_generate, history, last_state, *args, **kwargs): + last_state[0] = copy.deepcopy(history) + last_state[1] = query + last_state[2] = latest_message if history is None: history = [] allow_generate[0] = True