Skip to content

Commit

Permalink
Make isExpensiveLoadOrStore consider blocked pointers load and stor…
Browse files Browse the repository at this point in the history
…es (#2570)

The `isExpensiveLoadOrStore` function
(third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp) fails to
consider block pointers and consequently always returns `false` for
loads (and stores) operations that use a block pointer.
In turn, this causes the `RemoveLayoutConversion` pass to never consider
loads using block pointers as `anchor` operations.

This PR changes `isExpensiveLoadOrStore` so that block pointer loads can
be properly recognized. The `RemoveLayourConversion` pass is then able
to consider those loads as anchor operations and preserve their layout.

Because `RemoveLayoutConversion` is invoked at several points in the
optimization pipeline, the change in
third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp alone causes
performance degradation in a couple of GEMM like benchmarks,
specifically when operand A of `tl.dot` is transposed and when the input
of `tl.dot` is first fed into an exponential.

These 2 performance degradation have ben fixed by an enhancing the
`MaterializeBlockPointer` and `MatmulLoopPipeline` optimizations, so
that they can retrieve the dot layout of block pointer loads
transitively from its users (in those benchmarks the blocked layout of
block ptrs loads is transitively converted to a dot layout).

---------

Signed-off-by: Tiotto, Ettore <[email protected]>
  • Loading branch information
etiotto authored Oct 31, 2024
1 parent 55702d9 commit 1dbef57
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 49 deletions.
21 changes: 10 additions & 11 deletions test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
// CHECK: %[[VAL_40:.*]] = tt.make_tensor_ptr %{{.*}}, {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
// CHECK: %[[VAL_41:.*]]:3 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %[[VAL_36]], %{{.*}} = %[[VAL_40]]) -> (tensor<64x256xf32, #[[DPAS]]>, !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>) : i32 {
// CHECK: %[[VAL_46:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_47:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_46:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_47:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK-NOT: triton_gpu.convert_layout
// CHECK-NEXT: %[[VAL_48:.*]] = tt.dot %[[VAL_46]], %[[VAL_47]], %{{.*}}, inputPrecision = tf32 : tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<64x256xf32, #[[DPAS]]>
// CHECK: %[[VAL_49:.*]] = tt.advance %{{.*}}, {{\[}}%{{.*}}, %{{.*}}] : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_50:.*]] = tt.advance %{{.*}}, {{\[}}%{{.*}}, %{{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: scf.yield %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x256xf32, #[[DPAS]]>, !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
Expand Down Expand Up @@ -130,7 +130,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
}
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
// CHECK-NOT: triton_gpu.convert_layout
%25 = triton_gpu.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1>
%26 = arith.extsi %arg8 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
Expand All @@ -147,6 +146,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// COM: Checks that DPAS encoding has been forwarded to the store op
// COM: The `tt.make_tensor_ptr` has multiple users (the storeOp + another OP)
// COM: The initial `tt.make_tensor_ptr` with non-DPAS encoding must be kept.
// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
Expand Down Expand Up @@ -188,8 +188,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%21 = arith.extsi %arg7 : i32 to i64
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
Expand All @@ -198,11 +198,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
}
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
// CHECK-NOT: triton_gpu.convert_layout
%25 = triton_gpu.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1>
%26 = arith.extsi %arg8 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[BLOCKED]]>>
%27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #blocked1>>
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
tt.store %27, %25 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #blocked1>>
Expand Down Expand Up @@ -243,8 +242,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #blocked1>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%36 = triton_gpu.convert_layout %arg10 : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #dpas>
%30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
Expand Down
14 changes: 6 additions & 8 deletions test/TritonIntelGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2324,31 +2324,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked2>
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
// CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xbf16, {{.*}}>>
// CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xbf16, {{.*}}>>
%12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #blocked3>>
%14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<32x256xbf16, #blocked2>>
// CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>) : i32 {
// CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, {{.*}}>>, !tt.ptr<tensor<32x256xbf16, {{.*}}>>) : i32 {
%15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst_1, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #blocked2>, !tt.ptr<tensor<256x32xbf16, #blocked3>>, !tt.ptr<tensor<32x256xbf16, #blocked2>>) : i32 {
%47 = tt.load %arg5 : !tt.ptr<tensor<256x32xbf16, #blocked3>>
%48 = tt.load %arg6 : !tt.ptr<tensor<32x256xbf16, #blocked2>>
// CHEKC-NOT: triton_gpu.convert_layout
%49 = triton_gpu.convert_layout %arg4 : tensor<256x256xf32, #blocked2> -> tensor<256x256xf32, #mma>
%50 = triton_gpu.convert_layout %47 : tensor<256x32xbf16, #blocked3> -> tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%51 = triton_gpu.convert_layout %48 : tensor<32x256xbf16, #blocked2> -> tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%52 = tt.dot %50, %51, %49, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
%53 = triton_gpu.convert_layout %52 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked2>
// CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : <tensor<256x32xbf16, {{.*}}>>
// CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : <tensor<32x256xbf16, {{.*}}>>
// CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, {{.*}}>>, !tt.ptr<tensor<32x256xbf16, {{.*}}>>
%54 = tt.advance %arg5, [%c0_i32, %c128_i32] : <tensor<256x32xbf16, #blocked3>>
%55 = tt.advance %arg6, [%c128_i32, %c0_i32] : <tensor<32x256xbf16, #blocked2>>
scf.yield %53, %54, %55 : tensor<256x256xf32, #blocked2>, !tt.ptr<tensor<256x32xbf16, #blocked3>>, !tt.ptr<tensor<32x256xbf16, #blocked2>>
}
%16 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
%32 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #blocked2>
%38 = arith.cmpi slt, %16, %cst : tensor<256xi32, #blocked>
// CHEKC-NOT: triton_gpu.convert_layout
%39 = triton_gpu.convert_layout %38 : tensor<256xi1, #blocked> -> tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>
%40 = tt.expand_dims %39 {axis = 0 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x256xi1, #blocked4>
%41 = triton_gpu.convert_layout %40 : tensor<1x256xi1, #blocked4> -> tensor<1x256xi1, #blocked2>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Visitors.h"
#include "triton/Analysis/Utility.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include <optional>

#define DEBUG_TYPE "tritonintelgpu-materialize-block-pointer"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttgi = mlir::triton::gpu::intel;

namespace mlir::triton::gpu::intel {
Expand All @@ -37,7 +40,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
return;

MLIRContext *context = &getContext();
mod.walk([context](tt::LoadOp loadOp) {
mod.walk([context, this](tt::LoadOp loadOp) {
LDBG("Considering op: " << loadOp);

Value ptr = loadOp.getPtr();
Expand All @@ -51,7 +54,6 @@ struct TritonIntelGPUMaterializeBlockPointerPass
LDBG("Found make tensor ptr op: " << makeTensorPtrOp);
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
auto dotLayout = ttgi::getDotEncoding(tensorType);

Operation::operand_range shape = makeTensorPtrOp.getShape();
unsigned rank = shape.size();
Expand Down Expand Up @@ -100,11 +102,13 @@ struct TritonIntelGPUMaterializeBlockPointerPass
return;

const bool isRowMajor = fastChangeDim == rank - 1;
std::optional<ttg::DotOperandEncodingAttr> dotLayout =
getDotLayout(loadOp);
if (dotLayout) {
// Check if the load is being used in a dot layout, and if so is this
// the first op and is it a transposed row major matrix. If so, skip
// the block ptr attribute as performance is worse than if we remove
// the tensor pointer
// Check if the load is being used by a tt.dot operation, and if so is
// this the first operand and is it a transposed row major matrix. If
// so, skip the block ptr attribute as performance is worse than if we
// remove the tensor pointer.
LDBG("dotLayout: " << *dotLayout);
const unsigned opIdx = dotLayout->getOpIdx();
auto dotOrder = dotLayout->getThreadOrder();
Expand All @@ -122,6 +126,52 @@ struct TritonIntelGPUMaterializeBlockPointerPass
}
});
}

private:
// Return the load layout if it is a dot layout. If it is not, check if the
// load result is converted to a dot layout. If so, return the dot layout,
// otherwise return nullopt.
std::optional<ttg::DotOperandEncodingAttr>
getDotLayout(tt::LoadOp loadOp) const {
Value ptr = loadOp.getPtr();
if (!tt::isTensorPointerType(ptr.getType()))
return std::nullopt;

RankedTensorType tensorType = ttgi::getRankedTensorType(ptr.getType());
if (!tensorType)
return std::nullopt;

auto dotLayout = ttgi::getDotEncoding(tensorType);
if (dotLayout)
return dotLayout;

auto allUsersAreConvertOps = [](Operation::user_range users) {
return llvm::all_of(users, [](Operation *user) {
return isa<ttg::ConvertLayoutOp>(user);
});
};

auto allUserHaveIdenticalLayout = [](Operation::user_range users) {
Attribute firstUserLayout =
cast<ttg::ConvertLayoutOp>(*users.begin()).getType().getEncoding();
return llvm::all_of(users, [&firstUserLayout](Operation *user) {
return firstUserLayout ==
cast<ttg::ConvertLayoutOp>(user).getType().getEncoding();
});
};

Operation::user_range users = loadOp->getUsers();
if (!users.empty() && allUsersAreConvertOps(users) &&
allUserHaveIdenticalLayout(users)) {
Attribute firstUserLayout =
cast<ttg::ConvertLayoutOp>(*users.begin()).getType().getEncoding();
if (isa<ttg::DotOperandEncodingAttr>(firstUserLayout))
return dyn_cast<ttg::DotOperandEncodingAttr>(firstUserLayout);
return std::nullopt;
}

return std::nullopt;
}
};

} // anonymous namespace
Loading

0 comments on commit 1dbef57

Please sign in to comment.