Skip to content

Commit

Permalink
vision
Browse files Browse the repository at this point in the history
  • Loading branch information
OlegGoless committed Apr 2, 2024
1 parent 299fb57 commit 3b8b0b2
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 45 deletions.
196 changes: 189 additions & 7 deletions bot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import database
import openai_utils

import base64

# setup
db = database.Database()
Expand Down Expand Up @@ -177,6 +178,168 @@ async def retry_handle(update: Update, context: CallbackContext):

await message_handle(update, context, message=last_dialog_message["user"], use_new_dialog_timeout=False)

async def _vision_message_handle_fn(
update: Update, context: CallbackContext, use_new_dialog_timeout: bool = True
):
logger.info('_vision_message_handle_fn')
user_id = update.message.from_user.id
current_model = db.get_user_attribute(user_id, "current_model")

if current_model != "gpt-4-vision-preview":
await update.message.reply_text(
"🥲 Images processing is only available for <b>gpt-4-vision-preview</b> model. Please change your settings in /settings",
parse_mode=ParseMode.HTML,
)
return

chat_mode = db.get_user_attribute(user_id, "current_chat_mode")

# new dialog timeout
if use_new_dialog_timeout:
if (datetime.now() - db.get_user_attribute(user_id, "last_interaction")).seconds > config.new_dialog_timeout and len(db.get_dialog_messages(user_id)) > 0:
db.start_new_dialog(user_id)
await update.message.reply_text(f"Starting new dialog due to timeout (<b>{config.chat_modes[chat_mode]['name']}</b> mode) ✅", parse_mode=ParseMode.HTML)
db.set_user_attribute(user_id, "last_interaction", datetime.now())

buf = None
if update.message.effective_attachment:
photo = update.message.effective_attachment[-1]
photo_file = await context.bot.get_file(photo.file_id)

# store file in memory, not on disk
buf = io.BytesIO()
await photo_file.download_to_memory(buf)
buf.name = "image.jpg" # file extension is required
buf.seek(0) # move cursor to the beginning of the buffer

# in case of CancelledError
n_input_tokens, n_output_tokens = 0, 0

try:
# send placeholder message to user
placeholder_message = await update.message.reply_text("...")
message = update.message.caption or update.message.text

# send typing action
await update.message.chat.send_action(action="typing")

if message is None or len(message) == 0:
await update.message.reply_text(
"🥲 You sent <b>empty message</b>. Please, try again!",
parse_mode=ParseMode.HTML,
)
return

dialog_messages = db.get_dialog_messages(user_id, dialog_id=None)
parse_mode = {"html": ParseMode.HTML, "markdown": ParseMode.MARKDOWN}[
config.chat_modes[chat_mode]["parse_mode"]
]

chatgpt_instance = openai_utils.ChatGPT(model=current_model)
if config.enable_message_streaming:
gen = chatgpt_instance.send_vision_message_stream(
message,
dialog_messages=dialog_messages,
image_buffer=buf,
chat_mode=chat_mode,
)
else:
(
answer,
(n_input_tokens, n_output_tokens),
n_first_dialog_messages_removed,
) = await chatgpt_instance.send_vision_message(
message,
dialog_messages=dialog_messages,
image_buffer=buf,
chat_mode=chat_mode,
)

async def fake_gen():
yield "finished", answer, (
n_input_tokens,
n_output_tokens,
), n_first_dialog_messages_removed

gen = fake_gen()

prev_answer = ""
async for gen_item in gen:
(
status,
answer,
(n_input_tokens, n_output_tokens),
n_first_dialog_messages_removed,
) = gen_item
answer = current_model + " " + answer
answer = answer[:4096] # telegram message limit

# update only when 100 new symbols are ready
if abs(len(answer) - len(prev_answer)) < 100 and status != "finished":
continue

try:
await context.bot.edit_message_text(
answer,
chat_id=placeholder_message.chat_id,
message_id=placeholder_message.message_id,
parse_mode=parse_mode,
)
except telegram.error.BadRequest as e:
if str(e).startswith("Message is not modified"):
continue
else:
await context.bot.edit_message_text(
answer,
chat_id=placeholder_message.chat_id,
message_id=placeholder_message.message_id,
)

await asyncio.sleep(0.01) # wait a bit to avoid flooding

prev_answer = answer

# update user data
if buf is not None:
base_image = base64.b64encode(buf.getvalue()).decode("utf-8")
new_dialog_message = {"user": [
{
"type": "text",
"text": message,
},
{
"type": "image",
"image": base_image,
}
]
, "bot": answer, "date": datetime.now()}
else:
new_dialog_message = {"user": [{"type": "text", "text": message}], "bot": answer, "date": datetime.now()}

db.set_dialog_messages(
user_id,
db.get_dialog_messages(user_id, dialog_id=None) + [new_dialog_message],
dialog_id=None
)

db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens)

except asyncio.CancelledError:
# note: intermediate token updates only work when enable_message_streaming=True (config.yml)
db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens)
raise

except Exception as e:
error_text = f"Something went wrong during completion. Reason: {e}"
logger.error(error_text)
await update.message.reply_text(error_text)
return

async def unsupport_message_handle(update: Update, context: CallbackContext, message=None):
error_text = f"I don't know how to read files or videos. Send the picture in normal mode (Quick Mode)."
logger.error(error_text)
await update.message.reply_text(error_text)
return

async def message_handle(update: Update, context: CallbackContext, message=None, use_new_dialog_timeout=True):
# check if bot was mentioned (for group chats)
Expand Down Expand Up @@ -204,6 +367,8 @@ async def message_handle(update: Update, context: CallbackContext, message=None,
await generate_image_handle(update, context, message=message)
return

current_model = db.get_user_attribute(user_id, "current_model")

async def message_handle_fn():
# new dialog timeout
if use_new_dialog_timeout:
Expand All @@ -214,7 +379,6 @@ async def message_handle_fn():

# in case of CancelledError
n_input_tokens, n_output_tokens = 0, 0
current_model = db.get_user_attribute(user_id, "current_model")

try:
# send placeholder message to user
Expand Down Expand Up @@ -249,11 +413,12 @@ async def fake_gen():
gen = fake_gen()

prev_answer = ""

async for gen_item in gen:
status, answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed = gen_item

answer = current_model + " " + answer
answer = answer[:4096] # telegram message limit

# update only when 100 new symbols are ready
if abs(len(answer) - len(prev_answer)) < 100 and status != "finished":
continue
Expand All @@ -267,11 +432,12 @@ async def fake_gen():
await context.bot.edit_message_text(answer, chat_id=placeholder_message.chat_id, message_id=placeholder_message.message_id)

await asyncio.sleep(0.01) # wait a bit to avoid flooding

prev_answer = answer

# update user data
new_dialog_message = {"user": _message, "bot": answer, "date": datetime.now()}

db.set_dialog_messages(
user_id,
db.get_dialog_messages(user_id, dialog_id=None) + [new_dialog_message],
Expand Down Expand Up @@ -300,7 +466,19 @@ async def fake_gen():
await update.message.reply_text(text, parse_mode=ParseMode.HTML)

async with user_semaphores[user_id]:
task = asyncio.create_task(message_handle_fn())
if current_model == "gpt-4-vision-preview" or update.message.photo is not None and len(update.message.photo) > 0:
logger.error('gpt-4-vision-preview')
if current_model != "gpt-4-vision-preview":
current_model = "gpt-4-vision-preview"
db.set_user_attribute(user_id, "current_model", "gpt-4-vision-preview")
task = asyncio.create_task(
_vision_message_handle_fn(update, context, use_new_dialog_timeout=use_new_dialog_timeout)
)
else:
task = asyncio.create_task(
message_handle_fn()
)

user_tasks[user_id] = task

try:
Expand Down Expand Up @@ -392,6 +570,7 @@ async def new_dialog_handle(update: Update, context: CallbackContext):

user_id = update.message.from_user.id
db.set_user_attribute(user_id, "last_interaction", datetime.now())
db.set_user_attribute(user_id, "current_model", "gpt-3.5-turbo")

db.start_new_dialog(user_id)
await update.message.reply_text("Starting new dialog ✅")
Expand Down Expand Up @@ -672,6 +851,9 @@ def run_bot() -> None:
application.add_handler(CommandHandler("help_group_chat", help_group_chat_handle, filters=user_filter))

application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND & user_filter, message_handle))
application.add_handler(MessageHandler(filters.PHOTO & ~filters.COMMAND & user_filter, message_handle))
application.add_handler(MessageHandler(filters.VIDEO & ~filters.COMMAND & user_filter, unsupport_message_handle))
application.add_handler(MessageHandler(filters.Document.ALL & ~filters.COMMAND & user_filter, unsupport_message_handle))
application.add_handler(CommandHandler("retry", retry_handle, filters=user_filter))
application.add_handler(CommandHandler("new", new_dialog_handle, filters=user_filter))
application.add_handler(CommandHandler("cancel", cancel_handle, filters=user_filter))
Expand All @@ -694,4 +876,4 @@ def run_bot() -> None:


if __name__ == "__main__":
run_bot()
run_bot()
Loading

0 comments on commit 3b8b0b2

Please sign in to comment.