Skip to content

Commit

Permalink
Merge pull request #37 from rlywtf/fix_session_handling
Browse files Browse the repository at this point in the history
Fix session handling
  • Loading branch information
johndotpub authored Feb 3, 2024
2 parents 1c03fa0 + bc22735 commit 5194b4d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/flake8-pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --max-complexity=20 --statistics --max-line-length=99
flake8 . --count --max-complexity=10 --statistics --max-line-length=99
- name: Test with pytest
run: |
pytest
69 changes: 50 additions & 19 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,26 +105,31 @@ async def check_rate_limit(
user: discord.User,
logger: logging.Logger = None
) -> bool:
"""
Check if a user has exceeded the rate limit for sending messages.
"""
if logger is None:
logger = logging.getLogger(__name__)

try:
if logger is None:
logger = logging.getLogger(__name__)
"""
Check if a user has exceeded the rate limit for sending messages.
"""
current_time = time.time()
last_command_timestamp = last_command_timestamps.get(user.id, 0)
last_command_count_user = last_command_count.get(user.id, 0)

if current_time - last_command_timestamp > RATE_LIMIT_PER:
last_command_timestamps[user.id] = current_time
last_command_count[user.id] = 1
logger.info(f"Rate limit passed for user: {user}")
return True

if last_command_count_user < RATE_LIMIT:
last_command_count[user.id] += 1
logger.info(f"Rate limit passed for user: {user}")
return True

logger.info(f"Rate limit exceeded for user: {user}")
return False

except Exception as e:
logger.error(f"Error checking rate limit: {e}")
raise
Expand Down Expand Up @@ -161,16 +166,19 @@ async def process_input_message(
# Log the current conversation history
# logger.info(f"Current conversation history: {conversation}")

response = client.chat.completions.create(
model=GPT_MODEL,
messages=[
{"role": "system", "content": SYSTEM_MESSAGE},
*conversation_summary,
{"role": "user", "content": input_message}
],
max_tokens=max_tokens,
temperature=0.7
)
def call_openai_api():
return client.chat.completions.create(
model=GPT_MODEL,
messages=[
{"role": "system", "content": SYSTEM_MESSAGE},
*conversation_summary,
{"role": "user", "content": input_message}
],
max_tokens=max_tokens,
temperature=0.7
)

response = await asyncio.to_thread(call_openai_api)

try:
# Extracting the response content from the new API response format
Expand Down Expand Up @@ -212,8 +220,8 @@ async def process_input_message(
return "An error occurred while processing the message."


# Execute the argparse code only when the file is run directly
if __name__ == "__main__":
# Executes the argparse code only when the file is run directly
if __name__ == "__main__": # noqa: C901 (ignore complexity in main function)
# Parse command-line arguments
args = parse_arguments()

Expand Down Expand Up @@ -311,6 +319,27 @@ async def on_ready():
status=discord.Status(BOT_PRESENCE)
)

@bot.event
async def on_disconnect():
"""
Event handler for when the bot disconnects from the Discord server.
"""
logger.info('Bot has disconnected')

@bot.event
async def on_resumed():
"""
Event handler for when the bot resumes its session.
"""
logger.info('Bot has resumed session')

@bot.event
async def on_shard_ready(shard_id):
"""
Event handler for when a shard is ready.
"""
logger.info(f'Shard {shard_id} is ready')

@bot.event
async def on_message(message):
"""
Expand Down Expand Up @@ -341,8 +370,9 @@ async def process_dm_message(message):

if not await check_rate_limit(message.author):
await message.channel.send(
"Command on cooldown. Please wait before using it again."
f"{message.author.mention} Exceeded the Rate Limit! Please slow down!"
)
logger.warning(f"Rate Limit Exceed by DM from {message.author}")
return

conversation_summary = get_conversation_summary(
Expand All @@ -367,8 +397,9 @@ async def process_channel_message(message):

if not await check_rate_limit(message.author):
await message.channel.send(
"Command on cooldown. Please wait before using it again."
f"{message.author.mention} Exceeded the Rate Limit! Please slow down!"
)
logger.warning(f"Rate Limit Exceeded in {message.channel} by {message.author}")
return

conversation_summary = get_conversation_summary(
Expand Down

0 comments on commit 5194b4d

Please sign in to comment.