Skip to content

Commit

Permalink
Feat: support cuda on modal (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
S1ro1 authored Nov 23, 2024
1 parent 593b1d6 commit cda9465
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 5 deletions.
15 changes: 11 additions & 4 deletions src/discord-cluster-manager/cogs/modal_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ async def run_modal(
script: discord.Attachment,
gpu_type: app_commands.Choice[str],
):
if not script.filename.endswith(".py"):
if not script.filename.endswith(".py") and not script.filename.endswith(".cu"):
await interaction.response.send_message(
"Please provide a Python (.py) file"
"Please provide a Python (.py) or CUDA (.cu) file"
)
return

Expand All @@ -55,12 +55,19 @@ async def run_modal(
async def trigger_modal_run(self, script_content: str, filename: str) -> str:
logger.info("Attempting to trigger Modal run")

from modal_runner import modal_app, run_script
from modal_runner import modal_app

try:
print(f"Running {filename} with Modal")
with modal.enable_output():
with modal_app.run():
result = run_script.remote(script_content)
if filename.endswith(".py"):
from modal_runner import run_script

result = run_script.remote(script_content)
elif filename.endswith(".cu"):
from modal_runner import run_cuda_script
result = run_cuda_script.remote(script_content)
return result
except Exception as e:
logger.error(f"Error in trigger_modal_run: {str(e)}", exc_info=True)
Expand Down
49 changes: 48 additions & 1 deletion src/discord-cluster-manager/modal_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def timeout_handler(signum, frame):
signal.signal(signal.SIGALRM, original_handler)



@modal_app.function(
gpu="T4", image=Image.debian_slim(python_version="3.10").pip_install(["torch"])
)
Expand Down Expand Up @@ -66,3 +65,51 @@ def run_script(script_content: str, timeout_seconds: int = 300) -> str:
return f"Error executing script: {str(e)}"
finally:
sys.stdout = sys.__stdout__


@modal_app.function(
gpu="T4",
image=Image.from_registry(
"nvidia/cuda:12.6.0-devel-ubuntu24.04", add_python="3.11"
),
)
def run_cuda_script(script_content: str, timeout_seconds: int = 600) -> str:
import sys
from io import StringIO
import subprocess
import os

output = StringIO()
sys.stdout = output

try:
with timeout(timeout_seconds):
with open("script.cu", "w") as f:
f.write(script_content)

# Compile the CUDA code
compile_process = subprocess.run(
["nvcc", "script.cu", "-o", "script.out"],
capture_output=True,
text=True,
)

if compile_process.returncode != 0:
return f"Compilation Error:\n{compile_process.stderr}"

run_process = subprocess.run(
["./script.out"], capture_output=True, text=True
)

return run_process.stdout

except TimeoutException as e:
return f"Timeout Error: {str(e)}"
except Exception as e:
return f"Error: {str(e)}"
finally:
if os.path.exists("script.cu"):
os.remove("script.cu")
if os.path.exists("script.out"):
os.remove("script.out")
sys.stdout = sys.__stdout__

0 comments on commit cda9465

Please sign in to comment.