Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for attention in match-target-size #1540

Merged
merged 18 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
etiotto marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,9 @@ elements along the K dim, or they use all elements of the tensor along the K dim
ins
"unsigned":$opIdx,
"Attribute":$parent,
DefaultValuedParameter<"unsigned", "0">:$kWidth
DefaultValuedParameter<"unsigned", "0">:$kWidth,
// intel specific
DefaultValuedParameter<"bool", "false">:$transpose
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
);

let builders = [
Expand All @@ -1286,15 +1288,27 @@ elements along the K dim, or they use all elements of the tensor along the K dim
"Type":$eltTy), [{
NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
if (!parentAttr || !parentAttr.isAmpere())
return $_get(context, opIdx, parent, 0);
return $_get(context, opIdx, parent, 0, false);
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
return $_get(context, opIdx, parent, MMAv2kWidth, false);
}]>,
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"unsigned":$kWidth), [{
return $_get(context, opIdx, parent, kWidth, false);
}]>,
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"unsigned":$kWidth,
"bool":$transpose), [{
return $_get(context, opIdx, parent, kWidth, transpose);
Copy link
Contributor Author

@Dewei-Wang-sh Dewei-Wang-sh Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add above so that existing upstream code in other .h .cpp do not need to change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created a new issue to track the upcoming refactor.
#1621

}]>
];

let assemblyFormat = "`<` `{` struct(params) `}` `>`";
let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<unsigned> getContigPerThread() {
return getSizePerThread();
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(

LogicalResult DotOperandEncodingAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
unsigned opIdx, Attribute parent, unsigned kWidth) {
unsigned opIdx, Attribute parent, unsigned kWidth, bool transpose) {
etiotto marked this conversation as resolved.
Show resolved Hide resolved
if (opIdx != 0 && opIdx != 1) {
return emitError()
<< "triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: "
Expand Down
134 changes: 106 additions & 28 deletions test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -116,34 +116,6 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16

// -----

// COM: Test SCF canonicalization: ensure loop is not modified when the result is not used by just 'extract' operations.
etiotto marked this conversation as resolved.
Show resolved Hide resolved
tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr<f16, 1>,
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) {
// CHECK-LABEL: @simplify_scf_for
// CHECK: [[GLUE:%.*]] = triton_intel_gpu.glue
// CHECK-NEXT: [[RES:%.*]] = scf.for {{.*}} iter_args([[INIT1:%.*]] = [[GLUE]]) -> (tensor<16x16xf16>) : i32 {
// CHECK: scf.yield {{.*}} : tensor<16x16xf16>
// CHECK-NEXT: }
// CHECK-NEXT: [[PTR:%.*]] = tt.make_tensor_ptr %arg2
// CHECK-NEXT: tt.store [[PTR]], [[RES]]
%lb = arith.constant 0 : i32
%ub = arith.constant 32 : i32
%st = arith.constant 1 : i32
%c1_i64 = arith.constant 1 : i64
%glue = triton_intel_gpu.glue %arg0, %arg1 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
%res = scf.for %iv = %lb to %ub step %st iter_args(%arg = %glue) -> (tensor<16x16xf16>) : i32 {
%e1 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16>
%e2 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16>
%g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
scf.yield %g1 : tensor<16x16xf16>
}
%ptr = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg5, %c1_i64], [%arg6, %arg7] {order = array<i32: 1, 0>} : <tensor<16x16xf16>>
tt.store %ptr, %res {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
tt.return
}

// -----

// COM: Test SCF canonicalization: ensure loop is not modified if any user of a 'glue' init value is not an 'extract' operation.
tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr<f16, 1>,
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) {
Expand Down Expand Up @@ -279,3 +251,109 @@ tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.d
tt.store %tptr_c, %35#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xf32, #warp>>
tt.return
}

// -----

// COM: Test Attention Related Ops
// CHECK-LABEL: @attn_fwd
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
tt.func public @attn_fwd(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: f32, %arg4: !tt.ptr<f32>, %arg5: !tt.ptr<f32>) {
%c16_i32 = arith.constant 16 : i32
%c128_i32 = arith.constant 128 : i32
%c1024_i64 = arith.constant 1024 : i64
%c64_i64 = arith.constant 64 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.44269502 : f32
%c3145728_i64 = arith.constant 3145728 : i64
%c65536_i64 = arith.constant 65536 : i64
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x64xf32, #warp>
%c64_i32 = arith.constant 64 : i32
%c1024_i32 = arith.constant 1024 : i32
%cst_1 = arith.constant dense<0xFF800000> : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%cst_2 = arith.constant dense<1.000000e+00> : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%0 = gpu.subgroup_id : index
%1 = arith.index_cast %0 : index to i32
%2 = tt.get_program_id z : i32
%3 = tt.get_program_id x : i32
%4 = tt.get_program_id y : i32
%5 = arith.extsi %3 : i32 to i64
%6 = arith.muli %5, %c3145728_i64 : i64
%7 = arith.extsi %4 : i32 to i64
%8 = arith.muli %7, %c65536_i64 : i64
%9 = arith.addi %6, %8 : i64
%10 = tt.addptr %arg0, %9 : !tt.ptr<f16>, i64
%11 = arith.muli %2, %c128_i32 : i32
%12 = arith.muli %1, %c16_i32 : i32
%13 = arith.addi %12, %11 : i32
%14 = tt.make_tensor_ptr %10, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%15 = tt.addptr %arg2, %9 : !tt.ptr<f16>, i64
%16 = tt.make_tensor_ptr %15, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%17 = tt.addptr %arg1, %9 : !tt.ptr<f16>, i64
%18 = tt.make_tensor_ptr %17, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
%19 = tt.addptr %arg5, %9 : !tt.ptr<f32>, i64
%20 = tt.make_tensor_ptr %19, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf32, #warp>>
%21 = arith.mulf %arg3, %cst : f32
%22 = tt.load %14 : !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
// CHECK: tt.splat {{.*}} : f32 -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
// CHECK-COUNT-8: tt.splat {{.*}} : f32 -> tensor<8x16xf32>
%23 = tt.splat %21 : f32 -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%24 = tt.splat %21 : f32 -> tensor<16x64xf32, #warp>
// CHECK: scf.for
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
%25:5 = scf.for %arg6 = %c0_i32 to %c1024_i32 step %c64_i32 iter_args(%arg7 = %cst_2, %arg8 = %cst_0, %arg9 = %cst_1, %arg10 = %16, %arg11 = %18) -> (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>) : i32 {
// CHECK-COUNT-16: tt.load {{.*}} {DotIdx = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
// CHECK-COUNT-32: tt.dot {{.*}} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32>
%29 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
%30 = tt.dot %22, %29, %cst_0, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>> -> tensor<16x64xf32, #warp>

// CHECK-COUNT-2: arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK: [[MAX:%.*]] = arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK-NEXT: [[EXTRACT0:%.*]] = triton_intel_gpu.extract [[MAX]][0] : tensor<8x16xf32> -> tensor<16xf32>
// CHECK-NEXT: "tt.reduce"([[EXTRACT0]]) <{axis = 0 : i32}> ({
// CHECK: }) : (tensor<16xf32>) -> f32
%31 = "tt.reduce"(%30) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%53 = arith.maxnumf %arg12, %arg13 : f32
tt.reduce.return %53 : f32
}) : (tensor<16x64xf32, #warp>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%32 = arith.mulf %31, %23 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%33 = arith.maxnumf %arg9, %32 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%34 = arith.mulf %30, %24 : tensor<16x64xf32, #warp>

// CHECK: tt.expand_dims {{.*}} {axis = 1 : i32} : tensor<16xf32
// CHECK: triton_intel_gpu.broadcast {{.*}} -> tensor<16x16xf32>
%35 = tt.expand_dims %33 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%36 = tt.broadcast %35 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%37 = arith.subf %34, %36 : tensor<16x64xf32, #warp>
%38 = math.exp2 %37 : tensor<16x64xf32, #warp>
%39 = "tt.reduce"(%38) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%53 = arith.addf %arg12, %arg13 : f32
tt.reduce.return %53 : f32
}) : (tensor<16x64xf32, #warp>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%40 = arith.subf %arg9, %33 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%41 = math.exp2 %40 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%42 = arith.mulf %arg7, %41 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%43 = arith.addf %42, %39 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%44 = tt.expand_dims %41 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%45 = tt.broadcast %44 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%46 = arith.mulf %arg8, %45 : tensor<16x64xf32, #warp>
%47 = tt.load %arg10 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%48 = arith.truncf %38 : tensor<16x64xf32, #warp> to tensor<16x64xf16, #warp>
%49 = triton_gpu.convert_layout %48 : tensor<16x64xf16, #warp> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>

// CHECK-COUNT-32: tt.dot {{.*}} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32>
// CHECK-COUNT-4: tt.advance {{.*}} : <tensor<32x32xf16>>
// CHECK-COUNT-16: tt.advance {{.*}} : <tensor<16x16xf16>>
// CHECK: scf.yield
%50 = tt.dot %49, %47, %46, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<16x64xf32, #warp>
%51 = tt.advance %arg10, [%c64_i32, %c0_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%52 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
scf.yield %43, %50, %33, %51, %52 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
} {triton_gpu.workload = 4 : i32, tt.divisibility_arg1 = dense<64> : tensor<1xi32>}
%26 = tt.expand_dims %25#0 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%27 = tt.broadcast %26 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%28 = arith.divf %25#1, %27 : tensor<16x64xf32, #warp>
tt.store %20, %28 : !tt.ptr<tensor<16x64xf32, #warp>>
tt.return
}
58 changes: 6 additions & 52 deletions test/TritonIntelGPU/tritonintelgpu-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,53 +1,23 @@
// RUN: triton-opt -split-input-file -verify-diagnostics %s

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<8xf16>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result must have the same rank}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<8xf16>) -> tensor<24x8xf16>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr<tensor<16x8xf16>>, %ptr2 : !tt.ptr<tensor<16xf16>>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result must have the same rank}}
triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr<tensor<16x8xf16>>, !tt.ptr<tensor<16xf16>>) -> !tt.ptr<tensor<16x24xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16x8xf32>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result element type must match}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<16x8xf32>) -> tensor<16x16xf16>
// COM: Ensure that tensors with different shape cannot be glued.
tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16xf16>, %tensor2 : tensor<16x8xf32>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same type}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16xf16>, tensor<16x8xf32>) -> tensor<16x16xf16>
tt.return
}

// -----

// COM: Ensure that tensors with the different element types cannot be glued.
tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr<tensor<16x8xf16>>, %ptr2 : !tt.ptr<tensor<16x8xf32>>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result element type must match}}
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same type}}
etiotto marked this conversation as resolved.
Show resolved Hide resolved
triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr<tensor<16x8xf16>>, !tt.ptr<tensor<16x8xf32>>) -> !tt.ptr<tensor<16x16xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<8x8xf16>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same shape}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<8x8xf16>) -> tensor<24x8xf16>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr<tensor<16x8xf16>>, %ptr2 : !tt.ptr<tensor<16x16xf16>>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same shape}}
triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr<tensor<16x8xf16>>, !tt.ptr<tensor<16x16xf16>>) -> !tt.ptr<tensor<16x24xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16x8xf16>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands cannot exceed result size along any dimension}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<8x16xf16>
Expand Down Expand Up @@ -88,22 +58,6 @@ tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16

// -----

tt.func @triton_intel_gpu.extract(%tensor : tensor<16x16xf16>) {
// expected-error @+1 {{'triton_intel_gpu.extract' op operand and result must have the same rank}}
triton_intel_gpu.extract %tensor[0] : tensor<16x16xf16> -> tensor<4xf16>
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
tt.return
}

// -----

tt.func @triton_intel_gpu.extract(%ptr : !tt.ptr<tensor<16x16xf16>>) {
// expected-error @+1 {{'triton_intel_gpu.extract' op operand and result must have the same rank}}
triton_intel_gpu.extract %ptr[0] : !tt.ptr<tensor<16x16xf16>> -> !tt.ptr<tensor<4xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.extract(%tensor : tensor<16x16xf16>) {
// expected-error @+1 {{'triton_intel_gpu.extract' op operand and result element type must match}}
triton_intel_gpu.extract %tensor[0] : tensor<16x16xf16> -> tensor<4x4xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ def TTIG_GlueOp : TTIG_Op<"glue", [Pure]> {
The `glue` operation glues its input operands to form the result tensor (ptr
to tensor) shape. Input operands are first concatenated along the first
(leftmost) dimension until the result shape along that dimension is reached,
then along the next dimension, and so on. The result tensor and all input
operands must have the same rank and element type. Furthermore all the input
tensors must have the same shape. Concatenation of the input operands must
yield the exact tensor shape of the result.
then along the next dimension, and so on. Input operands must have the same type.
Concatenation of the input operands must yield the exact tensor shape of the result.

Examples:
```mlir
Expand All @@ -57,9 +55,12 @@ def TTIG_GlueOp : TTIG_Op<"glue", [Pure]> {
%res3 = triton_intel_gpu.glue %p1, %p2, %p3
: (ptr<tensor<16x8xf16>>, ptr<tensor<16x8xf16>>, ptr<tensor<16x8xf16>>)
-> ptr<tensor<16x24xf16>>
%res4 = triton_intel_gpu.glue %f1, %f2 : (f16, f16) -> tensor<2xf16>
%res5 = triton_intel_gpu.glue %t1, %t2 : (tensor<8xf16>, tensor<8xf16>) -> tensor<2x8xf16>
%res6 = triton_intel_gpu.glue %t1, %t2 : (tensor<8xf16>, tensor<8xf16>) -> tensor<16xf16>
```
}];
let arguments = (ins Variadic<TT_TensorOrTensorPtr>:$operands);
etiotto marked this conversation as resolved.
Show resolved Hide resolved
let arguments = (ins Variadic<TT_Type>:$operands);
let results = (outs TT_TensorOrTensorPtr:$res);
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
Expand Down Expand Up @@ -90,7 +91,7 @@ def TTIG_ExtractOp : TTIG_Op<"extract", [Pure]> {
column-major order
}];
let arguments = (ins TT_TensorOrTensorPtr:$base, I32Attr:$index);
let results = (outs TT_TensorOrTensorPtr:$res);
let results = (outs TT_Type:$res);
let assemblyFormat = [{
$base `[` $index `]` attr-dict `:` type($base) `->` type($res)
}];
Expand All @@ -116,4 +117,20 @@ def TTIG_PrefetchOp : TTIG_Op<"prefetch"> {
}];
}

// same as tt.broadcast except that we don't require SameOperandsAndResultEncoding
def TTIG_BroadcastOp : TTIG_Op<"broadcast", [Pure, SameOperandsAndResultElementType]> {
let summary = "broadcast a tensor";
let description = [{
For a given tensor, broadcast changes one or more dimensions with size 1
to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot
change the size of a non-1 dimension.
}];

let arguments = (ins TT_Tensor:$src);
let results = (outs TT_Tensor:$result);
let assemblyFormat = [{
$src attr-dict `:` type($src) `->` type($result)
}];
}

#endif
Loading