From c5e3900c52ac2f331f17d76b7d012f4833cc90ca Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Mon, 11 Nov 2024 18:02:26 -0800 Subject: [PATCH] slash command all --- discord-bot.py | 188 +++++++++++++++++++++++++++---------------------- 1 file changed, 104 insertions(+), 84 deletions(-) diff --git a/discord-bot.py b/discord-bot.py index 29ca610..00f397e 100644 --- a/discord-bot.py +++ b/discord-bot.py @@ -5,6 +5,7 @@ from datetime import datetime, timezone import requests import discord +from discord import app_commands import asyncio import logging import zipfile @@ -23,6 +24,36 @@ load_dotenv() logger.info("Environment variables loaded") +# Validate environment variables +if not os.getenv('DISCORD_TOKEN'): + logger.error("DISCORD_TOKEN not found in environment variables") + raise ValueError("DISCORD_TOKEN not found") +if not os.getenv('GITHUB_TOKEN'): + logger.error("GITHUB_TOKEN not found in environment variables") + raise ValueError("GITHUB_TOKEN not found") +if not os.getenv('GITHUB_REPO'): + logger.error("GITHUB_REPO not found in environment variables") + raise ValueError("GITHUB_REPO not found") + +logger.info(f"Using GitHub repo: {os.getenv('GITHUB_REPO')}") + +class ClusterBot(discord.Client): + def __init__(self): + # Initialize with the same intents as before + intents = discord.Intents.default() + intents.message_content = True + super().__init__(intents=intents) + + # Create a command tree for slash commands + self.tree = app_commands.CommandTree(self) + + async def setup_hook(self): + # This is called when the bot starts up + await self.tree.sync() + logger.info("Slash commands synced") + +client = ClusterBot() + def get_github_branch_name(): """ Runs a git command to determine the remote branch name, to be used in the GitHub Workflow @@ -40,25 +71,6 @@ def get_github_branch_name(): logging.warning("Could not determine remote branch, falling back to 'main'") return 'main' -# Validate environment variables -if not os.getenv('DISCORD_TOKEN'): - logger.error("DISCORD_TOKEN not found in environment variables") - raise ValueError("DISCORD_TOKEN not found") -if not os.getenv('GITHUB_TOKEN'): - logger.error("GITHUB_TOKEN not found in environment variables") - raise ValueError("GITHUB_TOKEN not found") -if not os.getenv('GITHUB_REPO'): - logger.error("GITHUB_REPO not found in environment variables") - raise ValueError("GITHUB_REPO not found") - -logger.info(f"Using GitHub repo: {os.getenv('GITHUB_REPO')}") - -# Bot setup with minimal intents -intents = discord.Intents.default() -intents.message_content = True -client = discord.Client(intents=intents) - - async def trigger_github_action(script_content): """ Triggers the GitHub action with custom train.py contents @@ -164,7 +176,7 @@ async def on_ready(): logger.info(f'Logged in as {client.user}') for guild in client.guilds: try: - if globals().get('args') and args.debug: # TODO: Fix Do this properly, maybe subclass `discord.Client` for better argument passing + if globals().get('args') and args.debug: await guild.me.edit(nick="Cluster Bot (Staging)") else: await guild.me.edit(nick="Cluster Bot") @@ -172,72 +184,80 @@ async def on_ready(): except Exception as e: logger.warning(f'Failed to update nickname in guild {guild.name}: {e}') -@client.event -async def on_message(message): - # Ignore messages from the bot itself - if message.author == client.user: - return +@client.tree.command(name="train", description="Start a training job with a train.py file") +async def train_command(interaction: discord.Interaction, script: discord.Attachment): + """ + Slash command to handle training job requests + """ + try: + # Verify the attachment + if script.filename != "train.py": + await interaction.response.send_message( + "Please provide a file named 'train.py' as the attachment.", + ephemeral=True + ) + return - # Check if the bot is mentioned and there's an attachment - if client.user in message.mentions: - logger.info(f"Bot mentioned in message with {len(message.attachments)} attachments") - if message.attachments: - for attachment in message.attachments: - logger.info(f"Processing attachment: {attachment.filename}") - if attachment.filename == "train.py": - # Create a thread directly from the original message - thread = await message.create_thread( - name=f"Training Job - {datetime.now().strftime('%Y-%m-%d %H:%M')}", - auto_archive_duration=1440 # Archive after 24 hours of inactivity - ) - - # Send initial message in the thread - await thread.send("Found train.py! Starting training process...") - - try: - # Download the file content - logger.info("Downloading train.py content") - script_content = await attachment.read() - script_content = script_content.decode('utf-8') - logger.info("Successfully read train.py content") - - # Trigger GitHub Action - run_id = await trigger_github_action(script_content) - - if run_id: - logger.info(f"Successfully triggered workflow with run ID: {run_id}") - await thread.send(f"GitHub Action triggered successfully! Run ID: {run_id}\nMonitoring progress...") - - # Monitor the workflow - status, logs, url = await check_workflow_status(run_id, thread) - - # Send results back to Discord thread - await thread.send(f"Training completed with status: {status}") - - # Split logs if they're too long for Discord's message limit - if len(logs) > 1900: - chunks = [logs[i:i+1900] for i in range(0, len(logs), 1900)] - for i, chunk in enumerate(chunks): - await thread.send(f"```\nLogs (part {i+1}/{len(chunks)}):\n{chunk}\n```") - else: - await thread.send(f"```\nLogs:\n{logs}\n```") - - if url: - await thread.send(f"View the full run at: {url}") - else: - logger.error("Failed to trigger GitHub Action") - await thread.send("Failed to trigger GitHub Action. Please check the configuration.") - - except Exception as e: - logger.error(f"Error processing request: {str(e)}", exc_info=True) - await thread.send(f"Error processing request: {str(e)}") - - break + # Create a thread + thread = await interaction.channel.create_thread( + name=f"Training Job - {datetime.now().strftime('%Y-%m-%d %H:%M')}", + auto_archive_duration=1440 # Archive after 24 hours of inactivity + ) + + # Acknowledge the command + await interaction.response.send_message( + f"Training job initiated! Following the progress in thread: {thread.jump_url}", + ephemeral=True + ) - if not any(att.filename == "train.py" for att in message.attachments): - await message.reply("Please attach a file named 'train.py' to your message.") + # Send initial message in thread + await thread.send("Starting training process...") + + try: + # Download the file content + script_content = await script.read() + script_content = script_content.decode('utf-8') + logger.info("Successfully read train.py content") + + # Trigger GitHub Action + run_id = await trigger_github_action(script_content) + + if run_id: + logger.info(f"Successfully triggered workflow with run ID: {run_id}") + await thread.send(f"GitHub Action triggered successfully! Run ID: {run_id}\nMonitoring progress...") + + # Monitor the workflow + status, logs, url = await check_workflow_status(run_id, thread) + + # Send results back to Discord thread + await thread.send(f"Training completed with status: {status}") + + # Split logs if they're too long for Discord's message limit + if len(logs) > 1900: + chunks = [logs[i:i+1900] for i in range(0, len(logs), 1900)] + for i, chunk in enumerate(chunks): + await thread.send(f"```\nLogs (part {i+1}/{len(chunks)}):\n{chunk}\n```") + else: + await thread.send(f"```\nLogs:\n{logs}\n```") + + if url: + await thread.send(f"View the full run at: {url}") + else: + logger.error("Failed to trigger GitHub Action") + await thread.send("Failed to trigger GitHub Action. Please check the configuration.") + + except Exception as e: + logger.error(f"Error processing request: {str(e)}", exc_info=True) + await thread.send(f"Error processing request: {str(e)}") + + except Exception as e: + logger.error(f"Error in train command: {str(e)}", exc_info=True) + if not interaction.response.is_done(): + await interaction.response.send_message( + f"An error occurred while processing your request: {str(e)}", + ephemeral=True + ) -# Run the bot if __name__ == "__main__": parser = argparse.ArgumentParser(description='Run the Discord Cluster Bot') parser.add_argument('--debug', action='store_true', help='Run in debug/staging mode') @@ -253,4 +273,4 @@ async def on_message(message): else: token = os.getenv('DISCORD_TOKEN') - client.run(token) + client.run(token) \ No newline at end of file