diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 04ba951963..84ffa054dd 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -1270,13 +1270,18 @@ The parent field is the layout of d. kWidth defines number of consecutive elements stored by one thread along k dimension. Some layouts do not use this parameter, either because they have a fixed number of elements along the K dim, or they use all elements of the tensor along the K dim. + +`isTransposed` indicates the result tensor is transposed so that it can be loaded with +a transpose flag. This is used in the case of chained dot (E.g. Flash-Attention kernel). }]; let parameters = ( ins "unsigned":$opIdx, "Attribute":$parent, - DefaultValuedParameter<"unsigned", "0">:$kWidth + DefaultValuedParameter<"unsigned", "0">:$kWidth, + // intel specific + DefaultValuedParameter<"bool", "false">:$isTransposed ); let builders = [ @@ -1286,15 +1291,27 @@ elements along the K dim, or they use all elements of the tensor along the K dim "Type":$eltTy), [{ NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast(parent); if (!parentAttr || !parentAttr.isAmpere()) - return $_get(context, opIdx, parent, 0); + return $_get(context, opIdx, parent, 0, false); unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); unsigned MMAv2kWidth = 32 / bitwidth; - return $_get(context, opIdx, parent, MMAv2kWidth); + return $_get(context, opIdx, parent, MMAv2kWidth, false); + }]>, + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent, + "unsigned":$kWidth), [{ + return $_get(context, opIdx, parent, kWidth, false); + }]>, + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent, + "unsigned":$kWidth, + "bool":$isTransposed), [{ + return $_get(context, opIdx, parent, kWidth, isTransposed); }]> ]; let assemblyFormat = "`<` `{` struct(params) `}` `>`"; let genVerifyDecl = 1; + let skipDefaultBuilders = 1; let extraClassDeclaration = extraDistributedDeclaration # [{ SmallVector getContigPerThread() { return getSizePerThread(); diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 584221f337..c5acde63da 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1018,7 +1018,7 @@ SmallVector DotOperandEncodingAttr::getShapePerCTATile( LogicalResult DotOperandEncodingAttr::verify( ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - unsigned opIdx, Attribute parent, unsigned kWidth) { + unsigned opIdx, Attribute parent, unsigned kWidth, bool isTransposed) { if (opIdx != 0 && opIdx != 1) { return emitError() << "triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: " diff --git a/test/TritonIntelGPU/match-target-size.mlir b/test/TritonIntelGPU/match-target-size.mlir index 6c910f3328..dff3ef0042 100644 --- a/test/TritonIntelGPU/match-target-size.mlir +++ b/test/TritonIntelGPU/match-target-size.mlir @@ -116,34 +116,6 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16 // ----- -// COM: Test SCF canonicalization: ensure loop is not modified when the result is not used by just 'extract' operations. -tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr, - %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) { - // CHECK-LABEL: @simplify_scf_for - // CHECK: [[GLUE:%.*]] = triton_intel_gpu.glue - // CHECK-NEXT: [[RES:%.*]] = scf.for {{.*}} iter_args([[INIT1:%.*]] = [[GLUE]]) -> (tensor<16x16xf16>) : i32 { - // CHECK: scf.yield {{.*}} : tensor<16x16xf16> - // CHECK-NEXT: } - // CHECK-NEXT: [[PTR:%.*]] = tt.make_tensor_ptr %arg2 - // CHECK-NEXT: tt.store [[PTR]], [[RES]] - %lb = arith.constant 0 : i32 - %ub = arith.constant 32 : i32 - %st = arith.constant 1 : i32 - %c1_i64 = arith.constant 1 : i64 - %glue = triton_intel_gpu.glue %arg0, %arg1 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> - %res = scf.for %iv = %lb to %ub step %st iter_args(%arg = %glue) -> (tensor<16x16xf16>) : i32 { - %e1 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16> - %e2 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16> - %g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16> - scf.yield %g1 : tensor<16x16xf16> - } - %ptr = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg5, %c1_i64], [%arg6, %arg7] {order = array} : > - tt.store %ptr, %res {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32} : !tt.ptr> - tt.return -} - -// ----- - // COM: Test SCF canonicalization: ensure loop is not modified if any user of a 'glue' init value is not an 'extract' operation. tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) { @@ -279,3 +251,111 @@ tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr {tt.d tt.store %tptr_c, %35#0 {boundaryCheck = array} : !tt.ptr> tt.return } + +// ----- + +// COM: Test Attention Related Ops +#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 1 : i32} { +// CHECK-LABEL: @attn_fwd +tt.func public @attn_fwd(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: f32, %arg4: !tt.ptr, %arg5: !tt.ptr) { + %c16_i32 = arith.constant 16 : i32 + %c128_i32 = arith.constant 128 : i32 + %c1024_i64 = arith.constant 1024 : i64 + %c64_i64 = arith.constant 64 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant 1.44269502 : f32 + %c3145728_i64 = arith.constant 3145728 : i64 + %c65536_i64 = arith.constant 65536 : i64 + %cst_0 = arith.constant dense<0.000000e+00> : tensor<16x64xf32, #warp> + %c64_i32 = arith.constant 64 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %cst_1 = arith.constant dense<0xFF800000> : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %cst_2 = arith.constant dense<1.000000e+00> : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %0 = gpu.subgroup_id : index + %1 = arith.index_cast %0 : index to i32 + %2 = tt.get_program_id z : i32 + %3 = tt.get_program_id x : i32 + %4 = tt.get_program_id y : i32 + %5 = arith.extsi %3 : i32 to i64 + %6 = arith.muli %5, %c3145728_i64 : i64 + %7 = arith.extsi %4 : i32 to i64 + %8 = arith.muli %7, %c65536_i64 : i64 + %9 = arith.addi %6, %8 : i64 + %10 = tt.addptr %arg0, %9 : !tt.ptr, i64 + %11 = arith.muli %2, %c128_i32 : i32 + %12 = arith.muli %1, %c16_i32 : i32 + %13 = arith.addi %12, %11 : i32 + %14 = tt.make_tensor_ptr %10, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array} : >> + %15 = tt.addptr %arg2, %9 : !tt.ptr, i64 + %16 = tt.make_tensor_ptr %15, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : >> + %17 = tt.addptr %arg1, %9 : !tt.ptr, i64 + %18 = tt.make_tensor_ptr %17, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array} : >> + %19 = tt.addptr %arg5, %9 : !tt.ptr, i64 + %20 = tt.make_tensor_ptr %19, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array} : > + %21 = arith.mulf %arg3, %cst : f32 + %22 = tt.load %14 : !tt.ptr>> + // CHECK: tt.splat {{.*}} : f32 -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + // CHECK-COUNT-8: tt.splat {{.*}} : f32 -> tensor<8x16xf32> + %23 = tt.splat %21 : f32 -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %24 = tt.splat %21 : f32 -> tensor<16x64xf32, #warp> + // CHECK: 30 = scf.for + %25:5 = scf.for %arg6 = %c0_i32 to %c1024_i32 step %c64_i32 iter_args(%arg7 = %cst_2, %arg8 = %cst_0, %arg9 = %cst_1, %arg10 = %16, %arg11 = %18) -> (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr>>, !tt.ptr>>) : i32 { + // CHECK-COUNT-16: tt.load {{.*}} {DotIdx = 1 : i32} : !tt.ptr> + // CHECK-COUNT-32: tt.dot {{.*}} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32> + %29 = tt.load %arg11 : !tt.ptr>> + %30 = tt.dot %22, %29, %cst_0, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>> -> tensor<16x64xf32, #warp> + + // CHECK-COUNT-2: arith.maxnumf {{.*}} : tensor<8x16xf32> + // CHECK: [[MAX:%.*]] = arith.maxnumf {{.*}} : tensor<8x16xf32> + // CHECK-NEXT: [[EXTRACT0:%.*]] = triton_intel_gpu.extract [[MAX]][0] : tensor<8x16xf32> -> tensor<16xf32> + // CHECK-NEXT: "tt.reduce"([[EXTRACT0]]) <{axis = 0 : i32}> ({ + // CHECK: }) : (tensor<16xf32>) -> f32 + %31 = "tt.reduce"(%30) <{axis = 1 : i32}> ({ + ^bb0(%arg12: f32, %arg13: f32): + %53 = arith.maxnumf %arg12, %arg13 : f32 + tt.reduce.return %53 : f32 + }) : (tensor<16x64xf32, #warp>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %32 = arith.mulf %31, %23 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %33 = arith.maxnumf %arg9, %32 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %34 = arith.mulf %30, %24 : tensor<16x64xf32, #warp> + + // CHECK: tt.expand_dims {{.*}} {axis = 1 : i32} : tensor<16xf32 + // CHECK: triton_intel_gpu.broadcast {{.*}} -> tensor<16x16xf32> + %35 = tt.expand_dims %33 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp> + %36 = tt.broadcast %35 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp> + %37 = arith.subf %34, %36 : tensor<16x64xf32, #warp> + %38 = math.exp2 %37 : tensor<16x64xf32, #warp> + %39 = "tt.reduce"(%38) <{axis = 1 : i32}> ({ + ^bb0(%arg12: f32, %arg13: f32): + %53 = arith.addf %arg12, %arg13 : f32 + tt.reduce.return %53 : f32 + }) : (tensor<16x64xf32, #warp>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %40 = arith.subf %arg9, %33 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %41 = math.exp2 %40 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %42 = arith.mulf %arg7, %41 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %43 = arith.addf %42, %39 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> + %44 = tt.expand_dims %41 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp> + %45 = tt.broadcast %44 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp> + %46 = arith.mulf %arg8, %45 : tensor<16x64xf32, #warp> + %47 = tt.load %arg10 : !tt.ptr>> + %48 = arith.truncf %38 : tensor<16x64xf32, #warp> to tensor<16x64xf16, #warp> + %49 = triton_gpu.convert_layout %48 : tensor<16x64xf16, #warp> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> + + // CHECK-COUNT-32: tt.dot {{.*}} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32> + // CHECK-COUNT-4: tt.advance {{.*}} : > + // CHECK-COUNT-16: tt.advance {{.*}} : > + // CHECK: scf.yield + %50 = tt.dot %49, %47, %46, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<16x64xf32, #warp> + %51 = tt.advance %arg10, [%c64_i32, %c0_i32] : >> + %52 = tt.advance %arg11, [%c0_i32, %c64_i32] : >> + scf.yield %43, %50, %33, %51, %52 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr>>, !tt.ptr>> + } {triton_gpu.workload = 4 : i32, tt.divisibility_arg1 = dense<64> : tensor<1xi32>} + %26 = tt.expand_dims %25#0 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp> + %27 = tt.broadcast %26 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp> + %28 = arith.divf %25#1, %27 : tensor<16x64xf32, #warp> + tt.store %20, %28 : !tt.ptr> + tt.return +} +} diff --git a/test/TritonIntelGPU/tritonintelgpu-invalid.mlir b/test/TritonIntelGPU/tritonintelgpu-invalid.mlir index 674773ce20..22cba75bee 100644 --- a/test/TritonIntelGPU/tritonintelgpu-invalid.mlir +++ b/test/TritonIntelGPU/tritonintelgpu-invalid.mlir @@ -1,53 +1,23 @@ // RUN: triton-opt -split-input-file -verify-diagnostics %s -tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<8xf16>) { - // expected-error @+1 {{'triton_intel_gpu.glue' op operands and result must have the same rank}} - triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<8xf16>) -> tensor<24x8xf16> - tt.return -} - -// ----- - -tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr>, %ptr2 : !tt.ptr>) { - // expected-error @+1 {{'triton_intel_gpu.glue' op operands and result must have the same rank}} - triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr>, !tt.ptr>) -> !tt.ptr> - tt.return -} - -// ----- - -tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16x8xf32>) { - // expected-error @+1 {{'triton_intel_gpu.glue' op operands and result element type must match}} - triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<16x8xf32>) -> tensor<16x16xf16> +// COM: Ensure that tensors with different shape cannot be glued. +tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16xf16>, %tensor2 : tensor<16x8xf32>) { + // expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same type}} + triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16xf16>, tensor<16x8xf32>) -> tensor<16x16xf16> tt.return } // ----- +// COM: Ensure that tensors with the different element types cannot be glued. tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr>, %ptr2 : !tt.ptr>) { - // expected-error @+1 {{'triton_intel_gpu.glue' op operands and result element type must match}} + // expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same type}} triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr>, !tt.ptr>) -> !tt.ptr> tt.return } // ----- -tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<8x8xf16>) { - // expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same shape}} - triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<8x8xf16>) -> tensor<24x8xf16> - tt.return -} - -// ----- - -tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr>, %ptr2 : !tt.ptr>) { - // expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same shape}} - triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr>, !tt.ptr>) -> !tt.ptr> - tt.return -} - -// ----- - tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16x8xf16>) { // expected-error @+1 {{'triton_intel_gpu.glue' op operands cannot exceed result size along any dimension}} triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<8x16xf16> @@ -88,22 +58,6 @@ tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16 // ----- -tt.func @triton_intel_gpu.extract(%tensor : tensor<16x16xf16>) { - // expected-error @+1 {{'triton_intel_gpu.extract' op operand and result must have the same rank}} - triton_intel_gpu.extract %tensor[0] : tensor<16x16xf16> -> tensor<4xf16> - tt.return -} - -// ----- - -tt.func @triton_intel_gpu.extract(%ptr : !tt.ptr>) { - // expected-error @+1 {{'triton_intel_gpu.extract' op operand and result must have the same rank}} - triton_intel_gpu.extract %ptr[0] : !tt.ptr> -> !tt.ptr> - tt.return -} - -// ----- - tt.func @triton_intel_gpu.extract(%tensor : tensor<16x16xf16>) { // expected-error @+1 {{'triton_intel_gpu.extract' op operand and result element type must match}} triton_intel_gpu.extract %tensor[0] : tensor<16x16xf16> -> tensor<4x4xf32> diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td index 77773d4bb4..262704fdee 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td @@ -42,10 +42,8 @@ def TTIG_GlueOp : TTIG_Op<"glue", [Pure]> { The `glue` operation glues its input operands to form the result tensor (ptr to tensor) shape. Input operands are first concatenated along the first (leftmost) dimension until the result shape along that dimension is reached, - then along the next dimension, and so on. The result tensor and all input - operands must have the same rank and element type. Furthermore all the input - tensors must have the same shape. Concatenation of the input operands must - yield the exact tensor shape of the result. + then along the next dimension, and so on. Input operands must have the same type. + Concatenation of the input operands must yield the exact tensor shape of the result. Examples: ```mlir @@ -57,9 +55,12 @@ def TTIG_GlueOp : TTIG_Op<"glue", [Pure]> { %res3 = triton_intel_gpu.glue %p1, %p2, %p3 : (ptr>, ptr>, ptr>) -> ptr> + %res4 = triton_intel_gpu.glue %f1, %f2 : (f16, f16) -> tensor<2xf16> + %res5 = triton_intel_gpu.glue %t1, %t2 : (tensor<8xf16>, tensor<8xf16>) -> tensor<2x8xf16> + %res6 = triton_intel_gpu.glue %t1, %t2 : (tensor<8xf16>, tensor<8xf16>) -> tensor<16xf16> ``` }]; - let arguments = (ins Variadic:$operands); + let arguments = (ins Variadic:$operands); let results = (outs TT_TensorOrTensorPtr:$res); let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) @@ -90,7 +91,7 @@ def TTIG_ExtractOp : TTIG_Op<"extract", [Pure]> { column-major order }]; let arguments = (ins TT_TensorOrTensorPtr:$base, I32Attr:$index); - let results = (outs TT_TensorOrTensorPtr:$res); + let results = (outs TT_Type:$res); let assemblyFormat = [{ $base `[` $index `]` attr-dict `:` type($base) `->` type($res) }]; @@ -116,4 +117,20 @@ def TTIG_PrefetchOp : TTIG_Op<"prefetch"> { }]; } +// same as tt.broadcast except that we don't require SameOperandsAndResultEncoding +def TTIG_BroadcastOp : TTIG_Op<"broadcast", [Pure, SameOperandsAndResultElementType]> { + let summary = "broadcast a tensor"; + let description = [{ + For a given tensor, broadcast changes one or more dimensions with size 1 + to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot + change the size of a non-1 dimension. + }]; + + let arguments = (ins TT_Tensor:$src); + let results = (outs TT_Tensor:$result); + let assemblyFormat = [{ + $src attr-dict `:` type($src) `->` type($result) + }]; +} + #endif diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp index d4856c03c3..fbbb1ad0f2 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp @@ -59,16 +59,13 @@ static SmallVector getShape(Type type) { /// Return the element type of an input tensor (or ptr to tensor). static Type getElementType(Type type) { return TypeSwitch(type) - .Case([](auto ty) { return ty.getElementType(); }) + .Case([](auto ty) { return ty.getElementType(); }) .Case([](auto ty) { assert(isa(ty.getPointeeType()) && "Expecting ptr to tensor"); return cast(ty.getPointeeType()).getElementType(); }) - .Default([](auto) { - llvm_unreachable("Unexpected type"); - return Type(); - }); + .Default([](auto ty) { return ty; }); } /// Return the size of the specified dimension of an input tensor (or ptr to @@ -91,26 +88,18 @@ namespace mlir::triton::gpu::intel { LogicalResult GlueOp::verify() { SmallVector inputTypes; - for (auto input : getOperands()) + for (Value input : getOperands()) inputTypes.push_back(input.getType()); + Type inputType = inputTypes.front(); Type resultType = getRes().getType(); - unsigned resultRank = getRank(resultType); - if (llvm::any_of(inputTypes, - [&](Type type) { return getRank(type) != resultRank; })) - return emitOpError("operands and result must have the same rank"); + if (!llvm::all_of(inputTypes, [&](Type type) { return type == inputType; })) + return emitOpError("operands must have the same type"); - Type resultElementType = getElementType(resultType); - if (llvm::any_of(inputTypes, [&](Type type) { - return getElementType(type) != resultElementType; - })) - return emitOpError("operands and result element type must match"); - - SmallVector inputShape = getShape(inputTypes[0]); - if (llvm::any_of(inputTypes, - [&](Type type) { return getShape(type) != inputShape; })) - return emitOpError("operands must have the same shape"); + if (!isTensorOrTensorPointerType(inputType)) + return success(); + unsigned resultRank = getRank(resultType); if (llvm::any_of(inputTypes, [&](Type type) { for (unsigned i = 0; i < resultRank; ++i) { unsigned resultSize = getDimSize(resultType, i); @@ -123,7 +112,6 @@ LogicalResult GlueOp::verify() { return emitOpError( "operands cannot exceed result size along any dimension"); - auto inputType = inputTypes[0]; for (unsigned i = 0; i < resultRank; ++i) { unsigned resultSize = getDimSize(resultType, i); unsigned inputSize = getDimSize(inputType, i); @@ -133,13 +121,13 @@ LogicalResult GlueOp::verify() { // Verify that the composition of the input operands covers the output tensor // shape. + SmallVector inputShape = getShape(inputTypes[0]); SmallVector resultShape = getShape(resultType); unsigned numResultElems = product(resultShape); unsigned numInputElems = product(inputShape); if (inputTypes.size() * numInputElems != numResultElems) return emitOpError("glued operands do not exactly cover the result shape"); - return success(); } @@ -147,16 +135,22 @@ LogicalResult ExtractOp::verify() { Type resultType = getRes().getType(); Type operandType = getBase().getType(); - unsigned resultRank = getRank(resultType); - unsigned operandRank = getRank(operandType); - if (operandRank != resultRank) - return emitOpError("operand and result must have the same rank"); - Type resultElemType = getElementType(resultType); Type operandElemType = getElementType(operandType); if (operandElemType != resultElemType) return emitOpError("operand and result element type must match"); + if (!isTensorOrTensorPointerType(operandType)) + return success(); + + unsigned resultRank = getRank(resultType); + unsigned operandRank = getRank(operandType); + if (operandRank != resultRank) + return success(); + + /// FIXME: the check below works for tensors with same rank, try to simplify + /// it later. + // ensure the input can be partitioned by the requested result. SmallVector resultShape = getShape(resultType); SmallVector operandShape = getShape(operandType); @@ -179,7 +173,6 @@ LogicalResult ExtractOp::verify() { unsigned index = getIndex(); if (index >= numTiles) return emitOpError("index must be less than ") << numTiles; - return success(); } diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp index e835574298..d6503a1e3c 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MatchTargetSize.cpp @@ -146,6 +146,11 @@ class MatchTargetSizePass dotAttrs.insert(resultType.getEncoding()); }); + auto hasSliceAttr = [](Type type) { + auto tType = dyn_cast(type); + return tType && isa(tType.getEncoding()); + }; + // Split operations to match the target architecture native shapes. m.walk([&](Operation *op) { SmallVector types(op->getOperandTypes().begin(), @@ -158,13 +163,18 @@ class MatchTargetSizePass return WalkResult::advance(); if (isa(op)) return WalkResult::advance(); + if (auto expand = dyn_cast(op)) + return WalkResult::advance(); - LLVM_DEBUG({ - llvm::dbgs() << "Processing operation: " << *op << "\n"; - llvm::dbgs() << "Module before transformation:\n" << m << "\n\n"; - }); - - if (auto cstOp = dyn_cast(op)) { + // FIXME: hack it for now + if (auto convert = dyn_cast(op)) + convert.getResult().replaceAllUsesWith(convert.getSrc()); + else if (auto reduce = dyn_cast(op)) + transformReduceOp(reduce); + else if (op->getNumResults() == 1 && + hasSliceAttr(op->getResultTypes()[0])) + return WalkResult::advance(); + else if (auto cstOp = dyn_cast(op)) { recordRootSubSize(cstOp.getResult().getType()); transformArithConstantOp(cstOp); } else if (auto ptrOp = dyn_cast(op)) { @@ -172,17 +182,17 @@ class MatchTargetSizePass transformMakeTensorPtrOp(ptrOp); } else if (auto dot = dyn_cast(op)) transformDotOp(dot); + else if (auto bc = dyn_cast(op)) + transformBroadcastOp(bc); else transformGenericOp(op); - LLVM_DEBUG({ - llvm::dbgs() << "Module after transformation:\n" << m << "\n\n"; - }); - return WalkResult::advance(); }); LLVM_DEBUG(llvm::dbgs() << "Canonicalizing...\n"); + LLVM_DEBUG(llvm::dbgs() << "Module before canonicalization:\n" + << m << "\n\n"); canonicalize(); LLVM_DEBUG(llvm::dbgs() << "Module after canonicalization:\n" << m << "\n\n"); @@ -212,11 +222,15 @@ class MatchTargetSizePass SmallVector getSubOpSize(RankedTensorType type) const; std::tuple, Type, SmallVector> getSubTypeAndShape(Type type) const; + Value getSubVal(Operation *op, Value val, ArrayRef srcOffset, + ArrayRef dstSize); /// Transformations for specific operations. void transformMakeTensorPtrOp(tt::MakeTensorPtrOp op); void transformArithConstantOp(arith::ConstantOp op); void transformDotOp(tt::DotOp dot); + void transformReduceOp(tt::ReduceOp op); + void transformBroadcastOp(tt::BroadcastOp op); /// Generic transformation. void transformGenericOp(Operation *op); @@ -403,17 +417,6 @@ class ScfPattern : public OpRewritePattern { } } - // Bail out if any user of a loop result is not an 'extract' operation - // (otherwise we would have to materialize a 'glue' operation after the loop - // is replaced, which complicates things). - for (OpResult result : forOp->getResults()) { - if (llvm::any_of(result.getUsers(), [](Operation *user) { - return !isa(user); - })) { - return false; - } - } - return true; } }; @@ -491,12 +494,14 @@ MatchTargetSizePass::getSubOpSize(RankedTensorType type) const { Attribute layout = type.getEncoding(); assert(layout && "Expecting a valid layout"); - // Dot operation. + // Dot related operation. + const TargetArchNativeSizes::DotShape &dotShape = + nativeSizes.getDotShape(type.getElementTypeBitWidth()); if (dotAttrs.count(layout)) { - const TargetArchNativeSizes::DotShape &dotShape = - nativeSizes.getDotShape(type.getElementTypeBitWidth()); - SmallVector nativeDotSize{dotShape.m, dotShape.n}; - return nativeDotSize; + return {dotShape.m, dotShape.n}; + } else if (auto dotAttr = dyn_cast(layout)) { + if (dotAttr.getIsTransposed() == 1 && dotAttr.getOpIdx() == 1) + return {dotShape.k, dotShape.n}; } // Load/Store operations. @@ -516,7 +521,7 @@ MatchTargetSizePass::getSubOpSize(RankedTensorType type) const { // 32 = 2 * 16(subgroupSize) which is for large load/store // max 2d block prefetch width is 16 for 32-bit datatype subSize[1] = std::min(sizeInBits == 32 ? 16L : 32L, shape[1]); - // FIXME: From gfxspec, max 2d block load height is 32 + // max 2d block load height is 32 subSize[0] = std::min(32L, shape[0]); } else if (auto dotLayout = dyn_cast(layout)) { const TargetArchNativeSizes::BlockMemShape &memShape = @@ -544,6 +549,7 @@ MatchTargetSizePass::getSubOpSize(RankedTensorType type) const { return subSize; } +/// FIXME: add a map for look up /// return [shape, subType, subSize] for a tensor (or pointer to tensor) std::tuple, Type, SmallVector> MatchTargetSizePass::getSubTypeAndShape(Type type) const { @@ -567,6 +573,91 @@ MatchTargetSizePass::getSubTypeAndShape(Type type) const { return {{0}, type, {0}}; } +Value MatchTargetSizePass::getSubVal(Operation *op, Value val, + ArrayRef srcOffset, + ArrayRef dstSize) { + OpBuilder b(op); + Location loc = op->getLoc(); + auto elmTy = cast(val.getType()).getElementType(); + auto [shape, subType, subSize] = getSubTypeAndShape(val.getType()); + unsigned srcIdx = (srcOffset[1] / subSize[1]) * (shape[0] / subSize[0]) + + srcOffset[0] / subSize[0]; + Value subSrcVal = b.create(loc, subType, val, srcIdx); + assert(dstSize[0] <= subSize[0] && "add more support"); + unsigned dstIdx = + ((srcOffset[1] % subSize[1]) / dstSize[1]) * (subSize[0] / dstSize[0]) + + (srcOffset[0] % subSize[0]) / dstSize[0]; + auto dstType = dstSize[0] == 1 ? RankedTensorType::get(dstSize[1], elmTy) + : RankedTensorType::get(dstSize, elmTy); + Value dstVal = b.create(loc, dstType, subSrcVal, dstIdx); + return dstVal; +} + +void MatchTargetSizePass::transformReduceOp(tt::ReduceOp op) { + Location loc = op.getLoc(); + OpBuilder b(op); + assert(op.getSrcs().size() == 1 && "only support one src"); + Value src = op.getSrcs().front(); + auto srcTy = cast(src.getType()); + unsigned dims = srcTy.getShape().size(); + unsigned axis = op.getAxis(); + assert(axis == dims - 1 && "only support last axis"); + assert(dims <= 2 && "only support 1D/2D tensor"); + int64_t outer = dims == 2 ? srcTy.getShape()[0] : 1; + + SmallVector glueVals; + unsigned step = 8; // FIXME: fixed to 8 for now. + for (unsigned i = 0; i < outer; i += step) { + SmallVector subVals; + // FIXME: 16 is the supported IR reduce length + for (unsigned j = 0; j < srcTy.getShape()[axis]; j += 16) { + Value subVal = getSubVal(op, src, {i, j}, {step, 16}); + subVals.push_back(subVal); + } + auto subType = RankedTensorType::get({step, 16}, srcTy.getElementType()); + auto combine = op.getCombineOp().front().getOperations().begin(); + StringAttr id = combine->getName().getIdentifier(); + Value acc; + switch (subVals.size()) { + case 1: + acc = subVals[0]; + break; + case 2: { + Operation *acc01 = b.create(loc, id, {subVals[0], subVals[1]}, subType); + acc = acc01->getResult(0); + break; + } + case 4: { + Operation *acc01 = b.create(loc, id, {subVals[0], subVals[1]}, subType); + Operation *acc23 = b.create(loc, id, {subVals[2], subVals[3]}, subType); + Operation *accOp = b.create( + loc, id, {acc01->getResult(0), acc23->getResult(0)}, subType); + acc = accOp->getResult(0); + break; + } + default: + assert(false && "add more reduce size support"); + } + + SmallVector subOps; + for (unsigned j = 0; j < step; j++) { + auto subType = RankedTensorType::get(16, srcTy.getElementType()); + Value subAcc = b.create(loc, subType, acc, j); + auto subRed = b.create(loc, subAcc, 0); + Region &subRegion = subRed.getCombineOp(); + b.cloneRegionBefore(op.getCombineOp(), subRegion, subRegion.end()); + subOps.push_back(subRed.getResult()[0]); + } + glueVals.append(subOps); + } + + auto glue = b.create(loc, op.getResultTypes()[0], glueVals); + op->replaceAllUsesWith(glue->getResults()); + op->erase(); + + return; +} + void MatchTargetSizePass::transformMakeTensorPtrOp(tt::MakeTensorPtrOp op) { Type resultType = op.getResult().getType(); auto [shape, subType, subSize] = getSubTypeAndShape(resultType); @@ -693,6 +784,39 @@ void MatchTargetSizePass::transformDotOp(tt::DotOp dot) { dot->erase(); } +void MatchTargetSizePass::transformBroadcastOp(tt::BroadcastOp op) { + OpBuilder b(op); + Location loc = op->getLoc(); + RankedTensorType resType = op.getResult().getType(); + auto [shape, subType, subSize] = getSubTypeAndShape(resType); + auto tType = cast(subType); + RankedTensorType srcType = op.getSrc().getType(); + unsigned srcDim0 = srcType.getShape()[0]; + unsigned dstDim0 = tType.getShape()[0]; + Operation *glue; + if (srcDim0 == dstDim0) { + Value newOp = b.create(loc, tType, op.getSrc()); + unsigned num = resType.getShape()[1] / tType.getShape()[1]; + SmallVector ops(num, newOp); + glue = b.create(loc, resType, ops); + } else { + assert(srcDim0 == 2 * dstDim0 && "add more support"); + auto newTy = RankedTensorType::get({srcDim0, tType.getShape()[1]}, + tType.getElementType()); + auto newOp = b.create(loc, newTy, op.getSrc()); + auto extract0 = b.create(loc, tType, newOp, 0); + auto extract1 = b.create(loc, tType, newOp, 1); + SmallVector ops{extract0, extract1, extract0, extract1, + extract0, extract1, extract0, extract1}; + glue = b.create(loc, resType, ops); + } + + op->replaceAllUsesWith(glue->getResults()); + op->erase(); + + return; +} + void MatchTargetSizePass::transformGenericOp(Operation *op) { unsigned numResults = op->getResults().size(); unsigned dotIdx = 2; @@ -743,7 +867,8 @@ void MatchTargetSizePass::transformGenericOp(Operation *op) { } return operand; }); - Operation *subOp; + + Operation *subOp = nullptr; if (numResults == 0) subOp = b.create(loc, op->getName().getIdentifier(), newOperands, {}, op->getAttrs()); @@ -762,9 +887,9 @@ void MatchTargetSizePass::transformGenericOp(Operation *op) { } } - if (numResults == 1) { + if (numResults == 1) op->replaceAllUsesWith(b.create(loc, type, subOps)); - } + op->erase(); }