Skip to content

Commit

Permalink
Fix session handling
Browse files Browse the repository at this point in the history
We need to handle session disconnects, resuming, and changes in
shard states.

This fixes some of the errors we were seeing with:
 - Shard ID None WebSocket closed
 - Shard ID None heartbeat blocked for more than

This also makes a new `call_openai_api` function definition and
turns the response into an async call on a thread.

The rate limit function was cleaned up while I was here.

The Flake8 complexity was reduced to its default of 10 and the main
function has this check disabled now.
  • Loading branch information
johndotpub committed Feb 3, 2024
1 parent 1c03fa0 commit bc22735
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/flake8-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --max-complexity=20 --statistics --max-line-length=99
flake8 . --count --max-complexity=10 --statistics --max-line-length=99
- name: Test with pytest
run: |
pytest
69 changes: 50 additions & 19 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,26 +105,31 @@ async def check_rate_limit(
user: discord.User,
logger: logging.Logger = None
) -> bool:
"""
Check if a user has exceeded the rate limit for sending messages.
"""
if logger is None:
logger = logging.getLogger(__name__)

try:
if logger is None:
logger = logging.getLogger(__name__)
"""
Check if a user has exceeded the rate limit for sending messages.
"""
current_time = time.time()
last_command_timestamp = last_command_timestamps.get(user.id, 0)
last_command_count_user = last_command_count.get(user.id, 0)

if current_time - last_command_timestamp > RATE_LIMIT_PER:
last_command_timestamps[user.id] = current_time
last_command_count[user.id] = 1
logger.info(f"Rate limit passed for user: {user}")
return True

if last_command_count_user < RATE_LIMIT:
last_command_count[user.id] += 1
logger.info(f"Rate limit passed for user: {user}")
return True

logger.info(f"Rate limit exceeded for user: {user}")
return False

except Exception as e:
logger.error(f"Error checking rate limit: {e}")
raise
Expand Down Expand Up @@ -161,16 +166,19 @@ async def process_input_message(
# Log the current conversation history
# logger.info(f"Current conversation history: {conversation}")

response = client.chat.completions.create(
model=GPT_MODEL,
messages=[
{"role": "system", "content": SYSTEM_MESSAGE},
*conversation_summary,
{"role": "user", "content": input_message}
],
max_tokens=max_tokens,
temperature=0.7
)
def call_openai_api():
return client.chat.completions.create(
model=GPT_MODEL,
messages=[
{"role": "system", "content": SYSTEM_MESSAGE},
*conversation_summary,
{"role": "user", "content": input_message}
],
max_tokens=max_tokens,
temperature=0.7
)

response = await asyncio.to_thread(call_openai_api)

try:
# Extracting the response content from the new API response format
Expand Down Expand Up @@ -212,8 +220,8 @@ async def process_input_message(
return "An error occurred while processing the message."


# Execute the argparse code only when the file is run directly
if __name__ == "__main__":
# Executes the argparse code only when the file is run directly
if __name__ == "__main__": # noqa: C901 (ignore complexity in main function)
# Parse command-line arguments
args = parse_arguments()

Expand Down Expand Up @@ -311,6 +319,27 @@ async def on_ready():
status=discord.Status(BOT_PRESENCE)
)

@bot.event
async def on_disconnect():
"""
Event handler for when the bot disconnects from the Discord server.
"""
logger.info('Bot has disconnected')

@bot.event
async def on_resumed():
"""
Event handler for when the bot resumes its session.
"""
logger.info('Bot has resumed session')

@bot.event
async def on_shard_ready(shard_id):
"""
Event handler for when a shard is ready.
"""
logger.info(f'Shard {shard_id} is ready')

@bot.event
async def on_message(message):
"""
Expand Down Expand Up @@ -341,8 +370,9 @@ async def process_dm_message(message):

if not await check_rate_limit(message.author):
await message.channel.send(
"Command on cooldown. Please wait before using it again."
f"{message.author.mention} Exceeded the Rate Limit! Please slow down!"
)
logger.warning(f"Rate Limit Exceed by DM from {message.author}")
return

conversation_summary = get_conversation_summary(
Expand All @@ -367,8 +397,9 @@ async def process_channel_message(message):

if not await check_rate_limit(message.author):
await message.channel.send(
"Command on cooldown. Please wait before using it again."
f"{message.author.mention} Exceeded the Rate Limit! Please slow down!"
)
logger.warning(f"Rate Limit Exceeded in {message.channel} by {message.author}")
return

conversation_summary = get_conversation_summary(
Expand Down

0 comments on commit bc22735

Please sign in to comment.