From efa224c4963f673f47bd3f3652780132b66f6918 Mon Sep 17 00:00:00 2001 From: Ben Horowitz Date: Sat, 16 Nov 2024 11:59:05 -0800 Subject: [PATCH] Add smoke test for discord-bot.py --- README.md | 2 + discord-bot.py | 5 +- tests/discord-bot-smoke-test.py | 149 ++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 tests/discord-bot-smoke-test.py diff --git a/README.md b/README.md index 7cc3e5a..512737f 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,8 @@ Every triggered job is containerized so we don't have to worry too much about se Instead of testing on GPU MODE directly we can leverage a staging environment called "Discord Cluster Staging". If you need access to this server please ping "Seraphim", however, you can also test the bot on your own server by following the instructions below. +The smoke test script in `tests/discord-bot-smoke-test.py` should be run to verify basic functionality of the cluster bot. For usage information, run with `python tests/discord-bot-smoke-test.py -h`. + ### How to add the bot to a personal server For testing purposes, bot can be run on a personal server as well. Follow the steps [here](https://discordjs.guide/preparations/setting-up-a-bot-application.html#creating-your-bot) and [here](https://discordjs.guide/preparations/adding-your-bot-to-servers.html#bot-invite-links) to create a bot application and then add it to your server. diff --git a/discord-bot.py b/discord-bot.py index 022646d..6345daf 100644 --- a/discord-bot.py +++ b/discord-bot.py @@ -95,7 +95,7 @@ async def trigger_github_action(script_content): logger.error(f"Error in trigger_github_action: {str(e)}", exc_info=True) return None -async def download_artifact(run_id): +async def download_artifact(run_id, thread): """ Downloads the training log artifact from the workflow run """ @@ -131,6 +131,7 @@ async def download_artifact(run_id): if log_file: with z.open(log_file) as f: logs = f.read().decode('utf-8') + await thread.send(f"Located training.log.") else: logs = "training.log file not found in artifact" @@ -161,7 +162,7 @@ async def check_workflow_status(run_id, thread): if run.status == "completed": logger.info("Workflow completed, downloading artifacts") - logs = await download_artifact(run_id) + logs = await download_artifact(run_id, thread) return run.conclusion, logs, run.html_url await thread.send(f"Workflow still running... Status: {run.status}\nLive view: {run.html_url}") diff --git a/tests/discord-bot-smoke-test.py b/tests/discord-bot-smoke-test.py new file mode 100644 index 0000000..c3e9d58 --- /dev/null +++ b/tests/discord-bot-smoke-test.py @@ -0,0 +1,149 @@ +import discord +import logging +import os +import argparse +import asyncio + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger(__name__) + +parser = argparse.ArgumentParser( + description='Smoke Test for the Discord Cluster Bot', + epilog=f""" + This script can be used after deployment, or during development, to quickly + verify that basic functionality of the cluster bot is working correctly. + It should be run before further testing or usage of the bot. + + Example usage: + python {os.path.basename(__file__)} https://discord.com/channels/123/456/789 + The URL is the message link for some message that triggered the cluster bot. + To find this URL: click the 3 dots (...) to the right of the message, + then click 'Copy Message Link'.""", + formatter_class=argparse.RawTextHelpFormatter) +parser.add_argument('message_url', type=str, help='Discord message URL to test') +args = parser.parse_args() + +message_id = int(args.message_url.split('/')[-1]) + +# Client setup with minimal intents +intents = discord.Intents.default() +intents.message_content = True +client = discord.Client(intents=intents) + +# Event that signals when async client tests are done +client_tests_done = asyncio.Event() + +# Flag set to true if the thread tests pass +thread_tests_passed = False + +async def test_thread_messages(client, message_id): + """ + Test messages from a Discord thread identified by a message ID. + + Args: + client (discord.Client): the Discord client + message_id (int): the ID of the message under which to find thread messsages + + Side effect: + - Sets thread_tests_passed to True if all happy path messages are found + """ + + global thread_tests_passed + + required_strings = [ + "Found train.py! Starting training process...", + "GitHub Action triggered successfully! Run ID:", + "Training completed with status: success", + "Located training.log", + "View the full run at:", + ] + + message_contents = [] + thread_found = False + + # Iterate through guilds to find the thread by message ID + for guild in client.guilds: + try: + # Search for the thread using the message ID + for channel in guild.text_channels: + try: + message = await channel.fetch_message(message_id) + if message.thread: + thread_found = True + thread = message.thread + logger.info(f"Found thread: {thread.name}.") + + # Fetch messages from the thread + message_contents = [ + msg.content async for msg in thread.history(limit=None) + ] + break + + except discord.NotFound: + continue + except discord.Forbidden: + logger.warning(f"Bot does not have permission to access {channel.name}") + continue + + except Exception as e: + logger.error(f"Error fetching thread: {e}", exc_info=True) + + if thread_found: + # Already found the thread, so no need to continue to iterate + # through guilds + break + + if message_contents: + all_strings_found = all( + any(req_str in contents for contents in message_contents) + for req_str in required_strings + ) + + if all_strings_found: + thread_tests_passed = True + else: + logger.warning("Thread not found!") + + if thread_tests_passed: + logger.info('All required strings were found in the thread.') + else: + logger.warning('Some required string was not found in the thread!') + logger.info('Thread contents were: ') + logger.info('\n'.join(f'\t{contents}' for contents in message_contents)) + +@client.event +async def on_ready(): + await test_thread_messages(client, message_id) + + # We could add additional tests that use the client here if needed. + + client_tests_done.set() + await client.close() + +if __name__ == '__main__': + logger.info("Running smoke tests...") + + token = os.getenv('DISCORD_TOKEN') + if not token: + logger.error('DISCORD_TOKEN environment variable not set.') + exit(1) + + client.run(token) + + async def await_client_tests(): + await client_tests_done.wait() + + asyncio.run(await_client_tests()) + + if not thread_tests_passed: + # If other tests are needed, add `... and other_test_passed` above. + logger.warning("One or more tests failed!") + exit(1) + else: + logger.info('All tests passed!') + exit(0) \ No newline at end of file