Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for attention in match-target-size #1540

Merged
merged 18 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
etiotto marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -1270,13 +1270,18 @@ The parent field is the layout of d.
kWidth defines number of consecutive elements stored by one thread along k dimension.
Some layouts do not use this parameter, either because they have a fixed number of
elements along the K dim, or they use all elements of the tensor along the K dim.

`isTransposed` indicates the result tensor is transposed so that it can be loaded with
a transpose flag. This is used in the case of chained dot (E.g. Flash-Attention kernel).
}];

let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent,
DefaultValuedParameter<"unsigned", "0">:$kWidth
DefaultValuedParameter<"unsigned", "0">:$kWidth,
// intel specific
DefaultValuedParameter<"bool", "false">:$isTransposed
etiotto marked this conversation as resolved.
Show resolved Hide resolved
);

let builders = [
Expand All @@ -1286,15 +1291,27 @@ elements along the K dim, or they use all elements of the tensor along the K dim
"Type":$eltTy), [{
NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
if (!parentAttr || !parentAttr.isAmpere())
return $_get(context, opIdx, parent, 0);
return $_get(context, opIdx, parent, 0, false);
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
return $_get(context, opIdx, parent, MMAv2kWidth, false);
}]>,
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"unsigned":$kWidth), [{
return $_get(context, opIdx, parent, kWidth, false);
}]>,
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"unsigned":$kWidth,
"bool":$isTransposed), [{
return $_get(context, opIdx, parent, kWidth, isTransposed);
}]>
];

let assemblyFormat = "`<` `{` struct(params) `}` `>`";
let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<unsigned> getContigPerThread() {
return getSizePerThread();
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(

LogicalResult DotOperandEncodingAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
unsigned opIdx, Attribute parent, unsigned kWidth) {
unsigned opIdx, Attribute parent, unsigned kWidth, bool isTransposed) {
if (opIdx != 0 && opIdx != 1) {
return emitError()
<< "triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: "
Expand Down
136 changes: 108 additions & 28 deletions test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -116,34 +116,6 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16

// -----

// COM: Test SCF canonicalization: ensure loop is not modified when the result is not used by just 'extract' operations.
etiotto marked this conversation as resolved.
Show resolved Hide resolved
tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr<f16, 1>,
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) {
// CHECK-LABEL: @simplify_scf_for
// CHECK: [[GLUE:%.*]] = triton_intel_gpu.glue
// CHECK-NEXT: [[RES:%.*]] = scf.for {{.*}} iter_args([[INIT1:%.*]] = [[GLUE]]) -> (tensor<16x16xf16>) : i32 {
// CHECK: scf.yield {{.*}} : tensor<16x16xf16>
// CHECK-NEXT: }
// CHECK-NEXT: [[PTR:%.*]] = tt.make_tensor_ptr %arg2
// CHECK-NEXT: tt.store [[PTR]], [[RES]]
%lb = arith.constant 0 : i32
%ub = arith.constant 32 : i32
%st = arith.constant 1 : i32
%c1_i64 = arith.constant 1 : i64
%glue = triton_intel_gpu.glue %arg0, %arg1 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
%res = scf.for %iv = %lb to %ub step %st iter_args(%arg = %glue) -> (tensor<16x16xf16>) : i32 {
%e1 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16>
%e2 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16>
%g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
scf.yield %g1 : tensor<16x16xf16>
}
%ptr = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg5, %c1_i64], [%arg6, %arg7] {order = array<i32: 1, 0>} : <tensor<16x16xf16>>
tt.store %ptr, %res {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
tt.return
}

// -----

// COM: Test SCF canonicalization: ensure loop is not modified if any user of a 'glue' init value is not an 'extract' operation.
tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr<f16, 1>,
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) {
Expand Down Expand Up @@ -279,3 +251,111 @@ tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.d
tt.store %tptr_c, %35#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xf32, #warp>>
tt.return
}

// -----

// COM: Test Attention Related Ops
#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 1 : i32} {
// CHECK-LABEL: @attn_fwd
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
tt.func public @attn_fwd(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: f32, %arg4: !tt.ptr<f32>, %arg5: !tt.ptr<f32>) {
%c16_i32 = arith.constant 16 : i32
%c128_i32 = arith.constant 128 : i32
%c1024_i64 = arith.constant 1024 : i64
%c64_i64 = arith.constant 64 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.44269502 : f32
%c3145728_i64 = arith.constant 3145728 : i64
%c65536_i64 = arith.constant 65536 : i64
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x64xf32, #warp>
%c64_i32 = arith.constant 64 : i32
%c1024_i32 = arith.constant 1024 : i32
%cst_1 = arith.constant dense<0xFF800000> : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%cst_2 = arith.constant dense<1.000000e+00> : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%0 = gpu.subgroup_id : index
%1 = arith.index_cast %0 : index to i32
%2 = tt.get_program_id z : i32
%3 = tt.get_program_id x : i32
%4 = tt.get_program_id y : i32
%5 = arith.extsi %3 : i32 to i64
%6 = arith.muli %5, %c3145728_i64 : i64
%7 = arith.extsi %4 : i32 to i64
%8 = arith.muli %7, %c65536_i64 : i64
%9 = arith.addi %6, %8 : i64
%10 = tt.addptr %arg0, %9 : !tt.ptr<f16>, i64
%11 = arith.muli %2, %c128_i32 : i32
%12 = arith.muli %1, %c16_i32 : i32
%13 = arith.addi %12, %11 : i32
%14 = tt.make_tensor_ptr %10, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%15 = tt.addptr %arg2, %9 : !tt.ptr<f16>, i64
%16 = tt.make_tensor_ptr %15, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%17 = tt.addptr %arg1, %9 : !tt.ptr<f16>, i64
%18 = tt.make_tensor_ptr %17, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>>>
%19 = tt.addptr %arg5, %9 : !tt.ptr<f32>, i64
%20 = tt.make_tensor_ptr %19, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf32, #warp>>
%21 = arith.mulf %arg3, %cst : f32
%22 = tt.load %14 : !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
// CHECK: tt.splat {{.*}} : f32 -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
// CHECK-COUNT-8: tt.splat {{.*}} : f32 -> tensor<8x16xf32>
%23 = tt.splat %21 : f32 -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%24 = tt.splat %21 : f32 -> tensor<16x64xf32, #warp>
// CHECK: 30 = scf.for
%25:5 = scf.for %arg6 = %c0_i32 to %c1024_i32 step %c64_i32 iter_args(%arg7 = %cst_2, %arg8 = %cst_0, %arg9 = %cst_1, %arg10 = %16, %arg11 = %18) -> (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>>>) : i32 {
// CHECK-COUNT-16: tt.load {{.*}} {DotIdx = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
// CHECK-COUNT-32: tt.dot {{.*}} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32>
%29 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>>>
%30 = tt.dot %22, %29, %cst_0, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>> -> tensor<16x64xf32, #warp>

// CHECK-COUNT-2: arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK: [[MAX:%.*]] = arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK-NEXT: [[EXTRACT0:%.*]] = triton_intel_gpu.extract [[MAX]][0] : tensor<8x16xf32> -> tensor<16xf32>
// CHECK-NEXT: "tt.reduce"([[EXTRACT0]]) <{axis = 0 : i32}> ({
// CHECK: }) : (tensor<16xf32>) -> f32
%31 = "tt.reduce"(%30) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%53 = arith.maxnumf %arg12, %arg13 : f32
tt.reduce.return %53 : f32
}) : (tensor<16x64xf32, #warp>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%32 = arith.mulf %31, %23 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%33 = arith.maxnumf %arg9, %32 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%34 = arith.mulf %30, %24 : tensor<16x64xf32, #warp>

// CHECK: tt.expand_dims {{.*}} {axis = 1 : i32} : tensor<16xf32
// CHECK: triton_intel_gpu.broadcast {{.*}} -> tensor<16x16xf32>
%35 = tt.expand_dims %33 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%36 = tt.broadcast %35 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%37 = arith.subf %34, %36 : tensor<16x64xf32, #warp>
%38 = math.exp2 %37 : tensor<16x64xf32, #warp>
%39 = "tt.reduce"(%38) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%53 = arith.addf %arg12, %arg13 : f32
tt.reduce.return %53 : f32
}) : (tensor<16x64xf32, #warp>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%40 = arith.subf %arg9, %33 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%41 = math.exp2 %40 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%42 = arith.mulf %arg7, %41 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%43 = arith.addf %42, %39 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%44 = tt.expand_dims %41 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%45 = tt.broadcast %44 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%46 = arith.mulf %arg8, %45 : tensor<16x64xf32, #warp>
%47 = tt.load %arg10 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%48 = arith.truncf %38 : tensor<16x64xf32, #warp> to tensor<16x64xf16, #warp>
%49 = triton_gpu.convert_layout %48 : tensor<16x64xf16, #warp> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>

// CHECK-COUNT-32: tt.dot {{.*}} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32>
// CHECK-COUNT-4: tt.advance {{.*}} : <tensor<32x32xf16>>
// CHECK-COUNT-16: tt.advance {{.*}} : <tensor<16x16xf16>>
// CHECK: scf.yield
%50 = tt.dot %49, %47, %46, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<16x64xf32, #warp>
%51 = tt.advance %arg10, [%c64_i32, %c0_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%52 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>>>
scf.yield %43, %50, %33, %51, %52 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, isTransposed = 1}>>>
} {triton_gpu.workload = 4 : i32, tt.divisibility_arg1 = dense<64> : tensor<1xi32>}
%26 = tt.expand_dims %25#0 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%27 = tt.broadcast %26 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%28 = arith.divf %25#1, %27 : tensor<16x64xf32, #warp>
tt.store %20, %28 : !tt.ptr<tensor<16x64xf32, #warp>>
tt.return
}
}
58 changes: 6 additions & 52 deletions test/TritonIntelGPU/tritonintelgpu-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,53 +1,23 @@
// RUN: triton-opt -split-input-file -verify-diagnostics %s

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<8xf16>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result must have the same rank}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<8xf16>) -> tensor<24x8xf16>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr<tensor<16x8xf16>>, %ptr2 : !tt.ptr<tensor<16xf16>>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result must have the same rank}}
triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr<tensor<16x8xf16>>, !tt.ptr<tensor<16xf16>>) -> !tt.ptr<tensor<16x24xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16x8xf32>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result element type must match}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<16x8xf32>) -> tensor<16x16xf16>
// COM: Ensure that tensors with different shape cannot be glued.
tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16xf16>, %tensor2 : tensor<16x8xf32>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same type}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16xf16>, tensor<16x8xf32>) -> tensor<16x16xf16>
tt.return
}

// -----

// COM: Ensure that tensors with the different element types cannot be glued.
tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr<tensor<16x8xf16>>, %ptr2 : !tt.ptr<tensor<16x8xf32>>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result element type must match}}
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same type}}
etiotto marked this conversation as resolved.
Show resolved Hide resolved
triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr<tensor<16x8xf16>>, !tt.ptr<tensor<16x8xf32>>) -> !tt.ptr<tensor<16x16xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<8x8xf16>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same shape}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<8x8xf16>) -> tensor<24x8xf16>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr<tensor<16x8xf16>>, %ptr2 : !tt.ptr<tensor<16x16xf16>>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same shape}}
triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr<tensor<16x8xf16>>, !tt.ptr<tensor<16x16xf16>>) -> !tt.ptr<tensor<16x24xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16x8xf16>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands cannot exceed result size along any dimension}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<8x16xf16>
Expand Down Expand Up @@ -88,22 +58,6 @@ tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16

// -----

tt.func @triton_intel_gpu.extract(%tensor : tensor<16x16xf16>) {
// expected-error @+1 {{'triton_intel_gpu.extract' op operand and result must have the same rank}}
triton_intel_gpu.extract %tensor[0] : tensor<16x16xf16> -> tensor<4xf16>
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
tt.return
}

// -----

tt.func @triton_intel_gpu.extract(%ptr : !tt.ptr<tensor<16x16xf16>>) {
// expected-error @+1 {{'triton_intel_gpu.extract' op operand and result must have the same rank}}
triton_intel_gpu.extract %ptr[0] : !tt.ptr<tensor<16x16xf16>> -> !tt.ptr<tensor<4xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.extract(%tensor : tensor<16x16xf16>) {
// expected-error @+1 {{'triton_intel_gpu.extract' op operand and result element type must match}}
triton_intel_gpu.extract %tensor[0] : tensor<16x16xf16> -> tensor<4x4xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ def TTIG_GlueOp : TTIG_Op<"glue", [Pure]> {
The `glue` operation glues its input operands to form the result tensor (ptr
to tensor) shape. Input operands are first concatenated along the first
(leftmost) dimension until the result shape along that dimension is reached,
then along the next dimension, and so on. The result tensor and all input
operands must have the same rank and element type. Furthermore all the input
tensors must have the same shape. Concatenation of the input operands must
yield the exact tensor shape of the result.
then along the next dimension, and so on. Input operands must have the same type.
Concatenation of the input operands must yield the exact tensor shape of the result.

Examples:
```mlir
Expand All @@ -57,9 +55,12 @@ def TTIG_GlueOp : TTIG_Op<"glue", [Pure]> {
%res3 = triton_intel_gpu.glue %p1, %p2, %p3
: (ptr<tensor<16x8xf16>>, ptr<tensor<16x8xf16>>, ptr<tensor<16x8xf16>>)
-> ptr<tensor<16x24xf16>>
%res4 = triton_intel_gpu.glue %f1, %f2 : (f16, f16) -> tensor<2xf16>
%res5 = triton_intel_gpu.glue %t1, %t2 : (tensor<8xf16>, tensor<8xf16>) -> tensor<2x8xf16>
%res6 = triton_intel_gpu.glue %t1, %t2 : (tensor<8xf16>, tensor<8xf16>) -> tensor<16xf16>
```
}];
let arguments = (ins Variadic<TT_TensorOrTensorPtr>:$operands);
etiotto marked this conversation as resolved.
Show resolved Hide resolved
let arguments = (ins Variadic<TT_Type>:$operands);
let results = (outs TT_TensorOrTensorPtr:$res);
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
Expand Down Expand Up @@ -90,7 +91,7 @@ def TTIG_ExtractOp : TTIG_Op<"extract", [Pure]> {
column-major order
}];
let arguments = (ins TT_TensorOrTensorPtr:$base, I32Attr:$index);
let results = (outs TT_TensorOrTensorPtr:$res);
let results = (outs TT_Type:$res);
let assemblyFormat = [{
$base `[` $index `]` attr-dict `:` type($base) `->` type($res)
}];
Expand All @@ -116,4 +117,20 @@ def TTIG_PrefetchOp : TTIG_Op<"prefetch"> {
}];
}

// same as tt.broadcast except that we don't require SameOperandsAndResultEncoding
def TTIG_BroadcastOp : TTIG_Op<"broadcast", [Pure, SameOperandsAndResultElementType]> {
let summary = "broadcast a tensor";
let description = [{
For a given tensor, broadcast changes one or more dimensions with size 1
to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot
change the size of a non-1 dimension.
}];

let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = [{
$src attr-dict `:` type($src) `->` type($result)
}];
}

#endif
Loading