diff --git a/.github/workflows/train_workflow.yml b/.github/workflows/train_workflow.yml index 53d55c2..5659198 100644 --- a/.github/workflows/train_workflow.yml +++ b/.github/workflows/train_workflow.yml @@ -1,30 +1,35 @@ -name: Training Workflow +name: AMD PyTorch Job + on: workflow_dispatch: inputs: script_content: description: 'Content of train.py' - required: true - type: string # Explicitly specify the type + required: false + type: string jobs: train: - runs-on: ubuntu-latest + runs-on: [amdgpu-mi250-x86-64] + steps: - - name: Install dependencies - run: | - pip install numpy - # pip install torch - need to find a way to cache this otherwise it will take a long time to install + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: '3.10' # or your preferred version + + - name: Install dependencies + run: | + pip install numpy + pip install torch --index-url https://download.pytorch.org/whl/rocm6.2 + + - name: Create and run training script + run: | + python train.py > training.log 2>&1 - - name: Create and run training script - run: | - echo "${{ inputs.script_content }}" > train.py - cat train.py # Debug: print the content - python train.py > training.log 2>&1 - - - name: Upload logs - uses: actions/upload-artifact@v3 - if: always() # Upload logs whether the job succeeds or fails - with: - name: training-logs - path: training.log \ No newline at end of file + - name: Upload logs + uses: actions/upload-artifact@v3 + if: always() # Upload logs whether the job succeeds or fails + with: + name: training-logs + path: training.log \ No newline at end of file diff --git a/train.py b/train.py index cd867ab..d019af1 100644 --- a/train.py +++ b/train.py @@ -1,8 +1,12 @@ -import numpy +import torch -a = numpy.array([1, 2, 3]) -b = numpy.array([4, 5, 6]) +a = torch.Tensor([1, 2, 3, 4, 5]).to('cuda') +b= torch.Tensor([1, 2, 3, 4, 5]).to('cuda') -c = a + b +if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name(0) + print(f"GPU Name: {gpu_name}") +else: + print("No GPU available") -print(c) \ No newline at end of file +print(a + b) \ No newline at end of file