Skip to content

Commit

Permalink
slash command all
Browse files Browse the repository at this point in the history
  • Loading branch information
msaroufim committed Nov 12, 2024
1 parent 8a88a1c commit c5e3900
Showing 1 changed file with 104 additions and 84 deletions.
188 changes: 104 additions & 84 deletions discord-bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime, timezone
import requests
import discord
from discord import app_commands
import asyncio
import logging
import zipfile
Expand All @@ -23,6 +24,36 @@
load_dotenv()
logger.info("Environment variables loaded")

# Validate environment variables
if not os.getenv('DISCORD_TOKEN'):
logger.error("DISCORD_TOKEN not found in environment variables")
raise ValueError("DISCORD_TOKEN not found")
if not os.getenv('GITHUB_TOKEN'):
logger.error("GITHUB_TOKEN not found in environment variables")
raise ValueError("GITHUB_TOKEN not found")
if not os.getenv('GITHUB_REPO'):
logger.error("GITHUB_REPO not found in environment variables")
raise ValueError("GITHUB_REPO not found")

logger.info(f"Using GitHub repo: {os.getenv('GITHUB_REPO')}")

class ClusterBot(discord.Client):
def __init__(self):
# Initialize with the same intents as before
intents = discord.Intents.default()
intents.message_content = True
super().__init__(intents=intents)

# Create a command tree for slash commands
self.tree = app_commands.CommandTree(self)

async def setup_hook(self):
# This is called when the bot starts up
await self.tree.sync()
logger.info("Slash commands synced")

client = ClusterBot()

def get_github_branch_name():
"""
Runs a git command to determine the remote branch name, to be used in the GitHub Workflow
Expand All @@ -40,25 +71,6 @@ def get_github_branch_name():
logging.warning("Could not determine remote branch, falling back to 'main'")
return 'main'

# Validate environment variables
if not os.getenv('DISCORD_TOKEN'):
logger.error("DISCORD_TOKEN not found in environment variables")
raise ValueError("DISCORD_TOKEN not found")
if not os.getenv('GITHUB_TOKEN'):
logger.error("GITHUB_TOKEN not found in environment variables")
raise ValueError("GITHUB_TOKEN not found")
if not os.getenv('GITHUB_REPO'):
logger.error("GITHUB_REPO not found in environment variables")
raise ValueError("GITHUB_REPO not found")

logger.info(f"Using GitHub repo: {os.getenv('GITHUB_REPO')}")

# Bot setup with minimal intents
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)


async def trigger_github_action(script_content):
"""
Triggers the GitHub action with custom train.py contents
Expand Down Expand Up @@ -164,80 +176,88 @@ async def on_ready():
logger.info(f'Logged in as {client.user}')
for guild in client.guilds:
try:
if globals().get('args') and args.debug: # TODO: Fix Do this properly, maybe subclass `discord.Client` for better argument passing
if globals().get('args') and args.debug:
await guild.me.edit(nick="Cluster Bot (Staging)")
else:
await guild.me.edit(nick="Cluster Bot")
logger.info(f'Updated nickname in guild: {guild.name}')
except Exception as e:
logger.warning(f'Failed to update nickname in guild {guild.name}: {e}')

@client.event
async def on_message(message):
# Ignore messages from the bot itself
if message.author == client.user:
return
@client.tree.command(name="train", description="Start a training job with a train.py file")
async def train_command(interaction: discord.Interaction, script: discord.Attachment):
"""
Slash command to handle training job requests
"""
try:
# Verify the attachment
if script.filename != "train.py":
await interaction.response.send_message(
"Please provide a file named 'train.py' as the attachment.",
ephemeral=True
)
return

# Check if the bot is mentioned and there's an attachment
if client.user in message.mentions:
logger.info(f"Bot mentioned in message with {len(message.attachments)} attachments")
if message.attachments:
for attachment in message.attachments:
logger.info(f"Processing attachment: {attachment.filename}")
if attachment.filename == "train.py":
# Create a thread directly from the original message
thread = await message.create_thread(
name=f"Training Job - {datetime.now().strftime('%Y-%m-%d %H:%M')}",
auto_archive_duration=1440 # Archive after 24 hours of inactivity
)

# Send initial message in the thread
await thread.send("Found train.py! Starting training process...")

try:
# Download the file content
logger.info("Downloading train.py content")
script_content = await attachment.read()
script_content = script_content.decode('utf-8')
logger.info("Successfully read train.py content")

# Trigger GitHub Action
run_id = await trigger_github_action(script_content)

if run_id:
logger.info(f"Successfully triggered workflow with run ID: {run_id}")
await thread.send(f"GitHub Action triggered successfully! Run ID: {run_id}\nMonitoring progress...")

# Monitor the workflow
status, logs, url = await check_workflow_status(run_id, thread)

# Send results back to Discord thread
await thread.send(f"Training completed with status: {status}")

# Split logs if they're too long for Discord's message limit
if len(logs) > 1900:
chunks = [logs[i:i+1900] for i in range(0, len(logs), 1900)]
for i, chunk in enumerate(chunks):
await thread.send(f"```\nLogs (part {i+1}/{len(chunks)}):\n{chunk}\n```")
else:
await thread.send(f"```\nLogs:\n{logs}\n```")

if url:
await thread.send(f"View the full run at: {url}")
else:
logger.error("Failed to trigger GitHub Action")
await thread.send("Failed to trigger GitHub Action. Please check the configuration.")

except Exception as e:
logger.error(f"Error processing request: {str(e)}", exc_info=True)
await thread.send(f"Error processing request: {str(e)}")

break
# Create a thread
thread = await interaction.channel.create_thread(
name=f"Training Job - {datetime.now().strftime('%Y-%m-%d %H:%M')}",
auto_archive_duration=1440 # Archive after 24 hours of inactivity
)

# Acknowledge the command
await interaction.response.send_message(
f"Training job initiated! Following the progress in thread: {thread.jump_url}",
ephemeral=True
)

if not any(att.filename == "train.py" for att in message.attachments):
await message.reply("Please attach a file named 'train.py' to your message.")
# Send initial message in thread
await thread.send("Starting training process...")

try:
# Download the file content
script_content = await script.read()
script_content = script_content.decode('utf-8')
logger.info("Successfully read train.py content")

# Trigger GitHub Action
run_id = await trigger_github_action(script_content)

if run_id:
logger.info(f"Successfully triggered workflow with run ID: {run_id}")
await thread.send(f"GitHub Action triggered successfully! Run ID: {run_id}\nMonitoring progress...")

# Monitor the workflow
status, logs, url = await check_workflow_status(run_id, thread)

# Send results back to Discord thread
await thread.send(f"Training completed with status: {status}")

# Split logs if they're too long for Discord's message limit
if len(logs) > 1900:
chunks = [logs[i:i+1900] for i in range(0, len(logs), 1900)]
for i, chunk in enumerate(chunks):
await thread.send(f"```\nLogs (part {i+1}/{len(chunks)}):\n{chunk}\n```")
else:
await thread.send(f"```\nLogs:\n{logs}\n```")

if url:
await thread.send(f"View the full run at: {url}")
else:
logger.error("Failed to trigger GitHub Action")
await thread.send("Failed to trigger GitHub Action. Please check the configuration.")

except Exception as e:
logger.error(f"Error processing request: {str(e)}", exc_info=True)
await thread.send(f"Error processing request: {str(e)}")

except Exception as e:
logger.error(f"Error in train command: {str(e)}", exc_info=True)
if not interaction.response.is_done():
await interaction.response.send_message(
f"An error occurred while processing your request: {str(e)}",
ephemeral=True
)

# Run the bot
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run the Discord Cluster Bot')
parser.add_argument('--debug', action='store_true', help='Run in debug/staging mode')
Expand All @@ -253,4 +273,4 @@ async def on_message(message):
else:
token = os.getenv('DISCORD_TOKEN')

client.run(token)
client.run(token)

0 comments on commit c5e3900

Please sign in to comment.