diff --git a/.github/workflows/train_workflow.yml b/.github/workflows/train_workflow.yml index c42c554..5255892 100644 --- a/.github/workflows/train_workflow.yml +++ b/.github/workflows/train_workflow.yml @@ -24,13 +24,6 @@ jobs: pip install numpy # Add other Python dependencies as needed - - name: Install CUDA dependencies - if: inputs.script_type == 'cu' - run: | - sudo apt-get update - sudo apt-get install -y nvidia-cuda-toolkit - nvcc --version - - name: Debug Step if: inputs.script_type == 'cu' run: | @@ -40,6 +33,14 @@ jobs: EOL cat train.${{ inputs.script_type }} + - name: Install CUDA dependencies + if: inputs.script_type == 'cu' + run: | + sudo apt-get update + sudo apt-get install -y nvidia-cuda-toolkit + nvcc --version + + - name: Create training script if : inputs.script_type == 'py' # TODO: remove later run: |