diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 92f9adf23b..a0c4b08ae2 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -33,7 +33,6 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_INTEL_ADVANCED_PATH", "TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN", "TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT", - "TRITON_INTEL_ENABLE_FAST_PREFETCH", "TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM", "TRITON_INTEL_ENABLE_INSTR_SCHED", "TRITON_INTEL_ENABLE_POST_PROCESS_LLIR", diff --git a/test/TritonIntelGPU/prefetch-to-llvm.mlir b/test/TritonIntelGPU/prefetch-to-llvm.mlir index 263659d088..a5709ec017 100644 --- a/test/TritonIntelGPU/prefetch-to-llvm.mlir +++ b/test/TritonIntelGPU/prefetch-to-llvm.mlir @@ -1,5 +1,4 @@ // RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm -// RUN: TRITON_INTEL_ENABLE_FAST_PREFETCH=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefix=FAST // CHECK-DAG: llvm.func spir_funccc @_Z45intel_sub_group_2d_block_prefetch_16b_4r16x2cPU3AS1viiiDv2_i(!llvm.ptr<1> {llvm.nonnull}, i32, i32, i32, vector<2xi32>) attributes {memory_effects = #llvm.memory_effects, no_unwind} module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 16 : i32} { @@ -8,12 +7,6 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war %c0_i32 = arith.constant 0 : i32 %c1_i64 = arith.constant 1 : i64 - // FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // FAST: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32 - // FAST: [[C4:%.*]] = llvm.mlir.constant(4 : i32) : i32 - // FAST: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // FAST: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[C16]], [[C32]], [[C4]], [[C1]], {{.*}}, {{.*}}, {{.*}}) {{.*}} : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> () // CHECK: %[[ROW_MAJOR_BLOCK_PTR:.*]] = llvm.insertvalue %arg0, {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[VAL_17:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv() // CHECK: %[[VAL_18:.*]] = llvm.sext %[[VAL_17]] : i32 to i64 @@ -49,12 +42,6 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war // COM: The memory layout is same for the column major memory and row major memory. The prefetch function should be the same. - // FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // FAST: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // FAST: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32 - // FAST: [[C4:%.*]] = llvm.mlir.constant(4 : i32) : i32 - // FAST: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : i32 - // FAST: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockPrefetch.isVoid({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[C16]], [[C32]], [[C4]], [[C1]], {{.*}}, {{.*}}, {{.*}}) {{.*}} : (i64, i32, i32, i32, i32, i32, i32, i32, i32, i32, i1, i1, i32) -> () // CHECK: %[[COLUMN_MAJOR_BLOCK_PTR:.*]] = llvm.insertvalue %arg1, {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[COLUMN_MAJOR_OFFSET_Y:.*]] = llvm.extractvalue %[[COLUMN_MAJOR_BLOCK_PTR]][0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> // CHECK: %[[COLUMN_MAJOR_OFFSET_X:.*]] = llvm.extractvalue %[[COLUMN_MAJOR_BLOCK_PTR]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)> diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index c1ad8011fb..2c83b6910c 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -421,18 +421,15 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() { if (verifyMatrixInput(*this).failed()) return failure(); - const bool enableFastPrefetch = - tools::getBoolEnv("TRITON_INTEL_ENABLE_FAST_PREFETCH"); uint32_t tileWidth = getTileWidth(); switch (getElemSizeInBits()) { case 8: - if (tileWidth != 16 && tileWidth != 32 && - !(enableFastPrefetch && tileWidth == 64)) + if (tileWidth != 16 && tileWidth != 32) return emitOpError("tile_width for 8 bit elements should be equal to " "16 or 32"); break; case 16: - if (tileWidth != 16 && !(enableFastPrefetch && tileWidth == 32)) + if (tileWidth != 16) return emitOpError("tile_width for 16 bit elements should be equal " "to 16"); break; diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index f0e831363c..2316c22704 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -1106,10 +1106,6 @@ struct TritonMatrix2DBlockPrefetchLowering ConversionPatternRewriter &rewriter) const override { // TODO: Remove GenISA lowering after PoC productization is completed. bool useGenISA = tools::getBoolEnv("TRITONGEN_FORCE_GENISA"); - if (tools::getBoolEnv("TRITON_INTEL_ENABLE_FAST_PREFETCH") && - ((op.getElemSizeInBits() == 8 && op.getTileWidth() == 64) || - (op.getElemSizeInBits() == 16 && op.getTileWidth() == 32))) - useGenISA = true; if (useGenISA) { rewriter.replaceOp(op, createGenISA2DBlockPrefetch(op, rewriter)); return success(); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index a43d10115c..30284f1bb3 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -276,25 +276,23 @@ struct PrefetchOpConversion unsigned tileWidthInElem = shapePerWarp[1]; unsigned tileHeightInElem = shapePerWarp[0]; unsigned vBlocks = 1; - if (!tools::getBoolEnv("TRITON_INTEL_ENABLE_FAST_PREFETCH")) { - switch (elemSizeInBits) { - case 8: - if (tileWidthInElem == 64) { - // OCL interface supports 8b_?r32x2c for 64 bytes per row of 8 bits - // element. - vBlocks = 2; - tileWidthInElem = 32; - } - break; - case 16: - if (tileWidthInElem == 32) { - // OCL interface supports 16b_?r16x2c for 64 bytes per row of 16 bits - // element. - vBlocks = 2; - tileWidthInElem = 16; - } - break; + switch (elemSizeInBits) { + case 8: + if (tileWidthInElem == 64) { + // OCL interface supports 8b_?r32x2c for 64 bytes per row of 8 bits + // element. + vBlocks = 2; + tileWidthInElem = 32; } + break; + case 16: + if (tileWidthInElem == 32) { + // OCL interface supports 16b_?r16x2c for 64 bytes per row of 16 bits + // element. + vBlocks = 2; + tileWidthInElem = 16; + } + break; } Value warpId = rewriter.create(