-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
152 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import discord | ||
import logging | ||
import os | ||
import argparse | ||
import asyncio | ||
import re | ||
|
||
# Set up logging | ||
logging.basicConfig( | ||
level=logging.INFO, | ||
format='%(asctime)s - %(levelname)s - %(message)s', | ||
datefmt='%Y-%m-%d %H:%M:%S' | ||
) | ||
logger = logging.getLogger(__name__) | ||
|
||
parser = argparse.ArgumentParser( | ||
description='Smoke Test for the Discord Cluster Bot', | ||
epilog=f""" | ||
This script can be used after deployment, or during development, to quickly | ||
verify that basic functionality of the cluster bot is working correctly. | ||
It should be run before further testing or usage of the bot. | ||
Example usage: | ||
python {os.path.basename(__file__)} https://discord.com/channels/123/456/789 | ||
The URL is the message link for some message that triggered the cluster bot. | ||
To find this URL: click the 3 dots (...) to the right of the message, | ||
then click 'Copy Message Link'.""", | ||
formatter_class=argparse.RawTextHelpFormatter) | ||
parser.add_argument('message_url', type=str, help='Discord message URL to test') | ||
args = parser.parse_args() | ||
|
||
message_id = int(args.message_url.split('/')[-1]) | ||
|
||
# Client setup with minimal intents | ||
intents = discord.Intents.default() | ||
intents.message_content = True | ||
client = discord.Client(intents=intents) | ||
|
||
# Event that signals when async client tests are done | ||
client_tests_done = asyncio.Event() | ||
|
||
# Flag set to true if the thread tests pass | ||
thread_tests_passed = False | ||
|
||
async def test_thread_messages(client, message_id): | ||
""" | ||
Test messages from a Discord thread identified by a message ID. | ||
Args: | ||
client (discord.Client): the Discord client | ||
message_id (int): the ID of the message under which to find thread messsages | ||
Side effect: | ||
- Sets thread_tests_passed to True if all happy path messages are found | ||
""" | ||
|
||
global thread_tests_passed | ||
|
||
required_strings = [ | ||
"Found train.py! Starting training process", | ||
"GitHub Action triggered successfully! Run ID:", | ||
"Training completed with status: success", | ||
".*\nLogs.*:", | ||
"View the full run at:", | ||
] | ||
|
||
message_contents = [] | ||
thread_found = False | ||
|
||
# Iterate through guilds to find the thread by message ID | ||
for guild in client.guilds: | ||
try: | ||
# Search for the thread using the message ID | ||
for channel in guild.text_channels: | ||
try: | ||
message = await channel.fetch_message(message_id) | ||
if message.thread: | ||
thread_found = True | ||
thread = message.thread | ||
logger.info(f"Found thread: {thread.name}.") | ||
|
||
# Fetch messages from the thread | ||
message_contents = [ | ||
msg.content async for msg in thread.history(limit=None) | ||
] | ||
break | ||
|
||
except discord.NotFound: | ||
continue | ||
except discord.Forbidden: | ||
logger.warning(f"Bot does not have permission to access {channel.name}") | ||
continue | ||
|
||
except Exception as e: | ||
logger.error(f"Error fetching thread: {e}", exc_info=True) | ||
|
||
if thread_found: | ||
# Already found the thread, so no need to continue to iterate | ||
# through guilds | ||
break | ||
|
||
if message_contents: | ||
all_strings_found = all( | ||
any(re.match(req_str, contents, re.DOTALL) != None for contents in message_contents) | ||
for req_str in required_strings | ||
) | ||
|
||
if all_strings_found: | ||
thread_tests_passed = True | ||
else: | ||
logger.warning("Thread not found!") | ||
|
||
if thread_tests_passed: | ||
logger.info('All required strings were found in the thread.') | ||
else: | ||
logger.warning('Some required string was not found in the thread!') | ||
logger.info('Thread contents were: ') | ||
logger.info('\n'.join(f'\t{contents}' for contents in message_contents)) | ||
|
||
@client.event | ||
async def on_ready(): | ||
await test_thread_messages(client, message_id) | ||
|
||
# We could add additional tests that use the client here if needed. | ||
|
||
client_tests_done.set() | ||
await client.close() | ||
|
||
if __name__ == '__main__': | ||
logger.info("Running smoke tests...") | ||
|
||
token = os.getenv('DISCORD_TOKEN') | ||
if not token: | ||
logger.error('DISCORD_TOKEN environment variable not set.') | ||
exit(1) | ||
|
||
client.run(token) | ||
|
||
async def await_client_tests(): | ||
await client_tests_done.wait() | ||
|
||
asyncio.run(await_client_tests()) | ||
|
||
if not thread_tests_passed: | ||
# If other tests are needed, add `... and other_test_passed` above. | ||
logger.warning("One or more tests failed!") | ||
exit(1) | ||
else: | ||
logger.info('All tests passed!') | ||
exit(0) |