diff --git a/.github/workflows/docker.yaml b/.github/workflows/docker.yaml index 2c4b2adc..df915ad2 100644 --- a/.github/workflows/docker.yaml +++ b/.github/workflows/docker.yaml @@ -5,6 +5,7 @@ on: - cron: '0 15 * * *' pull_request: paths: + - .github/workflows/docker.yaml - docker/*.dockerfile workflow_dispatch: inputs: @@ -12,7 +13,6 @@ on: description: "PyTorch nightly version" required: false env: - WITH_PUSH: "true" CONDA_ENV: "tritonbench" DOCKER_IMAGE: "ghcr.io/pytorch-labs/tritonbench:latest" SETUP_SCRIPT: "/workspace/setup_instance.sh" @@ -20,14 +20,14 @@ env: jobs: build-push-docker: if: ${{ github.repository_owner == 'pytorch-labs' }} - runs-on: ubuntu-latest + runs-on: 32-core-ubuntu steps: - name: Checkout uses: actions/checkout@v3 with: path: tritonbench - name: Login to GitHub Container Registry - if: ${{ env.WITH_PUSH == 'true' }} + if: github.event_name != 'pull_request' uses: docker/login-action@v2 with: registry: ghcr.io @@ -38,9 +38,9 @@ jobs: set -x export NIGHTLY_DATE="${{ github.event.inputs.nightly_date }}" cd tritonbench/docker - full_ref="${{ github.ref }}" - prefix="refs/heads/" - branch_name=${full_ref#$prefix} + # branch name is github.head_ref when triggered by pull_request + # and it is github.ref_name when triggered by workflow_dispatch + branch_name=${{ github.head_ref || github.ref_name }} docker build . --build-arg TRITONBENCH_BRANCH="${branch_name}" --build-arg FORCE_DATE="${NIGHTLY_DATE}" \ -f tritonbench-nightly.dockerfile -t ghcr.io/pytorch-labs/tritonbench:latest # Extract pytorch version from the docker @@ -48,7 +48,7 @@ jobs: export DOCKER_TAG=$(awk '{match($0, /dev[0-9]+/, arr); print arr[0]}' <<< "${PYTORCH_VERSION}") docker tag ghcr.io/pytorch-labs/tritonbench:latest ghcr.io/pytorch-labs/tritonbench:${DOCKER_TAG} - name: Push docker to remote - if: ${{ env.WITH_PUSH == 'true' }} + if: github.event_name != 'pull_request' run: | # Extract pytorch version from the docker PYTORCH_VERSION=$(docker run -e SETUP_SCRIPT="${SETUP_SCRIPT}" ghcr.io/pytorch-labs/tritonbench:latest bash -c '. "${SETUP_SCRIPT}"; python -c "import torch; print(torch.__version__)"') diff --git a/install.py b/install.py index 816b3b6a..7351a438 100644 --- a/install.py +++ b/install.py @@ -115,11 +115,7 @@ def install_tk(): if args.fbgemm or args.all: logger.info("[tritonbench] installing FBGEMM...") install_fbgemm() - # TODO: for some reason, fa2 compile will break docker build - if args.fa2: - logger.info("[tritonbench] installing fa2...") - install_fa2() - if args.fa2_compile: + if args.fa2 or args.all: logger.info("[tritonbench] installing fa2 from source...") install_fa2(compile=True) if args.fa3 or args.all: