Skip to content

Commit

Permalink
Generate 2D prefetch with array length = 1 (#1775)
Browse files Browse the repository at this point in the history
GEMM performance improvement up to 5% is observed by using prefetch with
array length equals to 1, e.g., changing from `16b_4r16x2c` to
`16b_4r32x1c`. A IGC ticket is created to track the optimization. In the
meantime, this PR allows generation of `16b_?r32x1c` and `8b_?r64x1c`
under environment variable `TRITON_INTEL_ENABLE_FAST_PREFETCH`.
Note: OpenCL C doesn't support `16b_?r32x1c` and `8b_?r64x1c`.

Signed-off-by: Whittney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang authored Aug 5, 2024
1 parent ef59f97 commit a71e7be
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 19 deletions.
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"NVPTX_ENABLE_DUMP",
"TRITON_INTEL_ADVANCED_PATH",
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
"TRITON_INTEL_ENABLE_FAST_PREFETCH"
// clang-format on
};

Expand Down
2 changes: 2 additions & 0 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ run_benchmark_gemm() {
if [ ! -d "${BENCHMARK_TEST_DIR}" ]; then
echo "Not found '${BENCHMARK_TEST_DIR}'." ; exit 5
fi
cd $TRITON_PROJ/benchmarks; python setup.py install
TRITON_INTEL_ADVANCED_PATH=0 \
TRITON_INTEL_ENABLE_FAST_PREFETCH=1 \
TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 \
IGC_VISAOptions=" -TotalGRFNum 256 -enableBCR -nolocalra -printregusage -DPASTokenReduction -enableHalfLSC -abiver 2" \
IGC_DisableLoopUnroll=1 \
Expand Down
13 changes: 13 additions & 0 deletions test/TritonIntelGPU/prefetch-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: triton-opt %s --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
// RUN: TRITON_INTEL_ENABLE_FAST_PREFETCH=1 triton-opt %s --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefix=FAST

// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {passthrough = ["convergent"]}
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {passthrough = ["nounwind"]}
Expand All @@ -12,6 +13,18 @@
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
tt.func public @matmul_with_prefetch(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64) {
// CHECK-LABEL: @matmul_with_prefetch
// FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
// FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
// FAST: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
// FAST: [[C4:%.*]] = llvm.mlir.constant(4 : i32) : i32
// FAST: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
// FAST: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[C16]], [[C32]], [[C4]], [[C1]], {{.*}}, {{.*}}, {{.*}}) {{.*}} : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> ()
// FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
// FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
// FAST: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
// FAST: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
// FAST: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
// FAST: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[C16]], [[C32]], [[C2]], [[C1]], {{.*}}, {{.*}}, {{.*}}) {{.*}} : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> ()
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x2cPU3AS1viiiDv2_i({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> ()
// CHECK: llvm.call spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_2r16x2cPU3AS1viiiDv2_i({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>) -> ()
// CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
Expand Down
8 changes: 6 additions & 2 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ADT/STLExtras.h"
#include <cstdint>

Expand Down Expand Up @@ -268,15 +269,18 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {
if (verifyMatrixInput(*this).failed())
return failure();

const bool enableFastPrefetch =
tools::getBoolEnv("TRITON_INTEL_ENABLE_FAST_PREFETCH");
uint32_t tileWidth = getTileWidth();
switch (getElemSizeInBits()) {
case 8:
if (tileWidth != 16 && tileWidth != 32)
if (tileWidth != 16 && tileWidth != 32 &&
!(enableFastPrefetch && tileWidth == 64))
return emitOpError("tile_width for 8 bit elements should be equal to "
"16 or 32");
break;
case 16:
if (tileWidth != 16)
if (tileWidth != 16 && !(enableFastPrefetch && tileWidth == 32))
return emitOpError("tile_width for 16 bit elements should be equal "
"to 16");
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1333,7 +1333,11 @@ struct TritonMatrix2DBlockPrefetchLowering
ConversionPatternRewriter &rewriter) const override {
// TODO: Remove GenISA lowering after PoC productization is completed.
char *env = std::getenv("TRITONGEN_FORCE_GENISA");
const bool useGenISA = env ? (bool)std::atoi(env) : false;
bool useGenISA = env ? (bool)std::atoi(env) : false;
if (tools::getBoolEnv("TRITON_INTEL_ENABLE_FAST_PREFETCH") &&
((op.getElemSizeInBits() == 8 && op.getTileWidth() == 64) ||
(op.getElemSizeInBits() == 16 && op.getTileWidth() == 32)))
useGenISA = true;
if (useGenISA) {
rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter));
return success();
Expand Down
34 changes: 18 additions & 16 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,23 +251,25 @@ struct PrefetchOpConversion
unsigned tileWidthInElem = shapePerWarp[1];
unsigned tileHeightInElem = shapePerWarp[0];
unsigned vBlocks = 1;
switch (elemSizeInBits) {
case 8:
if (tileWidthInElem == 64) {
// OCL interface supports 8b_?r32x2c for 64 bytes per row of 8 bits
// element.
vBlocks = 2;
tileWidthInElem = 32;
}
break;
case 16:
if (tileWidthInElem == 32) {
// OCL interface supports 16b_?r16x2c for 64 bytes per row of 16 bits
// element.
vBlocks = 2;
tileWidthInElem = 16;
if (!tools::getBoolEnv("TRITON_INTEL_ENABLE_FAST_PREFETCH")) {
switch (elemSizeInBits) {
case 8:
if (tileWidthInElem == 64) {
// OCL interface supports 8b_?r32x2c for 64 bytes per row of 8 bits
// element.
vBlocks = 2;
tileWidthInElem = 32;
}
break;
case 16:
if (tileWidthInElem == 32) {
// OCL interface supports 16b_?r16x2c for 64 bytes per row of 16 bits
// element.
vBlocks = 2;
tileWidthInElem = 16;
}
break;
}
break;
}

Value warpId = rewriter.create<arith::IndexCastOp>(
Expand Down

0 comments on commit a71e7be

Please sign in to comment.