Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for attention in match-target-size #1540

Merged
merged 18 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
etiotto marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,9 @@ elements along the K dim, or they use all elements of the tensor along the K dim
ins
"unsigned":$opIdx,
"Attribute":$parent,
DefaultValuedParameter<"unsigned", "0">:$kWidth
DefaultValuedParameter<"unsigned", "0">:$kWidth,
// intel specific
DefaultValuedParameter<"bool", "false">:$transpose
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
);

let builders = [
Expand All @@ -1286,15 +1288,27 @@ elements along the K dim, or they use all elements of the tensor along the K dim
"Type":$eltTy), [{
NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
if (!parentAttr || !parentAttr.isAmpere())
return $_get(context, opIdx, parent, 0);
return $_get(context, opIdx, parent, 0, false);
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
return $_get(context, opIdx, parent, MMAv2kWidth, false);
}]>,
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"unsigned":$kWidth), [{
return $_get(context, opIdx, parent, kWidth, false);
}]>,
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"unsigned":$kWidth,
"bool":$transpose), [{
return $_get(context, opIdx, parent, kWidth, transpose);
Copy link
Contributor Author

@Dewei-Wang-sh Dewei-Wang-sh Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add above so that existing upstream code in other .h .cpp do not need to change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created a new issue to track the upcoming refactor.
#1621

}]>
];

let assemblyFormat = "`<` `{` struct(params) `}` `>`";
let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<unsigned> getContigPerThread() {
return getSizePerThread();
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ SmallVector<unsigned> DotOperandEncodingAttr::getShapePerCTATile(

LogicalResult DotOperandEncodingAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
unsigned opIdx, Attribute parent, unsigned kWidth) {
unsigned opIdx, Attribute parent, unsigned kWidth, bool transpose) {
etiotto marked this conversation as resolved.
Show resolved Hide resolved
if (opIdx != 0 && opIdx != 1) {
return emitError()
<< "triton_gpu.dot_op opIdx paramenter can be 0 or 1, got: "
Expand Down
130 changes: 102 additions & 28 deletions test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -116,34 +116,6 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16

// -----

// COM: Test SCF canonicalization: ensure loop is not modified when the result is not used by just 'extract' operations.
etiotto marked this conversation as resolved.
Show resolved Hide resolved
tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr<f16, 1>,
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) {
// CHECK-LABEL: @simplify_scf_for
// CHECK: [[GLUE:%.*]] = triton_intel_gpu.glue
// CHECK-NEXT: [[RES:%.*]] = scf.for {{.*}} iter_args([[INIT1:%.*]] = [[GLUE]]) -> (tensor<16x16xf16>) : i32 {
// CHECK: scf.yield {{.*}} : tensor<16x16xf16>
// CHECK-NEXT: }
// CHECK-NEXT: [[PTR:%.*]] = tt.make_tensor_ptr %arg2
// CHECK-NEXT: tt.store [[PTR]], [[RES]]
%lb = arith.constant 0 : i32
%ub = arith.constant 32 : i32
%st = arith.constant 1 : i32
%c1_i64 = arith.constant 1 : i64
%glue = triton_intel_gpu.glue %arg0, %arg1 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
%res = scf.for %iv = %lb to %ub step %st iter_args(%arg = %glue) -> (tensor<16x16xf16>) : i32 {
%e1 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16>
%e2 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16>
%g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
scf.yield %g1 : tensor<16x16xf16>
}
%ptr = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg5, %c1_i64], [%arg6, %arg7] {order = array<i32: 1, 0>} : <tensor<16x16xf16>>
tt.store %ptr, %res {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
tt.return
}

// -----

// COM: Test SCF canonicalization: ensure loop is not modified if any user of a 'glue' init value is not an 'extract' operation.
tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr<f16, 1>,
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) {
Expand Down Expand Up @@ -279,3 +251,105 @@ tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.d
tt.store %tptr_c, %35#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xf32, #warp>>
tt.return
}

// -----

// COM: Test Attention Related Ops

#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 1 : i32} {
// CHECK-LABEL: @_attn_fwd
tt.func public @_attn_fwd(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c16_i32 = arith.constant 16 : i32
%c128_i32 = arith.constant 128 : i32
%c1024_i64 = arith.constant 1024 : i64
%c64_i64 = arith.constant 64 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.44269502 : f32
%c3145728_i64 = arith.constant 3145728 : i64
%c65536_i64 = arith.constant 65536 : i64
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x64xf32, #warp>
%c64_i32 = arith.constant 64 : i32
%c1024_i32 = arith.constant 1024 : i32
%cst_1 = arith.constant dense<0xFF800000> : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%cst_2 = arith.constant dense<1.000000e+00> : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%0 = gpu.subgroup_id : index
%1 = arith.index_cast %0 : index to i32
%2 = tt.get_program_id z : i32
%3 = tt.get_program_id x : i32
%4 = tt.get_program_id y : i32
%5 = arith.extsi %3 : i32 to i64
%6 = arith.muli %5, %c3145728_i64 : i64
%7 = arith.extsi %4 : i32 to i64
%8 = arith.muli %7, %c65536_i64 : i64
%9 = arith.addi %6, %8 : i64
%10 = tt.addptr %arg0, %9 : !tt.ptr<f16>, i64
%11 = arith.muli %2, %c128_i32 : i32
%12 = arith.muli %1, %c16_i32 : i32
%13 = arith.addi %12, %11 : i32
%14 = tt.make_tensor_ptr %10, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
%15 = tt.addptr %arg2, %9 : !tt.ptr<f16>, i64
%16 = tt.make_tensor_ptr %15, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%17 = tt.addptr %arg1, %9 : !tt.ptr<f16>, i64
%18 = tt.make_tensor_ptr %17, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
%19 = tt.addptr %arg5, %9 : !tt.ptr<f32>, i64
%20 = tt.make_tensor_ptr %19, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf32, #warp>>
%21 = arith.mulf %arg3, %cst : f32
%22 = tt.load %14 : !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>>
// CHECK: tt.splat {{.*}} : f32 -> tensor<16xf32
// CHECK-COUNT-8: tt.splat {{.*}} : f32 -> tensor<8x16xf32>
%23 = tt.splat %21 : f32 -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%24 = tt.splat %21 : f32 -> tensor<16x64xf32, #warp>
// CHECK: 30 = scf.for
%25:5 = scf.for %arg6 = %c0_i32 to %c1024_i32 step %c64_i32 iter_args(%arg7 = %cst_2, %arg8 = %cst_0, %arg9 = %cst_1, %arg10 = %16, %arg11 = %18) -> (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>) : i32 {
// CHECK-COUNT-16: tt.load {{.*}} {DotIdx = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
%29 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
%30 = tt.dot %22, %29, %cst_0, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>> -> tensor<16x64xf32, #warp>
etiotto marked this conversation as resolved.
Show resolved Hide resolved
// CHECK: arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK: arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK: [[MAX:%.*]] = arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK-NEXT: [[EXTRACT0:%.*]] = triton_intel_gpu.extract [[MAX]][0] : tensor<8x16xf32> -> tensor<16xf32>
// CHECK-NEXT: = "tt.reduce"([[EXTRACT0]]) <{axis = 0 : i32}> ({
// CHECK: (tensor<16xf32>) -> f32
%31 = "tt.reduce"(%30) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%53 = arith.maxnumf %arg12, %arg13 : f32
tt.reduce.return %53 : f32
}) : (tensor<16x64xf32, #warp>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%32 = arith.mulf %31, %23 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%33 = arith.maxnumf %arg9, %32 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%34 = arith.mulf %30, %24 : tensor<16x64xf32, #warp>
// CHECK: tt.expand_dims {{.*}} {axis = 1 : i32} : tensor<16xf32
// CHECK: triton_intel_gpu.broadcast {{.*}} -> tensor<16x16xf32>
%35 = tt.expand_dims %33 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%36 = tt.broadcast %35 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%37 = arith.subf %34, %36 : tensor<16x64xf32, #warp>
%38 = math.exp2 %37 : tensor<16x64xf32, #warp>
%39 = "tt.reduce"(%38) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%53 = arith.addf %arg12, %arg13 : f32
tt.reduce.return %53 : f32
}) : (tensor<16x64xf32, #warp>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%40 = arith.subf %arg9, %33 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%41 = math.exp2 %40 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%42 = arith.mulf %arg7, %41 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%43 = arith.addf %42, %39 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%44 = tt.expand_dims %41 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%45 = tt.broadcast %44 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%46 = arith.mulf %arg8, %45 : tensor<16x64xf32, #warp>
%47 = tt.load %arg10 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%48 = arith.truncf %38 : tensor<16x64xf32, #warp> to tensor<16x64xf16, #warp>
%49 = triton_gpu.convert_layout %48 : tensor<16x64xf16, #warp> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>
%50 = tt.dot %49, %47, %46, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<16x64xf32, #warp>
etiotto marked this conversation as resolved.
Show resolved Hide resolved
%51 = tt.advance %arg10, [%c64_i32, %c0_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
etiotto marked this conversation as resolved.
Show resolved Hide resolved
%52 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
scf.yield %43, %50, %33, %51, %52 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, transpose = 1}>>>
} {triton_gpu.workload = 4 : i32, tt.divisibility_arg1 = dense<64> : tensor<1xi32>}
%26 = tt.expand_dims %25#0 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%27 = tt.broadcast %26 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%28 = arith.divf %25#1, %27 : tensor<16x64xf32, #warp>
tt.store %20, %28 : !tt.ptr<tensor<16x64xf32, #warp>>
tt.return
}
}
52 changes: 2 additions & 50 deletions test/TritonIntelGPU/tritonintelgpu-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,53 +1,21 @@
// RUN: triton-opt -split-input-file -verify-diagnostics %s

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<8xf16>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result must have the same rank}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<8xf16>) -> tensor<24x8xf16>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr<tensor<16x8xf16>>, %ptr2 : !tt.ptr<tensor<16xf16>>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result must have the same rank}}
triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr<tensor<16x8xf16>>, !tt.ptr<tensor<16xf16>>) -> !tt.ptr<tensor<16x24xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16x8xf32>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result element type must match}}
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same type}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<16x8xf32>) -> tensor<16x16xf16>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr<tensor<16x8xf16>>, %ptr2 : !tt.ptr<tensor<16x8xf32>>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands and result element type must match}}
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same type}}
etiotto marked this conversation as resolved.
Show resolved Hide resolved
triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr<tensor<16x8xf16>>, !tt.ptr<tensor<16x8xf32>>) -> !tt.ptr<tensor<16x16xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<8x8xf16>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same shape}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<8x8xf16>) -> tensor<24x8xf16>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%ptr1 : !tt.ptr<tensor<16x8xf16>>, %ptr2 : !tt.ptr<tensor<16x16xf16>>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands must have the same shape}}
triton_intel_gpu.glue %ptr1, %ptr2 : (!tt.ptr<tensor<16x8xf16>>, !tt.ptr<tensor<16x16xf16>>) -> !tt.ptr<tensor<16x24xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16x8xf16>) {
// expected-error @+1 {{'triton_intel_gpu.glue' op operands cannot exceed result size along any dimension}}
triton_intel_gpu.glue %tensor1, %tensor2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<8x16xf16>
Expand Down Expand Up @@ -88,22 +56,6 @@ tt.func @triton_intel_gpu.glue(%tensor1 : tensor<16x8xf16>, %tensor2 : tensor<16

// -----

tt.func @triton_intel_gpu.extract(%tensor : tensor<16x16xf16>) {
// expected-error @+1 {{'triton_intel_gpu.extract' op operand and result must have the same rank}}
triton_intel_gpu.extract %tensor[0] : tensor<16x16xf16> -> tensor<4xf16>
Dewei-Wang-sh marked this conversation as resolved.
Show resolved Hide resolved
tt.return
}

// -----

tt.func @triton_intel_gpu.extract(%ptr : !tt.ptr<tensor<16x16xf16>>) {
// expected-error @+1 {{'triton_intel_gpu.extract' op operand and result must have the same rank}}
triton_intel_gpu.extract %ptr[0] : !tt.ptr<tensor<16x16xf16>> -> !tt.ptr<tensor<4xf16>>
tt.return
}

// -----

tt.func @triton_intel_gpu.extract(%tensor : tensor<16x16xf16>) {
// expected-error @+1 {{'triton_intel_gpu.extract' op operand and result element type must match}}
triton_intel_gpu.extract %tensor[0] : tensor<16x16xf16> -> tensor<4x4xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def TTIG_GlueOp : TTIG_Op<"glue", [Pure]> {
-> ptr<tensor<16x24xf16>>
```
}];
let arguments = (ins Variadic<TT_TensorOrTensorPtr>:$operands);
etiotto marked this conversation as resolved.
Show resolved Hide resolved
let arguments = (ins Variadic<TT_Type>:$operands);
let results = (outs TT_TensorOrTensorPtr:$res);
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
Expand Down Expand Up @@ -90,7 +90,7 @@ def TTIG_ExtractOp : TTIG_Op<"extract", [Pure]> {
column-major order
}];
let arguments = (ins TT_TensorOrTensorPtr:$base, I32Attr:$index);
let results = (outs TT_TensorOrTensorPtr:$res);
let results = (outs TT_Type:$res);
let assemblyFormat = [{
$base `[` $index `]` attr-dict `:` type($base) `->` type($res)
}];
Expand All @@ -116,4 +116,22 @@ def TTIG_PrefetchOp : TTIG_Op<"prefetch"> {
}];
}

// same as tt.broadcast except that we don't require SameOperandsAndResultEncoding
def TTIG_BroadcastOp : TTIG_Op<"broadcast", [Pure,
etiotto marked this conversation as resolved.
Show resolved Hide resolved
SameOperandsAndResultElementType]> {
let summary = "broadcast a tensor";

let description = [{
For a given tensor, broadcast changes one or more dimensions with size 1
to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot
change the size of a non-1 dimension.
}];

let arguments = (ins TT_Tensor:$src);

let results = (outs TT_Tensor:$result);

let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
etiotto marked this conversation as resolved.
Show resolved Hide resolved
}

#endif
Loading