From 9bebbff1886fb5616b7241968cafe853c289f4a4 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 4 Nov 2024 17:38:37 -0800 Subject: [PATCH] Add Discord scheduler --- .github/workflows/train_workflow.yml | 20 +-- README.md | 12 +- discord-bot.py | 219 +++++++++++++++++++++++++++ requirements.txt | 3 +- 4 files changed, 241 insertions(+), 13 deletions(-) create mode 100644 discord-bot.py diff --git a/.github/workflows/train_workflow.yml b/.github/workflows/train_workflow.yml index 16b54ec..1b7d44b 100644 --- a/.github/workflows/train_workflow.yml +++ b/.github/workflows/train_workflow.yml @@ -1,26 +1,26 @@ name: Training Workflow on: workflow_dispatch: + inputs: + script_content: + description: 'Content of train.py' + required: true + type: string # Explicitly specify the type jobs: train: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.x' - - name: Install dependencies run: | pip install -r ci-requirements.txt - - - name: Run training + + - name: Create and run training script run: | + echo "${{ inputs.script_content }}" > train.py + cat train.py # Debug: print the content python train.py > training.log 2>&1 - + - name: Upload logs uses: actions/upload-artifact@v3 if: always() # Upload logs whether the job succeeds or fails diff --git a/README.md b/README.md index 1a68019..a346ba3 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is the code for the Discord bot we'll be using to queue jobs to a cluster of GPUs that our generous sponsors have provided. -The key idea is that we're using Github Actions as a job scheduling engine and primarily making the Discord bot interact with the cluster via issuing Github Actions and and monitoring their status +The key idea is that we're using Github Actions as a job scheduling engine and primarily making the Discord bot interact with the cluster via issuing Github Actions and and monitoring their status and while we're focused on having a nice user experience on discord.gg/gpumode, we're happy to accept PRs that make it easier for other Discord communities to hook GPUs. ## How to run the bot locally @@ -20,10 +20,18 @@ Every triggered job is containerized so we don't have to worry too much about se Instead of testing on GPU MODE directly we can leverage a staging environment called "Discord Cluster Staging". If you need access to this server please ping "Seraphim" +Bot needs to be invited using an oauth2 token and needs the `Message Content Intent` permission + +The bot also needs to permissions to read and write messages which is easy to setup if you click on https://discord.com/api/oauth2/authorize?client_id=1303135152091697183&permissions=68608&scope=bot%20applications.commands + ### How to add a new GPU to the cluster Github has some nice instructions here https://docs.github.com/en/actions/hosting-your-own-runners/managing-self-hosted-runners/adding-self-hosted-runners but essentially the whole thing works by running a script on some GPU people own. ### Future work * Maybe we shouldn't use Github Action and can roll our own thing? -* Make registering new GPUs simpler \ No newline at end of file +* Make registering new GPUs simpler + +## Acknowledgements +* Luca Antiga did something very similar for the NeurIPS LLM efficiency competition, it was great! +* Midjourney was a similar inspiration in terms of UX \ No newline at end of file diff --git a/discord-bot.py b/discord-bot.py new file mode 100644 index 0000000..3c6df2c --- /dev/null +++ b/discord-bot.py @@ -0,0 +1,219 @@ +from dotenv import load_dotenv +from github import Github +import os +import time +from datetime import datetime, timezone +import requests +import discord +import asyncio +import logging + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger(__name__) + +# Load environment variables +load_dotenv() +logger.info("Environment variables loaded") + +# Validate environment variables +if not os.getenv('DISCORD_TOKEN'): + logger.error("DISCORD_TOKEN not found in environment variables") + raise ValueError("DISCORD_TOKEN not found") +if not os.getenv('GITHUB_TOKEN'): + logger.error("GITHUB_TOKEN not found in environment variables") + raise ValueError("GITHUB_TOKEN not found") +if not os.getenv('GITHUB_REPO'): + logger.error("GITHUB_REPO not found in environment variables") + raise ValueError("GITHUB_REPO not found") + +logger.info(f"Using GitHub repo: {os.getenv('GITHUB_REPO')}") + +# Bot setup with minimal intents +intents = discord.Intents.default() +intents.message_content = True +client = discord.Client(intents=intents) + +async def trigger_github_action(script_content): + """ + Triggers the GitHub action with custom train.py contents + """ + logger.info("Attempting to trigger GitHub action") + gh = Github(os.getenv('GITHUB_TOKEN')) + repo = gh.get_repo(os.getenv('GITHUB_REPO')) + + try: + # Record the time before triggering + trigger_time = datetime.now(timezone.utc) + + # Log workflow attempt + logger.info(f"Looking for workflow 'train_workflow.yml' in repo {os.getenv('GITHUB_REPO')}") + + # Trigger the workflow with the script content + workflow = repo.get_workflow("train_workflow.yml") + logger.info("Found workflow, attempting to dispatch") + + success = workflow.create_dispatch("main", {'script_content': script_content}) + logger.info(f"Workflow dispatch result: {success}") + + if success: + # Wait a moment for the run to be created + await asyncio.sleep(2) + + # Get runs created after our trigger time + runs = list(workflow.get_runs()) + logger.info(f"Found {len(runs)} total runs") + + for run in runs: + logger.info(f"Checking run {run.id} created at {run.created_at}") + if run.created_at.replace(tzinfo=timezone.utc) > trigger_time: + logger.info(f"Found matching run with ID: {run.id}") + return run.id + + logger.warning("No matching runs found after trigger") + return None + + except Exception as e: + logger.error(f"Error in trigger_github_action: {str(e)}", exc_info=True) + return None + +async def download_artifact(run_id): + """ + Downloads the training log artifact from the workflow run + """ + logger.info(f"Attempting to download artifacts for run {run_id}") + gh = Github(os.getenv('GITHUB_TOKEN')) + repo = gh.get_repo(os.getenv('GITHUB_REPO')) + + try: + # Get the specific run + run = repo.get_workflow_run(run_id) + + # Get artifacts from the run + artifacts = run.get_artifacts() + logger.info(f"Found {artifacts.totalCount} artifacts") + + for artifact in artifacts: + logger.info(f"Found artifact: {artifact.name}") + if artifact.name == 'training-logs': + # Download the artifact + url = artifact.archive_download_url + headers = {'Authorization': f'token {os.getenv("GITHUB_TOKEN")}'} + response = requests.get(url, headers=headers) + + if response.status_code == 200: + logger.info("Successfully downloaded artifact") + with open('training.log.zip', 'wb') as f: + f.write(response.content) + + # Read the log file from the zip + with zipfile.ZipFile('training.log.zip') as z: + with z.open('training.log') as f: + logs = f.read().decode('utf-8') + + # Clean up the zip file + os.remove('training.log.zip') + return logs + else: + logger.error(f"Failed to download artifact. Status code: {response.status_code}") + + logger.warning("No training-logs artifact found") + return "No training logs found in artifacts" + except Exception as e: + logger.error(f"Error in download_artifact: {str(e)}", exc_info=True) + return f"Error downloading artifacts: {str(e)}" + +async def check_workflow_status(run_id, message): + """ + Monitors the GitHub Action workflow status and updates Discord + """ + logger.info(f"Starting to monitor workflow status for run {run_id}") + gh = Github(os.getenv('GITHUB_TOKEN')) + repo = gh.get_repo(os.getenv('GITHUB_REPO')) + + while True: + try: + run = repo.get_workflow_run(run_id) + logger.info(f"Current status: {run.status}") + + if run.status == "completed": + logger.info("Workflow completed, downloading artifacts") + logs = await download_artifact(run_id) + return run.conclusion, logs, run.html_url + + await message.channel.send(f"Workflow still running... Status: {run.status}\nLive view: {run.html_url}") + await asyncio.sleep(30) + except Exception as e: + logger.error(f"Error in check_workflow_status: {str(e)}", exc_info=True) + return "error", str(e), None + +@client.event +async def on_ready(): + logger.info(f'Logged in as {client.user}') + +@client.event +async def on_message(message): + # Ignore messages from the bot itself + if message.author == client.user: + return + + # Check if the bot is mentioned and there's an attachment + if client.user in message.mentions: + logger.info(f"Bot mentioned in message with {len(message.attachments)} attachments") + if message.attachments: + for attachment in message.attachments: + logger.info(f"Processing attachment: {attachment.filename}") + if attachment.filename == "train.py": + await message.channel.send("Found train.py! Starting training process...") + + try: + # Download the file content + logger.info("Downloading train.py content") + script_content = await attachment.read() + script_content = script_content.decode('utf-8') + logger.info("Successfully read train.py content") + + # Trigger GitHub Action + run_id = await trigger_github_action(script_content) + + if run_id: + logger.info(f"Successfully triggered workflow with run ID: {run_id}") + await message.channel.send(f"GitHub Action triggered successfully! Run ID: {run_id}\nMonitoring progress...") + + # Monitor the workflow + status, logs, url = await check_workflow_status(run_id, message) + + # Send results back to Discord + await message.channel.send(f"Training completed with status: {status}") + + # Split logs if they're too long for Discord's message limit + if len(logs) > 1900: + chunks = [logs[i:i+1900] for i in range(0, len(logs), 1900)] + for i, chunk in enumerate(chunks): + await message.channel.send(f"```\nLogs (part {i+1}/{len(chunks)}):\n{chunk}\n```") + else: + await message.channel.send(f"```\nLogs:\n{logs}\n```") + + if url: + await message.channel.send(f"View the full run at: {url}") + else: + logger.error("Failed to trigger GitHub Action") + await message.channel.send("Failed to trigger GitHub Action. Please check the configuration.") + + except Exception as e: + logger.error(f"Error processing request: {str(e)}", exc_info=True) + await message.channel.send(f"Error processing request: {str(e)}") + + break + + if not any(att.filename == "train.py" for att in message.attachments): + await message.channel.send("Please attach a file named 'train.py' to your message.") + +# Run the bot +if __name__ == "__main__": + logger.info("Starting bot...") + client.run(os.getenv('DISCORD_TOKEN')) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d8bd49e..944b75c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ PyGithub aiohttp -discord +discord.py +audioop-lts # discord.py imports using * syntax python-dotenv requests \ No newline at end of file