Skip to content

Commit

Permalink
Add loading state
Browse files Browse the repository at this point in the history
  • Loading branch information
kaarthik108 committed Feb 7, 2024
1 parent 3b269d3 commit 31e4eee
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ def execute_sql(query, conn, retries=2):
and st.session_state["messages"][-1]["role"] != "assistant"
):
user_input_content = st.session_state["messages"][-1]["content"]
# print(f"User input content is: {user_input_content}")

if isinstance(user_input_content, str):
callback_handler.start_loading_message()

result = chain.invoke(
{
"question": user_input_content,
Expand Down
42 changes: 22 additions & 20 deletions utils/snowchat_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,38 @@ def message_func(text, is_user=False, is_df=False):

class StreamlitUICallbackHandler(BaseCallbackHandler):
def __init__(self):
# Buffer to accumulate tokens
self.token_buffer = []
self.placeholder = st.empty()
self.has_streaming_ended = False
self.has_streaming_started = False

def start_loading_message(self):
loading_message_content = self._get_bot_message_container("Thinking...")
self.placeholder.markdown(loading_message_content, unsafe_allow_html=True)

def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
if not self.has_streaming_started:
self.has_streaming_started = True

self.token_buffer.append(token)
complete_message = "".join(self.token_buffer)
container_content = self._get_bot_message_container(complete_message)
self.placeholder.markdown(container_content, unsafe_allow_html=True)

def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs):
self.token_buffer = []
self.has_streaming_ended = True
self.has_streaming_started = False

def _get_bot_message_container(self, text):
"""Generate the bot's message container style for the given text."""
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light"
message_alignment = "flex-start"
message_bg_color = "#71797E"
avatar_class = "bot-avatar"
formatted_text = format_message(text)
formatted_text = format_message(
text
) # Ensure this handles "Thinking..." appropriately.
container_content = f"""
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 50px; height: 50px;" />
Expand All @@ -105,17 +125,6 @@ def _get_bot_message_container(self, text):
"""
return container_content

def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
"""
Handle the new token from the model. Accumulate tokens in a buffer and update the Streamlit UI.
"""
self.token_buffer.append(token)
complete_message = "".join(self.token_buffer)

# Update the placeholder content with the complete message
container_content = self._get_bot_message_container(complete_message)
self.placeholder.markdown(container_content, unsafe_allow_html=True)

def display_dataframe(self, df):
"""
Display the dataframe in Streamlit UI within the chat container.
Expand All @@ -134,12 +143,5 @@ def display_dataframe(self, df):
)
st.write(df)

def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs):
"""
Reset the buffer when the LLM finishes running.
"""
self.token_buffer = [] # Reset the buffer
self.has_streaming_ended = True

def __call__(self, *args, **kwargs):
pass

0 comments on commit 31e4eee

Please sign in to comment.