Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit - torch cache #2

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions .github/workflows/train_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,55 @@ on:
workflow_dispatch:
inputs:
script_content:
description: 'Content of train.py'
description: 'Content of torch/train.py'
required: true
type: string # Explicitly specify the type
type: string

jobs:
train:
runs-on: ubuntu-latest
permissions:
contents: write
actions: write

steps:
- name: Check out the repository (required by caching step)
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.13'

- name: Cache pip dependencies
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-

- name: List pip cache before install (debug)
run: |
echo "Listing pip cache before installing dependencies:"
ls -lh ~/.cache/pip || echo "No pip cache found."

- name: Install dependencies
run: |
pip install numpy
# pip install torch - need to find a way to cache this otherwise it will take a long time to install
python -m pip install --upgrade pip
pip install torch numpy

- name: List pip cache after install (debug)
run: |
echo "Listing pip cache after installing dependencies:"
ls -lh ~/.cache/pip

- name: Create and run training script
run: |
echo "${{ inputs.script_content }}" > train.py
cat train.py # Debug: print the content
python train.py > training.log 2>&1
echo "${{ inputs.script_content }}" > torch/train.py
cat torch/train.py # Debug: print the content
python torch/train.py > training.log 2>&1

- name: Upload logs
uses: actions/upload-artifact@v3
if: always() # Upload logs whether the job succeeds or fails
Expand Down
153 changes: 89 additions & 64 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,107 +4,132 @@
import time
from datetime import datetime, timezone
import requests
import logging
import traceback
import zipfile

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()
logger.info("Environment variables loaded")

if not os.getenv('DISCORD_TOKEN'):
logger.error("DISCORD_TOKEN not found in environment variables")
raise ValueError("DISCORD_TOKEN not found")
if not os.getenv('GITHUB_TOKEN'):
logger.error("GITHUB_TOKEN not found in environment variables")
raise ValueError("GITHUB_TOKEN not found")
if not os.getenv('GITHUB_REPO'):
logger.error("GITHUB_REPO not found in environment variables")
raise ValueError("GITHUB_REPO not found")

logger.info(f"Using GitHub repo: {os.getenv('GITHUB_REPO')}")

def trigger_github_action():
"""
Triggers the GitHub action and returns the latest run ID
"""
gh = Github(os.getenv('GITHUB_TOKEN'))
repo = gh.get_repo(os.getenv('GITHUB_REPO'))

try:
# Record the time before triggering
trigger_time = datetime.now(timezone.utc)

# Trigger the workflow
workflow = repo.get_workflow("train_workflow.yml")
success = workflow.create_dispatch("main")
response = workflow.create_dispatch("main")

if success:
# Wait a moment for the run to be created
logger.info(f"Workflow found: {workflow}")
logger.info(f"Workflow dispatch response: {response}")

if response:
time.sleep(2)

# Get runs created after our trigger time
runs = list(workflow.get_runs())
for run in runs:
logger.info(f"Checking workflow run with ID: {run.id}, created at: {run.created_at}")
if run.created_at.replace(tzinfo=timezone.utc) > trigger_time:
logger.info(f"Found matching workflow run with ID: {run.id}")
return run.id


logger.warning("Workflow dispatch failed, check permissions.")
return None
except Exception as e:
print(f"Error: {str(e)}")
logger.error(f"Error triggering GitHub Action: {str(e)}")
logger.debug(traceback.format_exc())
return None

def download_artifact(run_id):
"""
Downloads the training log artifact from the workflow run
"""
gh = Github(os.getenv('GITHUB_TOKEN'))
repo = gh.get_repo(os.getenv('GITHUB_REPO'))

# Get the specific run
run = repo.get_workflow_run(run_id)

# Get artifacts from the run
artifacts = run.get_artifacts()

for artifact in artifacts:
if artifact.name == 'training-logs':
# Download the artifact
url = artifact.archive_download_url
headers = {'Authorization': f'token {os.getenv("GITHUB_TOKEN")}'}
response = requests.get(url, headers=headers)

if response.status_code == 200:
with open('training.log.zip', 'wb') as f:
f.write(response.content)

# Read the log file from the zip
import zipfile
with zipfile.ZipFile('training.log.zip') as z:
with z.open('training.log') as f:
logs = f.read().decode('utf-8')
try:
run = repo.get_workflow_run(run_id)
logger.info(f"Fetching artifacts for run ID: {run_id}")

artifacts = run.get_artifacts()
logger.info(f"Found {artifacts.totalCount} artifacts")

for artifact in artifacts:
logger.info(f"Found artifact: {artifact.name}")
if artifact.name == 'training-logs':
url = artifact.archive_download_url
headers = {'Authorization': f'token {os.getenv("GITHUB_TOKEN")}'}
response = requests.get(url, headers=headers)

# Clean up the zip file
os.remove('training.log.zip')
return logs

return "No training logs found in artifacts"
if response.status_code == 200:
logger.info("Successfully downloaded artifact")
with open('training.log.zip', 'wb') as f:
f.write(response.content)

with zipfile.ZipFile('training.log.zip') as z:
with z.open('training.log') as f:
logs = f.read().decode('utf-8')

os.remove('training.log.zip')
return logs
else:
logger.error(f"Failed to download artifact. Status code: {response.status_code}")

logger.warning("No training-logs artifact found")
return "No training logs found in artifacts"
except Exception as e:
logger.error(f"Error downloading artifact: {str(e)}")
logger.debug(traceback.format_exc())
return f"Error downloading artifacts: {str(e)}"

def check_workflow_status(run_id):
"""
Monitors the GitHub Action workflow status
"""
gh = Github(os.getenv('GITHUB_TOKEN'))
repo = gh.get_repo(os.getenv('GITHUB_REPO'))

while True:
run = repo.get_workflow_run(run_id)

if run.status == "completed":
logs = download_artifact(run_id)
return run.conclusion, logs, run.html_url

print(f"Workflow still running... Status: {run.status}")
print(f"Live view: {run.html_url}")
time.sleep(30)
try:
run = repo.get_workflow_run(run_id)
logger.info(f"Current status of run ID {run_id}: {run.status}")

if run.status == "completed":
logger.info("Workflow completed, downloading artifacts")
logs = download_artifact(run_id)
return run.conclusion, logs, run.html_url

logger.info(f"Workflow still running... Status: {run.status}")
logger.info(f"Live view: {run.html_url}")
time.sleep(30)
except Exception as e:
logger.error(f"Error checking workflow status: {str(e)}")
logger.debug(traceback.format_exc())
return "error", str(e), None

if __name__ == "__main__":
run_id = trigger_github_action()

if run_id:
print(f"GitHub Action triggered successfully! Run ID: {run_id}")
print("Monitoring progress...")
logger.info(f"GitHub Action triggered successfully! Run ID: {run_id}")
logger.info("Monitoring progress...")

# Monitor the workflow
status, logs, url = check_workflow_status(run_id)

print(f"\nWorkflow completed with status: {status}")
print("\nTraining Logs:")
print(logs)
print(f"\nView the full run at: {url}")
logger.info(f"\nWorkflow completed with status: {status}")
logger.info(f"\nTraining Logs:\n{logs}")
logger.info(f"\nView the full run at: {url}")
else:
print("Failed to trigger GitHub Action. Please check your configuration.")
logger.error("Failed to trigger GitHub Action. Please check your configuration.")
41 changes: 41 additions & 0 deletions torch/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn as nn
import torch.optim as optim

class NN(nn.Module):
"""
Simple test network
"""
def __init__(self):
super(NN, self).__init__()
self.fully_connected_layer = nn.Linear(10, 32)
self.fully_connected_layer2 = nn.Linear(32, 1)

def forward(self, x: torch.Tensor):
x = torch.relu(self.fully_connected_layer(x))
return self.fully_connected_layer2(x)

def train(epochs: int = 5):
x_train = torch.randn(1000, 10)
y_train = torch.randn(1000, 1)

model = NN()
costfunc = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=9e-1)

for epoch in range(num_epochs):
for i in range(0, len(x_train), 32):
inputs = x_train[i:i+32]
targets = y_train[i:i+32]

outputs = model(inputs)
loss = costfunc(outputs, targets)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"epoch {epoch+1} out of {epochs}, loss: {loss.item()}")

if __name__ == "__main__":
train()
Loading