From 6b6f4a2ee6362ac6dcf72e05f5e2b46d9a13d2bd Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Wed, 20 Nov 2024 23:16:08 +0000 Subject: [PATCH 01/13] [CI] Fix ccache cache restoration to improve build times (#5202) This improves a warm-cache macOS build from ~25 mins to 2 mins. --- .github/workflows/integration-tests.yml | 155 ++++++++++++++------- .github/workflows/integration-tests.yml.in | 81 +++++++---- 2 files changed, 155 insertions(+), 81 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index cfba6d7225..0070899e97 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -21,10 +21,12 @@ concurrency: cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} permissions: read-all env: + TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_CLANG_LLD: "TRUE" TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" TRITON_DISABLE_LINE_INFO: 1 PROTON_SKIP_PC_SAMPLING_TEST: 1 + CCACHE_COMPRESS: "true" jobs: Runner-Preparation: runs-on: ubuntu-latest @@ -154,6 +156,8 @@ jobs: strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-CUDA)}} + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} steps: - name: Checkout uses: actions/checkout@v4 @@ -199,22 +203,30 @@ jobs: # "restore" step. This is to prevent the caches from accumulating stale # files over time. name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directories run: | mkdir -p ~/.triton ls -alh ~/.triton + du -sh ~/.triton/** + + mkdir -p ~/.ccache + ls -alh ~/.ccache + du -sh ~/.ccache - name: Update PATH run: | echo "$HOME/.local/bin" >> $GITHUB_PATH @@ -224,12 +236,14 @@ jobs: python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-forked pytest-xdist lit - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" CUDA_HOME: "/usr/local/cuda" run: | echo "PATH is '$PATH'" cd python - python3 -m pip install '.[tests]' + ccache --zero-stats + python3 -m pip install -v '.[tests]' + - name: CCache Stats + run: ccache --print-stats - name: Run lit tests run: | cd python @@ -278,6 +292,15 @@ jobs: cd third_party/proton/test python3 -m pytest -s . cd .. + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + ls -alh ~/.triton + du -sh ~/.triton/** + + mkdir -p ~/.ccache + ls -alh ~/.ccache + du -sh ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -287,22 +310,17 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: ${{ steps.restore-build-cache.outputs.cache-primary-key }} Integration-Tests-AMD: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-HIP != '' runs-on: ${{ matrix.runner }} timeout-minutes: 30 + env: + RUNNER_TYPE: ${{ matrix.runner[1] }} strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-HIP)}} @@ -355,22 +373,30 @@ jobs: # "restore" step. This is to prevent the caches from accumulating stale # files over time. name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directories run: | mkdir -p ~/.triton ls -alh ~/.triton + du -sh ~/.triton/** + + mkdir -p ~/.ccache + ls -alh ~/.ccache + du -sh ~/.ccache - name: Update PATH run: | echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH @@ -378,17 +404,24 @@ jobs: run: | python3 -m pip install --upgrade pip python3 -m pip install lit + - name: Install apt dependencies + run: | + apt update + apt install ccache - name: Install Triton id: amd-install-triton run: | echo "PATH is '$PATH'" pip uninstall -y triton cd python + ccache --zero-stats pip install -v -e '.[tests]' - name: Clean up after an unsuccessful build if: ${{ !success() && steps.amd-install-triton.outcome != 'success' }} run: | rm -rf ~/.triton + - name: CCache Stats + run: ccache --print-stats - name: Run lit tests run: | cd python @@ -431,6 +464,15 @@ jobs: cd python cd "build/$(ls build | grep -i cmake)" ctest -j32 + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + ls -alh ~/.triton + du -sh ~/.triton/** + + mkdir -p ~/.ccache + ls -alh ~/.ccache + du -sh ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -440,17 +482,10 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: ${{ steps.restore-build-cache.outputs.cache-primary-key }} - name: Clean up caches run: | rm -rf ~/.triton/cache @@ -462,6 +497,8 @@ jobs: strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-MACOS)}} + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} steps: - name: Checkout uses: actions/checkout@v4 @@ -470,7 +507,7 @@ jobs: - name: Install brew dependencies run: | brew update - brew install ccache llvm@19 lld + brew install ccache llvm@19 lld coreutils - name: Compute cache keys id: cache-key run: | @@ -511,22 +548,30 @@ jobs: # "restore" step. This is to prevent the caches from accumulating stale # files over time. name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directory + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + - name: Inspect cache directories run: | mkdir -p ~/.triton ls -alh ~/.triton + du -sh ~/.triton/** + + mkdir -p ~/.ccache + ls -alh ~/.ccache + du -sh ~/.ccache - name: Update PATH run: | echo "$HOME/.local/bin" >> $GITHUB_PATH @@ -539,7 +584,6 @@ jobs: python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit pybind11 - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_O1: "true" # macos-latest has 3 vcpus and 7GB DRAM, to save memory we limit the number of jobs to 3 # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories @@ -548,7 +592,19 @@ jobs: source ~/.venv/bin/activate echo "PATH is '$PATH'" cd python - python3 -m pip install --no-build-isolation . + ccache --zero-stats + python3 -m pip install -v --no-build-isolation . + - name: CCache Stats + run: ccache --print-stats + - name: Inspect cache directories + run: | + mkdir -p ~/.triton + ls -alh ~/.triton + du -sh ~/.triton/** + + mkdir -p ~/.ccache + ls -alh ~/.ccache + du -sh ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -558,14 +614,7 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: ${{ steps.restore-build-cache.outputs.cache-primary-key }} diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 7da4aa0793..4404b2aa60 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -23,10 +23,12 @@ concurrency: permissions: read-all env: + TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_CLANG_LLD: "TRUE" TRITON_USE_ASSERT_ENABLED_LLVM: "TRUE" TRITON_DISABLE_LINE_INFO: 1 PROTON_SKIP_PC_SAMPLING_TEST: 1 + CCACHE_COMPRESS: "true" jobs: Runner-Preparation: @@ -174,6 +176,9 @@ jobs: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-CUDA)}} + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} + steps: - name: Checkout uses: actions/checkout@v4 @@ -225,24 +230,32 @@ jobs: # files over time. - &restore-build-artifacts-step name: Restore cache of ccache and Triton compilation artifacts - if: github.event_name != 'push' + id: restore-build-cache + if: github.ref != 'refs/heads/main' uses: actions/cache/restore@v4 with: path: | ~/.triton/cache - ~/.cache/ccache + ~/.ccache # Restore the most recent cache entry. - restore-keys: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}- + restore-keys: | + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}- + triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}- # We expect this cache key never to hit and for us to fall back # unconditionally to the restore-key, so it doesn't actually matter # what we put here (so long as it doesn't hit an existing key). - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - &inspect-cache-directory-step - name: Inspect cache directory + - &inspect-cache-directories-step + name: Inspect cache directories run: | mkdir -p ~/.triton ls -alh ~/.triton + du -sh ~/.triton/** + + mkdir -p ~/.ccache + ls -alh ~/.ccache + du -sh ~/.ccache - name: Update PATH run: | @@ -255,12 +268,16 @@ jobs: - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" CUDA_HOME: "/usr/local/cuda" run: | echo "PATH is '$PATH'" cd python - python3 -m pip install '.[tests]' + ccache --zero-stats + python3 -m pip install -v '.[tests]' + + - &print-ccache-stats + name: CCache Stats + run: ccache --print-stats - &run-lit-tests-step name: Run lit tests @@ -319,6 +336,8 @@ jobs: python3 -m pytest -s . cd .. + - *inspect-cache-directories-step + # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -329,19 +348,10 @@ jobs: if: github.ref == 'refs/heads/main' uses: actions/cache/save@v4 with: - path: ~/.triton/cache ~/.cache/ccache - key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ runner.name }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - - - &inspect-cache-directories-step - name: Inspect cache directories - run: | - mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** - - mkdir -p ~/.cache/ccache - ls -alh ~/.cache/ccache - du -sh ~/.cache/ccache + path: | + ~/.triton/cache + ~/.ccache + key: ${{ steps.restore-build-cache.outputs.cache-primary-key }} Integration-Tests-AMD: needs: Runner-Preparation @@ -350,6 +360,9 @@ jobs: runs-on: ${{ matrix.runner }} timeout-minutes: 30 + env: + RUNNER_TYPE: ${{ matrix.runner[1] }} + strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-HIP)}} @@ -369,7 +382,7 @@ jobs: - *compute-cache-keys-step - *cache-build-dependencies-step - *restore-build-artifacts-step - - *inspect-cache-directory-step + - *inspect-cache-directories-step - name: Update PATH run: | @@ -380,12 +393,18 @@ jobs: python3 -m pip install --upgrade pip python3 -m pip install lit + - name: Install apt dependencies + run: | + apt update + apt install ccache + - name: Install Triton id: amd-install-triton run: | echo "PATH is '$PATH'" pip uninstall -y triton cd python + ccache --zero-stats pip install -v -e '.[tests]' - name: Clean up after an unsuccessful build @@ -393,6 +412,7 @@ jobs: run: | rm -rf ~/.triton + - *print-ccache-stats - *run-lit-tests-step - name: Run python tests on HIP @@ -423,8 +443,8 @@ jobs: - *run-proton-tests-step - *run-cpp-unittests-step - - *save-build-artifacts-step - *inspect-cache-directories-step + - *save-build-artifacts-step - name: Clean up caches run: | @@ -438,6 +458,10 @@ jobs: strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-MACOS)}} + + env: + RUNNER_TYPE: ${{ matrix.runner[0] }} + steps: - name: Checkout uses: actions/checkout@v4 @@ -446,12 +470,12 @@ jobs: - name: Install brew dependencies run: | brew update - brew install ccache llvm@19 lld + brew install ccache llvm@19 lld coreutils - *compute-cache-keys-step - *cache-build-dependencies-step - *restore-build-artifacts-step - - *inspect-cache-directory-step + - *inspect-cache-directories-step - name: Update PATH run: | @@ -465,7 +489,6 @@ jobs: python3 -m pip install cython setuptools wheel cmake==3.24 ninja pytest-xdist lit pybind11 - name: Install Triton env: - TRITON_BUILD_WITH_CCACHE: "true" TRITON_BUILD_WITH_O1: "true" # macos-latest has 3 vcpus and 7GB DRAM, to save memory we limit the number of jobs to 3 # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories @@ -474,7 +497,9 @@ jobs: source ~/.venv/bin/activate echo "PATH is '$PATH'" cd python - python3 -m pip install --no-build-isolation . + ccache --zero-stats + python3 -m pip install -v --no-build-isolation . - - *save-build-artifacts-step + - *print-ccache-stats - *inspect-cache-directories-step + - *save-build-artifacts-step From 9c7a8c6a489801bc17f106b73bcf08ea54cae03c Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Thu, 21 Nov 2024 03:36:30 +0000 Subject: [PATCH 02/13] [CI] Fix `du` failling if cache restore fails (#5206) Follow up to #5202 It's currently failing with the error ``` du: /Users/runner/.triton/**: No such file or directory Error: Process completed with exit code 1. ``` which happens because even though the `.triton` directory exists, it is empty. This instead uses du on `.triton` with a depth of 1. --- .github/workflows/integration-tests.yml | 36 ++++++++-------------- .github/workflows/integration-tests.yml.in | 6 ++-- 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 0070899e97..011fbf824f 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -221,12 +221,10 @@ jobs: - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** + du -h -d 1 ~/.triton mkdir -p ~/.ccache - ls -alh ~/.ccache - du -sh ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | echo "$HOME/.local/bin" >> $GITHUB_PATH @@ -295,12 +293,10 @@ jobs: - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** + du -h -d 1 ~/.triton mkdir -p ~/.ccache - ls -alh ~/.ccache - du -sh ~/.ccache + du -h -d 1 ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -391,12 +387,10 @@ jobs: - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** + du -h -d 1 ~/.triton mkdir -p ~/.ccache - ls -alh ~/.ccache - du -sh ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | echo "/opt/rocm/llvm/bin" >> $GITHUB_PATH @@ -467,12 +461,10 @@ jobs: - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** + du -h -d 1 ~/.triton mkdir -p ~/.ccache - ls -alh ~/.ccache - du -sh ~/.ccache + du -h -d 1 ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # @@ -566,12 +558,10 @@ jobs: - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** + du -h -d 1 ~/.triton mkdir -p ~/.ccache - ls -alh ~/.ccache - du -sh ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | echo "$HOME/.local/bin" >> $GITHUB_PATH @@ -599,12 +589,10 @@ jobs: - name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** + du -h -d 1 ~/.triton mkdir -p ~/.ccache - ls -alh ~/.ccache - du -sh ~/.ccache + du -h -d 1 ~/.ccache - # If we're on branch `main`, save the ccache Triton compilation artifacts # to the cache so they can be used by other (non-main) CI runs. # diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 4404b2aa60..02e0b04653 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -250,12 +250,10 @@ jobs: name: Inspect cache directories run: | mkdir -p ~/.triton - ls -alh ~/.triton - du -sh ~/.triton/** + du -h -d 1 ~/.triton mkdir -p ~/.ccache - ls -alh ~/.ccache - du -sh ~/.ccache + du -h -d 1 ~/.ccache - name: Update PATH run: | From d5ba6acb33bd5b382e946b1ddf0b3c45b73554ff Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Wed, 20 Nov 2024 20:01:16 -0800 Subject: [PATCH 03/13] [BACKEND][LAYOUT] Use LL for AMDMfma related layout conversions (#5210) --- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 22 +++++++++---------- test/Conversion/amd/mfma-shortcut.mlir | 2 ++ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index e48cfca441..62499d8208 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -374,24 +374,24 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // TODO (Keren): Currently, we handle general mma/blocked/slice/dot(ampere) // -> mma/blocked/slice/dot(ampere) conversions. The following tasks must be // completed before we can remove the layoutIsOK check: - // 1. Support for AMD's MFMA and WMMA + // 1. Support for AMD's WMMA std::function layoutIsOK = [&](Attribute layout) { - if (auto nvidiaMma = dyn_cast(layout)) { - if (useLegacyMMAConversion) { - return false; - } - return true; + if (isa(layout)) { + return !useLegacyMMAConversion; } if (auto dotOperand = dyn_cast(layout)) { - if (auto nvidiaMma = - dyn_cast(dotOperand.getParent())) { - if (useLegacyMMAConversion) { - return false; - } + auto parent = dotOperand.getParent(); + if (isa(parent) && useLegacyMMAConversion) { + return false; + } + if (auto nvidiaMma = dyn_cast(parent)) { if (nvidiaMma.isAmpere()) { return true; } } + if (isa(parent)) { + return true; + } return false; } if (isa(layout)) { diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 83c9e535d8..a2c8f48718 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -7,6 +7,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK-NOT: store // CHECK-NOT: load + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } @@ -21,6 +22,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { // CHECK: store // CHECK: load + // CHECK: llvm.return %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } From cef2671e1bca3869e2c95baabd34fae8eb29bb5a Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Thu, 21 Nov 2024 15:59:30 +0000 Subject: [PATCH 04/13] [BUILD] Add option to limit number of parallel link jobs (#5212) --- CMakeLists.txt | 7 +++++++ python/setup.py | 1 + 2 files changed, 8 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ee4774860..56564c3896 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,6 +44,13 @@ if(TRITON_BUILD_WITH_CCACHE) endif() endif() +set(TRITON_PARALLEL_LINK_JOBS "" CACHE STRING + "Define the maximum number of concurrent link jobs (Ninja only).") +if (TRITON_PARALLEL_LINK_JOBS) + set_property(GLOBAL APPEND PROPERTY JOB_POOLS link_job_pool=${TRITON_PARALLEL_LINK_JOBS}) + set(CMAKE_JOB_POOL_LINK link_job_pool) +endif() + # Ensure Python3 vars are set correctly # used conditionally in this file and by lit tests diff --git a/python/setup.py b/python/setup.py index 86c4013b07..607670fc5e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -465,6 +465,7 @@ def build_extension(self, ext): "TRITON_BUILD_PROTON", "TRITON_BUILD_TUTORIALS", "TRITON_BUILD_WITH_CCACHE", + "TRITON_PARALLEL_LINK_JOBS", ] cmake_args += [f"-D{option}={os.getenv(option)}" for option in passthrough_args if option in os.environ] From 66012fcb0e796511762c2de062b6a86bcddf8aac Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Thu, 21 Nov 2024 16:12:38 +0000 Subject: [PATCH 05/13] [CI] Fix cache not saving (#5213) #### Commits in this PR 1. [CI] Fix cache not saving Re-using the output of the cache restore step was recommended by the `actons/cache` docs, but it doesn't work here because we actually start from a clean cache when we run save so there is no output available to read. The annoyances of testing in the PR but main being a different environment. 2. Bump macOS timeout --- .github/workflows/integration-tests.yml | 8 ++++---- .github/workflows/integration-tests.yml.in | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 011fbf824f..55863099d4 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -309,7 +309,7 @@ jobs: path: | ~/.triton/cache ~/.ccache - key: ${{ steps.restore-build-cache.outputs.cache-primary-key }} + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} Integration-Tests-AMD: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-HIP != '' @@ -477,7 +477,7 @@ jobs: path: | ~/.triton/cache ~/.ccache - key: ${{ steps.restore-build-cache.outputs.cache-primary-key }} + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} - name: Clean up caches run: | rm -rf ~/.triton/cache @@ -485,7 +485,7 @@ jobs: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-MACOS != '' runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + timeout-minutes: 40 strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-MACOS)}} @@ -605,4 +605,4 @@ jobs: path: | ~/.triton/cache ~/.ccache - key: ${{ steps.restore-build-cache.outputs.cache-primary-key }} + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index 02e0b04653..d4917816a4 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -349,7 +349,7 @@ jobs: path: | ~/.triton/cache ~/.ccache - key: ${{ steps.restore-build-cache.outputs.cache-primary-key }} + key: triton-artifacts-${{ runner.os }}-${{ runner.arch }}-${{ env.RUNNER_TYPE }}-llvm-${{ steps.cache-key.outputs.llvm }}-${{ steps.cache-key.outputs.datetime }} Integration-Tests-AMD: needs: Runner-Preparation @@ -452,7 +452,7 @@ jobs: needs: Runner-Preparation if: needs.Runner-Preparation.outputs.matrix-MACOS != '' runs-on: ${{ matrix.runner }} - timeout-minutes: 30 + timeout-minutes: 40 strategy: matrix: runner: ${{fromJson(needs.Runner-Preparation.outputs.matrix-MACOS)}} From 29e18cf36060327259c277c5d0f9c1150f2ce3e4 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Thu, 21 Nov 2024 17:17:23 +0100 Subject: [PATCH 06/13] Revert "Add test_scan_layouts to skiplist (#2663)" (#2671) Closes #2662 This reverts commit https://github.com/intel/intel-xpu-backend-for-triton/commit/4382295d96f180fd5ff2636fb71f2aacf244111a. This issue is still relevant for LTS. ~Blocked by https://github.com/intel/intel-xpu-backend-for-triton/pull/2657~ --------- Signed-off-by: Anatoly Myachev --- scripts/skiplist/a770/language.txt | 2 -- scripts/skiplist/conda/language.txt | 2 -- scripts/skiplist/default/language.txt | 2 -- scripts/skiplist/mtl/language.txt | 2 -- scripts/skiplist/xe2/language.txt | 2 -- 5 files changed, 10 deletions(-) diff --git a/scripts/skiplist/a770/language.txt b/scripts/skiplist/a770/language.txt index e833b924bd..7e3e8d62fc 100644 --- a/scripts/skiplist/a770/language.txt +++ b/scripts/skiplist/a770/language.txt @@ -1,7 +1,5 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] -# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662 -test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0] test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16] diff --git a/scripts/skiplist/conda/language.txt b/scripts/skiplist/conda/language.txt index 41035163ff..1f2dcf0d10 100644 --- a/scripts/skiplist/conda/language.txt +++ b/scripts/skiplist/conda/language.txt @@ -115,8 +115,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-1 test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256] # https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] -# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662 -test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0] test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16] diff --git a/scripts/skiplist/default/language.txt b/scripts/skiplist/default/language.txt index fb018c5e0f..a891b802b5 100644 --- a/scripts/skiplist/default/language.txt +++ b/scripts/skiplist/default/language.txt @@ -1,6 +1,4 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] -# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662 -test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0] diff --git a/scripts/skiplist/mtl/language.txt b/scripts/skiplist/mtl/language.txt index df2e44aae4..69530824f3 100644 --- a/scripts/skiplist/mtl/language.txt +++ b/scripts/skiplist/mtl/language.txt @@ -1,7 +1,5 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] -# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662 -test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0] test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16] diff --git a/scripts/skiplist/xe2/language.txt b/scripts/skiplist/xe2/language.txt index fb018c5e0f..a891b802b5 100644 --- a/scripts/skiplist/xe2/language.txt +++ b/scripts/skiplist/xe2/language.txt @@ -1,6 +1,4 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] -# https://github.com/intel/intel-xpu-backend-for-triton/issues/2662 -test/unit/language/test_core.py::test_scan_layouts[True-1-src_layout10-64-32] # https://github.com/intel/intel-xpu-backend-for-triton/issues/2703 test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0] From de1f346aa6737fa2e3e6a8a64dae118fcfab9995 Mon Sep 17 00:00:00 2001 From: Mario Lezcano Casado <3291265+lezcano@users.noreply.github.com> Date: Thu, 21 Nov 2024 16:43:00 +0000 Subject: [PATCH 07/13] [LAYOUTS] Implement IR support for LinearLayouts (#5170) We also exercise this in scale_dot, where we enable support for warps of arbitrary shape (before we just allowed `[num_warps, 1]`). With this infra in place, it should be rather easy to move from the legacy layouts to using LLs to represent all of our layouts. Something I'm concerned about is the amount of recomputation that happens when calling methods like `getSizePerThread` and the like, where we keep recomputing the result. There might be an optimisation opportunity here where we cache the result of all these functions. We choose the IR representation of an LL via its canonical form + a `repOrder` for several reasons: - It's generally more compact - It's easier to CSE, so it's easier to see when two layouts are in fact the same. - A technical reason: the `toLinearLayout` function returns a tensor with dimensions `dim0, ..., dim`, in other words, it "forgets" the repetition order. Without the repetition order, we cannot recover the tile size of the argument. In particular, we cannot recover `getSizePerThread`. There is an argument to be made about whether `getSizePerThread` is useful on its own, or whether it is `getElemsPerThread` the real useful abstraction here, but for now, we keep both for BC. --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 65 +++ .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 30 +- include/triton/Tools/LinearLayout.h | 3 + .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 + .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 5 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 456 ++++++++++++++++++ .../TritonGPU/IR/LinearLayoutConversions.cpp | 143 +----- lib/Dialect/TritonGPU/IR/Ops.cpp | 10 +- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 48 +- .../Transforms/RemoveLayoutConversions.cpp | 9 +- lib/Tools/LinearLayout.cpp | 60 +-- test/TritonGPU/accelerate-matmul.mlir | 6 +- test/TritonGPU/ops.mlir | 14 + unittest/Dialect/TritonGPU/DialectTest.cpp | 130 ++++- 14 files changed, 785 insertions(+), 197 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index f2715043d7..e0865e12af 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -149,6 +149,71 @@ triton::gpu::BlockedEncodingAttr getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, int numWarps, int threadsPerWarp, int numCTAs); +// For each output dimension d, ensure that the layout's output size (i.e., its +// codomain) does not exceed shape[d]. Do this without changing the size of the +// layout's inputs (i.e., leave its domain unchanged). +// +// This function is invariant to the order of the layout's input and output +// dimensions. +// +// We achieve this by setting the largest value in each output dimension d to 0 +// because bases that map to a location larger than shape[d] +// effectively duplicate along that dimension. For example, consider a layout +// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to +// shrink the output dimension size to 8: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 16 +// +// In the first step, we shrink the output dimension size to 16 by setting +// L(lane=2) to 0: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// This means that lane=2 has the same data as lane=0. +// +// Now the output dimension of this layout has a size of 16, which is still +// larger than 8. We find the current largest value in the output dimension, +// which is L(register=1) = 8, and we set L(register=1) to 0: +// +// L(register=1) = 0 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// Now the output dimension of this layout has a size of 8, which is the desired +// size. Note that this method works only because the bases are powers of two, +// which is the case for DistributedLayouts If broadcastRegisters is false, we +// remove any register that's larger than the desired shape. In the example +// above we would have +// L(register=1) = 4 +// L(register=2) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +LinearLayout +ensureLayoutNotLargerThan(const LinearLayout &layout, + const llvm::SmallDenseMap &shape, + bool broadcastRegisters = true); + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape); + // Dump information about which threads/registers contain each of the tensor // elements. void dumpLayout(RankedTensorType tensorType); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index e6be2f8332..26ff9f7e3a 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -56,7 +56,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute code extraBaseClassDeclaration = [{ unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const; SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const; - ::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const; }]; } @@ -147,7 +146,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to let genVerifyDecl = 1; let skipDefaultBuilders = 1; } - //===----------------------------------------------------------------------===// // Shared Layout Encoding //===----------------------------------------------------------------------===// @@ -565,6 +563,34 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, }]; } +//===----------------------------------------------------------------------===// +// Linear Layout Encoding +//===----------------------------------------------------------------------===// + +def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> { + let mnemonic = "linear"; + + let description = [{ + See the docs in LinearLayout.h for the definition of linear layouts. + }]; + + let parameters = (ins "LinearLayout":$linearLayout); + + let extraClassDeclaration = extraDistributedDeclaration # [{ + SmallVector getContigPerThread() const; + SmallVector getOrder() const; + }]; + + let genVerifyDecl = 1; + // Example of assembly format: + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + let hasCustomAssemblyFormat = 1; +} + + //===----------------------------------------------------------------------===// // Blocked Layout Encoding //===----------------------------------------------------------------------===// diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index 47e3fca79b..cfc4c0d13b 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -9,6 +9,7 @@ #include #include "mlir/IR/BuiltinAttributes.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -432,6 +433,7 @@ class LinearLayout { // (e.g. by reshaping) then the order doesn't really affect anything. auto getInDimNames() const { return llvm::make_first_range(bases); } auto getOutDimNames() const { return llvm::make_first_range(outDims); } + auto getOutDimSizes() const { return llvm::make_second_range(outDims); } // Gets the position that this outDim occupies in getOutDimNames(). Asserts // if the dim is not present. @@ -693,6 +695,7 @@ class LinearLayout { return !(lhs == rhs); } bool equalIgnoringOutDimSizes(const LinearLayout &other) const; + friend size_t hash_value(const LinearLayout &layout); private: // Factory function that gracefully fails rather than asserts if the layout is diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 62499d8208..aab97c7dd2 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -397,6 +397,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (isa(layout)) { return true; } + if (isa(layout)) { + return true; + } if (auto slice = dyn_cast(layout)) { return layoutIsOK(slice.getParent()); } diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 4cea14f095..b090670d95 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -165,8 +165,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); if (isa(srcLayout) && - (isa( - dstLayout) || + (isa(dstLayout) || isSupportedDotOpLayout(dstTy))) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); @@ -206,7 +206,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstTy = op.getResult().getType(); auto dstShape = dstTy.getShape(); auto srcSharedLayout = cast(srcTy.getEncoding()); - auto dstLayout = dstTy.getEncoding(); assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index c785271808..5ae07c3378 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -318,6 +318,9 @@ SmallVector getOrder(Attribute layout) { if (auto sharedLayout = mlir::dyn_cast(layout)) { return llvm::to_vector(sharedLayout.getOrder()); } + if (auto linearLayout = mlir::dyn_cast(layout)) { + return linearLayout.getOrder(); + } llvm::report_fatal_error("Unimplemented usage of getOrder"); return {}; @@ -541,6 +544,102 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, return encoding; } +LinearLayout +ensureLayoutNotLargerThan(const LinearLayout &layout, + const llvm::SmallDenseMap &shape, + bool broadcastRegisters) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + MLIRContext *ctx = shape.begin()->first.getContext(); + + auto bases = layout.getBases(); + + auto kRegister = StringAttr::get(ctx, "register"); + std::set broadcastedDims; + + for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { + auto outDimName = outDim.value(); + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + if (actualSize <= desiredSize) { + continue; + } + assert(actualSize % desiredSize == 0); + // + std::vector> sortedBases; + for (auto [inDimName, basis] : bases) { + for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { + auto outValue = basis[basisIdx][outDim.index()]; + if (outValue == 0) { + continue; + } + assert(llvm::isPowerOf2_32(outValue)); + sortedBases.emplace_back(inDimName, basisIdx, outValue); + } + } + // From the largest basis to the smallest. + llvm::sort(sortedBases, + [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); + for (auto [inDimName, basisIdx, outValue] : sortedBases) { + if (actualSize <= desiredSize) { + break; + } + if (!broadcastRegisters && inDimName == kRegister) { + broadcastedDims.insert(basisIdx); + } else { + bases[inDimName][basisIdx][outDim.index()] = 0; + } + actualSize >>= 1; + } + } + if (!broadcastRegisters) { + // Remove broadcasted registers + std::vector> newBasesRegister; + for (auto [idx, basis] : llvm::enumerate(bases[kRegister])) { + // Remove if it's broadcasted + if (broadcastedDims.find(idx) == broadcastedDims.end()) { + newBasesRegister.push_back(std::move(basis)); + } + } + bases[kRegister] = std::move(newBasesRegister); + } + + return LinearLayout(std::move(bases), + llvm::to_vector(layout.getOutDimNames())); +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + + MLIRContext *ctx = shape.begin()->first.getContext(); + StringAttr kDim = *layout.getInDimNames().begin(); + assert(kDim == "register" || kDim == "offset"); + + LinearLayout ret = layout; + for (StringAttr outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + assert(actualSize > desiredSize || desiredSize % actualSize == 0); + ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); + assert(ret.getOutDimSize(outDimName) >= desiredSize); + } + return ret; +} + } // namespace gpu } // namespace triton } // namespace mlir @@ -1197,6 +1296,360 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "}>"; } +// FIXME Can we take the LinearLayout by const&? +LogicalResult +LinearEncodingAttr::verify(function_ref emitError, + LinearLayout linearLayout) { + // Example of LinearEncodingAttr + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + // The input dims must be {register, lane, warp, block} + // The output dims of the linear layout should be dim0..dim[rank-1] + + static const auto expectedInDims = + SmallVector({"register", "lane", "warp", "block"}); + for (const auto &[i, dims] : llvm::enumerate( + llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { + const auto &[dim, expectedDimStr] = dims; + if (dim.str() != expectedDimStr) { + return emitError() << "Expected input dimension " << i << " to be '" + << expectedDimStr << "'. Got " << dim; + } + } + + // outDims are ['dim0', 'dim1', ...] + for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) { + if (dim.str() != ("dim" + llvm::Twine(i)).str()) { + return emitError() + << "Expected output dimensions to be ['dim0', 'dim1', ...]. Got " + << dim << " at position " << i; + } + } + + const auto &bases = linearLayout.getBases(); + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &dimBases : llvm::make_second_range(bases)) { + if (!llvm::all_of(dimBases, [&](const auto &basis) { + return std::count_if(basis.begin(), basis.end(), nonZero) <= 1; + })) { + return emitError() + << "In a distributed layout, each base must move in at most one " + "dimension."; + } + } + + return success(); +} + +void LinearEncodingAttr::print(mlir::AsmPrinter &printer) const { + // We don't use the default implementation as it's a bit too verbose + // This prints in the following format that is shape agnostic, in the sense + // that we don't print explicitly the outShape of the LL + // We always assume LLs to be surjective + // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], + // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], + // warp = [[16, 0], [32, 0]], + // block = []}> + auto ll = getLinearLayout(); + printer << "<{" << join(ll.getBases(), ", ", [](const auto &base) { + return base.first.str() + " = " + "[" + + join(base.second, ", ", + [](const std::vector &vec) { + return "[" + join(vec, ", ") + "]"; + }) + + "]"; + }) << "}>"; +} + +Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { + if (parser.parseLess().failed()) + return {}; + + DictionaryAttr dict; + if (parser.parseAttribute(dict).failed()) + return {}; + + if (parser.parseGreater().failed()) + return {}; + + LinearLayout::BasesT bases; + + // Parse the basis names in order (the order is relevant) + std::vector inDimNames = {"register", "lane", "warp", "block"}; + + for (const auto &inDimNameStr : inDimNames) { + auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); + Attribute value = dict.get(inDimName); + + // Expecting an array of arrays + auto arrayOfArraysAttr = mlir::dyn_cast(value); + if (!arrayOfArraysAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of arrays for basis of '") + << inDimName.getValue() << "'"; + return {}; + } + + std::vector> inDimBases; + for (Attribute arrayAttr : arrayOfArraysAttr) { + auto intArrayAttr = mlir::dyn_cast(arrayAttr); + if (!intArrayAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected array of integers in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + std::vector basis; + for (Attribute intAttr : intArrayAttr) { + auto intValueAttr = mlir::dyn_cast(intAttr); + if (!intValueAttr) { + parser.emitError(parser.getCurrentLocation(), + "Expected integer in basis for '") + << inDimName.getValue() << "'"; + return {}; + } + basis.push_back(intValueAttr.getInt()); + } + inDimBases.push_back(std::move(basis)); + } + bases[inDimName] = std::move(inDimBases); + } + size_t rank = 0; + for (const auto &basesDim : llvm::make_second_range(bases)) { + if (!basesDim.empty()) { + rank = basesDim[0].size(); + break; + } + } + + // To implement this we'd need to serialise the rank as well. + // We can do this if we ever need it + if (rank == 0) { + parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported"); + return {}; + } + + // Generate standared outDimNames (dim0, dim1, ...) + SmallVector outDimNames; + for (int i = 0; i < rank; ++i) { + outDimNames.push_back( + StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i))); + } + + // Create LinearLayout + LinearLayout linearLayout(std::move(bases), std::move(outDimNames)); + + // Create and return the LinearEncodingAttr + return parser.getChecked(parser.getContext(), + std::move(linearLayout)); +} + +SmallVector basesPerDim(const LinearLayout::BasesT &namedBases, + StringAttr dimName, size_t rank, + bool skipBroadcast = true) { + const auto &bases = namedBases.find(dimName)->second; + + if (bases.empty()) { + return SmallVector(rank, 1); + } + + SmallVector ret(rank, 1); + auto nonZero = [](auto val) { return val != 0; }; + int nonZeroIdx = -1; + for (const auto &basis : bases) { + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + if (it != basis.end()) { + nonZeroIdx = it - basis.begin(); + ret[nonZeroIdx] *= 2; + } else if (!skipBroadcast) { + // If we've seen a non-zero basis, we double the size of the previous dim + // This is just needed to count the CTAsPerCGA + assert(nonZeroIdx != -1); + ret[nonZeroIdx] *= 2; + } + } + return ret; +} + +SmallVector basesPerDim(const LinearLayout &ll, StringAttr dimName, + bool skipBroadcast = true) { + auto shapeIter = ll.getOutDimSizes(); + auto rank = std::distance(shapeIter.begin(), shapeIter.end()); + return basesPerDim(ll.getBases(), dimName, rank, skipBroadcast); +} + +SmallVector orderPerDim(const LinearLayout &ll, StringAttr dimName, + ArrayRef defaultOrder) { + const auto &bases = ll.getBases().find(dimName)->second; + llvm::SetVector order; + auto nonZero = [](auto val) { return val != 0; }; + for (const auto &basis : bases) { + // Bases can have one or zero non-zero elements + // Skip a basis if it's broadcasting (all zeros) + // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + if (it != basis.end()) { + auto i = it - basis.begin(); + order.insert(i); + } + } + // If any dim is missing, we add them in the defaultOrder + for (auto i : defaultOrder) { + order.insert(i); + } + return SmallVector(order.begin(), order.end()); +} + +// [Note. Divergence of methods wrt. legacy layouts] +// For smaller shapes where the CTATile is larger than the output +// tensor, some methods return different values than the legacy layouts. I think +// this is benign tho. An example: what is the the vector of `warpsPerCTA` if +// all the warps hold the same data? I think it should be [1, 1], even if we +// have 4 warps. But perhaps for this we have to add some masking in some +// places... We'll see +SmallVector LinearEncodingAttr::getRepOrder() const { + // This is not correct, but: + // - It happens to agree in most places with the legacy layout + // - getRepOrder does not make sense for LinearEncodingAttr as it already has + // the same shape as the tensor that uses it + return getOrder(); +} +SmallVector LinearEncodingAttr::getCTAsPerCGA() const { + // CTAs are split into an identity part (SplitNum) and a broadcast part + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "block"), + /*skipBroadcast=*/false); +} +SmallVector LinearEncodingAttr::getCTAOrder() const { + return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "block"), + getOrder()); +} +SmallVector LinearEncodingAttr::getCTASplitNum() const { + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "block")); +} +SmallVector LinearEncodingAttr::getWarpsPerCTA() const { + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "warp")); +} +SmallVector LinearEncodingAttr::getWarpOrder() const { + return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "warp"), + getOrder()); +} +SmallVector LinearEncodingAttr::getThreadsPerWarp() const { + return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "lane")); +} +SmallVector LinearEncodingAttr::getThreadOrder() const { + return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "lane"), + getOrder()); +} +SmallVector LinearEncodingAttr::getSizePerThread() const { + auto rank = getRepOrder().size(); + auto ll = getLinearLayout(); + auto ctx = getContext(); + auto kRegister = StringAttr::get(ctx, "register"); + + // We canonicalize on the spot, as if we use CGAs the regs are not in + // canonical form The order is [reg, lane, warp, rep, block], so we first + // remove the blocks + llvm::SmallVector ctaShape; + for (auto [shape, cgaNum] : + llvm::zip(ll.getOutDimSizes(), getCTASplitNum())) { + ctaShape.push_back(shape / cgaNum); + } + LinearLayout::BasesT bases = ll.getBases(); + + llvm::SetVector reverseRepOrder; + auto nonZero = [](auto val) { return val != 0; }; + auto ®isters = bases[StringAttr::get(ctx, "register")]; + while (!registers.empty()) { + auto &basis = registers.back(); + auto it = std::find_if(basis.begin(), basis.end(), nonZero); + // If there's broadcasting (base == zeros) there are no more reps + if (it == basis.end()) { + break; + } + auto dim = it - basis.begin(); + reverseRepOrder.insert(dim); + // As soon as we stop finding reps, we stop + if (dim != reverseRepOrder.back() || 2 * basis[dim] != ctaShape[dim]) { + break; + } + ctaShape[dim] /= 2; + registers.pop_back(); + } + return basesPerDim(bases, kRegister, rank); +} + +SmallVector LinearEncodingAttr::getOrder() const { + auto rank = getLinearLayout().getNumOutDims(); + SmallVector order(rank); + // Choose [rank-1, rank-2, ... 0] as the default order in case + // there are dims that do not move in the register + // This order is as good as any really + std::iota(order.rbegin(), order.rend(), 0); + + return orderPerDim(getLinearLayout(), + StringAttr::get(getContext(), "register"), order); +} + +std::optional +LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { + auto ll = getLinearLayout(); + auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); + llvm::SmallDenseMap namedShape; + llvm::SmallVector permutedDims; + for (auto dim : getRepOrder()) { + permutedDims.push_back(canonicalDims[dim]); + namedShape[canonicalDims[dim]] = shape[dim]; + } + ll = ll.transposeOuts(permutedDims); + ll = ensureLayoutNotSmallerThan(ll, namedShape); + ll = ensureLayoutNotLargerThan(ll, namedShape, /*broadcastRegisters=*/false); + ll = ll.transposeOuts(canonicalDims); + return ll; +} + +SmallVector +LinearEncodingAttr::getElemsPerThread(ArrayRef shape, Type) const { + // We can relax this assert by calling toLinearLayout rather than + // getLinearLayout + SmallVector shapeVec(shape.begin(), shape.end()); + assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes())); + auto ll = getLinearLayout(); + return basesPerDim(ll, StringAttr::get(getContext(), "register")); +} + +// Start of Selection +SmallVector LinearEncodingAttr::getContigPerThread() const { + auto ll = getLinearLayout(); + const auto ®s = + ll.getBases().find(StringAttr::get(getContext(), "register"))->second; + auto order = getOrder(); + auto rank = order.size(); + + SmallVector contig(rank, 1); + auto regIt = regs.begin(); + for (unsigned dim : order) { + std::vector basis(rank, 0); + basis[dim] = 1; + + while (regIt != regs.end() && *regIt == basis) { + contig[dim] *= 2; + basis[dim] *= 2; + ++regIt; + } + } + return contig; +} + +unsigned LinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape, + Type eltTy) const { + return product(getElemsPerThread(shape, eltTy)); +} + //===----------------------------------------------------------------------===// // MMA encoding //===----------------------------------------------------------------------===// @@ -1987,6 +2440,9 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { } else if (auto blockedAttr = mlir::dyn_cast(attr)) { os << "blocked"; return AliasResult::FinalAlias; + } else if (auto linearAttr = mlir::dyn_cast(attr)) { + os << "linear"; + return AliasResult::FinalAlias; } /* else if (auto sliceAttr = dyn_cast(attr)) { os << "slice"; return AliasResult::FinalAlias; diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 214707f7ea..aee7da8a75 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -52,25 +52,6 @@ SmallVector permuteDimNames(const SmallVector &names, return ret; } -void assertIsRegisterLayout(const LinearLayout &layout) { - assert(layout.getNumInDims() > 0); - MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); - StringAttr kRegister = S("register"); - StringAttr kLane = S("lane"); - StringAttr kWarp = S("warp"); - StringAttr kBlock = S("block"); - - const auto &ins = layout.getInDimNames(); - assert(llvm::SmallVector(ins.begin(), ins.end()) == - llvm::SmallVector({kRegister, kLane, kWarp, kBlock})); - - const auto &outs = layout.getOutDimNames(); - const auto &expectedOuts = standardOutDimNames(ctx, layout.getNumOutDims()); - assert(llvm::SmallDenseSet(outs.begin(), outs.end()) == - llvm::SmallDenseSet(expectedOuts.begin(), - expectedOuts.end())); -} - // Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to // creating a 1D -> 1D mapping of size product(shape) and then reshaping to // permute(shape, order). @@ -121,124 +102,6 @@ LinearLayout makeCgaLayout(CTALayoutAttr layout) { return ret.transposeOuts(outDimNames); } -// For each output dimension d, ensure that the layout's output size (i.e., its -// codomain) does not exceed shape[d]. Do this without changing the size of the -// layout's inputs (i.e., leave its domain unchanged). -// -// This function is invariant to the order of the layout's input and output -// dimensions. -// -// We achieve this by setting the largest value in each output dimension d to 0 -// because bases that map to a location larger than shape[d] -// effectively duplicate along that dimension. For example, consider a layout -// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to -// shrink the output dimension size to 8: -// -// L(register=1) = 8 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 16 -// -// In the first step, we shrink the output dimension size to 16 by setting -// L(lane=2) to 0: -// -// L(register=1) = 8 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 0 -// -// This means that lane=2 has the same data as lane=0. -// -// Now the output dimension of this layout has a size of 16, which is still -// larger than 8. We find the current largest value in the output dimension, -// which is L(register=1) = 8, and we set L(register=1) to 0: -// -// L(register=1) = 0 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 0 -// -// Now the output dimension of this layout has a size of 8, which is the desired -// size. Note that this method works only because the bases are powers of two. -// It is unclear what to do when they are not. -LinearLayout ensureLayoutNotLargerThan( - const LinearLayout &layout, - const llvm::SmallDenseMap &shape) { - assert(shape.size() == layout.getNumOutDims()); - if (shape.empty()) { - return layout; - } - MLIRContext *ctx = shape.begin()->first.getContext(); - - auto bases = layout.getBases(); - for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { - auto outDimName = outDim.value(); - int32_t actualSize = layout.getOutDimSize(outDimName); - int32_t desiredSize = shape.lookup(outDimName); - if (actualSize <= desiredSize) { - continue; - } - assert(actualSize % desiredSize == 0); - // - std::vector> sortedBases; - for (auto [inDimName, basis] : bases) { - for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { - auto outValue = basis[basisIdx][outDim.index()]; - if (outValue == 0) { - continue; - } - assert(llvm::isPowerOf2_32(outValue)); - sortedBases.emplace_back(inDimName, basisIdx, outValue); - } - } - // From the largest basis to the smallest. - llvm::sort(sortedBases, - [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); - for (auto [inDimName, basisIdx, outValue] : sortedBases) { - if (actualSize <= desiredSize) { - break; - } - bases[inDimName][basisIdx][outDim.index()] = 0; - actualSize >>= 1; - } - } - return LinearLayout(std::move(bases), - llvm::to_vector(layout.getOutDimNames())); -} - -// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no -// smaller than shape[d]. Do this by increasing the size of the layout's inputs -// along its most-minor dimension ("register" for register layouts, "offset" for -// shared layouts). -// -// This function is invariant to the order of the layout's input dimensions, but -// it cares about the order of the output dims, which should be minor-to-major. -LinearLayout ensureLayoutNotSmallerThan( - const LinearLayout &layout, - const llvm::SmallDenseMap &shape) { - assert(shape.size() == layout.getNumOutDims()); - if (shape.empty()) { - return layout; - } - - MLIRContext *ctx = shape.begin()->first.getContext(); - StringAttr kDim = *layout.getInDimNames().begin(); - assert(kDim == "register" || kDim == "offset"); - - LinearLayout ret = layout; - for (StringAttr outDimName : layout.getOutDimNames()) { - int32_t actualSize = layout.getOutDimSize(outDimName); - int32_t desiredSize = shape.lookup(outDimName); - assert(actualSize > desiredSize || desiredSize % actualSize == 0); - ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); - assert(ret.getOutDimSize(outDimName) >= desiredSize); - } - return ret; -} - // Combines the layout of a CTA (input dims [register, lane, warp]) with the // layout of a CGA (i.e. a block), and ensures that the resulting layout has the // given shape. @@ -928,10 +791,10 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { std::optional toLinearLayout(ArrayRef shape, Attribute layout, std::optional elemBitWidth /*= std::nullopt*/) { + // Layouts are distributed or shared if (auto distributed = dyn_cast(layout)) { return distributed.toLinearLayout(shape); - } - if (auto shared = dyn_cast(layout)) { + } else if (auto shared = dyn_cast(layout)) { if (shared.getHasLeadingOffset()) { assert(elemBitWidth.has_value()); return sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); @@ -940,7 +803,7 @@ toLinearLayout(ArrayRef shape, Attribute layout, } } - // TODO(jlebar): Other layouts + // Third party layouts return std::nullopt; } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 233883964f..068965468e 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -53,15 +53,17 @@ LogicalResult UpcastMXFPOp::verify() { if (!dotEncoding) { return emitOpError("Expected a DotOperandEncodingAttr for values"); } - auto blockedScale = dyn_cast(layoutScale); - if (!blockedScale) { - return emitOpError("Expected a BlockOperandEncoding for scales"); + if (!isa(layoutScale)) { + return emitOpError( + "Expected a BlockOperandEncoding or LinearOperandEncoding " + "for scales"); } if (isa(dotEncoding.getParent())) { // Necessary to keep all of the scales of a given block of values in the // same warp - auto threadsPerWarp = blockedScale.getThreadsPerWarp(); + auto threadsPerWarp = + cast(layoutScale).getThreadsPerWarp(); if (threadsPerWarp != ArrayRef({16, 2})) { return emitOpError("Expected threads per warp to be {16, 2}"); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 4dccc85da3..c07f314087 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -400,6 +400,7 @@ class DecomposeScaledBlocked "NYI: lhs supports fp4 or fp8"); assert(bType == ScaleDotElemType::E4M3 || bType == ScaleDotElemType::E5M2 || bType == ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16"); + bool isFp4 = aType == ScaleDotElemType::E2M1; auto mmaEnc = getMMAEncoding(rewriter, scaledDotOp); auto versionMajor = mmaEnc.getVersionMajor(); @@ -418,7 +419,7 @@ class DecomposeScaledBlocked // types auto aKWidth = mmaEnc.isHopper() ? 2 : 8; auto bKWidth = mmaEnc.isHopper() ? 2 : 8; - if (aType == ScaleDotElemType::E2M1) { + if (isFp4) { // Load 2x4-bit elements per thread aKWidth /= 2; } @@ -438,9 +439,43 @@ class DecomposeScaledBlocked // Necessary choice to leave all the scales of the tile in that given warp auto threadsPerWarp = SmallVector{instrShapeM, warpSize / instrShapeM}; - auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + + assert(versionMajor == 2 && + "NYI: MMAv3. Need to rethink the scale layout otherwise"); + + // Copy the bases + + Attribute newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), newAEncoding.getCTAOrder(), mmaEnc.getCTALayout()); + + auto dotBroadcastsWarpLevel = mmaEnc.getWarpsPerCTA()[1] != 1; + if (dotBroadcastsWarpLevel) { + // If mma has warpsPerCTA == {2, 2}, then newAEncoding has + // warpsPerCTA == {2, 1}. In this case, we need to broadcast the warps + // on the second dimension as per + // A: 0 1 | 0 1 + // - - | - - + // 2 3 | 2 3 + // This broadcasting is not representable by standard blocked encodings, + // so we need to use linear layouts. + // This broadcasting is implemented in ampereDotToLinearLayout + auto blocked = cast(newScaleEncoding); + auto blockedLL = *blocked.toLinearLayout(a.getType().getShape()); + LinearLayout::BasesT scaleBases = blockedLL.getBases(); + auto nBases = llvm::Log2_32(mmaEnc.getWarpsPerCTA()[1]); + auto &warps = scaleBases[StringAttr::get(ctx, "warp")]; + // Prepend the vector of zeros to the warpBases + warps.insert(warps.begin(), nBases, std::vector(rank, 0)); + auto outDims = llvm::to_vector(blockedLL.getOutDimNames()); + auto newLL = LinearLayout(scaleBases, outDims); + auto llEncoding = LinearEncodingAttr::get(ctx, std::move(newLL)); + // Adjust the shape of the layout to match the scale operand + auto scaleShape = scale.getType().getShape(); + newScaleEncoding = + LinearEncodingAttr::get(ctx, *llEncoding.toLinearLayout(scaleShape)); + } + a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding); // Upcast B operand @@ -543,7 +578,8 @@ class DecomposeScaledBlocked auto dotOp = rewriter.create( scaledDotOp.getLoc(), scaledDotOp.getType(), a, b, scaledDotOp.getC()); - // FIXME Waiting on the following comment to be fixed: + // Waiting for https://github.com/triton-lang/triton/pull/5003 to land + // cf. // https://github.com/triton-lang/triton/pull/5003#issuecomment-2445091746 // int versionMajor = getMMAVersionSafe(computeCapability, dotOp); int versionMajor = 2; @@ -559,10 +595,8 @@ class DecomposeScaledBlocked versionMajor, retShapePerCTA, dotOp.getA().getType().getElementType(), numWarps); - // FIXME Waiting on supporting LLs on convert_layout - // auto warpsPerCTA = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, - // numWarps, instrShape); - SmallVector warpsPerCTA = {(unsigned)numWarps, 1}; + auto warpsPerCTA = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + numWarps, instrShape); return NvidiaMmaEncodingAttr::get(ctx, versionMajor, versionMinor, warpsPerCTA, CTALayout, instrShape); } diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index bc4049dc30..70f5219111 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -970,7 +970,9 @@ void LayoutRematerialization::backwardRematerialization( // we don't handle conversions to DotOperandEncodingAttr // this is a heuristic to accommodate fused attention RankedTensorType targetType = convertOp.getType(); - if (isa(targetType.getEncoding())) + // We stop the rematerialization of linear layouts as we have to be a bit more + // careful with the heuristics for both correctness and perf + if (isa(targetType.getEncoding())) return; Value oldV = convertOp->getOperand(0); LDBG("check backward remat with source " << oldV << " encoding " @@ -1012,8 +1014,11 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention + // We stop the rematerialization of linear layouts as we have to be a bit more + // careful with the heuristics for both correctness and perf RankedTensorType targetType = convertOp.getType(); - if (mlir::isa(targetType.getEncoding())) + if (mlir::isa( + targetType.getEncoding())) return; auto isExtOrBroadcastOp = [](Operation *op) { diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 4319d1f086..3a81231ac8 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -212,42 +212,6 @@ void assertCommonDimsSameOrder(T &&aDims, U &&bDims) { "\nb: " + triton::join(bDims, ", ")); } } - -void eraseEmptyInOutDims(BasesT &bases, - llvm::MapVector &outDims) { - // Erase empty out-dims. - SmallVector emptyOutDims; - for (auto [i, outDim] : llvm::enumerate( - llvm::to_vector_of(llvm::make_first_range(outDims)))) { - if (outDims[outDim] == 1) { - emptyOutDims.push_back(i); - outDims.erase(outDim); - } - } - if (outDims.empty()) { - bases.clear(); - return; - } - - for (auto &[inDim, inDimBases] : bases) { - for (auto &basis : inDimBases) { - // Erase the basis elements corresponding to the empty out-dims. - for (int i : llvm::reverse(emptyOutDims)) { - basis.erase(basis.begin() + i); - } - } - } - - // Erase empty in-dims. - // TODO: This needs a test-case. - for (StringAttr inDim : - llvm::to_vector_of(llvm::make_first_range(bases))) { - if (bases[inDim].empty()) { - bases.erase(inDim); - } - } -} - } // anonymous namespace /*static*/ std::optional @@ -989,6 +953,30 @@ LinearLayout::getFreeVariableMasks() const { return ret; } +size_t hash_value(const LinearLayout &layout) { + size_t seed = 0; + + // Hash the bases + for (const auto &base : layout.getBases()) { + // Hash the input dimension name + seed = llvm::hash_combine(seed, base.first); + + // Hash the vectors in bases + for (const auto &vec : base.second) { + for (int32_t val : vec) { + seed = llvm::hash_combine(seed, val); + } + } + } + + // Hash the output dimensions and their sizes + for (const auto &outDim : layout.getOutDimNames()) { + seed = llvm::hash_combine(seed, outDim, layout.getOutDimSize(outDim)); + } + // Don't hash the surjective flag as it's a cached property + return seed; +} + bool operator==(LinearLayout lhs, LinearLayout rhs) { if (!lhs.equalIgnoringOutDimSizes(rhs)) return false; diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 703d379bf5..648a29c34f 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -168,14 +168,16 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: #[[LINEAR:.+]] = #triton_gpu.linear<{{.*}}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: dot_scaled + // CHECK: dot_scaled tt.func @dot_scaled( %a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> { - // CHECK: triton_gpu.upcast_mxfp + // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #[[LINEAR]]> + // CHECK: triton_gpu.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #triton_gpu.dot_op<{{.*}}>> // CHECK: tt.dot %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> diff --git a/test/TritonGPU/ops.mlir b/test/TritonGPU/ops.mlir index 9184a53120..70c1a315e7 100644 --- a/test/TritonGPU/ops.mlir +++ b/test/TritonGPU/ops.mlir @@ -33,3 +33,17 @@ module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i tt.return } } +// ----- + +#blocked= #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +// CHECK: #[[$LINEAR:.*]] = #triton_gpu.linear<{{.*}}> + +module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: @blocked_to_linear + tt.func @blocked_to_linear(%input: tensor<32x4xi8, #blocked>) { + // The layout is the basic layout generated by DecomposeScaledBlocked + %output = triton_gpu.convert_layout %input {allocation.offset = 0 : i32} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #triton_gpu.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>> + // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #[[$LINEAR]]> + tt.return + } +} diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index c27c63335e..779bc1b788 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -620,7 +620,135 @@ TEST_F(AMDMfmaLayoutTest, mfma_dot_op) { ASSERT_THAT(tdot3dOp1.getWarpOrder(), tmfma3d.getWarpOrder()); } -} // anonymous namespace +class LinearEncodingTest : public ::testing::Test { +public: + LinearEncodingTest() { ctx.getOrLoadDialect(); } + +protected: + MLIRContext ctx; +}; + +TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { + // Define a tensor shape + auto rank = 2; + SmallVector> shapes = {{64, 128}, {256, 1024}}; + SmallVector> orders = {{0, 1}, {1, 0}}; + SmallVector ctaLayouts = { + triton::gpu::CTALayoutAttr::getDefault(&ctx, rank), + triton::gpu::CTALayoutAttr::get(&ctx, {4, 2}, {2, 2}, {1, 0}), + }; + SmallVector distributedEncodings; + + // Create BlockedEncodingAttr and SliceEncodingAttr + { + SmallVector sizePerThread = {4, 4}; + SmallVector threadsPerWarp = {4, 8}; + SmallVector warpsPerCTA = {2, 2}; + + for (auto ctaLayout : ctaLayouts) { + for (const auto &order : orders) { + auto blockedEncoding = triton::gpu::BlockedEncodingAttr::get( + &ctx, sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + distributedEncodings.push_back(blockedEncoding); + distributedEncodings.push_back( + triton::gpu::SliceEncodingAttr::get(&ctx, 0, blockedEncoding)); + } + } + } + + // Create an MMAv2 and DotOperandEncodingAttr (MMAv3 doesn't support linear + // layouts yet) + { + unsigned versionMajor = 2; + unsigned versionMinor = 0; + SmallVector warpsPerCTA{4, 2}; + SmallVector instrShape{16, 8}; // Instruction shape (M, N) + auto mma = triton::gpu::NvidiaMmaEncodingAttr::get( + &ctx, versionMajor, versionMinor, warpsPerCTA, ctaLayouts[0], + instrShape); + distributedEncodings.push_back(mma); + // Create an opIdx=0 and opIdx=1 encoding + for (unsigned opIdx = 0; opIdx < 2; ++opIdx) { + distributedEncodings.push_back( + triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx, mma, 2)); + } + } + + for (const auto &distributedEncoding : distributedEncodings) { + for (auto shape : shapes) { + if (auto sliceEncoding = + dyn_cast(distributedEncoding)) { + shape.erase(shape.begin() + sliceEncoding.getDim()); + } + + // Create LinearEncodingAttr from the LinearLayout + auto linearLayout = *distributedEncoding.toLinearLayout(shape); + auto linearEncoding = + triton::gpu::LinearEncodingAttr::get(&ctx, linearLayout); + + // Test that the canonical form of the LinearLayout is indeed canonical + // by expanding it to the original shape + auto expandedLL = linearEncoding.toLinearLayout(shape); + ASSERT_EQ(linearLayout, expandedLL); + + // Test that methods of DistributedEncoding return the same values + Type eltTy = FloatType::getF32(&ctx); + + ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder()); + ASSERT_EQ(cast(distributedEncoding) + .getTotalElemsPerThread(shape, eltTy), + linearEncoding.getTotalElemsPerThread(shape, eltTy)); + ASSERT_EQ(cast(distributedEncoding) + .getElemsPerThread(shape, eltTy), + linearEncoding.getElemsPerThread(shape, eltTy)); + ASSERT_EQ(distributedEncoding.getRepOrder(), + linearEncoding.getRepOrder()); + ASSERT_EQ(distributedEncoding.getContigPerThread(), + linearEncoding.getContigPerThread()); + // DotOperandEncodingAttr::getWarpOrder() is not defined + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getWarpOrder(), + linearEncoding.getWarpOrder()); + } + ASSERT_EQ(distributedEncoding.getThreadOrder(), + linearEncoding.getThreadOrder()); + // For slice these do not equal the total number of lines / warps + // See [Note. Divergence of methods wrt. legacy layouts] + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getWarpsPerCTA(), + linearEncoding.getWarpsPerCTA()); + ASSERT_EQ(distributedEncoding.getThreadsPerWarp(), + linearEncoding.getThreadsPerWarp()); + } + // Canonicalisation for opIdx=0 takes just a [2 x 2] subtile as it takes + // the second repetition along K as the second tile. + if (!isa(distributedEncoding)) { + // FIXME: This happens to be correct for SliceLayout because of the hack + // in SliceEncodingAttr::toLinearLayout(). We should remove the hack + // and the skips in the getWarpsPerCTA() and getThreadsPerWarp() + ASSERT_EQ(distributedEncoding.getSizePerThread(), + linearEncoding.getSizePerThread()); + } + + // block level + // SliceEncoding is not well-defined for CGAs + if (!isa(distributedEncoding)) { + ASSERT_EQ(distributedEncoding.getCTASplitNum(), + linearEncoding.getCTASplitNum()); + ASSERT_EQ(distributedEncoding.getCTAsPerCGA(), + linearEncoding.getCTAsPerCGA()); + // If we are not using CGAs, the order is meaningless + auto useCGA = distributedEncoding.getCTAsPerCGA() != + SmallVector(rank, 1); + if (useCGA) { + ASSERT_EQ(distributedEncoding.getCTAOrder(), + linearEncoding.getCTAOrder()); + } + } + } + } +} +} // namespace } // namespace mlir::triton::gpu int main(int argc, char *argv[]) { From ad28e6ca62c5fbc355aacade42c91078ec18ba9b Mon Sep 17 00:00:00 2001 From: peterbell10 Date: Thu, 21 Nov 2024 16:55:03 +0000 Subject: [PATCH 08/13] [CI] Run tests when CI is manually triggered (#5216) Currently you can manually call a workflow dispatch, but it won't actually run the tests because the variable enable_integration isn't set. --- .github/workflows/integration-tests.yml | 5 +++++ .github/workflows/integration-tests.yml.in | 6 ++++++ 2 files changed, 11 insertions(+) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 55863099d4..2922da501e 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -41,6 +41,11 @@ jobs: if: github.event_name == 'pull_request' run: | echo "enable_integration=true" >> $GITHUB_ENV + - name: Decide manual trigger integration test enablement + # Always enable integration tests when manually triggered + if: github.event_name == 'workflow_dispatch' + run: | + echo "enable_integration=true" >> $GITHUB_ENV - name: Checkout post-submit commits if: github.event_name == 'push' uses: actions/checkout@v4 diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index d4917816a4..7de7264272 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -45,6 +45,12 @@ jobs: run: | echo "enable_integration=true" >> $GITHUB_ENV + - name: Decide manual trigger integration test enablement + # Always enable integration tests when manually triggered + if: github.event_name == 'workflow_dispatch' + run: | + echo "enable_integration=true" >> $GITHUB_ENV + - name: Checkout post-submit commits if: github.event_name == 'push' uses: actions/checkout@v4 From 5bcbdc32641be0b0b76f5e834b02097c034cb02c Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 21 Nov 2024 14:16:06 -0500 Subject: [PATCH 09/13] [DOC] Add a guide to update SYCL device library (#2788) Signed-off-by: Whitney Tsang --- docs/update_sycl_libdevice.md | 90 +++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 docs/update_sycl_libdevice.md diff --git a/docs/update_sycl_libdevice.md b/docs/update_sycl_libdevice.md new file mode 100644 index 0000000000..307b0c887a --- /dev/null +++ b/docs/update_sycl_libdevice.md @@ -0,0 +1,90 @@ +# Guide to Update SYCL Device Library + +This guide will walk you through the steps to update the SYCL device library using the Intel DPC++ compiler. + +## Step 1: Display Commands used during Compilation Process +1. Open a terminal. +2. Run the following command to compile a C++ file: +```sh +dpcpp -save-temps -#x t.cpp +``` +Replace t.cpp with any C++ file of your choice. This command will display the commands used during the compilation process. + +## Step 2: Locate the llvm-link Command +From the output of the previous command, find the llvm-link command line. It should look similar to the following example: +```sh +"/opt/intel/oneapi/compiler/2025.0/bin/compiler/llvm-link" \ + -only-needed \ + t-sycl-spir64-unknown-unknown-b331ea.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-crt.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-complex.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-complex-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-cmath.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-cmath-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-imf.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-imf-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-imf-bf16.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-cassert.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-cstring.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-complex.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-complex-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-cmath.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-cmath-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-imf.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-imf-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-imf-bf16.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-itt-user-wrappers.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-itt-compiler-wrappers.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-itt-stubs.bc \ + -o \ + t-sycl-spir64-unknown-unknown-d81f68.bc \ + --suppress-warnings +``` + +## Step 3: Modify the llvm-link Command +Remove the `-only-needed` option and the intermediate file `t-sycl-spir64-unknown-unknown-b331ea.bc` from the command line. +And modify to output file name to `libsycl-spir64-unknown-unknown.bc`. +The modified command should look like this: +```sh +"/opt/intel/oneapi/compiler/2025.0/bin/compiler/llvm-link" \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-crt.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-complex.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-complex-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-cmath.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-cmath-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-imf.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-imf-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-imf-bf16.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-cassert.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-cstring.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-complex.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-complex-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-cmath.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-cmath-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-imf.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-imf-fp64.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-fallback-imf-bf16.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-itt-user-wrappers.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-itt-compiler-wrappers.bc \ + /opt/intel/oneapi/compiler/2025.0/bin/compiler/../../lib/libsycl-itt-stubs.bc \ + -o \ + libsycl-spir64-unknown-unknown.bc \ + --suppress-warnings +``` + +## Step 4: Execute the Modified Command +Copy the modified llvm-link command. +Paste and run it in the terminal. + +## Step 5: Check for Manual Changes +Check the log of the existing device library to see what manual changes need to be made: +```sh +git log third_party/intel/backend/lib/libsycl-spir64-unknown-unknown.bc +``` +Look for any specific changes mentioned in the commit messages. For example, from commit 0dd37fc92c46f35c6ced34801e51058b6b89ea47, you need to change one of the module metadata from 4 to 3. + +## Step 6: Apply Manual Changes +`llvm-dis` to disassemble the bitcode library, then based on the information from the git log, apply the necessary manual changes to the updated device library. +Reassemble the modified LLVMIR device library using `llvm-as`. + +By following these steps, you will have successfully updated the SYCL device library and applied any necessary manual changes. From e9db1862b80633eaa4f8a61366fec16248eb2cb5 Mon Sep 17 00:00:00 2001 From: Yuanwei Fang Date: Thu, 21 Nov 2024 12:01:45 -0800 Subject: [PATCH 10/13] [PROTON] Introduce the Proton dialect as a third-party plugin for intra-kernel perf tooling (#5119) This PR introduces the `Proton Dialect` to enable intra kernel profiling and tooling for Triton. As a third-party dialect, it serves as the building blocks to create 3rd-party perf tools (e.g., profilers, analysis, modeling) for Triton compiler developers in a compiler-centric way, such as an intra-kernel latency profiler to understand software pipelining, warp specialization, and CTA fine-grained orchestration (e.g., cuda core, tensor core, TMA). Future developments would integrate this dialect with the existing Proton backend profiling infrastructure to make it a powerful and general perf tool utility. As a first step, this PR adds some basic boilerplate code and mechanics, and the `proton.record` op for the `Proton Dialect`. --------- Co-authored-by: Yuanwei Fang Co-authored-by: Keren Zhou --- CMakeLists.txt | 4 ++ bin/RegisterTritonDialects.h | 18 ++--- test/Proton/ops.mlir | 15 +++++ third_party/proton/dialect/CMakeLists.txt | 7 ++ .../proton/dialect/include/CMakeLists.txt | 1 + .../dialect/include/Dialect/CMakeLists.txt | 1 + .../include/Dialect/Proton/CMakeLists.txt | 1 + .../include/Dialect/Proton/IR/CMakeLists.txt | 18 +++++ .../include/Dialect/Proton/IR/Dialect.h | 23 +++++++ .../Dialect/Proton/IR/ProtonAttrDefs.td | 12 ++++ .../Dialect/Proton/IR/ProtonDialect.td | 18 +++++ .../include/Dialect/Proton/IR/ProtonOps.td | 65 +++++++++++++++++++ third_party/proton/dialect/lib/CMakeLists.txt | 1 + .../proton/dialect/lib/Dialect/CMakeLists.txt | 1 + .../dialect/lib/Dialect/Proton/CMakeLists.txt | 1 + .../lib/Dialect/Proton/IR/CMakeLists.txt | 13 ++++ .../dialect/lib/Dialect/Proton/IR/Dialect.cpp | 25 +++++++ .../dialect/lib/Dialect/Proton/IR/Ops.cpp | 33 ++++++++++ third_party/proton/dialect/triton_proton.cc | 20 ++++++ 19 files changed, 269 insertions(+), 8 deletions(-) create mode 100644 test/Proton/ops.mlir create mode 100644 third_party/proton/dialect/CMakeLists.txt create mode 100644 third_party/proton/dialect/include/CMakeLists.txt create mode 100644 third_party/proton/dialect/include/Dialect/CMakeLists.txt create mode 100644 third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt create mode 100644 third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt create mode 100644 third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h create mode 100644 third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td create mode 100644 third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td create mode 100644 third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td create mode 100644 third_party/proton/dialect/lib/CMakeLists.txt create mode 100644 third_party/proton/dialect/lib/Dialect/CMakeLists.txt create mode 100644 third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt create mode 100644 third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt create mode 100644 third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp create mode 100644 third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp create mode 100644 third_party/proton/dialect/triton_proton.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 56564c3896..c5aa40499e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -206,6 +206,9 @@ if(TRITON_BUILD_PYTHON_MODULE) if (TRITON_BUILD_PROTON) add_subdirectory(third_party/proton) endif() + # We always build proton dialect + list(APPEND TRITON_PLUGIN_NAMES "proton") + add_subdirectory(third_party/proton/dialect) get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS) get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS) @@ -311,6 +314,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE) foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS}) add_subdirectory(third_party/${CODEGEN_BACKEND}) endforeach() + add_subdirectory(third_party/proton/dialect) endif() add_subdirectory(third_party/f2reduce) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index c69e46792c..71d75b35db 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -2,6 +2,7 @@ #include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "amd/include/TritonAMDGPUTransforms/Passes.h" #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -68,12 +69,13 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); // TODO: register Triton & TritonGPU passes - registry.insert(); + registry + .insert(); } diff --git a/test/Proton/ops.mlir b/test/Proton/ops.mlir new file mode 100644 index 0000000000..22a17e3f0f --- /dev/null +++ b/test/Proton/ops.mlir @@ -0,0 +1,15 @@ +// RUN: triton-opt --split-input-file %s -cse -canonicalize | FileCheck %s + +module { + // CHECK-LABEL: proton_record + tt.func @proton_record() { + // CHECK: proton.record() {isStart = true, regionId = 1 : i32} + // CHECK-NEXT: proton.record() {isStart = false, regionId = 1 : i32} + // CHECK-NEXT: tt.return + proton.record() {isStart = true, regionId = 1 : i32} + proton.record() {isStart = false, regionId = 1 : i32} + tt.return + } +} // end module + +// ----- diff --git a/third_party/proton/dialect/CMakeLists.txt b/third_party/proton/dialect/CMakeLists.txt new file mode 100644 index 0000000000..c7b5413a0e --- /dev/null +++ b/third_party/proton/dialect/CMakeLists.txt @@ -0,0 +1,7 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonProton ${CMAKE_CURRENT_SOURCE_DIR}/triton_proton.cc LINK_LIBS ProtonIR) +endif() diff --git a/third_party/proton/dialect/include/CMakeLists.txt b/third_party/proton/dialect/include/CMakeLists.txt new file mode 100644 index 0000000000..0ca0f41c5a --- /dev/null +++ b/third_party/proton/dialect/include/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/proton/dialect/include/Dialect/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..f18c30ba1a --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Proton) diff --git a/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt b/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 0000000000..4645b0ebcd --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS ProtonOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=proton) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=proton) +mlir_tablegen(OpsConversions.inc -gen-llvmir-conversions) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(ProtonDialect ProtonDialect dialects/ -gen-dialect-doc) +add_mlir_doc(ProtonOps ProtonOps dialects/ -gen-op-doc) +add_public_tablegen_target(ProtonTableGen) + +set(LLVM_TARGET_DEFINITIONS ProtonAttrDefs.td) +mlir_tablegen(ProtonAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(ProtonAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(ProtonAttrDefsIncGen) diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h b/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h new file mode 100644 index 0000000000..680a205f08 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h @@ -0,0 +1,23 @@ +#ifndef TRITON_DIALECT_PROTON_IR_DIALECT_H_ +#define TRITON_DIALECT_PROTON_IR_DIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "proton/dialect/include/Dialect/Proton/IR/Dialect.h.inc" +#include "proton/dialect/include/Dialect/Proton/IR/OpsEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.h.inc" + +#define GET_OP_CLASSES +#include "proton/dialect/include/Dialect/Proton/IR/Ops.h.inc" + +namespace mlir { +namespace triton { +namespace proton {} // namespace proton +} // namespace triton +} // namespace mlir + +#endif // TRITON_DIALECT_PROTON_IR_DIALECT_H_ diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td new file mode 100644 index 0000000000..d469fbb35f --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonAttrDefs.td @@ -0,0 +1,12 @@ +#ifndef PROTON_ATTRDEFS +#define PROTON_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "ProtonDialect.td" + +class Proton_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { +} + +#endif // PROTON_ATTRDEFS diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td new file mode 100644 index 0000000000..245f2e09a2 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonDialect.td @@ -0,0 +1,18 @@ +#ifndef PROTON_DIALECT +#define PROTON_DIALECT + +include "mlir/IR/OpBase.td" + +def Proton_Dialect : Dialect { + let name = "proton"; + let cppNamespace = "::mlir::triton::proton"; + + let description = [{ + Proton Dialect provides core ops for building third-party compiler-based + performance profiling and analysis tools. + }]; + + let dependentDialects = []; +} + +#endif diff --git a/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td new file mode 100644 index 0000000000..d18a48d5d1 --- /dev/null +++ b/third_party/proton/dialect/include/Dialect/Proton/IR/ProtonOps.td @@ -0,0 +1,65 @@ +#ifndef PROTON_OPS +#define PROTON_OPS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" +include "ProtonDialect.td" +include "ProtonAttrDefs.td" + +class TT_Proton_Op traits = []> : + Op { +} + +// Proton profiling metric. +def MetricAttr : I32EnumAttr< + "Metric", "", + [ + I32EnumAttrCase<"CYCLE", 0, "cycle">, + ]> { + let cppNamespace = "::mlir::triton::proton"; +} + +// Proton profiling granularity. +def GranularityAttr : I32EnumAttr< + "Granularity", "", + [ + I32EnumAttrCase<"WARPGROUP", 0, "warpgroup">, + I32EnumAttrCase<"WARP", 1, "warp">, + ]> { + let cppNamespace = "::mlir::triton::proton"; +} + +def TT_RecordOp : TT_Proton_Op<"record", [DeclareOpInterfaceMethods]> { + let summary = "Record a GPU hardware event"; + + let description = [{ + The operator records GPU events from performance counters. + Currently only cycle counter is supported. + + Example: + + ```mlir + proton.record() {isStart = true, regionId = 4 : i32} + ... + proton.record() {isStart = false, regionId = 4 : i32} + ... + proton.record() {isStart = true, regionId = 1 : i32, granularity = 1 : i32} + ... + proton.record() {isStart = false, regionId = 1 : i32, granularity = 1 : i32} + ``` + }]; + let arguments = ( + ins BoolAttr: $isStart, + ConfinedAttr:$regionId, + DefaultValuedAttr:$metric, + DefaultValuedAttr:$granularity + ); + let assemblyFormat = " `(` operands `)` attr-dict"; +} + +#endif // PROTON_OPS diff --git a/third_party/proton/dialect/lib/CMakeLists.txt b/third_party/proton/dialect/lib/CMakeLists.txt new file mode 100644 index 0000000000..0ca0f41c5a --- /dev/null +++ b/third_party/proton/dialect/lib/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Dialect) diff --git a/third_party/proton/dialect/lib/Dialect/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/CMakeLists.txt new file mode 100644 index 0000000000..f18c30ba1a --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(Proton) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt new file mode 100644 index 0000000000..f33061b2d8 --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt new file mode 100644 index 0000000000..5eea5cb3cf --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(ProtonIR + Dialect.cpp + Ops.cpp + + DEPENDS + ProtonTableGen + ProtonAttrDefsIncGen + + LINK_LIBS PUBLIC + MLIRLLVMDialect + TritonIR + TritonGPUIR +) diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp b/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp new file mode 100644 index 0000000000..60c2852654 --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/Dialect.cpp @@ -0,0 +1,25 @@ +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" + +// clang-format off +#include "Dialect/Proton/IR/Dialect.h" +#include "Dialect/Proton/IR/Dialect.cpp.inc" +// clang-format on + +using namespace mlir; +using namespace mlir::triton::proton; + +void mlir::triton::proton::ProtonDialect::initialize() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc" + >(); + + addOperations< +#define GET_OP_LIST +#include "Dialect/Proton/IR/Ops.cpp.inc" + >(); +} + +#define GET_ATTRDEF_CLASSES +#include "Dialect/Proton/IR/ProtonAttrDefs.cpp.inc" diff --git a/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp b/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp new file mode 100644 index 0000000000..1a0799aea1 --- /dev/null +++ b/third_party/proton/dialect/lib/Dialect/Proton/IR/Ops.cpp @@ -0,0 +1,33 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Support/LLVM.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" + +#define GET_OP_CLASSES +#include "Dialect/Proton/IR/Ops.cpp.inc" +#include "Dialect/Proton/IR/OpsEnums.cpp.inc" + +namespace mlir { +namespace triton { +namespace proton { + +// -- RecordOp -- +void RecordOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +} // namespace proton +} // namespace triton +} // namespace mlir diff --git a/third_party/proton/dialect/triton_proton.cc b/third_party/proton/dialect/triton_proton.cc new file mode 100644 index 0000000000..8046539794 --- /dev/null +++ b/third_party/proton/dialect/triton_proton.cc @@ -0,0 +1,20 @@ +#include "Dialect/Proton/IR/Dialect.h" +#include "mlir/Pass/PassManager.h" +#include "passes.h" +#include +#include +#include + +namespace py = pybind11; + +void init_triton_proton(py::module &&m) { + auto passes = m.def_submodule("passes"); + + // load dialects + m.def("load_dialects", [](mlir::MLIRContext &context) { + mlir::DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + }); +} From 7b5daa47fd38226587c1e6be35d6e6e371074735 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 21 Nov 2024 21:42:49 +0000 Subject: [PATCH 11/13] Revert "[LAYOUTS] Implement IR support for LinearLayouts (#5170)" This reverts commit de1f346aa6737fa2e3e6a8a64dae118fcfab9995. --- include/triton/Dialect/TritonGPU/IR/Dialect.h | 65 --- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 30 +- include/triton/Tools/LinearLayout.h | 3 - .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 - .../TritonGPUToLLVM/MemoryOpToLLVM.cpp | 5 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 456 ------------------ .../TritonGPU/IR/LinearLayoutConversions.cpp | 143 +++++- lib/Dialect/TritonGPU/IR/Ops.cpp | 10 +- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 48 +- .../Transforms/RemoveLayoutConversions.cpp | 9 +- lib/Tools/LinearLayout.cpp | 60 ++- test/TritonGPU/accelerate-matmul.mlir | 6 +- test/TritonGPU/ops.mlir | 14 - unittest/Dialect/TritonGPU/DialectTest.cpp | 130 +---- 14 files changed, 197 insertions(+), 785 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 7d5d801d57..a9b49448c1 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -151,71 +151,6 @@ triton::gpu::BlockedEncodingAttr getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, int numWarps, int threadsPerWarp, int numCTAs); -// For each output dimension d, ensure that the layout's output size (i.e., its -// codomain) does not exceed shape[d]. Do this without changing the size of the -// layout's inputs (i.e., leave its domain unchanged). -// -// This function is invariant to the order of the layout's input and output -// dimensions. -// -// We achieve this by setting the largest value in each output dimension d to 0 -// because bases that map to a location larger than shape[d] -// effectively duplicate along that dimension. For example, consider a layout -// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to -// shrink the output dimension size to 8: -// -// L(register=1) = 8 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 16 -// -// In the first step, we shrink the output dimension size to 16 by setting -// L(lane=2) to 0: -// -// L(register=1) = 8 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 0 -// -// This means that lane=2 has the same data as lane=0. -// -// Now the output dimension of this layout has a size of 16, which is still -// larger than 8. We find the current largest value in the output dimension, -// which is L(register=1) = 8, and we set L(register=1) to 0: -// -// L(register=1) = 0 -// L(register=2) = 4 -// L(register=4) = 1 -// L(lane=1) = 2 -// L(lane=2) = 0 -// -// Now the output dimension of this layout has a size of 8, which is the desired -// size. Note that this method works only because the bases are powers of two, -// which is the case for DistributedLayouts If broadcastRegisters is false, we -// remove any register that's larger than the desired shape. In the example -// above we would have -// L(register=1) = 4 -// L(register=2) = 1 -// L(lane=1) = 2 -// L(lane=2) = 0 -LinearLayout -ensureLayoutNotLargerThan(const LinearLayout &layout, - const llvm::SmallDenseMap &shape, - bool broadcastRegisters = true); - -// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no -// smaller than shape[d]. Do this by increasing the size of the layout's inputs -// along its most-minor dimension ("register" for register layouts, "offset" for -// shared layouts). -// -// This function is invariant to the order of the layout's input dimensions, but -// it cares about the order of the output dims, which should be minor-to-major. -LinearLayout ensureLayoutNotSmallerThan( - const LinearLayout &layout, - const llvm::SmallDenseMap &shape); - // Dump information about which threads/registers contain each of the tensor // elements. void dumpLayout(RankedTensorType tensorType); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index a5f2ed8dc7..93723e2282 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -56,6 +56,7 @@ Right now, Triton implements two main classes of layouts: shared, and distribute code extraBaseClassDeclaration = [{ unsigned getTotalElemsPerThread(ArrayRef shape, Type eltTy) const; SmallVector getElemsPerThread(ArrayRef shape, Type eltTy) const; + ::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const; }]; } @@ -146,6 +147,7 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to let genVerifyDecl = 1; let skipDefaultBuilders = 1; } + //===----------------------------------------------------------------------===// // Shared Layout Encoding //===----------------------------------------------------------------------===// @@ -569,34 +571,6 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11}, }]; } -//===----------------------------------------------------------------------===// -// Linear Layout Encoding -//===----------------------------------------------------------------------===// - -def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> { - let mnemonic = "linear"; - - let description = [{ - See the docs in LinearLayout.h for the definition of linear layouts. - }]; - - let parameters = (ins "LinearLayout":$linearLayout); - - let extraClassDeclaration = extraDistributedDeclaration # [{ - SmallVector getContigPerThread() const; - SmallVector getOrder() const; - }]; - - let genVerifyDecl = 1; - // Example of assembly format: - // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], - // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], - // warp = [[16, 0], [32, 0]], - // block = []}> - let hasCustomAssemblyFormat = 1; -} - - //===----------------------------------------------------------------------===// // Blocked Layout Encoding //===----------------------------------------------------------------------===// diff --git a/include/triton/Tools/LinearLayout.h b/include/triton/Tools/LinearLayout.h index cfc4c0d13b..47e3fca79b 100644 --- a/include/triton/Tools/LinearLayout.h +++ b/include/triton/Tools/LinearLayout.h @@ -9,7 +9,6 @@ #include #include "mlir/IR/BuiltinAttributes.h" -#include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -433,7 +432,6 @@ class LinearLayout { // (e.g. by reshaping) then the order doesn't really affect anything. auto getInDimNames() const { return llvm::make_first_range(bases); } auto getOutDimNames() const { return llvm::make_first_range(outDims); } - auto getOutDimSizes() const { return llvm::make_second_range(outDims); } // Gets the position that this outDim occupies in getOutDimNames(). Asserts // if the dim is not present. @@ -695,7 +693,6 @@ class LinearLayout { return !(lhs == rhs); } bool equalIgnoringOutDimSizes(const LinearLayout &other) const; - friend size_t hash_value(const LinearLayout &layout); private: // Factory function that gracefully fails rather than asserts if the layout is diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index c6dc34ed62..2d06980809 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -397,9 +397,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion if (isa(layout)) { return true; } - if (isa(layout)) { - return true; - } if (auto slice = dyn_cast(layout)) { return layoutIsOK(slice.getParent()); } diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index b090670d95..4cea14f095 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -165,8 +165,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { Attribute srcLayout = srcTy.getEncoding(); Attribute dstLayout = dstTy.getEncoding(); if (isa(srcLayout) && - (isa(dstLayout) || + (isa( + dstLayout) || isSupportedDotOpLayout(dstTy))) { return lowerSharedToDistributed(op, adaptor, getTypeConverter(), rewriter); @@ -206,6 +206,7 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern { auto dstTy = op.getResult().getType(); auto dstShape = dstTy.getShape(); auto srcSharedLayout = cast(srcTy.getEncoding()); + auto dstLayout = dstTy.getEncoding(); assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) && "Unexpected rank of ConvertLayout(shared->distributed)"); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 07397063c5..7cde755873 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -318,9 +318,6 @@ SmallVector getOrder(Attribute layout) { if (auto sharedLayout = mlir::dyn_cast(layout)) { return llvm::to_vector(sharedLayout.getOrder()); } - if (auto linearLayout = mlir::dyn_cast(layout)) { - return linearLayout.getOrder(); - } llvm::report_fatal_error("Unimplemented usage of getOrder"); return {}; @@ -544,102 +541,6 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, return encoding; } -LinearLayout -ensureLayoutNotLargerThan(const LinearLayout &layout, - const llvm::SmallDenseMap &shape, - bool broadcastRegisters) { - assert(shape.size() == layout.getNumOutDims()); - if (shape.empty()) { - return layout; - } - MLIRContext *ctx = shape.begin()->first.getContext(); - - auto bases = layout.getBases(); - - auto kRegister = StringAttr::get(ctx, "register"); - std::set broadcastedDims; - - for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { - auto outDimName = outDim.value(); - int32_t actualSize = layout.getOutDimSize(outDimName); - int32_t desiredSize = shape.lookup(outDimName); - if (actualSize <= desiredSize) { - continue; - } - assert(actualSize % desiredSize == 0); - // - std::vector> sortedBases; - for (auto [inDimName, basis] : bases) { - for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { - auto outValue = basis[basisIdx][outDim.index()]; - if (outValue == 0) { - continue; - } - assert(llvm::isPowerOf2_32(outValue)); - sortedBases.emplace_back(inDimName, basisIdx, outValue); - } - } - // From the largest basis to the smallest. - llvm::sort(sortedBases, - [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); - for (auto [inDimName, basisIdx, outValue] : sortedBases) { - if (actualSize <= desiredSize) { - break; - } - if (!broadcastRegisters && inDimName == kRegister) { - broadcastedDims.insert(basisIdx); - } else { - bases[inDimName][basisIdx][outDim.index()] = 0; - } - actualSize >>= 1; - } - } - if (!broadcastRegisters) { - // Remove broadcasted registers - std::vector> newBasesRegister; - for (auto [idx, basis] : llvm::enumerate(bases[kRegister])) { - // Remove if it's broadcasted - if (broadcastedDims.find(idx) == broadcastedDims.end()) { - newBasesRegister.push_back(std::move(basis)); - } - } - bases[kRegister] = std::move(newBasesRegister); - } - - return LinearLayout(std::move(bases), - llvm::to_vector(layout.getOutDimNames())); -} - -// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no -// smaller than shape[d]. Do this by increasing the size of the layout's inputs -// along its most-minor dimension ("register" for register layouts, "offset" for -// shared layouts). -// -// This function is invariant to the order of the layout's input dimensions, but -// it cares about the order of the output dims, which should be minor-to-major. -LinearLayout ensureLayoutNotSmallerThan( - const LinearLayout &layout, - const llvm::SmallDenseMap &shape) { - assert(shape.size() == layout.getNumOutDims()); - if (shape.empty()) { - return layout; - } - - MLIRContext *ctx = shape.begin()->first.getContext(); - StringAttr kDim = *layout.getInDimNames().begin(); - assert(kDim == "register" || kDim == "offset"); - - LinearLayout ret = layout; - for (StringAttr outDimName : layout.getOutDimNames()) { - int32_t actualSize = layout.getOutDimSize(outDimName); - int32_t desiredSize = shape.lookup(outDimName); - assert(actualSize > desiredSize || desiredSize % actualSize == 0); - ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); - assert(ret.getOutDimSize(outDimName) >= desiredSize); - } - return ret; -} - } // namespace gpu } // namespace triton } // namespace mlir @@ -1354,360 +1255,6 @@ void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const { printer << "}>"; } -// FIXME Can we take the LinearLayout by const&? -LogicalResult -LinearEncodingAttr::verify(function_ref emitError, - LinearLayout linearLayout) { - // Example of LinearEncodingAttr - // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], - // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], - // warp = [[16, 0], [32, 0]], - // block = []}> - // The input dims must be {register, lane, warp, block} - // The output dims of the linear layout should be dim0..dim[rank-1] - - static const auto expectedInDims = - SmallVector({"register", "lane", "warp", "block"}); - for (const auto &[i, dims] : llvm::enumerate( - llvm::zip(linearLayout.getInDimNames(), expectedInDims))) { - const auto &[dim, expectedDimStr] = dims; - if (dim.str() != expectedDimStr) { - return emitError() << "Expected input dimension " << i << " to be '" - << expectedDimStr << "'. Got " << dim; - } - } - - // outDims are ['dim0', 'dim1', ...] - for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) { - if (dim.str() != ("dim" + llvm::Twine(i)).str()) { - return emitError() - << "Expected output dimensions to be ['dim0', 'dim1', ...]. Got " - << dim << " at position " << i; - } - } - - const auto &bases = linearLayout.getBases(); - auto nonZero = [](auto val) { return val != 0; }; - for (const auto &dimBases : llvm::make_second_range(bases)) { - if (!llvm::all_of(dimBases, [&](const auto &basis) { - return std::count_if(basis.begin(), basis.end(), nonZero) <= 1; - })) { - return emitError() - << "In a distributed layout, each base must move in at most one " - "dimension."; - } - } - - return success(); -} - -void LinearEncodingAttr::print(mlir::AsmPrinter &printer) const { - // We don't use the default implementation as it's a bit too verbose - // This prints in the following format that is shape agnostic, in the sense - // that we don't print explicitly the outShape of the LL - // We always assume LLs to be surjective - // <{register = [[0, 1], [8, 0], [0, 8], [64, 0]], - // lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]], - // warp = [[16, 0], [32, 0]], - // block = []}> - auto ll = getLinearLayout(); - printer << "<{" << join(ll.getBases(), ", ", [](const auto &base) { - return base.first.str() + " = " + "[" + - join(base.second, ", ", - [](const std::vector &vec) { - return "[" + join(vec, ", ") + "]"; - }) + - "]"; - }) << "}>"; -} - -Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) { - if (parser.parseLess().failed()) - return {}; - - DictionaryAttr dict; - if (parser.parseAttribute(dict).failed()) - return {}; - - if (parser.parseGreater().failed()) - return {}; - - LinearLayout::BasesT bases; - - // Parse the basis names in order (the order is relevant) - std::vector inDimNames = {"register", "lane", "warp", "block"}; - - for (const auto &inDimNameStr : inDimNames) { - auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr); - Attribute value = dict.get(inDimName); - - // Expecting an array of arrays - auto arrayOfArraysAttr = mlir::dyn_cast(value); - if (!arrayOfArraysAttr) { - parser.emitError(parser.getCurrentLocation(), - "Expected array of arrays for basis of '") - << inDimName.getValue() << "'"; - return {}; - } - - std::vector> inDimBases; - for (Attribute arrayAttr : arrayOfArraysAttr) { - auto intArrayAttr = mlir::dyn_cast(arrayAttr); - if (!intArrayAttr) { - parser.emitError(parser.getCurrentLocation(), - "Expected array of integers in basis for '") - << inDimName.getValue() << "'"; - return {}; - } - std::vector basis; - for (Attribute intAttr : intArrayAttr) { - auto intValueAttr = mlir::dyn_cast(intAttr); - if (!intValueAttr) { - parser.emitError(parser.getCurrentLocation(), - "Expected integer in basis for '") - << inDimName.getValue() << "'"; - return {}; - } - basis.push_back(intValueAttr.getInt()); - } - inDimBases.push_back(std::move(basis)); - } - bases[inDimName] = std::move(inDimBases); - } - size_t rank = 0; - for (const auto &basesDim : llvm::make_second_range(bases)) { - if (!basesDim.empty()) { - rank = basesDim[0].size(); - break; - } - } - - // To implement this we'd need to serialise the rank as well. - // We can do this if we ever need it - if (rank == 0) { - parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported"); - return {}; - } - - // Generate standared outDimNames (dim0, dim1, ...) - SmallVector outDimNames; - for (int i = 0; i < rank; ++i) { - outDimNames.push_back( - StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i))); - } - - // Create LinearLayout - LinearLayout linearLayout(std::move(bases), std::move(outDimNames)); - - // Create and return the LinearEncodingAttr - return parser.getChecked(parser.getContext(), - std::move(linearLayout)); -} - -SmallVector basesPerDim(const LinearLayout::BasesT &namedBases, - StringAttr dimName, size_t rank, - bool skipBroadcast = true) { - const auto &bases = namedBases.find(dimName)->second; - - if (bases.empty()) { - return SmallVector(rank, 1); - } - - SmallVector ret(rank, 1); - auto nonZero = [](auto val) { return val != 0; }; - int nonZeroIdx = -1; - for (const auto &basis : bases) { - auto it = std::find_if(basis.begin(), basis.end(), nonZero); - // Bases can have one or zero non-zero elements - // Skip a basis if it's broadcasting (all zeros) - // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) - if (it != basis.end()) { - nonZeroIdx = it - basis.begin(); - ret[nonZeroIdx] *= 2; - } else if (!skipBroadcast) { - // If we've seen a non-zero basis, we double the size of the previous dim - // This is just needed to count the CTAsPerCGA - assert(nonZeroIdx != -1); - ret[nonZeroIdx] *= 2; - } - } - return ret; -} - -SmallVector basesPerDim(const LinearLayout &ll, StringAttr dimName, - bool skipBroadcast = true) { - auto shapeIter = ll.getOutDimSizes(); - auto rank = std::distance(shapeIter.begin(), shapeIter.end()); - return basesPerDim(ll.getBases(), dimName, rank, skipBroadcast); -} - -SmallVector orderPerDim(const LinearLayout &ll, StringAttr dimName, - ArrayRef defaultOrder) { - const auto &bases = ll.getBases().find(dimName)->second; - llvm::SetVector order; - auto nonZero = [](auto val) { return val != 0; }; - for (const auto &basis : bases) { - // Bases can have one or zero non-zero elements - // Skip a basis if it's broadcasting (all zeros) - // e.g. warps for DotOperandEncodingAttr (see ampereDotToLinearLayout) - auto it = std::find_if(basis.begin(), basis.end(), nonZero); - if (it != basis.end()) { - auto i = it - basis.begin(); - order.insert(i); - } - } - // If any dim is missing, we add them in the defaultOrder - for (auto i : defaultOrder) { - order.insert(i); - } - return SmallVector(order.begin(), order.end()); -} - -// [Note. Divergence of methods wrt. legacy layouts] -// For smaller shapes where the CTATile is larger than the output -// tensor, some methods return different values than the legacy layouts. I think -// this is benign tho. An example: what is the the vector of `warpsPerCTA` if -// all the warps hold the same data? I think it should be [1, 1], even if we -// have 4 warps. But perhaps for this we have to add some masking in some -// places... We'll see -SmallVector LinearEncodingAttr::getRepOrder() const { - // This is not correct, but: - // - It happens to agree in most places with the legacy layout - // - getRepOrder does not make sense for LinearEncodingAttr as it already has - // the same shape as the tensor that uses it - return getOrder(); -} -SmallVector LinearEncodingAttr::getCTAsPerCGA() const { - // CTAs are split into an identity part (SplitNum) and a broadcast part - return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "block"), - /*skipBroadcast=*/false); -} -SmallVector LinearEncodingAttr::getCTAOrder() const { - return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "block"), - getOrder()); -} -SmallVector LinearEncodingAttr::getCTASplitNum() const { - return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "block")); -} -SmallVector LinearEncodingAttr::getWarpsPerCTA() const { - return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "warp")); -} -SmallVector LinearEncodingAttr::getWarpOrder() const { - return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "warp"), - getOrder()); -} -SmallVector LinearEncodingAttr::getThreadsPerWarp() const { - return basesPerDim(getLinearLayout(), StringAttr::get(getContext(), "lane")); -} -SmallVector LinearEncodingAttr::getThreadOrder() const { - return orderPerDim(getLinearLayout(), StringAttr::get(getContext(), "lane"), - getOrder()); -} -SmallVector LinearEncodingAttr::getSizePerThread() const { - auto rank = getRepOrder().size(); - auto ll = getLinearLayout(); - auto ctx = getContext(); - auto kRegister = StringAttr::get(ctx, "register"); - - // We canonicalize on the spot, as if we use CGAs the regs are not in - // canonical form The order is [reg, lane, warp, rep, block], so we first - // remove the blocks - llvm::SmallVector ctaShape; - for (auto [shape, cgaNum] : - llvm::zip(ll.getOutDimSizes(), getCTASplitNum())) { - ctaShape.push_back(shape / cgaNum); - } - LinearLayout::BasesT bases = ll.getBases(); - - llvm::SetVector reverseRepOrder; - auto nonZero = [](auto val) { return val != 0; }; - auto ®isters = bases[StringAttr::get(ctx, "register")]; - while (!registers.empty()) { - auto &basis = registers.back(); - auto it = std::find_if(basis.begin(), basis.end(), nonZero); - // If there's broadcasting (base == zeros) there are no more reps - if (it == basis.end()) { - break; - } - auto dim = it - basis.begin(); - reverseRepOrder.insert(dim); - // As soon as we stop finding reps, we stop - if (dim != reverseRepOrder.back() || 2 * basis[dim] != ctaShape[dim]) { - break; - } - ctaShape[dim] /= 2; - registers.pop_back(); - } - return basesPerDim(bases, kRegister, rank); -} - -SmallVector LinearEncodingAttr::getOrder() const { - auto rank = getLinearLayout().getNumOutDims(); - SmallVector order(rank); - // Choose [rank-1, rank-2, ... 0] as the default order in case - // there are dims that do not move in the register - // This order is as good as any really - std::iota(order.rbegin(), order.rend(), 0); - - return orderPerDim(getLinearLayout(), - StringAttr::get(getContext(), "register"), order); -} - -std::optional -LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { - auto ll = getLinearLayout(); - auto canonicalDims = llvm::to_vector(ll.getOutDimNames()); - llvm::SmallDenseMap namedShape; - llvm::SmallVector permutedDims; - for (auto dim : getRepOrder()) { - permutedDims.push_back(canonicalDims[dim]); - namedShape[canonicalDims[dim]] = shape[dim]; - } - ll = ll.transposeOuts(permutedDims); - ll = ensureLayoutNotSmallerThan(ll, namedShape); - ll = ensureLayoutNotLargerThan(ll, namedShape, /*broadcastRegisters=*/false); - ll = ll.transposeOuts(canonicalDims); - return ll; -} - -SmallVector -LinearEncodingAttr::getElemsPerThread(ArrayRef shape, Type) const { - // We can relax this assert by calling toLinearLayout rather than - // getLinearLayout - SmallVector shapeVec(shape.begin(), shape.end()); - assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes())); - auto ll = getLinearLayout(); - return basesPerDim(ll, StringAttr::get(getContext(), "register")); -} - -// Start of Selection -SmallVector LinearEncodingAttr::getContigPerThread() const { - auto ll = getLinearLayout(); - const auto ®s = - ll.getBases().find(StringAttr::get(getContext(), "register"))->second; - auto order = getOrder(); - auto rank = order.size(); - - SmallVector contig(rank, 1); - auto regIt = regs.begin(); - for (unsigned dim : order) { - std::vector basis(rank, 0); - basis[dim] = 1; - - while (regIt != regs.end() && *regIt == basis) { - contig[dim] *= 2; - basis[dim] *= 2; - ++regIt; - } - } - return contig; -} - -unsigned LinearEncodingAttr::getTotalElemsPerThread(ArrayRef shape, - Type eltTy) const { - return product(getElemsPerThread(shape, eltTy)); -} - //===----------------------------------------------------------------------===// // MMA encoding //===----------------------------------------------------------------------===// @@ -2584,9 +2131,6 @@ class TritonGPUOpAsmInterface : public OpAsmDialectInterface { } else if (auto blockedAttr = mlir::dyn_cast(attr)) { os << "blocked"; return AliasResult::FinalAlias; - } else if (auto linearAttr = mlir::dyn_cast(attr)) { - os << "linear"; - return AliasResult::FinalAlias; } else if (auto warpAttr = mlir::dyn_cast(attr)) { os << "warp"; return AliasResult::FinalAlias; diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 20280a6ddb..5c840d1dcb 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -53,6 +53,25 @@ SmallVector permuteDimNames(const SmallVector &names, return ret; } +void assertIsRegisterLayout(const LinearLayout &layout) { + assert(layout.getNumInDims() > 0); + MLIRContext *ctx = layout.getInDimNames().begin()->getContext(); + StringAttr kRegister = S("register"); + StringAttr kLane = S("lane"); + StringAttr kWarp = S("warp"); + StringAttr kBlock = S("block"); + + const auto &ins = layout.getInDimNames(); + assert(llvm::SmallVector(ins.begin(), ins.end()) == + llvm::SmallVector({kRegister, kLane, kWarp, kBlock})); + + const auto &outs = layout.getOutDimNames(); + const auto &expectedOuts = standardOutDimNames(ctx, layout.getNumOutDims()); + assert(llvm::SmallDenseSet(outs.begin(), outs.end()) == + llvm::SmallDenseSet(expectedOuts.begin(), + expectedOuts.end())); +} + // Returns a 1D -> ND layout into [dim0, dim1, ...] that's equivalent to // creating a 1D -> 1D mapping of size product(shape) and then reshaping to // permute(shape, order). @@ -103,6 +122,124 @@ LinearLayout makeCgaLayout(CTALayoutAttr layout) { return ret.transposeOuts(outDimNames); } +// For each output dimension d, ensure that the layout's output size (i.e., its +// codomain) does not exceed shape[d]. Do this without changing the size of the +// layout's inputs (i.e., leave its domain unchanged). +// +// This function is invariant to the order of the layout's input and output +// dimensions. +// +// We achieve this by setting the largest value in each output dimension d to 0 +// because bases that map to a location larger than shape[d] +// effectively duplicate along that dimension. For example, consider a layout +// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to +// shrink the output dimension size to 8: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 16 +// +// In the first step, we shrink the output dimension size to 16 by setting +// L(lane=2) to 0: +// +// L(register=1) = 8 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// This means that lane=2 has the same data as lane=0. +// +// Now the output dimension of this layout has a size of 16, which is still +// larger than 8. We find the current largest value in the output dimension, +// which is L(register=1) = 8, and we set L(register=1) to 0: +// +// L(register=1) = 0 +// L(register=2) = 4 +// L(register=4) = 1 +// L(lane=1) = 2 +// L(lane=2) = 0 +// +// Now the output dimension of this layout has a size of 8, which is the desired +// size. Note that this method works only because the bases are powers of two. +// It is unclear what to do when they are not. +LinearLayout ensureLayoutNotLargerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + MLIRContext *ctx = shape.begin()->first.getContext(); + + auto bases = layout.getBases(); + for (auto outDim : llvm::enumerate(layout.getOutDimNames())) { + auto outDimName = outDim.value(); + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + if (actualSize <= desiredSize) { + continue; + } + assert(actualSize % desiredSize == 0); + // + std::vector> sortedBases; + for (auto [inDimName, basis] : bases) { + for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) { + auto outValue = basis[basisIdx][outDim.index()]; + if (outValue == 0) { + continue; + } + assert(llvm::isPowerOf2_32(outValue)); + sortedBases.emplace_back(inDimName, basisIdx, outValue); + } + } + // From the largest basis to the smallest. + llvm::sort(sortedBases, + [](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); }); + for (auto [inDimName, basisIdx, outValue] : sortedBases) { + if (actualSize <= desiredSize) { + break; + } + bases[inDimName][basisIdx][outDim.index()] = 0; + actualSize >>= 1; + } + } + return LinearLayout(std::move(bases), + llvm::to_vector(layout.getOutDimNames())); +} + +// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no +// smaller than shape[d]. Do this by increasing the size of the layout's inputs +// along its most-minor dimension ("register" for register layouts, "offset" for +// shared layouts). +// +// This function is invariant to the order of the layout's input dimensions, but +// it cares about the order of the output dims, which should be minor-to-major. +LinearLayout ensureLayoutNotSmallerThan( + const LinearLayout &layout, + const llvm::SmallDenseMap &shape) { + assert(shape.size() == layout.getNumOutDims()); + if (shape.empty()) { + return layout; + } + + MLIRContext *ctx = shape.begin()->first.getContext(); + StringAttr kDim = *layout.getInDimNames().begin(); + assert(kDim == "register" || kDim == "offset"); + + LinearLayout ret = layout; + for (StringAttr outDimName : layout.getOutDimNames()) { + int32_t actualSize = layout.getOutDimSize(outDimName); + int32_t desiredSize = shape.lookup(outDimName); + assert(actualSize > desiredSize || desiredSize % actualSize == 0); + ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName); + assert(ret.getOutDimSize(outDimName) >= desiredSize); + } + return ret; +} + // Combines the layout of a CTA (input dims [register, lane, warp]) with the // layout of a CGA (i.e. a block), and ensures that the resulting layout has the // given shape. @@ -795,10 +932,10 @@ SliceEncodingAttr::toLinearLayout(ArrayRef shape) const { std::optional toLinearLayout(ArrayRef shape, Attribute layout, std::optional elemBitWidth /*= std::nullopt*/) { - // Layouts are distributed or shared if (auto distributed = dyn_cast(layout)) { return distributed.toLinearLayout(shape); - } else if (auto shared = dyn_cast(layout)) { + } + if (auto shared = dyn_cast(layout)) { if (shared.getHasLeadingOffset()) { assert(elemBitWidth.has_value()); return sharedToLinearLayoutLeadingOffset(shape, shared, *elemBitWidth); @@ -807,7 +944,7 @@ toLinearLayout(ArrayRef shape, Attribute layout, } } - // Third party layouts + // TODO(jlebar): Other layouts return std::nullopt; } diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 068965468e..233883964f 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -53,17 +53,15 @@ LogicalResult UpcastMXFPOp::verify() { if (!dotEncoding) { return emitOpError("Expected a DotOperandEncodingAttr for values"); } - if (!isa(layoutScale)) { - return emitOpError( - "Expected a BlockOperandEncoding or LinearOperandEncoding " - "for scales"); + auto blockedScale = dyn_cast(layoutScale); + if (!blockedScale) { + return emitOpError("Expected a BlockOperandEncoding for scales"); } if (isa(dotEncoding.getParent())) { // Necessary to keep all of the scales of a given block of values in the // same warp - auto threadsPerWarp = - cast(layoutScale).getThreadsPerWarp(); + auto threadsPerWarp = blockedScale.getThreadsPerWarp(); if (threadsPerWarp != ArrayRef({16, 2})) { return emitOpError("Expected threads per warp to be {16, 2}"); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index c07f314087..4dccc85da3 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -400,7 +400,6 @@ class DecomposeScaledBlocked "NYI: lhs supports fp4 or fp8"); assert(bType == ScaleDotElemType::E4M3 || bType == ScaleDotElemType::E5M2 || bType == ScaleDotElemType::BF16 && "NYI: rhs supports fp8 and bf16"); - bool isFp4 = aType == ScaleDotElemType::E2M1; auto mmaEnc = getMMAEncoding(rewriter, scaledDotOp); auto versionMajor = mmaEnc.getVersionMajor(); @@ -419,7 +418,7 @@ class DecomposeScaledBlocked // types auto aKWidth = mmaEnc.isHopper() ? 2 : 8; auto bKWidth = mmaEnc.isHopper() ? 2 : 8; - if (isFp4) { + if (aType == ScaleDotElemType::E2M1) { // Load 2x4-bit elements per thread aKWidth /= 2; } @@ -439,43 +438,9 @@ class DecomposeScaledBlocked // Necessary choice to leave all the scales of the tile in that given warp auto threadsPerWarp = SmallVector{instrShapeM, warpSize / instrShapeM}; - - assert(versionMajor == 2 && - "NYI: MMAv3. Need to rethink the scale layout otherwise"); - - // Copy the bases - - Attribute newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( + auto newScaleEncoding = triton::gpu::BlockedEncodingAttr::get( ctx, {1, 1}, threadsPerWarp, newAEncoding.getWarpsPerCTA(), newAEncoding.getCTAOrder(), mmaEnc.getCTALayout()); - - auto dotBroadcastsWarpLevel = mmaEnc.getWarpsPerCTA()[1] != 1; - if (dotBroadcastsWarpLevel) { - // If mma has warpsPerCTA == {2, 2}, then newAEncoding has - // warpsPerCTA == {2, 1}. In this case, we need to broadcast the warps - // on the second dimension as per - // A: 0 1 | 0 1 - // - - | - - - // 2 3 | 2 3 - // This broadcasting is not representable by standard blocked encodings, - // so we need to use linear layouts. - // This broadcasting is implemented in ampereDotToLinearLayout - auto blocked = cast(newScaleEncoding); - auto blockedLL = *blocked.toLinearLayout(a.getType().getShape()); - LinearLayout::BasesT scaleBases = blockedLL.getBases(); - auto nBases = llvm::Log2_32(mmaEnc.getWarpsPerCTA()[1]); - auto &warps = scaleBases[StringAttr::get(ctx, "warp")]; - // Prepend the vector of zeros to the warpBases - warps.insert(warps.begin(), nBases, std::vector(rank, 0)); - auto outDims = llvm::to_vector(blockedLL.getOutDimNames()); - auto newLL = LinearLayout(scaleBases, outDims); - auto llEncoding = LinearEncodingAttr::get(ctx, std::move(newLL)); - // Adjust the shape of the layout to match the scale operand - auto scaleShape = scale.getType().getShape(); - newScaleEncoding = - LinearEncodingAttr::get(ctx, *llEncoding.toLinearLayout(scaleShape)); - } - a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding); // Upcast B operand @@ -578,8 +543,7 @@ class DecomposeScaledBlocked auto dotOp = rewriter.create( scaledDotOp.getLoc(), scaledDotOp.getType(), a, b, scaledDotOp.getC()); - // Waiting for https://github.com/triton-lang/triton/pull/5003 to land - // cf. + // FIXME Waiting on the following comment to be fixed: // https://github.com/triton-lang/triton/pull/5003#issuecomment-2445091746 // int versionMajor = getMMAVersionSafe(computeCapability, dotOp); int versionMajor = 2; @@ -595,8 +559,10 @@ class DecomposeScaledBlocked versionMajor, retShapePerCTA, dotOp.getA().getType().getElementType(), numWarps); - auto warpsPerCTA = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, - numWarps, instrShape); + // FIXME Waiting on supporting LLs on convert_layout + // auto warpsPerCTA = getWarpsPerTile(dotOp, retShapePerCTA, versionMajor, + // numWarps, instrShape); + SmallVector warpsPerCTA = {(unsigned)numWarps, 1}; return NvidiaMmaEncodingAttr::get(ctx, versionMajor, versionMinor, warpsPerCTA, CTALayout, instrShape); } diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 70f5219111..bc4049dc30 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -970,9 +970,7 @@ void LayoutRematerialization::backwardRematerialization( // we don't handle conversions to DotOperandEncodingAttr // this is a heuristic to accommodate fused attention RankedTensorType targetType = convertOp.getType(); - // We stop the rematerialization of linear layouts as we have to be a bit more - // careful with the heuristics for both correctness and perf - if (isa(targetType.getEncoding())) + if (isa(targetType.getEncoding())) return; Value oldV = convertOp->getOperand(0); LDBG("check backward remat with source " << oldV << " encoding " @@ -1014,11 +1012,8 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention - // We stop the rematerialization of linear layouts as we have to be a bit more - // careful with the heuristics for both correctness and perf RankedTensorType targetType = convertOp.getType(); - if (mlir::isa( - targetType.getEncoding())) + if (mlir::isa(targetType.getEncoding())) return; auto isExtOrBroadcastOp = [](Operation *op) { diff --git a/lib/Tools/LinearLayout.cpp b/lib/Tools/LinearLayout.cpp index 3a81231ac8..4319d1f086 100644 --- a/lib/Tools/LinearLayout.cpp +++ b/lib/Tools/LinearLayout.cpp @@ -212,6 +212,42 @@ void assertCommonDimsSameOrder(T &&aDims, U &&bDims) { "\nb: " + triton::join(bDims, ", ")); } } + +void eraseEmptyInOutDims(BasesT &bases, + llvm::MapVector &outDims) { + // Erase empty out-dims. + SmallVector emptyOutDims; + for (auto [i, outDim] : llvm::enumerate( + llvm::to_vector_of(llvm::make_first_range(outDims)))) { + if (outDims[outDim] == 1) { + emptyOutDims.push_back(i); + outDims.erase(outDim); + } + } + if (outDims.empty()) { + bases.clear(); + return; + } + + for (auto &[inDim, inDimBases] : bases) { + for (auto &basis : inDimBases) { + // Erase the basis elements corresponding to the empty out-dims. + for (int i : llvm::reverse(emptyOutDims)) { + basis.erase(basis.begin() + i); + } + } + } + + // Erase empty in-dims. + // TODO: This needs a test-case. + for (StringAttr inDim : + llvm::to_vector_of(llvm::make_first_range(bases))) { + if (bases[inDim].empty()) { + bases.erase(inDim); + } + } +} + } // anonymous namespace /*static*/ std::optional @@ -953,30 +989,6 @@ LinearLayout::getFreeVariableMasks() const { return ret; } -size_t hash_value(const LinearLayout &layout) { - size_t seed = 0; - - // Hash the bases - for (const auto &base : layout.getBases()) { - // Hash the input dimension name - seed = llvm::hash_combine(seed, base.first); - - // Hash the vectors in bases - for (const auto &vec : base.second) { - for (int32_t val : vec) { - seed = llvm::hash_combine(seed, val); - } - } - } - - // Hash the output dimensions and their sizes - for (const auto &outDim : layout.getOutDimNames()) { - seed = llvm::hash_combine(seed, outDim, layout.getOutDimSize(outDim)); - } - // Don't hash the surjective flag as it's a cached property - return seed; -} - bool operator==(LinearLayout lhs, LinearLayout rhs) { if (!lhs.equalIgnoringOutDimSizes(rhs)) return false; diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 648a29c34f..703d379bf5 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -168,16 +168,14 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}> #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}> -// CHECK: #[[LINEAR:.+]] = #triton_gpu.linear<{{.*}}> module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK: dot_scaled + // CHECK-LABEL: dot_scaled tt.func @dot_scaled( %a: tensor<128x32xi8, #blocked2>, %scale: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xbf16, #blocked>) -> tensor<128x128xf32, #blocked> { - // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #[[LINEAR]]> - // CHECK: triton_gpu.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #triton_gpu.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #triton_gpu.dot_op<{{.*}}>> + // CHECK: triton_gpu.upcast_mxfp // CHECK: tt.dot %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> %result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> diff --git a/test/TritonGPU/ops.mlir b/test/TritonGPU/ops.mlir index 70c1a315e7..9184a53120 100644 --- a/test/TritonGPU/ops.mlir +++ b/test/TritonGPU/ops.mlir @@ -33,17 +33,3 @@ module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i tt.return } } -// ----- - -#blocked= #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> -// CHECK: #[[$LINEAR:.*]] = #triton_gpu.linear<{{.*}}> - -module attributes {"triton_gpu.target" = "cuda:0", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - // CHECK-LABEL: @blocked_to_linear - tt.func @blocked_to_linear(%input: tensor<32x4xi8, #blocked>) { - // The layout is the basic layout generated by DecomposeScaledBlocked - %output = triton_gpu.convert_layout %input {allocation.offset = 0 : i32} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #triton_gpu.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>> - // CHECK: %{{.+}} = triton_gpu.convert_layout %{{.+}} : tensor<32x4xi8, #blocked> -> tensor<32x4xi8, #[[$LINEAR]]> - tt.return - } -} diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index 779bc1b788..c27c63335e 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -620,135 +620,7 @@ TEST_F(AMDMfmaLayoutTest, mfma_dot_op) { ASSERT_THAT(tdot3dOp1.getWarpOrder(), tmfma3d.getWarpOrder()); } -class LinearEncodingTest : public ::testing::Test { -public: - LinearEncodingTest() { ctx.getOrLoadDialect(); } - -protected: - MLIRContext ctx; -}; - -TEST_F(LinearEncodingTest, DistributedEncodingToLinearEncoding) { - // Define a tensor shape - auto rank = 2; - SmallVector> shapes = {{64, 128}, {256, 1024}}; - SmallVector> orders = {{0, 1}, {1, 0}}; - SmallVector ctaLayouts = { - triton::gpu::CTALayoutAttr::getDefault(&ctx, rank), - triton::gpu::CTALayoutAttr::get(&ctx, {4, 2}, {2, 2}, {1, 0}), - }; - SmallVector distributedEncodings; - - // Create BlockedEncodingAttr and SliceEncodingAttr - { - SmallVector sizePerThread = {4, 4}; - SmallVector threadsPerWarp = {4, 8}; - SmallVector warpsPerCTA = {2, 2}; - - for (auto ctaLayout : ctaLayouts) { - for (const auto &order : orders) { - auto blockedEncoding = triton::gpu::BlockedEncodingAttr::get( - &ctx, sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); - distributedEncodings.push_back(blockedEncoding); - distributedEncodings.push_back( - triton::gpu::SliceEncodingAttr::get(&ctx, 0, blockedEncoding)); - } - } - } - - // Create an MMAv2 and DotOperandEncodingAttr (MMAv3 doesn't support linear - // layouts yet) - { - unsigned versionMajor = 2; - unsigned versionMinor = 0; - SmallVector warpsPerCTA{4, 2}; - SmallVector instrShape{16, 8}; // Instruction shape (M, N) - auto mma = triton::gpu::NvidiaMmaEncodingAttr::get( - &ctx, versionMajor, versionMinor, warpsPerCTA, ctaLayouts[0], - instrShape); - distributedEncodings.push_back(mma); - // Create an opIdx=0 and opIdx=1 encoding - for (unsigned opIdx = 0; opIdx < 2; ++opIdx) { - distributedEncodings.push_back( - triton::gpu::DotOperandEncodingAttr::get(&ctx, opIdx, mma, 2)); - } - } - - for (const auto &distributedEncoding : distributedEncodings) { - for (auto shape : shapes) { - if (auto sliceEncoding = - dyn_cast(distributedEncoding)) { - shape.erase(shape.begin() + sliceEncoding.getDim()); - } - - // Create LinearEncodingAttr from the LinearLayout - auto linearLayout = *distributedEncoding.toLinearLayout(shape); - auto linearEncoding = - triton::gpu::LinearEncodingAttr::get(&ctx, linearLayout); - - // Test that the canonical form of the LinearLayout is indeed canonical - // by expanding it to the original shape - auto expandedLL = linearEncoding.toLinearLayout(shape); - ASSERT_EQ(linearLayout, expandedLL); - - // Test that methods of DistributedEncoding return the same values - Type eltTy = FloatType::getF32(&ctx); - - ASSERT_EQ(getOrder(distributedEncoding), linearEncoding.getRepOrder()); - ASSERT_EQ(cast(distributedEncoding) - .getTotalElemsPerThread(shape, eltTy), - linearEncoding.getTotalElemsPerThread(shape, eltTy)); - ASSERT_EQ(cast(distributedEncoding) - .getElemsPerThread(shape, eltTy), - linearEncoding.getElemsPerThread(shape, eltTy)); - ASSERT_EQ(distributedEncoding.getRepOrder(), - linearEncoding.getRepOrder()); - ASSERT_EQ(distributedEncoding.getContigPerThread(), - linearEncoding.getContigPerThread()); - // DotOperandEncodingAttr::getWarpOrder() is not defined - if (!isa(distributedEncoding)) { - ASSERT_EQ(distributedEncoding.getWarpOrder(), - linearEncoding.getWarpOrder()); - } - ASSERT_EQ(distributedEncoding.getThreadOrder(), - linearEncoding.getThreadOrder()); - // For slice these do not equal the total number of lines / warps - // See [Note. Divergence of methods wrt. legacy layouts] - if (!isa(distributedEncoding)) { - ASSERT_EQ(distributedEncoding.getWarpsPerCTA(), - linearEncoding.getWarpsPerCTA()); - ASSERT_EQ(distributedEncoding.getThreadsPerWarp(), - linearEncoding.getThreadsPerWarp()); - } - // Canonicalisation for opIdx=0 takes just a [2 x 2] subtile as it takes - // the second repetition along K as the second tile. - if (!isa(distributedEncoding)) { - // FIXME: This happens to be correct for SliceLayout because of the hack - // in SliceEncodingAttr::toLinearLayout(). We should remove the hack - // and the skips in the getWarpsPerCTA() and getThreadsPerWarp() - ASSERT_EQ(distributedEncoding.getSizePerThread(), - linearEncoding.getSizePerThread()); - } - - // block level - // SliceEncoding is not well-defined for CGAs - if (!isa(distributedEncoding)) { - ASSERT_EQ(distributedEncoding.getCTASplitNum(), - linearEncoding.getCTASplitNum()); - ASSERT_EQ(distributedEncoding.getCTAsPerCGA(), - linearEncoding.getCTAsPerCGA()); - // If we are not using CGAs, the order is meaningless - auto useCGA = distributedEncoding.getCTAsPerCGA() != - SmallVector(rank, 1); - if (useCGA) { - ASSERT_EQ(distributedEncoding.getCTAOrder(), - linearEncoding.getCTAOrder()); - } - } - } - } -} -} // namespace +} // anonymous namespace } // namespace mlir::triton::gpu int main(int argc, char *argv[]) { From 325fefa0e97e1a93911ae41e267e0aba5af99764 Mon Sep 17 00:00:00 2001 From: Pavel Chekin Date: Thu, 21 Nov 2024 16:55:03 -0800 Subject: [PATCH 12/13] Build PyTorch with USE_STATIC_MKL=1 (#2792) Build PyTorch with static MKL to minimize the number of runtime dependencies. Fixes #2791. --- .github/actions/setup-pytorch/action.yml | 4 ++-- scripts/compile-pytorch-ipex.sh | 2 +- scripts/install-pytorch.sh | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/actions/setup-pytorch/action.yml b/.github/actions/setup-pytorch/action.yml index b6cc894722..88f482f78a 100644 --- a/.github/actions/setup-pytorch/action.yml +++ b/.github/actions/setup-pytorch/action.yml @@ -83,7 +83,7 @@ runs: uses: ./.github/actions/load env: # Increase this value to reset cache - CACHE_NUMBER: 12 + CACHE_NUMBER: 14 with: path: pytorch key: pytorch-$PYTORCH_CACHE_KEY-$CACHE_NUMBER @@ -120,7 +120,7 @@ runs: cd pytorch pip install wheel pip install -r requirements.txt - python setup.py bdist_wheel + USE_STATIC_MKL=1 python setup.py bdist_wheel - name: Install PyTorch (built from source) if: ${{ inputs.mode == 'source' }} diff --git a/scripts/compile-pytorch-ipex.sh b/scripts/compile-pytorch-ipex.sh index 7c5a41f6c8..d753abe113 100755 --- a/scripts/compile-pytorch-ipex.sh +++ b/scripts/compile-pytorch-ipex.sh @@ -117,7 +117,7 @@ if [[ $BUILD_PYTORCH = true ]]; then echo "****** Building $PYTORCH_PROJ ******" pip install -r requirements.txt pip install cmake ninja "numpy<2.0" - python setup.py bdist_wheel + USE_STATIC_MKL=1 python setup.py bdist_wheel echo "****** Installing PyTorch ******" pip install dist/*.whl diff --git a/scripts/install-pytorch.sh b/scripts/install-pytorch.sh index a8a0f2b83a..74b3ac5158 100755 --- a/scripts/install-pytorch.sh +++ b/scripts/install-pytorch.sh @@ -155,7 +155,7 @@ $SCRIPTS_DIR/patch-pytorch.sh echo "****** Building $PYTORCH_PROJ ******" pip install -r requirements.txt pip install cmake ninja -python setup.py bdist_wheel +USE_STATIC_MKL=1 python setup.py bdist_wheel echo "****** Installing PyTorch ******" pip install dist/*.whl From 816d7efe1dcfdf0810c91fd324a578781b89b574 Mon Sep 17 00:00:00 2001 From: chengjunlu Date: Fri, 22 Nov 2024 10:08:52 +0800 Subject: [PATCH 13/13] [XPU][TritonIntelGPU] Fix issue in convert DotOp A layout to LinearLayout of DPAS. (#2766) The code was using the old definition of the layout for DPAS operand A when converting to LinearLayout. Update the code to the new layout which is supported by the OCL interface. Enable the DotOp A to LinearLayout conversion. --------- Co-authored-by: Julian Oppermann --- bin/CMakeLists.txt | 1 + bin/triton-tensor-layout.cpp | 12 +- test/Conversion/intel/dot_layout_offset.mlir | 312 +++++++------- ...ritonintelgpu-convert-layout-shortcut.mlir | 380 ++++++++++++------ .../tritonintlgpu-nested-layout.mlir | 137 +++---- .../IR/LinearLayoutConversions.cpp | 59 +-- .../TritonGPU/DPAStoLinearLayoutTest.cpp | 52 ++- 7 files changed, 560 insertions(+), 393 deletions(-) diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index b32e533b64..aa1293bd49 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -102,6 +102,7 @@ add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCE target_link_libraries(triton-tensor-layout PRIVATE TritonGPUIR TritonNvidiaGPUIR + TritonIntelGPUIR ${triton_libs} ${conversion_libs} ${dialect_libs} diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index 4087ac1350..b330cfb5aa 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -80,17 +80,9 @@ static cl::opt TensorStr( //===--------------------------------------------------------------------===// LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { - StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace(); - // Dispatch to the corresponding dialect helper function to print the layout. - if (dialectName == "triton_gpu") { - os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); - return success(); - } - - llvm::errs() << "Unsupported tensor layout attribute: " - << tensorType.getEncoding() << "\n"; - return failure(); + os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); + return success(); } LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename, diff --git a/test/Conversion/intel/dot_layout_offset.mlir b/test/Conversion/intel/dot_layout_offset.mlir index 92129848d0..09615f4252 100644 --- a/test/Conversion/intel/dot_layout_offset.mlir +++ b/test/Conversion/intel/dot_layout_offset.mlir @@ -12,317 +12,307 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // CHECK: %[[THREAD_ID_I64:.*]] = llvm.call spir_funccc @_Z12get_local_idj // CHECK: %[[THREAD_ID_I32:.*]] = llvm.trunc %[[THREAD_ID_I64]] : i64 to i32 // CHECK: %[[VAL_145:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_145]] : i32 // CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_I32]], %[[VAL_145]] : i32 - // CHECK: %[[VAL_147:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_I32]], %[[VAL_147]] : i32 + // CHECK-COUNT-3: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[WARP_ID_N:.*]] = llvm.urem %[[WARP_ID]], %[[VAL_149]] : i32 - // CHECK: %[[VAL_151:.*]] = llvm.udiv %[[WARP_ID]], %[[VAL_149]] : i32 + // CHECK: %[[VAL_150:.*]] = llvm.and %[[LANE_ID]], %[[VAL_149]] : i32 + // CHECK: %[[VAL_151:.*]] = llvm.icmp "eq" %[[VAL_150]], %[[CST_0]] : i32 // CHECK: %[[VAL_152:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[WARP_ID_M:.*]] = llvm.urem %[[VAL_151]], %[[VAL_152]] : i32 - // CHECK: %[[VAL_154:.*]] = llvm.udiv %[[VAL_151]], %[[VAL_152]] : i32 + // CHECK: %[[VAL_153:.*]] = llvm.select %[[VAL_151]], %[[CST_0]], %[[VAL_152]] : i1, i32 + // CHECK: %[[VAL_154:.*]] = llvm.xor %[[CST_0]], %[[VAL_153]] : i32 // CHECK: %[[VAL_155:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK: %[[ROUNDED_WARP_ID_M:.*]] = llvm.urem %[[WARP_ID_M]], %[[VAL_155]] : i32 - // CHECK: %[[warpShape_M:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[warpOffset:.*]] = llvm.mul %[[ROUNDED_WARP_ID_M]], %[[warpShape_M]] : i32 - // CHECK: %[[VAL_159:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[laneRowIndex:.*]] = llvm.udiv %[[LANE_ID]], %[[VAL_159]] : i32 - // CHECK: %[[VAL_161:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_162:.*]] = llvm.urem %[[LANE_ID]], %[[VAL_161]] : i32 - // CHECK: %[[VAL_163:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[multiDimBase_N:.*]] = llvm.mul %[[VAL_162]], %[[VAL_163]] : i32 - // CHECK: %[[multiDimBase_M:.*]] = llvm.add %[[laneRowIndex]], %[[warpOffset]] : i32 - // CHECK: %[[VAL_166:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_167:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_168:.*]] = llvm.urem %[[VAL_166]], %[[VAL_167]] : i32 - // CHECK: %[[VAL_169:.*]] = llvm.udiv %[[VAL_166]], %[[VAL_167]] : i32 - // CHECK: %[[VAL_170:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_171:.*]] = llvm.urem %[[VAL_169]], %[[VAL_170]] : i32 - // CHECK: %[[VAL_172:.*]] = llvm.udiv %[[VAL_169]], %[[VAL_170]] : i32 - // CHECK: %[[VAL_173:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_174:.*]] = llvm.urem %[[VAL_171]], %[[VAL_173]] : i32 - // CHECK: %[[VAL_175:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_176:.*]] = llvm.urem %[[VAL_168]], %[[VAL_175]] : i32 - // CHECK: %[[VAL_177:.*]] = llvm.mlir.constant(32 : i32) : i32 - // CHECK: %[[CTAOffset_M:.*]] = llvm.mul %[[VAL_174]], %[[VAL_177]] : i32 - // CHECK: %[[VAL_179:.*]] = llvm.mlir.constant(32 : i32) : i32 - // CHECK: %[[CTAOffset_N:.*]] = llvm.mul %[[VAL_176]], %[[VAL_179]] : i32 - // CHECK: %[[VAL_181:.*]] = llvm.add %[[multiDimBase_M]], %[[CTAOffset_M]] : i32 - // CHECK: %[[VAL_182:.*]] = llvm.add %[[multiDimBase_N]], %[[CTAOffset_N]] : i32 + // CHECK: %[[VAL_156:.*]] = llvm.and %[[LANE_ID]], %[[VAL_155]] : i32 + // CHECK: %[[VAL_157:.*]] = llvm.icmp "eq" %[[VAL_156]], %[[CST_0]] : i32 + // CHECK: %[[VAL_158:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[VAL_159:.*]] = llvm.select %[[VAL_157]], %[[CST_0]], %[[VAL_158]] : i1, i32 + // CHECK: %[[VAL_160:.*]] = llvm.xor %[[VAL_154]], %[[VAL_159]] : i32 + // CHECK: %[[VAL_161:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[VAL_162:.*]] = llvm.and %[[LANE_ID]], %[[VAL_161]] : i32 + // CHECK: %[[VAL_163:.*]] = llvm.icmp "eq" %[[VAL_162]], %[[CST_0]] : i32 + // CHECK: %[[VAL_164:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: %[[VAL_165:.*]] = llvm.select %[[VAL_163]], %[[CST_0]], %[[VAL_164]] : i1, i32 + // CHECK: %[[VAL_182:.*]] = llvm.xor %[[VAL_160]], %[[VAL_165]] : i32 + // CHECK: %[[VAL_167:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAL_168:.*]] = llvm.and %[[LANE_ID]], %[[VAL_167]] : i32 + // CHECK: %[[VAL_169:.*]] = llvm.icmp "eq" %[[VAL_168]], %[[CST_0]] : i32 + // CHECK: %[[VAL_170:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK: %[[VAL_171:.*]] = llvm.select %[[VAL_169]], %[[CST_0]], %[[VAL_170]] : i1, i32 + // CHECK: %[[VAL_181:.*]] = llvm.xor %[[VAL_182]], %[[VAL_171]] : i32 // COM: There are total [4, 2] repetitions of tensor shape [32, 32] per warp. // COM: The repetitions are clustered as [2, 1] for A operand. The repetitions orders are [0, 0], [1, 0], [0, 1], [1, 1], [2, 0], [3, 0], [2, 1], [3, 1] // COM: Offsets of rep [0, 0]. // CHECK: %[[VAL_183:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_184:.*]] = llvm.add %[[VAL_181]], %[[VAL_183]] : i32 + // CHECK: %[[VAL_184:.*]] = llvm.xor %[[CST_0]], %[[VAL_183]] : i32 // CHECK: %[[VAL_185:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_186:.*]] = llvm.add %[[VAL_182]], %[[VAL_185]] : i32 + // CHECK: %[[VAL_186:.*]] = llvm.xor %[[VAL_181]], %[[VAL_185]] : i32 // CHECK: %[[VAL_187:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_188:.*]] = llvm.add %[[VAL_181]], %[[VAL_187]] : i32 + // CHECK: %[[VAL_188:.*]] = llvm.xor %[[CST_0]], %[[VAL_187]] : i32 // CHECK: %[[VAL_189:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_190:.*]] = llvm.add %[[VAL_182]], %[[VAL_189]] : i32 + // CHECK: %[[VAL_190:.*]] = llvm.xor %[[VAL_181]], %[[VAL_189]] : i32 // CHECK: %[[VAL_191:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK: %[[VAL_192:.*]] = llvm.add %[[VAL_181]], %[[VAL_191]] : i32 + // CHECK: %[[VAL_192:.*]] = llvm.xor %[[CST_0]], %[[VAL_191]] : i32 // CHECK: %[[VAL_193:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_194:.*]] = llvm.add %[[VAL_182]], %[[VAL_193]] : i32 + // CHECK: %[[VAL_194:.*]] = llvm.xor %[[VAL_181]], %[[VAL_193]] : i32 // CHECK: %[[VAL_195:.*]] = llvm.mlir.constant(3 : i32) : i32 - // CHECK: %[[VAL_196:.*]] = llvm.add %[[VAL_181]], %[[VAL_195]] : i32 + // CHECK: %[[VAL_196:.*]] = llvm.xor %[[CST_0]], %[[VAL_195]] : i32 // CHECK: %[[VAL_197:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_198:.*]] = llvm.add %[[VAL_182]], %[[VAL_197]] : i32 + // CHECK: %[[VAL_198:.*]] = llvm.xor %[[VAL_181]], %[[VAL_197]] : i32 // CHECK: %[[VAL_199:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAL_200:.*]] = llvm.add %[[VAL_181]], %[[VAL_199]] : i32 + // CHECK: %[[VAL_200:.*]] = llvm.xor %[[CST_0]], %[[VAL_199]] : i32 // CHECK: %[[VAL_201:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_202:.*]] = llvm.add %[[VAL_182]], %[[VAL_201]] : i32 + // CHECK: %[[VAL_202:.*]] = llvm.xor %[[VAL_181]], %[[VAL_201]] : i32 // CHECK: %[[VAL_203:.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: %[[VAL_204:.*]] = llvm.add %[[VAL_181]], %[[VAL_203]] : i32 + // CHECK: %[[VAL_204:.*]] = llvm.xor %[[CST_0]], %[[VAL_203]] : i32 // CHECK: %[[VAL_205:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_206:.*]] = llvm.add %[[VAL_182]], %[[VAL_205]] : i32 + // CHECK: %[[VAL_206:.*]] = llvm.xor %[[VAL_181]], %[[VAL_205]] : i32 // CHECK: %[[VAL_207:.*]] = llvm.mlir.constant(6 : i32) : i32 - // CHECK: %[[VAL_208:.*]] = llvm.add %[[VAL_181]], %[[VAL_207]] : i32 + // CHECK: %[[VAL_208:.*]] = llvm.xor %[[CST_0]], %[[VAL_207]] : i32 // CHECK: %[[VAL_209:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_210:.*]] = llvm.add %[[VAL_182]], %[[VAL_209]] : i32 + // CHECK: %[[VAL_210:.*]] = llvm.xor %[[VAL_181]], %[[VAL_209]] : i32 // CHECK: %[[VAL_211:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_212:.*]] = llvm.add %[[VAL_181]], %[[VAL_211]] : i32 + // CHECK: %[[VAL_212:.*]] = llvm.xor %[[CST_0]], %[[VAL_211]] : i32 // CHECK: %[[VAL_213:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_214:.*]] = llvm.add %[[VAL_182]], %[[VAL_213]] : i32 + // CHECK: %[[VAL_214:.*]] = llvm.xor %[[VAL_181]], %[[VAL_213]] : i32 // COM: Offsets of rep [1, 0]. // CHECK: %[[VAL_215:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_216:.*]] = llvm.add %[[VAL_181]], %[[VAL_215]] : i32 + // CHECK: %[[VAL_216:.*]] = llvm.xor %[[CST_0]], %[[VAL_215]] : i32 // CHECK: %[[VAL_217:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_218:.*]] = llvm.add %[[VAL_182]], %[[VAL_217]] : i32 + // CHECK: %[[VAL_218:.*]] = llvm.xor %[[VAL_181]], %[[VAL_217]] : i32 // CHECK: %[[VAL_219:.*]] = llvm.mlir.constant(9 : i32) : i32 - // CHECK: %[[VAL_220:.*]] = llvm.add %[[VAL_181]], %[[VAL_219]] : i32 + // CHECK: %[[VAL_220:.*]] = llvm.xor %[[CST_0]], %[[VAL_219]] : i32 // CHECK: %[[VAL_221:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_222:.*]] = llvm.add %[[VAL_182]], %[[VAL_221]] : i32 + // CHECK: %[[VAL_222:.*]] = llvm.xor %[[VAL_181]], %[[VAL_221]] : i32 // CHECK: %[[VAL_223:.*]] = llvm.mlir.constant(10 : i32) : i32 - // CHECK: %[[VAL_224:.*]] = llvm.add %[[VAL_181]], %[[VAL_223]] : i32 + // CHECK: %[[VAL_224:.*]] = llvm.xor %[[CST_0]], %[[VAL_223]] : i32 // CHECK: %[[VAL_225:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_226:.*]] = llvm.add %[[VAL_182]], %[[VAL_225]] : i32 + // CHECK: %[[VAL_226:.*]] = llvm.xor %[[VAL_181]], %[[VAL_225]] : i32 // CHECK: %[[VAL_227:.*]] = llvm.mlir.constant(11 : i32) : i32 - // CHECK: %[[VAL_228:.*]] = llvm.add %[[VAL_181]], %[[VAL_227]] : i32 + // CHECK: %[[VAL_228:.*]] = llvm.xor %[[CST_0]], %[[VAL_227]] : i32 // CHECK: %[[VAL_229:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_230:.*]] = llvm.add %[[VAL_182]], %[[VAL_229]] : i32 + // CHECK: %[[VAL_230:.*]] = llvm.xor %[[VAL_181]], %[[VAL_229]] : i32 // CHECK: %[[VAL_231:.*]] = llvm.mlir.constant(12 : i32) : i32 - // CHECK: %[[VAL_232:.*]] = llvm.add %[[VAL_181]], %[[VAL_231]] : i32 + // CHECK: %[[VAL_232:.*]] = llvm.xor %[[CST_0]], %[[VAL_231]] : i32 // CHECK: %[[VAL_233:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_234:.*]] = llvm.add %[[VAL_182]], %[[VAL_233]] : i32 + // CHECK: %[[VAL_234:.*]] = llvm.xor %[[VAL_181]], %[[VAL_233]] : i32 // CHECK: %[[VAL_235:.*]] = llvm.mlir.constant(13 : i32) : i32 - // CHECK: %[[VAL_236:.*]] = llvm.add %[[VAL_181]], %[[VAL_235]] : i32 + // CHECK: %[[VAL_236:.*]] = llvm.xor %[[CST_0]], %[[VAL_235]] : i32 // CHECK: %[[VAL_237:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_238:.*]] = llvm.add %[[VAL_182]], %[[VAL_237]] : i32 + // CHECK: %[[VAL_238:.*]] = llvm.xor %[[VAL_181]], %[[VAL_237]] : i32 // CHECK: %[[VAL_239:.*]] = llvm.mlir.constant(14 : i32) : i32 - // CHECK: %[[VAL_240:.*]] = llvm.add %[[VAL_181]], %[[VAL_239]] : i32 + // CHECK: %[[VAL_240:.*]] = llvm.xor %[[CST_0]], %[[VAL_239]] : i32 // CHECK: %[[VAL_241:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_242:.*]] = llvm.add %[[VAL_182]], %[[VAL_241]] : i32 + // CHECK: %[[VAL_242:.*]] = llvm.xor %[[VAL_181]], %[[VAL_241]] : i32 // CHECK: %[[VAL_243:.*]] = llvm.mlir.constant(15 : i32) : i32 - // CHECK: %[[VAL_244:.*]] = llvm.add %[[VAL_181]], %[[VAL_243]] : i32 + // CHECK: %[[VAL_244:.*]] = llvm.xor %[[CST_0]], %[[VAL_243]] : i32 // CHECK: %[[VAL_245:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_246:.*]] = llvm.add %[[VAL_182]], %[[VAL_245]] : i32 + // CHECK: %[[VAL_246:.*]] = llvm.xor %[[VAL_181]], %[[VAL_245]] : i32 // COM: Offsets of rep [0, 1]. // CHECK: %[[VAL_247:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_248:.*]] = llvm.add %[[VAL_181]], %[[VAL_247]] : i32 + // CHECK: %[[VAL_248:.*]] = llvm.xor %[[CST_0]], %[[VAL_247]] : i32 // CHECK: %[[VAL_249:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_250:.*]] = llvm.add %[[VAL_182]], %[[VAL_249]] : i32 + // CHECK: %[[VAL_250:.*]] = llvm.xor %[[VAL_181]], %[[VAL_249]] : i32 // CHECK: %[[VAL_251:.*]] = llvm.mlir.constant(1 : i32) : i32 - // CHECK: %[[VAL_252:.*]] = llvm.add %[[VAL_181]], %[[VAL_251]] : i32 + // CHECK: %[[VAL_252:.*]] = llvm.xor %[[CST_0]], %[[VAL_251]] : i32 // CHECK: %[[VAL_253:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_254:.*]] = llvm.add %[[VAL_182]], %[[VAL_253]] : i32 + // CHECK: %[[VAL_254:.*]] = llvm.xor %[[VAL_181]], %[[VAL_253]] : i32 // CHECK: %[[VAL_255:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK: %[[VAL_256:.*]] = llvm.add %[[VAL_181]], %[[VAL_255]] : i32 + // CHECK: %[[VAL_256:.*]] = llvm.xor %[[CST_0]], %[[VAL_255]] : i32 // CHECK: %[[VAL_257:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_258:.*]] = llvm.add %[[VAL_182]], %[[VAL_257]] : i32 + // CHECK: %[[VAL_258:.*]] = llvm.xor %[[VAL_181]], %[[VAL_257]] : i32 // CHECK: %[[VAL_259:.*]] = llvm.mlir.constant(3 : i32) : i32 - // CHECK: %[[VAL_260:.*]] = llvm.add %[[VAL_181]], %[[VAL_259]] : i32 + // CHECK: %[[VAL_260:.*]] = llvm.xor %[[CST_0]], %[[VAL_259]] : i32 // CHECK: %[[VAL_261:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_262:.*]] = llvm.add %[[VAL_182]], %[[VAL_261]] : i32 + // CHECK: %[[VAL_262:.*]] = llvm.xor %[[VAL_181]], %[[VAL_261]] : i32 // CHECK: %[[VAL_263:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK: %[[VAL_264:.*]] = llvm.add %[[VAL_181]], %[[VAL_263]] : i32 + // CHECK: %[[VAL_264:.*]] = llvm.xor %[[CST_0]], %[[VAL_263]] : i32 // CHECK: %[[VAL_265:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_266:.*]] = llvm.add %[[VAL_182]], %[[VAL_265]] : i32 + // CHECK: %[[VAL_266:.*]] = llvm.xor %[[VAL_181]], %[[VAL_265]] : i32 // CHECK: %[[VAL_267:.*]] = llvm.mlir.constant(5 : i32) : i32 - // CHECK: %[[VAL_268:.*]] = llvm.add %[[VAL_181]], %[[VAL_267]] : i32 + // CHECK: %[[VAL_268:.*]] = llvm.xor %[[CST_0]], %[[VAL_267]] : i32 // CHECK: %[[VAL_269:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_270:.*]] = llvm.add %[[VAL_182]], %[[VAL_269]] : i32 + // CHECK: %[[VAL_270:.*]] = llvm.xor %[[VAL_181]], %[[VAL_269]] : i32 // CHECK: %[[VAL_271:.*]] = llvm.mlir.constant(6 : i32) : i32 - // CHECK: %[[VAL_272:.*]] = llvm.add %[[VAL_181]], %[[VAL_271]] : i32 + // CHECK: %[[VAL_272:.*]] = llvm.xor %[[CST_0]], %[[VAL_271]] : i32 // CHECK: %[[VAL_273:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_274:.*]] = llvm.add %[[VAL_182]], %[[VAL_273]] : i32 + // CHECK: %[[VAL_274:.*]] = llvm.xor %[[VAL_181]], %[[VAL_273]] : i32 // CHECK: %[[VAL_275:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_276:.*]] = llvm.add %[[VAL_181]], %[[VAL_275]] : i32 + // CHECK: %[[VAL_276:.*]] = llvm.xor %[[CST_0]], %[[VAL_275]] : i32 // CHECK: %[[VAL_277:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_278:.*]] = llvm.add %[[VAL_182]], %[[VAL_277]] : i32 + // CHECK: %[[VAL_278:.*]] = llvm.xor %[[VAL_181]], %[[VAL_277]] : i32 // COM: Offsets of rep [1, 1]. // CHECK: %[[VAL_279:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK: %[[VAL_280:.*]] = llvm.add %[[VAL_181]], %[[VAL_279]] : i32 + // CHECK: %[[VAL_280:.*]] = llvm.xor %[[CST_0]], %[[VAL_279]] : i32 // CHECK: %[[VAL_281:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_282:.*]] = llvm.add %[[VAL_182]], %[[VAL_281]] : i32 + // CHECK: %[[VAL_282:.*]] = llvm.xor %[[VAL_181]], %[[VAL_281]] : i32 // CHECK: %[[VAL_283:.*]] = llvm.mlir.constant(9 : i32) : i32 - // CHECK: %[[VAL_284:.*]] = llvm.add %[[VAL_181]], %[[VAL_283]] : i32 + // CHECK: %[[VAL_284:.*]] = llvm.xor %[[CST_0]], %[[VAL_283]] : i32 // CHECK: %[[VAL_285:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_286:.*]] = llvm.add %[[VAL_182]], %[[VAL_285]] : i32 + // CHECK: %[[VAL_286:.*]] = llvm.xor %[[VAL_181]], %[[VAL_285]] : i32 // CHECK: %[[VAL_287:.*]] = llvm.mlir.constant(10 : i32) : i32 - // CHECK: %[[VAL_288:.*]] = llvm.add %[[VAL_181]], %[[VAL_287]] : i32 + // CHECK: %[[VAL_288:.*]] = llvm.xor %[[CST_0]], %[[VAL_287]] : i32 // CHECK: %[[VAL_289:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_290:.*]] = llvm.add %[[VAL_182]], %[[VAL_289]] : i32 + // CHECK: %[[VAL_290:.*]] = llvm.xor %[[VAL_181]], %[[VAL_289]] : i32 // CHECK: %[[VAL_291:.*]] = llvm.mlir.constant(11 : i32) : i32 - // CHECK: %[[VAL_292:.*]] = llvm.add %[[VAL_181]], %[[VAL_291]] : i32 + // CHECK: %[[VAL_292:.*]] = llvm.xor %[[CST_0]], %[[VAL_291]] : i32 // CHECK: %[[VAL_293:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_294:.*]] = llvm.add %[[VAL_182]], %[[VAL_293]] : i32 + // CHECK: %[[VAL_294:.*]] = llvm.xor %[[VAL_181]], %[[VAL_293]] : i32 // CHECK: %[[VAL_295:.*]] = llvm.mlir.constant(12 : i32) : i32 - // CHECK: %[[VAL_296:.*]] = llvm.add %[[VAL_181]], %[[VAL_295]] : i32 + // CHECK: %[[VAL_296:.*]] = llvm.xor %[[CST_0]], %[[VAL_295]] : i32 // CHECK: %[[VAL_297:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_298:.*]] = llvm.add %[[VAL_182]], %[[VAL_297]] : i32 + // CHECK: %[[VAL_298:.*]] = llvm.xor %[[VAL_181]], %[[VAL_297]] : i32 // CHECK: %[[VAL_299:.*]] = llvm.mlir.constant(13 : i32) : i32 - // CHECK: %[[VAL_300:.*]] = llvm.add %[[VAL_181]], %[[VAL_299]] : i32 + // CHECK: %[[VAL_300:.*]] = llvm.xor %[[CST_0]], %[[VAL_299]] : i32 // CHECK: %[[VAL_301:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_302:.*]] = llvm.add %[[VAL_182]], %[[VAL_301]] : i32 + // CHECK: %[[VAL_302:.*]] = llvm.xor %[[VAL_181]], %[[VAL_301]] : i32 // CHECK: %[[VAL_303:.*]] = llvm.mlir.constant(14 : i32) : i32 - // CHECK: %[[VAL_304:.*]] = llvm.add %[[VAL_181]], %[[VAL_303]] : i32 + // CHECK: %[[VAL_304:.*]] = llvm.xor %[[CST_0]], %[[VAL_303]] : i32 // CHECK: %[[VAL_305:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_306:.*]] = llvm.add %[[VAL_182]], %[[VAL_305]] : i32 + // CHECK: %[[VAL_306:.*]] = llvm.xor %[[VAL_181]], %[[VAL_305]] : i32 // CHECK: %[[VAL_307:.*]] = llvm.mlir.constant(15 : i32) : i32 - // CHECK: %[[VAL_308:.*]] = llvm.add %[[VAL_181]], %[[VAL_307]] : i32 + // CHECK: %[[VAL_308:.*]] = llvm.xor %[[CST_0]], %[[VAL_307]] : i32 // CHECK: %[[VAL_309:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_310:.*]] = llvm.add %[[VAL_182]], %[[VAL_309]] : i32 + // CHECK: %[[VAL_310:.*]] = llvm.xor %[[VAL_181]], %[[VAL_309]] : i32 // COM: Offsets of rep [2, 0]. // CHECK: %[[VAL_311:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_312:.*]] = llvm.add %[[VAL_181]], %[[VAL_311]] : i32 + // CHECK: %[[VAL_312:.*]] = llvm.xor %[[CST_0]], %[[VAL_311]] : i32 // CHECK: %[[VAL_313:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_314:.*]] = llvm.add %[[VAL_182]], %[[VAL_313]] : i32 + // CHECK: %[[VAL_314:.*]] = llvm.xor %[[VAL_181]], %[[VAL_313]] : i32 // CHECK: %[[VAL_315:.*]] = llvm.mlir.constant(17 : i32) : i32 - // CHECK: %[[VAL_316:.*]] = llvm.add %[[VAL_181]], %[[VAL_315]] : i32 + // CHECK: %[[VAL_316:.*]] = llvm.xor %[[CST_0]], %[[VAL_315]] : i32 // CHECK: %[[VAL_317:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_318:.*]] = llvm.add %[[VAL_182]], %[[VAL_317]] : i32 + // CHECK: %[[VAL_318:.*]] = llvm.xor %[[VAL_181]], %[[VAL_317]] : i32 // CHECK: %[[VAL_319:.*]] = llvm.mlir.constant(18 : i32) : i32 - // CHECK: %[[VAL_320:.*]] = llvm.add %[[VAL_181]], %[[VAL_319]] : i32 + // CHECK: %[[VAL_320:.*]] = llvm.xor %[[CST_0]], %[[VAL_319]] : i32 // CHECK: %[[VAL_321:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_322:.*]] = llvm.add %[[VAL_182]], %[[VAL_321]] : i32 + // CHECK: %[[VAL_322:.*]] = llvm.xor %[[VAL_181]], %[[VAL_321]] : i32 // CHECK: %[[VAL_323:.*]] = llvm.mlir.constant(19 : i32) : i32 - // CHECK: %[[VAL_324:.*]] = llvm.add %[[VAL_181]], %[[VAL_323]] : i32 + // CHECK: %[[VAL_324:.*]] = llvm.xor %[[CST_0]], %[[VAL_323]] : i32 // CHECK: %[[VAL_325:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_326:.*]] = llvm.add %[[VAL_182]], %[[VAL_325]] : i32 + // CHECK: %[[VAL_326:.*]] = llvm.xor %[[VAL_181]], %[[VAL_325]] : i32 // CHECK: %[[VAL_327:.*]] = llvm.mlir.constant(20 : i32) : i32 - // CHECK: %[[VAL_328:.*]] = llvm.add %[[VAL_181]], %[[VAL_327]] : i32 + // CHECK: %[[VAL_328:.*]] = llvm.xor %[[CST_0]], %[[VAL_327]] : i32 // CHECK: %[[VAL_329:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_330:.*]] = llvm.add %[[VAL_182]], %[[VAL_329]] : i32 + // CHECK: %[[VAL_330:.*]] = llvm.xor %[[VAL_181]], %[[VAL_329]] : i32 // CHECK: %[[VAL_331:.*]] = llvm.mlir.constant(21 : i32) : i32 - // CHECK: %[[VAL_332:.*]] = llvm.add %[[VAL_181]], %[[VAL_331]] : i32 + // CHECK: %[[VAL_332:.*]] = llvm.xor %[[CST_0]], %[[VAL_331]] : i32 // CHECK: %[[VAL_333:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_334:.*]] = llvm.add %[[VAL_182]], %[[VAL_333]] : i32 + // CHECK: %[[VAL_334:.*]] = llvm.xor %[[VAL_181]], %[[VAL_333]] : i32 // CHECK: %[[VAL_335:.*]] = llvm.mlir.constant(22 : i32) : i32 - // CHECK: %[[VAL_336:.*]] = llvm.add %[[VAL_181]], %[[VAL_335]] : i32 + // CHECK: %[[VAL_336:.*]] = llvm.xor %[[CST_0]], %[[VAL_335]] : i32 // CHECK: %[[VAL_337:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_338:.*]] = llvm.add %[[VAL_182]], %[[VAL_337]] : i32 + // CHECK: %[[VAL_338:.*]] = llvm.xor %[[VAL_181]], %[[VAL_337]] : i32 // CHECK: %[[VAL_339:.*]] = llvm.mlir.constant(23 : i32) : i32 - // CHECK: %[[VAL_340:.*]] = llvm.add %[[VAL_181]], %[[VAL_339]] : i32 + // CHECK: %[[VAL_340:.*]] = llvm.xor %[[CST_0]], %[[VAL_339]] : i32 // CHECK: %[[VAL_341:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_342:.*]] = llvm.add %[[VAL_182]], %[[VAL_341]] : i32 + // CHECK: %[[VAL_342:.*]] = llvm.xor %[[VAL_181]], %[[VAL_341]] : i32 // COM: Offsets of rep [3, 0]. // CHECK: %[[VAL_343:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_344:.*]] = llvm.add %[[VAL_181]], %[[VAL_343]] : i32 + // CHECK: %[[VAL_344:.*]] = llvm.xor %[[CST_0]], %[[VAL_343]] : i32 // CHECK: %[[VAL_345:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_346:.*]] = llvm.add %[[VAL_182]], %[[VAL_345]] : i32 + // CHECK: %[[VAL_346:.*]] = llvm.xor %[[VAL_181]], %[[VAL_345]] : i32 // CHECK: %[[VAL_347:.*]] = llvm.mlir.constant(25 : i32) : i32 - // CHECK: %[[VAL_348:.*]] = llvm.add %[[VAL_181]], %[[VAL_347]] : i32 + // CHECK: %[[VAL_348:.*]] = llvm.xor %[[CST_0]], %[[VAL_347]] : i32 // CHECK: %[[VAL_349:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_350:.*]] = llvm.add %[[VAL_182]], %[[VAL_349]] : i32 + // CHECK: %[[VAL_350:.*]] = llvm.xor %[[VAL_181]], %[[VAL_349]] : i32 // CHECK: %[[VAL_351:.*]] = llvm.mlir.constant(26 : i32) : i32 - // CHECK: %[[VAL_352:.*]] = llvm.add %[[VAL_181]], %[[VAL_351]] : i32 + // CHECK: %[[VAL_352:.*]] = llvm.xor %[[CST_0]], %[[VAL_351]] : i32 // CHECK: %[[VAL_353:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_354:.*]] = llvm.add %[[VAL_182]], %[[VAL_353]] : i32 + // CHECK: %[[VAL_354:.*]] = llvm.xor %[[VAL_181]], %[[VAL_353]] : i32 // CHECK: %[[VAL_355:.*]] = llvm.mlir.constant(27 : i32) : i32 - // CHECK: %[[VAL_356:.*]] = llvm.add %[[VAL_181]], %[[VAL_355]] : i32 + // CHECK: %[[VAL_356:.*]] = llvm.xor %[[CST_0]], %[[VAL_355]] : i32 // CHECK: %[[VAL_357:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_358:.*]] = llvm.add %[[VAL_182]], %[[VAL_357]] : i32 + // CHECK: %[[VAL_358:.*]] = llvm.xor %[[VAL_181]], %[[VAL_357]] : i32 // CHECK: %[[VAL_359:.*]] = llvm.mlir.constant(28 : i32) : i32 - // CHECK: %[[VAL_360:.*]] = llvm.add %[[VAL_181]], %[[VAL_359]] : i32 + // CHECK: %[[VAL_360:.*]] = llvm.xor %[[CST_0]], %[[VAL_359]] : i32 // CHECK: %[[VAL_361:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_362:.*]] = llvm.add %[[VAL_182]], %[[VAL_361]] : i32 + // CHECK: %[[VAL_362:.*]] = llvm.xor %[[VAL_181]], %[[VAL_361]] : i32 // CHECK: %[[VAL_363:.*]] = llvm.mlir.constant(29 : i32) : i32 - // CHECK: %[[VAL_364:.*]] = llvm.add %[[VAL_181]], %[[VAL_363]] : i32 + // CHECK: %[[VAL_364:.*]] = llvm.xor %[[CST_0]], %[[VAL_363]] : i32 // CHECK: %[[VAL_365:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_366:.*]] = llvm.add %[[VAL_182]], %[[VAL_365]] : i32 + // CHECK: %[[VAL_366:.*]] = llvm.xor %[[VAL_181]], %[[VAL_365]] : i32 // CHECK: %[[VAL_367:.*]] = llvm.mlir.constant(30 : i32) : i32 - // CHECK: %[[VAL_368:.*]] = llvm.add %[[VAL_181]], %[[VAL_367]] : i32 + // CHECK: %[[VAL_368:.*]] = llvm.xor %[[CST_0]], %[[VAL_367]] : i32 // CHECK: %[[VAL_369:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_370:.*]] = llvm.add %[[VAL_182]], %[[VAL_369]] : i32 + // CHECK: %[[VAL_370:.*]] = llvm.xor %[[VAL_181]], %[[VAL_369]] : i32 // CHECK: %[[VAL_371:.*]] = llvm.mlir.constant(31 : i32) : i32 - // CHECK: %[[VAL_372:.*]] = llvm.add %[[VAL_181]], %[[VAL_371]] : i32 + // CHECK: %[[VAL_372:.*]] = llvm.xor %[[CST_0]], %[[VAL_371]] : i32 // CHECK: %[[VAL_373:.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: %[[VAL_374:.*]] = llvm.add %[[VAL_182]], %[[VAL_373]] : i32 + // CHECK: %[[VAL_374:.*]] = llvm.xor %[[VAL_181]], %[[VAL_373]] : i32 // COM: Offsets of rep [2, 1]. // CHECK: %[[VAL_375:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_376:.*]] = llvm.add %[[VAL_181]], %[[VAL_375]] : i32 + // CHECK: %[[VAL_376:.*]] = llvm.xor %[[CST_0]], %[[VAL_375]] : i32 // CHECK: %[[VAL_377:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_378:.*]] = llvm.add %[[VAL_182]], %[[VAL_377]] : i32 + // CHECK: %[[VAL_378:.*]] = llvm.xor %[[VAL_181]], %[[VAL_377]] : i32 // CHECK: %[[VAL_379:.*]] = llvm.mlir.constant(17 : i32) : i32 - // CHECK: %[[VAL_380:.*]] = llvm.add %[[VAL_181]], %[[VAL_379]] : i32 + // CHECK: %[[VAL_380:.*]] = llvm.xor %[[CST_0]], %[[VAL_379]] : i32 // CHECK: %[[VAL_381:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_382:.*]] = llvm.add %[[VAL_182]], %[[VAL_381]] : i32 + // CHECK: %[[VAL_382:.*]] = llvm.xor %[[VAL_181]], %[[VAL_381]] : i32 // CHECK: %[[VAL_383:.*]] = llvm.mlir.constant(18 : i32) : i32 - // CHECK: %[[VAL_384:.*]] = llvm.add %[[VAL_181]], %[[VAL_383]] : i32 + // CHECK: %[[VAL_384:.*]] = llvm.xor %[[CST_0]], %[[VAL_383]] : i32 // CHECK: %[[VAL_385:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_386:.*]] = llvm.add %[[VAL_182]], %[[VAL_385]] : i32 + // CHECK: %[[VAL_386:.*]] = llvm.xor %[[VAL_181]], %[[VAL_385]] : i32 // CHECK: %[[VAL_387:.*]] = llvm.mlir.constant(19 : i32) : i32 - // CHECK: %[[VAL_388:.*]] = llvm.add %[[VAL_181]], %[[VAL_387]] : i32 + // CHECK: %[[VAL_388:.*]] = llvm.xor %[[CST_0]], %[[VAL_387]] : i32 // CHECK: %[[VAL_389:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_390:.*]] = llvm.add %[[VAL_182]], %[[VAL_389]] : i32 + // CHECK: %[[VAL_390:.*]] = llvm.xor %[[VAL_181]], %[[VAL_389]] : i32 // CHECK: %[[VAL_391:.*]] = llvm.mlir.constant(20 : i32) : i32 - // CHECK: %[[VAL_392:.*]] = llvm.add %[[VAL_181]], %[[VAL_391]] : i32 + // CHECK: %[[VAL_392:.*]] = llvm.xor %[[CST_0]], %[[VAL_391]] : i32 // CHECK: %[[VAL_393:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_394:.*]] = llvm.add %[[VAL_182]], %[[VAL_393]] : i32 + // CHECK: %[[VAL_394:.*]] = llvm.xor %[[VAL_181]], %[[VAL_393]] : i32 // CHECK: %[[VAL_395:.*]] = llvm.mlir.constant(21 : i32) : i32 - // CHECK: %[[VAL_396:.*]] = llvm.add %[[VAL_181]], %[[VAL_395]] : i32 + // CHECK: %[[VAL_396:.*]] = llvm.xor %[[CST_0]], %[[VAL_395]] : i32 // CHECK: %[[VAL_397:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_398:.*]] = llvm.add %[[VAL_182]], %[[VAL_397]] : i32 + // CHECK: %[[VAL_398:.*]] = llvm.xor %[[VAL_181]], %[[VAL_397]] : i32 // CHECK: %[[VAL_399:.*]] = llvm.mlir.constant(22 : i32) : i32 - // CHECK: %[[VAL_400:.*]] = llvm.add %[[VAL_181]], %[[VAL_399]] : i32 + // CHECK: %[[VAL_400:.*]] = llvm.xor %[[CST_0]], %[[VAL_399]] : i32 // CHECK: %[[VAL_401:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_402:.*]] = llvm.add %[[VAL_182]], %[[VAL_401]] : i32 + // CHECK: %[[VAL_402:.*]] = llvm.xor %[[VAL_181]], %[[VAL_401]] : i32 // CHECK: %[[VAL_403:.*]] = llvm.mlir.constant(23 : i32) : i32 - // CHECK: %[[VAL_404:.*]] = llvm.add %[[VAL_181]], %[[VAL_403]] : i32 + // CHECK: %[[VAL_404:.*]] = llvm.xor %[[CST_0]], %[[VAL_403]] : i32 // CHECK: %[[VAL_405:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_406:.*]] = llvm.add %[[VAL_182]], %[[VAL_405]] : i32 + // CHECK: %[[VAL_406:.*]] = llvm.xor %[[VAL_181]], %[[VAL_405]] : i32 // COM: Offsets of rep [2, 2]. // CHECK: %[[VAL_407:.*]] = llvm.mlir.constant(24 : i32) : i32 - // CHECK: %[[VAL_408:.*]] = llvm.add %[[VAL_181]], %[[VAL_407]] : i32 + // CHECK: %[[VAL_408:.*]] = llvm.xor %[[CST_0]], %[[VAL_407]] : i32 // CHECK: %[[VAL_409:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_410:.*]] = llvm.add %[[VAL_182]], %[[VAL_409]] : i32 + // CHECK: %[[VAL_410:.*]] = llvm.xor %[[VAL_181]], %[[VAL_409]] : i32 // CHECK: %[[VAL_411:.*]] = llvm.mlir.constant(25 : i32) : i32 - // CHECK: %[[VAL_412:.*]] = llvm.add %[[VAL_181]], %[[VAL_411]] : i32 + // CHECK: %[[VAL_412:.*]] = llvm.xor %[[CST_0]], %[[VAL_411]] : i32 // CHECK: %[[VAL_413:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_414:.*]] = llvm.add %[[VAL_182]], %[[VAL_413]] : i32 + // CHECK: %[[VAL_414:.*]] = llvm.xor %[[VAL_181]], %[[VAL_413]] : i32 // CHECK: %[[VAL_415:.*]] = llvm.mlir.constant(26 : i32) : i32 - // CHECK: %[[VAL_416:.*]] = llvm.add %[[VAL_181]], %[[VAL_415]] : i32 + // CHECK: %[[VAL_416:.*]] = llvm.xor %[[CST_0]], %[[VAL_415]] : i32 // CHECK: %[[VAL_417:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_418:.*]] = llvm.add %[[VAL_182]], %[[VAL_417]] : i32 + // CHECK: %[[VAL_418:.*]] = llvm.xor %[[VAL_181]], %[[VAL_417]] : i32 // CHECK: %[[VAL_419:.*]] = llvm.mlir.constant(27 : i32) : i32 - // CHECK: %[[VAL_420:.*]] = llvm.add %[[VAL_181]], %[[VAL_419]] : i32 + // CHECK: %[[VAL_420:.*]] = llvm.xor %[[CST_0]], %[[VAL_419]] : i32 // CHECK: %[[VAL_421:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_422:.*]] = llvm.add %[[VAL_182]], %[[VAL_421]] : i32 + // CHECK: %[[VAL_422:.*]] = llvm.xor %[[VAL_181]], %[[VAL_421]] : i32 // CHECK: %[[VAL_423:.*]] = llvm.mlir.constant(28 : i32) : i32 - // CHECK: %[[VAL_424:.*]] = llvm.add %[[VAL_181]], %[[VAL_423]] : i32 + // CHECK: %[[VAL_424:.*]] = llvm.xor %[[CST_0]], %[[VAL_423]] : i32 // CHECK: %[[VAL_425:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_426:.*]] = llvm.add %[[VAL_182]], %[[VAL_425]] : i32 + // CHECK: %[[VAL_426:.*]] = llvm.xor %[[VAL_181]], %[[VAL_425]] : i32 // CHECK: %[[VAL_427:.*]] = llvm.mlir.constant(29 : i32) : i32 - // CHECK: %[[VAL_428:.*]] = llvm.add %[[VAL_181]], %[[VAL_427]] : i32 + // CHECK: %[[VAL_428:.*]] = llvm.xor %[[CST_0]], %[[VAL_427]] : i32 // CHECK: %[[VAL_429:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_430:.*]] = llvm.add %[[VAL_182]], %[[VAL_429]] : i32 + // CHECK: %[[VAL_430:.*]] = llvm.xor %[[VAL_181]], %[[VAL_429]] : i32 // CHECK: %[[VAL_431:.*]] = llvm.mlir.constant(30 : i32) : i32 - // CHECK: %[[VAL_432:.*]] = llvm.add %[[VAL_181]], %[[VAL_431]] : i32 + // CHECK: %[[VAL_432:.*]] = llvm.xor %[[CST_0]], %[[VAL_431]] : i32 // CHECK: %[[VAL_433:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_434:.*]] = llvm.add %[[VAL_182]], %[[VAL_433]] : i32 + // CHECK: %[[VAL_434:.*]] = llvm.xor %[[VAL_181]], %[[VAL_433]] : i32 // CHECK: %[[VAL_435:.*]] = llvm.mlir.constant(31 : i32) : i32 - // CHECK: %[[VAL_436:.*]] = llvm.add %[[VAL_181]], %[[VAL_435]] : i32 + // CHECK: %[[VAL_436:.*]] = llvm.xor %[[CST_0]], %[[VAL_435]] : i32 // CHECK: %[[VAL_437:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: %[[VAL_438:.*]] = llvm.add %[[VAL_182]], %[[VAL_437]] : i32 + // CHECK: %[[VAL_438:.*]] = llvm.xor %[[VAL_181]], %[[VAL_437]] : i32 tt.print " x: " {hex = false, isSigned = array} : %cst : tensor<32x32xf16, #dot_operand_a> tt.return } diff --git a/test/TritonIntelGPU/tritonintelgpu-convert-layout-shortcut.mlir b/test/TritonIntelGPU/tritonintelgpu-convert-layout-shortcut.mlir index 7bfff4fc36..48c9850418 100644 --- a/test/TritonIntelGPU/tritonintelgpu-convert-layout-shortcut.mlir +++ b/test/TritonIntelGPU/tritonintelgpu-convert-layout-shortcut.mlir @@ -6,46 +6,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 // CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<({{.*}})>) attributes {triton_gen.intel_reqd_sub_group_size = [16 : i32], triton_gen.max_work_group_size = [512 : i32, 1 : i32, 1 : i32]} { tt.func public @convert_dpas_to_dot_rep_cluster_1_2(%arg: tensor<1024x32xf16, #dpas>) { // COM: The repetitions order of dot layout and dpas layout are same when the GEMM tiling is clustered as repCluster [1, 2]. - // CHECK: %[[VAL_81:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_0:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_81]] : i32] : vector<8xf16> - // CHECK: %[[VAL_98:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_1:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_98]] : i32] : vector<8xf16> - // CHECK: %[[VAL_115:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_2:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_115]] : i32] : vector<8xf16> - // CHECK: %[[VAL_132:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_3:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_132]] : i32] : vector<8xf16> - // CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_4:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_149]] : i32] : vector<8xf16> - // CHECK: %[[VAL_166:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_5:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_166]] : i32] : vector<8xf16> - // CHECK: %[[VAL_183:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_6:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_183]] : i32] : vector<8xf16> - // CHECK: %[[VAL_200:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_7:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_200]] : i32] : vector<8xf16> - // CHECK: %[[VAL_216:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_217:.*]] = llvm.extractelement %[[REP_0]]{{\[}}%[[VAL_216]] : i32] : vector<8xf16> - // CHECK: %[[VAL_232:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_233:.*]] = llvm.extractelement %[[REP_1]]{{\[}}%[[VAL_232]] : i32] : vector<8xf16> - // CHECK: %[[VAL_248:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_249:.*]] = llvm.extractelement %[[REP_2]]{{\[}}%[[VAL_248]] : i32] : vector<8xf16> - // CHECK: %[[VAL_264:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_265:.*]] = llvm.extractelement %[[REP_3]]{{\[}}%[[VAL_264]] : i32] : vector<8xf16> - // CHECK: %[[VAL_280:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_281:.*]] = llvm.extractelement %[[REP_4]]{{\[}}%[[VAL_280]] : i32] : vector<8xf16> - // CHECK: %[[VAL_296:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_297:.*]] = llvm.extractelement %[[REP_5]]{{\[}}%[[VAL_296]] : i32] : vector<8xf16> - // CHECK: %[[VAL_312:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_313:.*]] = llvm.extractelement %[[REP_6]]{{\[}}%[[VAL_312]] : i32] : vector<8xf16> - // CHECK: %[[VAL_328:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_329:.*]] = llvm.extractelement %[[REP_7]]{{\[}}%[[VAL_328]] : i32] : vector<8xf16> - // CHECK: %[[VAL_338:.*]] = llvm.insertvalue %[[VAL_217]], {{.*}}[7] - // CHECK: %[[VAL_346:.*]] = llvm.insertvalue %[[VAL_233]], {{.*}}[15] - // CHECK: %[[VAL_354:.*]] = llvm.insertvalue %[[VAL_249]], {{.*}}[23] - // CHECK: %[[VAL_362:.*]] = llvm.insertvalue %[[VAL_265]], {{.*}}[31] - // CHECK: %[[VAL_370:.*]] = llvm.insertvalue %[[VAL_281]], {{.*}}[39] - // CHECK: %[[VAL_378:.*]] = llvm.insertvalue %[[VAL_297]], {{.*}}[47] - // CHECK: %[[VAL_386:.*]] = llvm.insertvalue %[[VAL_313]], {{.*}}[55] - // CHECK: %[[VAL_394:.*]] = llvm.insertvalue %[[VAL_329]], {{.*}}[63] + // CHECK-NO: llvm.insertvalue + // CHECK-NO: llvm.extractvalue %108 = triton_gpu.convert_layout %arg : tensor<1024x32xf16, #dpas> -> tensor<1024x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>> tt.return } @@ -62,46 +24,135 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 // COM: - 0, 1, 2, 3, 4, 5, 6, 7. // COM: The repetitions order of dot layout when the GEMM tiling is clustered as repCluster [2, 2]: // COM: - 0, 2, 1, 3, 4, 6, 5, 7. - // CHECK: %[[VAL_81:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_0:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_81]] : i32] : vector<8xf16> - // CHECK: %[[VAL_98:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_1:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_98]] : i32] : vector<8xf16> - // CHECK: %[[VAL_115:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_2:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_115]] : i32] : vector<8xf16> - // CHECK: %[[VAL_132:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_3:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_132]] : i32] : vector<8xf16> - // CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_4:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_149]] : i32] : vector<8xf16> - // CHECK: %[[VAL_166:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_5:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_166]] : i32] : vector<8xf16> - // CHECK: %[[VAL_183:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_6:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_183]] : i32] : vector<8xf16> - // CHECK: %[[VAL_200:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_7:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_200]] : i32] : vector<8xf16> - // CHECK: %[[VAL_216:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_217:.*]] = llvm.extractelement %[[REP_0]]{{\[}}%[[VAL_216]] : i32] : vector<8xf16> - // CHECK: %[[VAL_232:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_233:.*]] = llvm.extractelement %[[REP_2]]{{\[}}%[[VAL_232]] : i32] : vector<8xf16> - // CHECK: %[[VAL_248:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_249:.*]] = llvm.extractelement %[[REP_1]]{{\[}}%[[VAL_248]] : i32] : vector<8xf16> - // CHECK: %[[VAL_264:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_265:.*]] = llvm.extractelement %[[REP_3]]{{\[}}%[[VAL_264]] : i32] : vector<8xf16> - // CHECK: %[[VAL_280:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_281:.*]] = llvm.extractelement %[[REP_4]]{{\[}}%[[VAL_280]] : i32] : vector<8xf16> - // CHECK: %[[VAL_296:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_297:.*]] = llvm.extractelement %[[REP_6]]{{\[}}%[[VAL_296]] : i32] : vector<8xf16> - // CHECK: %[[VAL_312:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_313:.*]] = llvm.extractelement %[[REP_5]]{{\[}}%[[VAL_312]] : i32] : vector<8xf16> - // CHECK: %[[VAL_328:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_329:.*]] = llvm.extractelement %[[REP_7]]{{\[}}%[[VAL_328]] : i32] : vector<8xf16> - // CHECK: %[[VAL_338:.*]] = llvm.insertvalue %[[VAL_217]], {{.*}}[7] - // CHECK: %[[VAL_346:.*]] = llvm.insertvalue %[[VAL_233]], {{.*}}[15] - // CHECK: %[[VAL_354:.*]] = llvm.insertvalue %[[VAL_249]], {{.*}}[23] - // CHECK: %[[VAL_362:.*]] = llvm.insertvalue %[[VAL_265]], {{.*}}[31] - // CHECK: %[[VAL_370:.*]] = llvm.insertvalue %[[VAL_281]], {{.*}}[39] - // CHECK: %[[VAL_378:.*]] = llvm.insertvalue %[[VAL_297]], {{.*}}[47] - // CHECK: %[[VAL_386:.*]] = llvm.insertvalue %[[VAL_313]], {{.*}}[55] - // CHECK: %[[VAL_394:.*]] = llvm.insertvalue %[[VAL_329]], {{.*}}[63] + // CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] + // CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] + // CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][2] + // CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_0]][3] + // CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_0]][4] + // CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_0]][5] + // CHECK: %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_0]][6] + // CHECK: %[[VAL_8:.*]] = llvm.extractvalue %[[VAL_0]][7] + // CHECK: %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_0]][8] + // CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_0]][9] + // CHECK: %[[VAL_11:.*]] = llvm.extractvalue %[[VAL_0]][10] + // CHECK: %[[VAL_12:.*]] = llvm.extractvalue %[[VAL_0]][11] + // CHECK: %[[VAL_13:.*]] = llvm.extractvalue %[[VAL_0]][12] + // CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_0]][13] + // CHECK: %[[VAL_15:.*]] = llvm.extractvalue %[[VAL_0]][14] + // CHECK: %[[VAL_16:.*]] = llvm.extractvalue %[[VAL_0]][15] + // CHECK: %[[VAL_17:.*]] = llvm.extractvalue %[[VAL_0]][16] + // CHECK: %[[VAL_18:.*]] = llvm.extractvalue %[[VAL_0]][17] + // CHECK: %[[VAL_19:.*]] = llvm.extractvalue %[[VAL_0]][18] + // CHECK: %[[VAL_20:.*]] = llvm.extractvalue %[[VAL_0]][19] + // CHECK: %[[VAL_21:.*]] = llvm.extractvalue %[[VAL_0]][20] + // CHECK: %[[VAL_22:.*]] = llvm.extractvalue %[[VAL_0]][21] + // CHECK: %[[VAL_23:.*]] = llvm.extractvalue %[[VAL_0]][22] + // CHECK: %[[VAL_24:.*]] = llvm.extractvalue %[[VAL_0]][23] + // CHECK: %[[VAL_25:.*]] = llvm.extractvalue %[[VAL_0]][24] + // CHECK: %[[VAL_26:.*]] = llvm.extractvalue %[[VAL_0]][25] + // CHECK: %[[VAL_27:.*]] = llvm.extractvalue %[[VAL_0]][26] + // CHECK: %[[VAL_28:.*]] = llvm.extractvalue %[[VAL_0]][27] + // CHECK: %[[VAL_29:.*]] = llvm.extractvalue %[[VAL_0]][28] + // CHECK: %[[VAL_30:.*]] = llvm.extractvalue %[[VAL_0]][29] + // CHECK: %[[VAL_31:.*]] = llvm.extractvalue %[[VAL_0]][30] + // CHECK: %[[VAL_32:.*]] = llvm.extractvalue %[[VAL_0]][31] + // CHECK: %[[VAL_33:.*]] = llvm.extractvalue %[[VAL_0]][32] + // CHECK: %[[VAL_34:.*]] = llvm.extractvalue %[[VAL_0]][33] + // CHECK: %[[VAL_35:.*]] = llvm.extractvalue %[[VAL_0]][34] + // CHECK: %[[VAL_36:.*]] = llvm.extractvalue %[[VAL_0]][35] + // CHECK: %[[VAL_37:.*]] = llvm.extractvalue %[[VAL_0]][36] + // CHECK: %[[VAL_38:.*]] = llvm.extractvalue %[[VAL_0]][37] + // CHECK: %[[VAL_39:.*]] = llvm.extractvalue %[[VAL_0]][38] + // CHECK: %[[VAL_40:.*]] = llvm.extractvalue %[[VAL_0]][39] + // CHECK: %[[VAL_41:.*]] = llvm.extractvalue %[[VAL_0]][40] + // CHECK: %[[VAL_42:.*]] = llvm.extractvalue %[[VAL_0]][41] + // CHECK: %[[VAL_43:.*]] = llvm.extractvalue %[[VAL_0]][42] + // CHECK: %[[VAL_44:.*]] = llvm.extractvalue %[[VAL_0]][43] + // CHECK: %[[VAL_45:.*]] = llvm.extractvalue %[[VAL_0]][44] + // CHECK: %[[VAL_46:.*]] = llvm.extractvalue %[[VAL_0]][45] + // CHECK: %[[VAL_47:.*]] = llvm.extractvalue %[[VAL_0]][46] + // CHECK: %[[VAL_48:.*]] = llvm.extractvalue %[[VAL_0]][47] + // CHECK: %[[VAL_49:.*]] = llvm.extractvalue %[[VAL_0]][48] + // CHECK: %[[VAL_50:.*]] = llvm.extractvalue %[[VAL_0]][49] + // CHECK: %[[VAL_51:.*]] = llvm.extractvalue %[[VAL_0]][50] + // CHECK: %[[VAL_52:.*]] = llvm.extractvalue %[[VAL_0]][51] + // CHECK: %[[VAL_53:.*]] = llvm.extractvalue %[[VAL_0]][52] + // CHECK: %[[VAL_54:.*]] = llvm.extractvalue %[[VAL_0]][53] + // CHECK: %[[VAL_55:.*]] = llvm.extractvalue %[[VAL_0]][54] + // CHECK: %[[VAL_56:.*]] = llvm.extractvalue %[[VAL_0]][55] + // CHECK: %[[VAL_57:.*]] = llvm.extractvalue %[[VAL_0]][56] + // CHECK: %[[VAL_58:.*]] = llvm.extractvalue %[[VAL_0]][57] + // CHECK: %[[VAL_59:.*]] = llvm.extractvalue %[[VAL_0]][58] + // CHECK: %[[VAL_60:.*]] = llvm.extractvalue %[[VAL_0]][59] + // CHECK: %[[VAL_61:.*]] = llvm.extractvalue %[[VAL_0]][60] + // CHECK: %[[VAL_62:.*]] = llvm.extractvalue %[[VAL_0]][61] + // CHECK: %[[VAL_63:.*]] = llvm.extractvalue %[[VAL_0]][62] + // CHECK: %[[VAL_64:.*]] = llvm.extractvalue %[[VAL_0]][63] + // CHECK: %[[VAL_65:.*]] = llvm.mlir.undef + // CHECK: %[[VAL_66:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_65]][0] + // CHECK: %[[VAL_67:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_66]][1] + // CHECK: %[[VAL_68:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_67]][2] + // CHECK: %[[VAL_69:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_68]][3] + // CHECK: %[[VAL_70:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_69]][4] + // CHECK: %[[VAL_71:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_70]][5] + // CHECK: %[[VAL_72:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_71]][6] + // CHECK: %[[VAL_73:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_72]][7] + // CHECK: %[[VAL_74:.*]] = llvm.insertvalue %[[VAL_17]], %[[VAL_73]][8] + // CHECK: %[[VAL_75:.*]] = llvm.insertvalue %[[VAL_18]], %[[VAL_74]][9] + // CHECK: %[[VAL_76:.*]] = llvm.insertvalue %[[VAL_19]], %[[VAL_75]][10] + // CHECK: %[[VAL_77:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_76]][11] + // CHECK: %[[VAL_78:.*]] = llvm.insertvalue %[[VAL_21]], %[[VAL_77]][12] + // CHECK: %[[VAL_79:.*]] = llvm.insertvalue %[[VAL_22]], %[[VAL_78]][13] + // CHECK: %[[VAL_80:.*]] = llvm.insertvalue %[[VAL_23]], %[[VAL_79]][14] + // CHECK: %[[VAL_81:.*]] = llvm.insertvalue %[[VAL_24]], %[[VAL_80]][15] + // CHECK: %[[VAL_82:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_81]][16] + // CHECK: %[[VAL_83:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_82]][17] + // CHECK: %[[VAL_84:.*]] = llvm.insertvalue %[[VAL_11]], %[[VAL_83]][18] + // CHECK: %[[VAL_85:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_84]][19] + // CHECK: %[[VAL_86:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_85]][20] + // CHECK: %[[VAL_87:.*]] = llvm.insertvalue %[[VAL_14]], %[[VAL_86]][21] + // CHECK: %[[VAL_88:.*]] = llvm.insertvalue %[[VAL_15]], %[[VAL_87]][22] + // CHECK: %[[VAL_89:.*]] = llvm.insertvalue %[[VAL_16]], %[[VAL_88]][23] + // CHECK: %[[VAL_90:.*]] = llvm.insertvalue %[[VAL_25]], %[[VAL_89]][24] + // CHECK: %[[VAL_91:.*]] = llvm.insertvalue %[[VAL_26]], %[[VAL_90]][25] + // CHECK: %[[VAL_92:.*]] = llvm.insertvalue %[[VAL_27]], %[[VAL_91]][26] + // CHECK: %[[VAL_93:.*]] = llvm.insertvalue %[[VAL_28]], %[[VAL_92]][27] + // CHECK: %[[VAL_94:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_93]][28] + // CHECK: %[[VAL_95:.*]] = llvm.insertvalue %[[VAL_30]], %[[VAL_94]][29] + // CHECK: %[[VAL_96:.*]] = llvm.insertvalue %[[VAL_31]], %[[VAL_95]][30] + // CHECK: %[[VAL_97:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_96]][31] + // CHECK: %[[VAL_98:.*]] = llvm.insertvalue %[[VAL_33]], %[[VAL_97]][32] + // CHECK: %[[VAL_99:.*]] = llvm.insertvalue %[[VAL_34]], %[[VAL_98]][33] + // CHECK: %[[VAL_100:.*]] = llvm.insertvalue %[[VAL_35]], %[[VAL_99]][34] + // CHECK: %[[VAL_101:.*]] = llvm.insertvalue %[[VAL_36]], %[[VAL_100]][35] + // CHECK: %[[VAL_102:.*]] = llvm.insertvalue %[[VAL_37]], %[[VAL_101]][36] + // CHECK: %[[VAL_103:.*]] = llvm.insertvalue %[[VAL_38]], %[[VAL_102]][37] + // CHECK: %[[VAL_104:.*]] = llvm.insertvalue %[[VAL_39]], %[[VAL_103]][38] + // CHECK: %[[VAL_105:.*]] = llvm.insertvalue %[[VAL_40]], %[[VAL_104]][39] + // CHECK: %[[VAL_106:.*]] = llvm.insertvalue %[[VAL_49]], %[[VAL_105]][40] + // CHECK: %[[VAL_107:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_106]][41] + // CHECK: %[[VAL_108:.*]] = llvm.insertvalue %[[VAL_51]], %[[VAL_107]][42] + // CHECK: %[[VAL_109:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_108]][43] + // CHECK: %[[VAL_110:.*]] = llvm.insertvalue %[[VAL_53]], %[[VAL_109]][44] + // CHECK: %[[VAL_111:.*]] = llvm.insertvalue %[[VAL_54]], %[[VAL_110]][45] + // CHECK: %[[VAL_112:.*]] = llvm.insertvalue %[[VAL_55]], %[[VAL_111]][46] + // CHECK: %[[VAL_113:.*]] = llvm.insertvalue %[[VAL_56]], %[[VAL_112]][47] + // CHECK: %[[VAL_114:.*]] = llvm.insertvalue %[[VAL_41]], %[[VAL_113]][48] + // CHECK: %[[VAL_115:.*]] = llvm.insertvalue %[[VAL_42]], %[[VAL_114]][49] + // CHECK: %[[VAL_116:.*]] = llvm.insertvalue %[[VAL_43]], %[[VAL_115]][50] + // CHECK: %[[VAL_117:.*]] = llvm.insertvalue %[[VAL_44]], %[[VAL_116]][51] + // CHECK: %[[VAL_118:.*]] = llvm.insertvalue %[[VAL_45]], %[[VAL_117]][52] + // CHECK: %[[VAL_119:.*]] = llvm.insertvalue %[[VAL_46]], %[[VAL_118]][53] + // CHECK: %[[VAL_120:.*]] = llvm.insertvalue %[[VAL_47]], %[[VAL_119]][54] + // CHECK: %[[VAL_121:.*]] = llvm.insertvalue %[[VAL_48]], %[[VAL_120]][55] + // CHECK: %[[VAL_122:.*]] = llvm.insertvalue %[[VAL_57]], %[[VAL_121]][56] + // CHECK: %[[VAL_123:.*]] = llvm.insertvalue %[[VAL_58]], %[[VAL_122]][57] + // CHECK: %[[VAL_124:.*]] = llvm.insertvalue %[[VAL_59]], %[[VAL_123]][58] + // CHECK: %[[VAL_125:.*]] = llvm.insertvalue %[[VAL_60]], %[[VAL_124]][59] + // CHECK: %[[VAL_126:.*]] = llvm.insertvalue %[[VAL_61]], %[[VAL_125]][60] + // CHECK: %[[VAL_127:.*]] = llvm.insertvalue %[[VAL_62]], %[[VAL_126]][61] + // CHECK: %[[VAL_128:.*]] = llvm.insertvalue %[[VAL_63]], %[[VAL_127]][62] + // CHECK: %[[VAL_129:.*]] = llvm.insertvalue %[[VAL_64]], %[[VAL_128]][63] %108 = triton_gpu.convert_layout %arg : tensor<1024x32xf16, #dpas> -> tensor<1024x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>> tt.return } @@ -118,46 +169,135 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 // COM: - 0, 1, 2, 3, 4, 5, 6, 7. // COM: The repetitions order of dot layout when the GEMM tiling is clustered as repCluster [4, 2]: // COM: - 0, 2, 4, 6, 1, 3, 5, 7. - // CHECK: %[[VAL_81:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_0:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_81]] : i32] : vector<8xf16> - // CHECK: %[[VAL_98:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_1:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_98]] : i32] : vector<8xf16> - // CHECK: %[[VAL_115:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_2:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_115]] : i32] : vector<8xf16> - // CHECK: %[[VAL_132:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_3:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_132]] : i32] : vector<8xf16> - // CHECK: %[[VAL_149:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_4:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_149]] : i32] : vector<8xf16> - // CHECK: %[[VAL_166:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_5:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_166]] : i32] : vector<8xf16> - // CHECK: %[[VAL_183:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_6:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_183]] : i32] : vector<8xf16> - // CHECK: %[[VAL_200:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[REP_7:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[VAL_200]] : i32] : vector<8xf16> - // CHECK: %[[VAL_216:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_217:.*]] = llvm.extractelement %[[REP_0]]{{\[}}%[[VAL_216]] : i32] : vector<8xf16> - // CHECK: %[[VAL_232:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_233:.*]] = llvm.extractelement %[[REP_2]]{{\[}}%[[VAL_232]] : i32] : vector<8xf16> - // CHECK: %[[VAL_248:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_249:.*]] = llvm.extractelement %[[REP_4]]{{\[}}%[[VAL_248]] : i32] : vector<8xf16> - // CHECK: %[[VAL_264:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_265:.*]] = llvm.extractelement %[[REP_6]]{{\[}}%[[VAL_264]] : i32] : vector<8xf16> - // CHECK: %[[VAL_280:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_281:.*]] = llvm.extractelement %[[REP_1]]{{\[}}%[[VAL_280]] : i32] : vector<8xf16> - // CHECK: %[[VAL_296:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_297:.*]] = llvm.extractelement %[[REP_3]]{{\[}}%[[VAL_296]] : i32] : vector<8xf16> - // CHECK: %[[VAL_312:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_313:.*]] = llvm.extractelement %[[REP_5]]{{\[}}%[[VAL_312]] : i32] : vector<8xf16> - // CHECK: %[[VAL_328:.*]] = llvm.mlir.constant(7 : i32) : i32 - // CHECK: %[[VAL_329:.*]] = llvm.extractelement %[[REP_7]]{{\[}}%[[VAL_328]] : i32] : vector<8xf16> - // CHECK: %[[VAL_338:.*]] = llvm.insertvalue %[[VAL_217]], {{.*}}[7] - // CHECK: %[[VAL_346:.*]] = llvm.insertvalue %[[VAL_233]], {{.*}}[15] - // CHECK: %[[VAL_354:.*]] = llvm.insertvalue %[[VAL_249]], {{.*}}[23] - // CHECK: %[[VAL_362:.*]] = llvm.insertvalue %[[VAL_265]], {{.*}}[31] - // CHECK: %[[VAL_370:.*]] = llvm.insertvalue %[[VAL_281]], {{.*}}[39] - // CHECK: %[[VAL_378:.*]] = llvm.insertvalue %[[VAL_297]], {{.*}}[47] - // CHECK: %[[VAL_386:.*]] = llvm.insertvalue %[[VAL_313]], {{.*}}[55] - // CHECK: %[[VAL_394:.*]] = llvm.insertvalue %[[VAL_329]], {{.*}}[63] + // CHECK: %[[VAL_1:.*]] = llvm.extractvalue %[[VAL_0]][0] + // CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] + // CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_0]][2] + // CHECK: %[[VAL_4:.*]] = llvm.extractvalue %[[VAL_0]][3] + // CHECK: %[[VAL_5:.*]] = llvm.extractvalue %[[VAL_0]][4] + // CHECK: %[[VAL_6:.*]] = llvm.extractvalue %[[VAL_0]][5] + // CHECK: %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_0]][6] + // CHECK: %[[VAL_8:.*]] = llvm.extractvalue %[[VAL_0]][7] + // CHECK: %[[VAL_9:.*]] = llvm.extractvalue %[[VAL_0]][8] + // CHECK: %[[VAL_10:.*]] = llvm.extractvalue %[[VAL_0]][9] + // CHECK: %[[VAL_11:.*]] = llvm.extractvalue %[[VAL_0]][10] + // CHECK: %[[VAL_12:.*]] = llvm.extractvalue %[[VAL_0]][11] + // CHECK: %[[VAL_13:.*]] = llvm.extractvalue %[[VAL_0]][12] + // CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_0]][13] + // CHECK: %[[VAL_15:.*]] = llvm.extractvalue %[[VAL_0]][14] + // CHECK: %[[VAL_16:.*]] = llvm.extractvalue %[[VAL_0]][15] + // CHECK: %[[VAL_17:.*]] = llvm.extractvalue %[[VAL_0]][16] + // CHECK: %[[VAL_18:.*]] = llvm.extractvalue %[[VAL_0]][17] + // CHECK: %[[VAL_19:.*]] = llvm.extractvalue %[[VAL_0]][18] + // CHECK: %[[VAL_20:.*]] = llvm.extractvalue %[[VAL_0]][19] + // CHECK: %[[VAL_21:.*]] = llvm.extractvalue %[[VAL_0]][20] + // CHECK: %[[VAL_22:.*]] = llvm.extractvalue %[[VAL_0]][21] + // CHECK: %[[VAL_23:.*]] = llvm.extractvalue %[[VAL_0]][22] + // CHECK: %[[VAL_24:.*]] = llvm.extractvalue %[[VAL_0]][23] + // CHECK: %[[VAL_25:.*]] = llvm.extractvalue %[[VAL_0]][24] + // CHECK: %[[VAL_26:.*]] = llvm.extractvalue %[[VAL_0]][25] + // CHECK: %[[VAL_27:.*]] = llvm.extractvalue %[[VAL_0]][26] + // CHECK: %[[VAL_28:.*]] = llvm.extractvalue %[[VAL_0]][27] + // CHECK: %[[VAL_29:.*]] = llvm.extractvalue %[[VAL_0]][28] + // CHECK: %[[VAL_30:.*]] = llvm.extractvalue %[[VAL_0]][29] + // CHECK: %[[VAL_31:.*]] = llvm.extractvalue %[[VAL_0]][30] + // CHECK: %[[VAL_32:.*]] = llvm.extractvalue %[[VAL_0]][31] + // CHECK: %[[VAL_33:.*]] = llvm.extractvalue %[[VAL_0]][32] + // CHECK: %[[VAL_34:.*]] = llvm.extractvalue %[[VAL_0]][33] + // CHECK: %[[VAL_35:.*]] = llvm.extractvalue %[[VAL_0]][34] + // CHECK: %[[VAL_36:.*]] = llvm.extractvalue %[[VAL_0]][35] + // CHECK: %[[VAL_37:.*]] = llvm.extractvalue %[[VAL_0]][36] + // CHECK: %[[VAL_38:.*]] = llvm.extractvalue %[[VAL_0]][37] + // CHECK: %[[VAL_39:.*]] = llvm.extractvalue %[[VAL_0]][38] + // CHECK: %[[VAL_40:.*]] = llvm.extractvalue %[[VAL_0]][39] + // CHECK: %[[VAL_41:.*]] = llvm.extractvalue %[[VAL_0]][40] + // CHECK: %[[VAL_42:.*]] = llvm.extractvalue %[[VAL_0]][41] + // CHECK: %[[VAL_43:.*]] = llvm.extractvalue %[[VAL_0]][42] + // CHECK: %[[VAL_44:.*]] = llvm.extractvalue %[[VAL_0]][43] + // CHECK: %[[VAL_45:.*]] = llvm.extractvalue %[[VAL_0]][44] + // CHECK: %[[VAL_46:.*]] = llvm.extractvalue %[[VAL_0]][45] + // CHECK: %[[VAL_47:.*]] = llvm.extractvalue %[[VAL_0]][46] + // CHECK: %[[VAL_48:.*]] = llvm.extractvalue %[[VAL_0]][47] + // CHECK: %[[VAL_49:.*]] = llvm.extractvalue %[[VAL_0]][48] + // CHECK: %[[VAL_50:.*]] = llvm.extractvalue %[[VAL_0]][49] + // CHECK: %[[VAL_51:.*]] = llvm.extractvalue %[[VAL_0]][50] + // CHECK: %[[VAL_52:.*]] = llvm.extractvalue %[[VAL_0]][51] + // CHECK: %[[VAL_53:.*]] = llvm.extractvalue %[[VAL_0]][52] + // CHECK: %[[VAL_54:.*]] = llvm.extractvalue %[[VAL_0]][53] + // CHECK: %[[VAL_55:.*]] = llvm.extractvalue %[[VAL_0]][54] + // CHECK: %[[VAL_56:.*]] = llvm.extractvalue %[[VAL_0]][55] + // CHECK: %[[VAL_57:.*]] = llvm.extractvalue %[[VAL_0]][56] + // CHECK: %[[VAL_58:.*]] = llvm.extractvalue %[[VAL_0]][57] + // CHECK: %[[VAL_59:.*]] = llvm.extractvalue %[[VAL_0]][58] + // CHECK: %[[VAL_60:.*]] = llvm.extractvalue %[[VAL_0]][59] + // CHECK: %[[VAL_61:.*]] = llvm.extractvalue %[[VAL_0]][60] + // CHECK: %[[VAL_62:.*]] = llvm.extractvalue %[[VAL_0]][61] + // CHECK: %[[VAL_63:.*]] = llvm.extractvalue %[[VAL_0]][62] + // CHECK: %[[VAL_64:.*]] = llvm.extractvalue %[[VAL_0]][63] + // CHECK: %[[VAL_65:.*]] = llvm.mlir.undef + // CHECK: %[[VAL_66:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_65]][0] + // CHECK: %[[VAL_67:.*]] = llvm.insertvalue %[[VAL_2]], %[[VAL_66]][1] + // CHECK: %[[VAL_68:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_67]][2] + // CHECK: %[[VAL_69:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_68]][3] + // CHECK: %[[VAL_70:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_69]][4] + // CHECK: %[[VAL_71:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_70]][5] + // CHECK: %[[VAL_72:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_71]][6] + // CHECK: %[[VAL_73:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_72]][7] + // CHECK: %[[VAL_74:.*]] = llvm.insertvalue %[[VAL_17]], %[[VAL_73]][8] + // CHECK: %[[VAL_75:.*]] = llvm.insertvalue %[[VAL_18]], %[[VAL_74]][9] + // CHECK: %[[VAL_76:.*]] = llvm.insertvalue %[[VAL_19]], %[[VAL_75]][10] + // CHECK: %[[VAL_77:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_76]][11] + // CHECK: %[[VAL_78:.*]] = llvm.insertvalue %[[VAL_21]], %[[VAL_77]][12] + // CHECK: %[[VAL_79:.*]] = llvm.insertvalue %[[VAL_22]], %[[VAL_78]][13] + // CHECK: %[[VAL_80:.*]] = llvm.insertvalue %[[VAL_23]], %[[VAL_79]][14] + // CHECK: %[[VAL_81:.*]] = llvm.insertvalue %[[VAL_24]], %[[VAL_80]][15] + // CHECK: %[[VAL_82:.*]] = llvm.insertvalue %[[VAL_33]], %[[VAL_81]][16] + // CHECK: %[[VAL_83:.*]] = llvm.insertvalue %[[VAL_34]], %[[VAL_82]][17] + // CHECK: %[[VAL_84:.*]] = llvm.insertvalue %[[VAL_35]], %[[VAL_83]][18] + // CHECK: %[[VAL_85:.*]] = llvm.insertvalue %[[VAL_36]], %[[VAL_84]][19] + // CHECK: %[[VAL_86:.*]] = llvm.insertvalue %[[VAL_37]], %[[VAL_85]][20] + // CHECK: %[[VAL_87:.*]] = llvm.insertvalue %[[VAL_38]], %[[VAL_86]][21] + // CHECK: %[[VAL_88:.*]] = llvm.insertvalue %[[VAL_39]], %[[VAL_87]][22] + // CHECK: %[[VAL_89:.*]] = llvm.insertvalue %[[VAL_40]], %[[VAL_88]][23] + // CHECK: %[[VAL_90:.*]] = llvm.insertvalue %[[VAL_49]], %[[VAL_89]][24] + // CHECK: %[[VAL_91:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_90]][25] + // CHECK: %[[VAL_92:.*]] = llvm.insertvalue %[[VAL_51]], %[[VAL_91]][26] + // CHECK: %[[VAL_93:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_92]][27] + // CHECK: %[[VAL_94:.*]] = llvm.insertvalue %[[VAL_53]], %[[VAL_93]][28] + // CHECK: %[[VAL_95:.*]] = llvm.insertvalue %[[VAL_54]], %[[VAL_94]][29] + // CHECK: %[[VAL_96:.*]] = llvm.insertvalue %[[VAL_55]], %[[VAL_95]][30] + // CHECK: %[[VAL_97:.*]] = llvm.insertvalue %[[VAL_56]], %[[VAL_96]][31] + // CHECK: %[[VAL_98:.*]] = llvm.insertvalue %[[VAL_9]], %[[VAL_97]][32] + // CHECK: %[[VAL_99:.*]] = llvm.insertvalue %[[VAL_10]], %[[VAL_98]][33] + // CHECK: %[[VAL_100:.*]] = llvm.insertvalue %[[VAL_11]], %[[VAL_99]][34] + // CHECK: %[[VAL_101:.*]] = llvm.insertvalue %[[VAL_12]], %[[VAL_100]][35] + // CHECK: %[[VAL_102:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_101]][36] + // CHECK: %[[VAL_103:.*]] = llvm.insertvalue %[[VAL_14]], %[[VAL_102]][37] + // CHECK: %[[VAL_104:.*]] = llvm.insertvalue %[[VAL_15]], %[[VAL_103]][38] + // CHECK: %[[VAL_105:.*]] = llvm.insertvalue %[[VAL_16]], %[[VAL_104]][39] + // CHECK: %[[VAL_106:.*]] = llvm.insertvalue %[[VAL_25]], %[[VAL_105]][40] + // CHECK: %[[VAL_107:.*]] = llvm.insertvalue %[[VAL_26]], %[[VAL_106]][41] + // CHECK: %[[VAL_108:.*]] = llvm.insertvalue %[[VAL_27]], %[[VAL_107]][42] + // CHECK: %[[VAL_109:.*]] = llvm.insertvalue %[[VAL_28]], %[[VAL_108]][43] + // CHECK: %[[VAL_110:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_109]][44] + // CHECK: %[[VAL_111:.*]] = llvm.insertvalue %[[VAL_30]], %[[VAL_110]][45] + // CHECK: %[[VAL_112:.*]] = llvm.insertvalue %[[VAL_31]], %[[VAL_111]][46] + // CHECK: %[[VAL_113:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_112]][47] + // CHECK: %[[VAL_114:.*]] = llvm.insertvalue %[[VAL_41]], %[[VAL_113]][48] + // CHECK: %[[VAL_115:.*]] = llvm.insertvalue %[[VAL_42]], %[[VAL_114]][49] + // CHECK: %[[VAL_116:.*]] = llvm.insertvalue %[[VAL_43]], %[[VAL_115]][50] + // CHECK: %[[VAL_117:.*]] = llvm.insertvalue %[[VAL_44]], %[[VAL_116]][51] + // CHECK: %[[VAL_118:.*]] = llvm.insertvalue %[[VAL_45]], %[[VAL_117]][52] + // CHECK: %[[VAL_119:.*]] = llvm.insertvalue %[[VAL_46]], %[[VAL_118]][53] + // CHECK: %[[VAL_120:.*]] = llvm.insertvalue %[[VAL_47]], %[[VAL_119]][54] + // CHECK: %[[VAL_121:.*]] = llvm.insertvalue %[[VAL_48]], %[[VAL_120]][55] + // CHECK: %[[VAL_122:.*]] = llvm.insertvalue %[[VAL_57]], %[[VAL_121]][56] + // CHECK: %[[VAL_123:.*]] = llvm.insertvalue %[[VAL_58]], %[[VAL_122]][57] + // CHECK: %[[VAL_124:.*]] = llvm.insertvalue %[[VAL_59]], %[[VAL_123]][58] + // CHECK: %[[VAL_125:.*]] = llvm.insertvalue %[[VAL_60]], %[[VAL_124]][59] + // CHECK: %[[VAL_126:.*]] = llvm.insertvalue %[[VAL_61]], %[[VAL_125]][60] + // CHECK: %[[VAL_127:.*]] = llvm.insertvalue %[[VAL_62]], %[[VAL_126]][61] + // CHECK: %[[VAL_128:.*]] = llvm.insertvalue %[[VAL_63]], %[[VAL_127]][62] + // CHECK: %[[VAL_129:.*]] = llvm.insertvalue %[[VAL_64]], %[[VAL_128]][63] %108 = triton_gpu.convert_layout %arg : tensor<1024x32xf16, #dpas> -> tensor<1024x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>> tt.return } diff --git a/test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir b/test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir index 2bb504d76f..1ecb0a5a2c 100644 --- a/test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir +++ b/test/TritonIntelGPU/tritonintlgpu-nested-layout.mlir @@ -69,14 +69,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32 // CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32 // CHECK-DAG: %[[CST_5:.*]] = llvm.mlir.constant(5 : i32) : i32 // CHECK-DAG: %[[CST_6:.*]] = llvm.mlir.constant(6 : i32) : i32 // CHECK-DAG: %[[CST_7:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK-DAG: %[[CST_17:.*]] = llvm.mlir.constant(17 : i32) : i32 // CHECK-DAG: %[[CST_18:.*]] = llvm.mlir.constant(18 : i32) : i32 // CHECK-DAG: %[[CST_19:.*]] = llvm.mlir.constant(19 : i32) : i32 @@ -86,43 +85,46 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[CST_23:.*]] = llvm.mlir.constant(23 : i32) : i32 // CHECK: %[[THREAD_ID:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]]) // CHECK: %[[THREAD_ID_32:.*]] = llvm.trunc %[[THREAD_ID]] : i64 to i32 - // CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_32]], %[[CST_16]] : i32 // CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREAD_ID_32]], %[[CST_16]] : i32 - // CHECK: %[[VAL_29:.*]] = llvm.udiv %[[WARP_ID]], %[[CST_2]] : i32 - // CHECK: %[[WARP_ID_X:.*]] = llvm.urem %[[VAL_29]], %[[CST_2]] : i32 - // CHECK: %[[ROUNDED_WARP_ID_X:.*]] = llvm.urem %[[WARP_ID_X]], %[[CST_4]] : i32 - // CHECK: %[[WARP_OFFSET:.*]] = llvm.mul %[[ROUNDED_WARP_ID_X]], %[[CST_8]] : i32 - // CHECK: %[[LANE_ID_X:.*]] = llvm.udiv %[[LANE_ID]], %[[CST_16]] : i32 - // CHECK: %[[LANE_ID_Y:.*]] = llvm.urem %[[LANE_ID]], %[[CST_16]] : i32 - // CHECK: %[[OFFSET_Y:.*]] = llvm.mul %[[LANE_ID_Y]], %[[CST_2]] : i32 - // CHECK: %[[OFFSET_x:.*]] = llvm.add %[[LANE_ID_X]], %[[WARP_OFFSET]] : i32 - // CHECK: %[[VAL_37:.*]] = llvm.urem %[[CST_0]], %[[CST_1]] : i32 - // CHECK: %[[VAL_38:.*]] = llvm.udiv %[[CST_0]], %[[CST_1]] : i32 - // CHECK: %[[VAL_39:.*]] = llvm.urem %[[VAL_38]], %[[CST_1]] : i32 - // CHECK: %[[VAL_40:.*]] = llvm.urem %[[VAL_39]], %[[CST_1]] : i32 - // CHECK: %[[VAL_41:.*]] = llvm.urem %[[VAL_37]], %[[CST_1]] : i32 - // CHECK: %[[CTA_OFFSET_X:.*]] = llvm.mul %[[VAL_40]], %[[CST_32]] : i32 - // CHECK: %[[CTA_OFFSET_Y:.*]] = llvm.mul %[[VAL_41]], %[[CST_32]] : i32 - // CHECK: %[[VAL_44:.*]] = llvm.add %[[OFFSET_x]], %[[CTA_OFFSET_X]] : i32 - // CHECK: %[[VAL_45:.*]] = llvm.add %[[OFFSET_Y]], %[[CTA_OFFSET_Y]] : i32 - // CHECK: %[[OFFSET_X_0:.*]] = llvm.add %[[VAL_44]], %[[CST_0]] : i32 - // CHECK: %[[OFFSET_Y_0:.*]] = llvm.add %[[VAL_45]], %[[CST_0]] : i32 - // CHECK: %[[OFFSET_Y_1:.*]] = llvm.add %[[VAL_45]], %[[CST_1]] : i32 - // CHECK: %[[OFFSET_X_1:.*]] = llvm.add %[[VAL_44]], %[[CST_1]] : i32 - // CHECK: %[[OFFSET_X_2:.*]] = llvm.add %[[VAL_44]], %[[CST_2]] : i32 - // CHECK: %[[OFFSET_X_3:.*]] = llvm.add %[[VAL_44]], %[[CST_3]] : i32 - // CHECK: %[[OFFSET_X_4:.*]] = llvm.add %[[VAL_44]], %[[CST_4]] : i32 - // CHECK: %[[OFFSET_X_5:.*]] = llvm.add %[[VAL_44]], %[[CST_5]] : i32 - // CHECK: %[[OFFSET_X_6:.*]] = llvm.add %[[VAL_44]], %[[CST_6]] : i32 - // CHECK: %[[OFFSET_X_7:.*]] = llvm.add %[[VAL_44]], %[[CST_7]] : i32 - // CHECK: %[[OFFSET_X_8:.*]] = llvm.add %[[VAL_44]], %[[CST_16]] : i32 - // CHECK: %[[OFFSET_X_9:.*]] = llvm.add %[[VAL_44]], %[[CST_17]] : i32 - // CHECK: %[[OFFSET_X_10:.*]] = llvm.add %[[VAL_44]], %[[CST_18]] : i32 - // CHECK: %[[OFFSET_X_11:.*]] = llvm.add %[[VAL_44]], %[[CST_19]] : i32 - // CHECK: %[[OFFSET_X_12:.*]] = llvm.add %[[VAL_44]], %[[CST_20]] : i32 - // CHECK: %[[OFFSET_X_13:.*]] = llvm.add %[[VAL_44]], %[[CST_21]] : i32 - // CHECK: %[[OFFSET_X_14:.*]] = llvm.add %[[VAL_44]], %[[CST_22]] : i32 - // CHECK: %[[OFFSET_X_15:.*]] = llvm.add %[[VAL_44]], %[[CST_23]] : i32 + // CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREAD_ID_32]], %[[CST_16]] : i32 + // CHECK: %[[VAL_27:.*]] = llvm.and %[[LANE_ID]], %[[CST_1]] : i32 + // CHECK: %[[VAL_28:.*]] = llvm.icmp "eq" %[[VAL_27]], %[[CST_0]] : i32 + // CHECK: %[[VAL_29:.*]] = llvm.select %[[VAL_28]], %[[CST_0]], %[[CST_2]] : i1, i32 + // CHECK: %[[VAL_30:.*]] = llvm.xor %[[CST_0]], %[[VAL_29]] : i32 + // CHECK: %[[VAL_31:.*]] = llvm.and %[[LANE_ID]], %[[CST_2]] : i32 + // CHECK: %[[VAL_32:.*]] = llvm.icmp "eq" %[[VAL_31]], %[[CST_0]] : i32 + // CHECK: %[[VAL_33:.*]] = llvm.select %[[VAL_32]], %[[CST_0]], %[[CST_4]] : i1, i32 + // CHECK: %[[VAL_34:.*]] = llvm.xor %[[VAL_30]], %[[VAL_33]] : i32 + // CHECK: %[[VAL_35:.*]] = llvm.and %[[LANE_ID]], %[[CST_4]] : i32 + // CHECK: %[[VAL_36:.*]] = llvm.icmp "eq" %[[VAL_35]], %[[CST_0]] : i32 + // CHECK: %[[VAL_37:.*]] = llvm.select %[[VAL_36]], %[[CST_0]], %[[CST_8]] : i1, i32 + // CHECK: %[[VAL_38:.*]] = llvm.xor %[[VAL_34]], %[[VAL_37]] : i32 + // CHECK: %[[VAL_39:.*]] = llvm.and %[[LANE_ID]], %[[CST_8]] : i32 + // CHECK: %[[VAL_40:.*]] = llvm.icmp "eq" %[[VAL_39]], %[[CST_0]] : i32 + // CHECK: %[[VAL_41:.*]] = llvm.select %[[VAL_40]], %[[CST_0]], %[[CST_16]] : i1, i32 + // CHECK: %[[VAL_42:.*]] = llvm.xor %[[VAL_38]], %[[VAL_41]] : i32 + // CHECK: %[[VAL_43:.*]] = llvm.and %[[WARP_ID]], %[[CST_2]] : i32 + // CHECK: %[[VAL_44:.*]] = llvm.icmp "eq" %[[VAL_43]], %[[CST_0]] : i32 + // CHECK: %[[VAL_45:.*]] = llvm.select %[[VAL_44]], %[[CST_0]], %[[CST_8]] : i1, i32 + // CHECK: %[[VAL_46:.*]] = llvm.xor %[[CST_0]], %[[VAL_45]] : i32 + // CHECK: %[[OFFSET_X_0:.*]] = llvm.xor %[[VAL_46]], %[[CST_0]] : i32 + // CHECK: %[[OFFSET_Y_0:.*]] = llvm.xor %[[VAL_42]], %[[CST_0]] : i32 + // CHECK: %[[OFFSET_Y_1:.*]] = llvm.xor %[[VAL_42]], %[[CST_1]] : i32 + // CHECK: %[[OFFSET_X_1:.*]] = llvm.xor %[[VAL_46]], %[[CST_1]] : i32 + // CHECK: %[[OFFSET_X_2:.*]] = llvm.xor %[[VAL_46]], %[[CST_2]] : i32 + // CHECK: %[[OFFSET_X_3:.*]] = llvm.xor %[[VAL_46]], %[[CST_3]] : i32 + // CHECK: %[[OFFSET_X_4:.*]] = llvm.xor %[[VAL_46]], %[[CST_4]] : i32 + // CHECK: %[[OFFSET_X_5:.*]] = llvm.xor %[[VAL_46]], %[[CST_5]] : i32 + // CHECK: %[[OFFSET_X_6:.*]] = llvm.xor %[[VAL_46]], %[[CST_6]] : i32 + // CHECK: %[[OFFSET_X_7:.*]] = llvm.xor %[[VAL_46]], %[[CST_7]] : i32 + // CHECK: %[[OFFSET_X_8:.*]] = llvm.xor %[[VAL_46]], %[[CST_16]] : i32 + // CHECK: %[[OFFSET_X_9:.*]] = llvm.xor %[[VAL_46]], %[[CST_17]] : i32 + // CHECK: %[[OFFSET_X_10:.*]] = llvm.xor %[[VAL_46]], %[[CST_18]] : i32 + // CHECK: %[[OFFSET_X_11:.*]] = llvm.xor %[[VAL_46]], %[[CST_19]] : i32 + // CHECK: %[[OFFSET_X_12:.*]] = llvm.xor %[[VAL_46]], %[[CST_20]] : i32 + // CHECK: %[[OFFSET_X_13:.*]] = llvm.xor %[[VAL_46]], %[[CST_21]] : i32 + // CHECK: %[[OFFSET_X_14:.*]] = llvm.xor %[[VAL_46]], %[[CST_22]] : i32 + // CHECK: %[[OFFSET_X_15:.*]] = llvm.xor %[[VAL_46]], %[[CST_23]] : i32 // CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_0]], %[[OFFSET_Y_0]], {{.*}}, {{.*}}) // CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_0]], %[[OFFSET_Y_1]], {{.*}}, {{.*}}) // CHECK: llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_1]], %[[OFFSET_Y_0]], {{.*}}, {{.*}}) @@ -172,14 +174,13 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32 // CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32 - // CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32 - // CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32 // CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32 // CHECK-DAG: %[[CST_5:.*]] = llvm.mlir.constant(5 : i32) : i32 // CHECK-DAG: %[[CST_6:.*]] = llvm.mlir.constant(6 : i32) : i32 // CHECK-DAG: %[[CST_7:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32 + // CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32 // CHECK-DAG: %[[CST_17:.*]] = llvm.mlir.constant(17 : i32) : i32 // CHECK-DAG: %[[CST_18:.*]] = llvm.mlir.constant(18 : i32) : i32 // CHECK-DAG: %[[CST_19:.*]] = llvm.mlir.constant(19 : i32) : i32 @@ -190,34 +191,26 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[THREADS_ID:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]]) // CHECK: %[[THREADS_ID_32:.*]] = llvm.trunc %[[THREADS_ID]] : i64 to i32 // CHECK: %[[WARP_ID:.*]] = llvm.udiv %[[THREADS_ID_32]], %[[CST_16]] : i32 - // CHECK: %[[LANE_ID:.*]] = llvm.urem %[[THREADS_ID_32]], %[[CST_16]] : i32 - // CHECK: %[[VAL_29:.*]] = llvm.udiv %[[WARP_ID]], %[[CST_2]] : i32 - // CHECK: %[[WARP_ID_X:.*]] = llvm.urem %[[VAL_29]], %[[CST_2]] : i32 - // CHECK: %[[ROUNDED_WARP_ID_X:.*]] = llvm.urem %[[WARP_ID_X]], %[[CST_4]] : i32 - // CHECK: %[[WARP_OFFSET_X:.*]] = llvm.mul %[[ROUNDED_WARP_ID_X]], %[[CST_8]] : i32 - // CHECK: %[[LANE_OFFSET_X:.*]] = llvm.udiv %[[LANE_ID]], %[[CST_16]] : i32 - // CHECK: %[[OFFSET_X:.*]] = llvm.add %[[LANE_OFFSET_X]], %[[WARP_OFFSET_X]] : i32 - // CHECK: %[[VAL_35:.*]] = llvm.udiv %[[CST_0]], %[[CST_1]] : i32 - // CHECK: %[[VAL_36:.*]] = llvm.urem %[[VAL_35]], %[[CST_1]] : i32 - // CHECK: %[[VAL_37:.*]] = llvm.urem %[[VAL_36]], %[[CST_1]] : i32 - // CHECK: %[[CTA_OFFSET_X:.*]] = llvm.mul %[[VAL_37]], %[[CST_32]] : i32 - // CHECK: %[[VAL_39:.*]] = llvm.add %[[OFFSET_X]], %[[CTA_OFFSET_X]] : i32 - // CHECK: %[[OFFSET_X_0:.*]] = llvm.add %[[VAL_39]], %[[CST_0]] : i32 - // CHECK: %[[OFFSET_X_1:.*]] = llvm.add %[[VAL_39]], %[[CST_1]] : i32 - // CHECK: %[[OFFSET_X_2:.*]] = llvm.add %[[VAL_39]], %[[CST_2]] : i32 - // CHECK: %[[OFFSET_X_3:.*]] = llvm.add %[[VAL_39]], %[[CST_3]] : i32 - // CHECK: %[[OFFSET_X_4:.*]] = llvm.add %[[VAL_39]], %[[CST_4]] : i32 - // CHECK: %[[OFFSET_X_5:.*]] = llvm.add %[[VAL_39]], %[[CST_5]] : i32 - // CHECK: %[[OFFSET_X_6:.*]] = llvm.add %[[VAL_39]], %[[CST_6]] : i32 - // CHECK: %[[OFFSET_X_7:.*]] = llvm.add %[[VAL_39]], %[[CST_7]] : i32 - // CHECK: %[[OFFSET_X_8:.*]] = llvm.add %[[VAL_39]], %[[CST_16]] : i32 - // CHECK: %[[OFFSET_X_9:.*]] = llvm.add %[[VAL_39]], %[[CST_17]] : i32 - // CHECK: %[[OFFSET_X_10:.*]] = llvm.add %[[VAL_39]], %[[CST_18]] : i32 - // CHECK: %[[OFFSET_X_11:.*]] = llvm.add %[[VAL_39]], %[[CST_19]] : i32 - // CHECK: %[[OFFSET_X_12:.*]] = llvm.add %[[VAL_39]], %[[CST_20]] : i32 - // CHECK: %[[OFFSET_X_13:.*]] = llvm.add %[[VAL_39]], %[[CST_21]] : i32 - // CHECK: %[[OFFSET_X_14:.*]] = llvm.add %[[VAL_39]], %[[CST_22]] : i32 - // CHECK: %[[OFFSET_X_15:.*]] = llvm.add %[[VAL_39]], %[[CST_23]] : i32 + // CHECK: %[[VAL_26:.*]] = llvm.and %[[WARP_ID]], %[[CST_2]] : i32 + // CHECK: %[[VAL_27:.*]] = llvm.icmp "eq" %[[VAL_26]], %[[CST_0]] : i32 + // CHECK: %[[VAL_28:.*]] = llvm.select %[[VAL_27]], %[[CST_0]], %[[CST_8]] : i1, i32 + // CHECK: %[[VAL_29:.*]] = llvm.xor %[[CST_0]], %[[VAL_28]] : i32 + // CHECK: %[[OFFSET_X_0:.*]] = llvm.xor %[[VAL_29]], %[[CST_0]] : i32 + // CHECK: %[[OFFSET_X_1:.*]] = llvm.xor %[[VAL_29]], %[[CST_1]] : i32 + // CHECK: %[[OFFSET_X_2:.*]] = llvm.xor %[[VAL_29]], %[[CST_2]] : i32 + // CHECK: %[[OFFSET_X_3:.*]] = llvm.xor %[[VAL_29]], %[[CST_3]] : i32 + // CHECK: %[[OFFSET_X_4:.*]] = llvm.xor %[[VAL_29]], %[[CST_4]] : i32 + // CHECK: %[[OFFSET_X_5:.*]] = llvm.xor %[[VAL_29]], %[[CST_5]] : i32 + // CHECK: %[[OFFSET_X_6:.*]] = llvm.xor %[[VAL_29]], %[[CST_6]] : i32 + // CHECK: %[[OFFSET_X_7:.*]] = llvm.xor %[[VAL_29]], %[[CST_7]] : i32 + // CHECK: %[[OFFSET_X_8:.*]] = llvm.xor %[[VAL_29]], %[[CST_16]] : i32 + // CHECK: %[[OFFSET_X_9:.*]] = llvm.xor %[[VAL_29]], %[[CST_17]] : i32 + // CHECK: %[[OFFSET_X_10:.*]] = llvm.xor %[[VAL_29]], %[[CST_18]] : i32 + // CHECK: %[[OFFSET_X_11:.*]] = llvm.xor %[[VAL_29]], %[[CST_19]] : i32 + // CHECK: %[[OFFSET_X_12:.*]] = llvm.xor %[[VAL_29]], %[[CST_20]] : i32 + // CHECK: %[[OFFSET_X_13:.*]] = llvm.xor %[[VAL_29]], %[[CST_21]] : i32 + // CHECK: %[[OFFSET_X_14:.*]] = llvm.xor %[[VAL_29]], %[[CST_22]] : i32 + // CHECK: %[[OFFSET_X_15:.*]] = llvm.xor %[[VAL_29]], %[[CST_23]] : i32 // CHECK: %[[VAL_56:.*]] = llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_0]], {{.*}}, {{.*}}) // CHECK: %[[VAL_57:.*]] = llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_1]], {{.*}}, {{.*}}) // CHECK: %[[VAL_58:.*]] = llvm.call @_Z18__spirv_ocl_printf({{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X_2]], {{.*}}, {{.*}}) diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index 4ee77e934d..6b902003fb 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -341,38 +341,44 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout, } // anonymous namespace +// clang-format off // The layout example repeat_count=8, systolic_depth=8, // execution_size=16 and operands_per_chan=2 for warp size 32. // For A operand: -// systolic depth = 8 -//<-----------------------------------------------------> -// opsPerChan=2 -//<---------> -// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 ^ -// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 | -// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 | -// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 | repeat count <= 8 -// t0 ... t0 t1 ... t1 ~ t6 ... t6 t7 ... t7 | -// t8 ... t8 t9 ... t9 ~ t14 ... t14 t15 ... t15 | -// t16 ... t16 t17 ... t17 ~ t22 ... t22 t23 ... t23 | -// t24 ... t24 t25 ... t25 ~ t30 ... t30 t31 ... t31 v +// K = 16 (K = systolic depth * opsPerChan) +// <----------------------------------------------------------------------------> +// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 ^ +// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | +// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | +// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | M = 8 (repeat count) +// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 | +// t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 | +// t16 t17 t18 t19 t20 t21 t22 t23 t24 t25 t26 t27 t28 t29 t30 t31 v // In this case, the LinearLayout bases are: -// Register: {{0,1}, {4,0}} -// Lane: {{0,2}, {0,4}, {0,8}, {1,0}, {2,0}} +// Register: {{2,0}, {4,0}} +// Lane: {{0,1}, {0,2}, {0,4}, {0,8}, {1,0}} +// clang-format on std::vector> DPASRegBasesA(int opsPerChannel, int repeatCount, int threadsPerWarp, int systolicDepth) { - int rowPerWarp = threadsPerWarp / systolicDepth; - int warpRepeats = repeatCount / rowPerWarp; std::vector> regBases; - for (int opc = 1; opc < opsPerChannel; opc *= 2) { + // pack the value to i16 for scalar bit width <=16. + assert((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1) && + "invalid opsPerChannel number."); + int packedOpsPerLane = opsPerChannel == 4 ? 2 : 1; + int packedColNum = (systolicDepth * opsPerChannel) / packedOpsPerLane; + int rowsPerWarp = mlir::ceil(threadsPerWarp, packedColNum); + int warpRepeats = repeatCount / rowsPerWarp; + + for (int opc = 1; opc < packedOpsPerLane; opc *= 2) { regBases.push_back({0, opc}); } for (int warp = 1; warp < warpRepeats; warp *= 2) { - regBases.push_back({warp * rowPerWarp, 0}); + regBases.push_back({warp * rowsPerWarp, 0}); } return regBases; @@ -382,11 +388,17 @@ std::vector> DPASLaneBasesA(int opsPerChannel, int threadsPerWarp, int systolicDepth) { std::vector> laneBases; - for (int tid = 1; tid < systolicDepth; tid *= 2) { - laneBases.push_back({0, opsPerChannel * tid}); + // pack the value to i16 for scalar bit width <=16. + assert((opsPerChannel == 4 || opsPerChannel == 2 || opsPerChannel == 1) && + "invalid opsPerChannel number."); + int packedOpsPerLane = opsPerChannel == 4 ? 2 : 1; + int packedColNum = (systolicDepth * opsPerChannel) / packedOpsPerLane; + + for (int tid = 1; tid < packedColNum; tid *= 2) { + laneBases.push_back({0, packedOpsPerLane * tid}); } - for (int tid = systolicDepth; tid < threadsPerWarp; tid *= 2) { - laneBases.push_back({tid / systolicDepth, 0}); + for (int tid = packedColNum; tid < threadsPerWarp; tid *= 2) { + laneBases.push_back({tid / packedColNum, 0}); } return laneBases; @@ -602,8 +614,7 @@ std::optional dotOperandDpasToLinearLayout(DotOperandEncodingAttr dotDpasLayout, ArrayRef shape) { auto dpasLayout = cast(dotDpasLayout.getParent()); - if (dotDpasLayout.getOpIdx() == 0) - return std::nullopt; + return DPAStoLinearLayout(shape, dpasLayout, dotDpasLayout.getOpIdx()); } diff --git a/third_party/intel/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp b/third_party/intel/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp index 6d42c9948a..d4f6d0b821 100644 --- a/third_party/intel/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp +++ b/third_party/intel/unittest/Dialect/TritonGPU/DPAStoLinearLayoutTest.cpp @@ -59,17 +59,47 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_perInst) { }, {S("dim0"), S("dim1")})); // Test Operand A (opIdx=0) + EXPECT_EQ( + DPAStoLinearLayout({8, 32}, dpas({1, 1}, 8, 8, 16, 4, {1, 1}, 32), 0), + LinearLayout( + { + {S("register"), {{0, 1}, {2, 0}, {4, 0}}}, + {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); EXPECT_EQ( DPAStoLinearLayout({8, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32), 0), LinearLayout( { - {S("register"), {{0, 1}, {4, 0}}}, - {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}, + {S("register"), {{2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); + EXPECT_EQ( + DPAStoLinearLayout({8, 8}, dpas({1, 1}, 8, 8, 16, 1, {1, 1}, 32), 0), + LinearLayout( + { + {S("register"), {{4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {1, 0}, {2, 0}}}, {S("warp"), {}}, {S("block"), {}}, }, {S("dim0"), S("dim1")})); // Test Operand B (opIdx=1) + EXPECT_EQ( + DPAStoLinearLayout({32, 16}, dpas({1, 1}, 8, 8, 16, 4, {1, 1}, 32), 1), + LinearLayout( + { + {S("register"), {{1, 0}, {2, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {4, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); EXPECT_EQ( DPAStoLinearLayout({16, 16}, dpas({1, 1}, 8, 8, 16, 2, {1, 1}, 32), 1), LinearLayout( @@ -80,6 +110,16 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_perInst) { {S("block"), {}}, }, {S("dim0"), S("dim1")})); + EXPECT_EQ( + DPAStoLinearLayout({8, 16}, dpas({1, 1}, 8, 8, 16, 1, {1, 1}, 32), 1), + LinearLayout( + { + {S("register"), {{2, 0}, {4, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, + {S("warp"), {}}, + {S("block"), {}}, + }, + {S("dim0"), S("dim1")})); } TEST_F(DPAStoLinearLayoutTest, DPAS_withRepCluster) { @@ -98,8 +138,8 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withRepCluster) { DPAStoLinearLayout({32, 16}, dpas({1, 1}, 8, 8, 16, 2, {4, 2}, 32), 0), LinearLayout( { - {S("register"), {{0, 1}, {4, 0}, {8, 0}, {16, 0}}}, - {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}, + {S("register"), {{2, 0}, {4, 0}, {8, 0}, {16, 0}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, {S("warp"), {}}, {S("block"), {}}, }, @@ -154,8 +194,8 @@ TEST_F(DPAStoLinearLayoutTest, DPAS_withWarpOperandA) { LinearLayout( { {S("register"), - {{0, 1}, {4, 0}, {8, 0}, {16, 0}, {0, 16}, {0, 32}}}, - {S("lane"), {{0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}}, + {{2, 0}, {4, 0}, {8, 0}, {16, 0}, {0, 16}, {0, 32}}}, + {S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}}, {S("warp"), {{0, 0}, {32, 0}}}, {S("block"), {}}, },