Skip to content

Commit

Permalink
Regenerate (#26)
Browse files Browse the repository at this point in the history
* Update app.py

* Update base.py

* Update app.py

* Update app.py

* Update app.py

* Update app.py

* Update app.py

* Update app.py

* Update app.py

* Update app.py
  • Loading branch information
ypwhs authored Apr 23, 2023
1 parent e08de4e commit 5098f85
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
28 changes: 23 additions & 5 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,29 @@ def revise(history, latest_message):
return history, ''


def revoke(history):
def revoke(history, last_state):
if len(history) >= 1:
history.pop()
last_state[0] = history
last_state[1] = ''
last_state[2] = ''
return history


def interrupt(allow_generate):
allow_generate[0] = False


def regenerate(last_state, max_length, top_p, temperature, allow_generate):
history, query, continue_message = last_state
if len(query) == 0:
print("Please input a query first.")
return
for x in predictor.predict_continue(query, continue_message, max_length, top_p,
temperature, allow_generate, history, last_state):
yield x


# 搭建 UI 界面
with gr.Blocks(css=""".message {
width: inherit !important;
Expand All @@ -61,7 +74,7 @@ def interrupt(allow_generate):
""")
with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=800)
chatbot = gr.Chatbot(elem_id="chat-box", show_label=False).style(height=850)
with gr.Column(scale=1):
with gr.Row():
max_length = gr.Slider(32, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
Expand All @@ -78,21 +91,26 @@ def interrupt(allow_generate):
show_label=False, placeholder="Revise message", lines=2).style(container=False)
revise_btn = gr.Button("修订")
revoke_btn = gr.Button("撤回")
regenerate_btn = gr.Button("重新生成")
interrupt_btn = gr.Button("终止生成")

history = gr.State([])
allow_generate = gr.State([True])
blank_input = gr.State("")
last_state = gr.State([[], '', '']) # history, query, continue_message
generate_button.click(
predictor.predict_continue,
inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history],
inputs=[query, blank_input, max_length, top_p, temperature, allow_generate, history, last_state],
outputs=[chatbot, query])
revise_btn.click(revise, inputs=[history, revise_message], outputs=[chatbot, revise_message])
revoke_btn.click(revoke, inputs=[history], outputs=[chatbot])
revoke_btn.click(revoke, inputs=[history, last_state], outputs=[chatbot])
continue_btn.click(
predictor.predict_continue,
inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history],
inputs=[query, continue_message, max_length, top_p, temperature, allow_generate, history, last_state],
outputs=[chatbot, query, continue_message])
regenerate_btn.click(regenerate, inputs=[last_state, max_length, top_p, temperature, allow_generate],
outputs=[chatbot, query, continue_message])
interrupt_btn.click(interrupt, inputs=[allow_generate])

demo.queue(concurrency_count=4).launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False)
demo.close()
6 changes: 5 additions & 1 deletion predictors/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from abc import ABC, abstractmethod


Expand Down Expand Up @@ -27,8 +28,11 @@ def stream_chat_continue(self, *args, **kwargs):
raise NotImplementedError

def predict_continue(self, query, latest_message, max_length, top_p,
temperature, allow_generate, history, *args,
temperature, allow_generate, history, last_state, *args,
**kwargs):
last_state[0] = copy.deepcopy(history)
last_state[1] = query
last_state[2] = latest_message
if history is None:
history = []
allow_generate[0] = True
Expand Down

0 comments on commit 5098f85

Please sign in to comment.