Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Nov 19, 2024
2 parents 06d9885 + de57346 commit f15f840
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ Every triggered job is containerized so we don't have to worry too much about se

Instead of testing on GPU MODE directly we can leverage a staging environment called "Discord Cluster Staging". If you need access to this server please ping "Seraphim", however, you can also test the bot on your own server by following the instructions below.

The smoke test script in `tests/discord-bot-smoke-test.py` should be run to verify basic functionality of the cluster bot. For usage information, run with `python tests/discord-bot-smoke-test.py -h`. Run it against your own server (see below for instructions).

### How to add the bot to a personal server

For testing purposes, bot can be run on a personal server as well. Follow the steps [here](https://discordjs.guide/preparations/setting-up-a-bot-application.html#creating-your-bot) and [here](https://discordjs.guide/preparations/adding-your-bot-to-servers.html#bot-invite-links) to create a bot application and then add it to your server.
Expand Down
153 changes: 153 additions & 0 deletions tests/discord-bot-smoke-test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from dotenv import load_dotenv
import discord
import logging
import os
import argparse
import asyncio
import re
import sys

# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()
logger.info("Environment variables loaded")

parser = argparse.ArgumentParser(
description='Smoke Test for the Discord Cluster Bot',
epilog=f"""
This script can be used after deployment, or during development, to quickly
verify that basic functionality of the cluster bot is working correctly.
It should be run before further testing or usage of the bot.
Example usage:
python {os.path.basename(__file__)} https://discord.com/channels/123/456/789
The URL is the message link for some message that triggered the cluster bot.
To find this URL: click the 3 dots (...) to the right of the message,
then click 'Copy Message Link'.""",
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('message_url', type=str, help='Discord message URL to test')
args = parser.parse_args()

message_id = int(args.message_url.split('/')[-1])

# Client setup with minimal intents
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)

# Event that signals when async client tests are done
client_tests_done = asyncio.Event()

# Flag set to true if the thread tests pass
thread_tests_passed = False

async def verify_thread_messages():
"""
Test messages from a Discord thread identified by a message ID.
Side effect:
Sets thread_tests_passed to True if all happy path messages are found
"""

global thread_tests_passed

required_strings = [
"Found train.py! Starting training process",
"GitHub Action triggered successfully! Run ID:",
"Training completed with status: success",
".*```\nLogs.*:",
"View the full run at:",
]

message_contents = []
thread_found = False

# Iterate through guilds to find the thread by message ID
for guild in client.guilds:
try:
# Search for the thread using the message ID
for channel in guild.text_channels:
try:
message = await channel.fetch_message(message_id)
if message.thread:
thread_found = True
thread = message.thread
logger.info(f"Found thread: {thread.name}.")

# Fetch messages from the thread
message_contents = [
msg.content async for msg in thread.history(limit=None)
]
break

except discord.NotFound:
continue
except discord.Forbidden:
logger.warning(f"Bot does not have permission to access {channel.name}")
continue

except Exception as e:
logger.error(f"Error fetching thread: {e}", exc_info=True)

if thread_found:
# Already found the thread, so no need to continue to iterate
# through guilds
break

if message_contents:
all_strings_found = all(
any(re.match(req_str, contents, re.DOTALL) != None for contents in message_contents)
for req_str in required_strings
)

if all_strings_found:
thread_tests_passed = True
else:
logger.warning("Thread not found!")

if thread_tests_passed:
logger.info('All required strings were found in the thread.')
else:
logger.warning('Some required string was not found in the thread!')
logger.info('Thread contents were: ')
logger.info('\n'.join(f'\t{contents}' for contents in message_contents))

@client.event
async def on_ready():
await verify_thread_messages()

# We could add additional tests that use the client here if needed.

client_tests_done.set()
await client.close()

if __name__ == '__main__':
logger.info("Running smoke tests...")

token = os.getenv('DISCORD_TOKEN')
if not token:
logger.error('DISCORD_TOKEN environment variable not set.')
exit(1)

client.run(token)

async def await_client_tests():
await client_tests_done.wait()

asyncio.run(await_client_tests())

if not thread_tests_passed:
# If other tests are needed, add them above
# if (not thread_test_passed) or (not other_test_passed):
logger.warning("One or more tests failed!")
sys.exit(1)
else:
logger.info('All tests passed!')
sys.exit(0)

0 comments on commit f15f840

Please sign in to comment.