Skip to content

Commit

Permalink
Deprecate fast prefetch implementation (#2280)
Browse files Browse the repository at this point in the history
The optimization is implemented in IGC, and it is already disabled in
Triton for one agama release, so this PR removes it.

Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang authored Sep 20, 2024
1 parent 9621c64 commit d76446a
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 41 deletions.
1 change: 0 additions & 1 deletion include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_INTEL_ADVANCED_PATH",
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
"TRITON_INTEL_ENABLE_FAST_PREFETCH",
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
"TRITON_INTEL_ENABLE_INSTR_SCHED",
"TRITON_INTEL_ENABLE_POST_PROCESS_LLIR",
Expand Down
13 changes: 0 additions & 13 deletions test/TritonIntelGPU/prefetch-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
// RUN: TRITON_INTEL_ENABLE_FAST_PREFETCH=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefix=FAST

// CHECK-DAG: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x2cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind}
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} {
Expand All @@ -8,12 +7,6 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64

// FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
// FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
// FAST: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
// FAST: [[C4:%.*]] = llvm.mlir.constant(4 : i32) : i32
// FAST: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
// FAST: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[C16]], [[C32]], [[C4]], [[C1]], {{.*}}, {{.*}}, {{.*}}) {{.*}} : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> ()
// CHECK: %[[ROW_MAJOR_BLOCK_PTR:.*]] = llvm.insertvalue %arg0, {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[VAL_17:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv()
// CHECK: %[[VAL_18:.*]] = llvm.sext %[[VAL_17]] : i32 to i64
Expand Down Expand Up @@ -49,12 +42,6 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war

// COM: The memory layout is same for the column major memory and row major memory. The prefetch function should be the same.

// FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
// FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
// FAST: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
// FAST: [[C4:%.*]] = llvm.mlir.constant(4 : i32) : i32
// FAST: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32
// FAST: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[C16]], [[C32]], [[C4]], [[C1]], {{.*}}, {{.*}}, {{.*}}) {{.*}} : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> ()
// CHECK: %[[COLUMN_MAJOR_BLOCK_PTR:.*]] = llvm.insertvalue %arg1, {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[COLUMN_MAJOR_OFFSET_Y:.*]] = llvm.extractvalue %[[COLUMN_MAJOR_BLOCK_PTR]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[COLUMN_MAJOR_OFFSET_X:.*]] = llvm.extractvalue %[[COLUMN_MAJOR_BLOCK_PTR]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
Expand Down
7 changes: 2 additions & 5 deletions third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,18 +421,15 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {
if (verifyMatrixInput(*this).failed())
return failure();

const bool enableFastPrefetch =
tools::getBoolEnv("TRITON_INTEL_ENABLE_FAST_PREFETCH");
uint32_t tileWidth = getTileWidth();
switch (getElemSizeInBits()) {
case 8:
if (tileWidth != 16 && tileWidth != 32 &&
!(enableFastPrefetch && tileWidth == 64))
if (tileWidth != 16 && tileWidth != 32)
return emitOpError("tile_width for 8 bit elements should be equal to "
"16 or 32");
break;
case 16:
if (tileWidth != 16 && !(enableFastPrefetch && tileWidth == 32))
if (tileWidth != 16)
return emitOpError("tile_width for 16 bit elements should be equal "
"to 16");
break;
Expand Down
4 changes: 0 additions & 4 deletions third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1106,10 +1106,6 @@ struct TritonMatrix2DBlockPrefetchLowering
ConversionPatternRewriter &rewriter) const override {
// TODO: Remove GenISA lowering after PoC productization is completed.
bool useGenISA = tools::getBoolEnv("TRITONGEN_FORCE_GENISA");
if (tools::getBoolEnv("TRITON_INTEL_ENABLE_FAST_PREFETCH") &&
((op.getElemSizeInBits() == 8 && op.getTileWidth() == 64) ||
(op.getElemSizeInBits() == 16 && op.getTileWidth() == 32)))
useGenISA = true;
if (useGenISA) {
rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter));
return success();
Expand Down
34 changes: 16 additions & 18 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,25 +276,23 @@ struct PrefetchOpConversion
unsigned tileWidthInElem = shapePerWarp[1];
unsigned tileHeightInElem = shapePerWarp[0];
unsigned vBlocks = 1;
if (!tools::getBoolEnv("TRITON_INTEL_ENABLE_FAST_PREFETCH")) {
switch (elemSizeInBits) {
case 8:
if (tileWidthInElem == 64) {
// OCL interface supports 8b_?r32x2c for 64 bytes per row of 8 bits
// element.
vBlocks = 2;
tileWidthInElem = 32;
}
break;
case 16:
if (tileWidthInElem == 32) {
// OCL interface supports 16b_?r16x2c for 64 bytes per row of 16 bits
// element.
vBlocks = 2;
tileWidthInElem = 16;
}
break;
switch (elemSizeInBits) {
case 8:
if (tileWidthInElem == 64) {
// OCL interface supports 8b_?r32x2c for 64 bytes per row of 8 bits
// element.
vBlocks = 2;
tileWidthInElem = 32;
}
break;
case 16:
if (tileWidthInElem == 32) {
// OCL interface supports 16b_?r16x2c for 64 bytes per row of 16 bits
// element.
vBlocks = 2;
tileWidthInElem = 16;
}
break;
}

Value warpId = rewriter.create<arith::IndexCastOp>(
Expand Down

0 comments on commit d76446a

Please sign in to comment.