Skip to content

Commit

Permalink
clean up code, fix rebase, add lit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Dewei-Wang-sh committed Jul 4, 2024
1 parent 1801d6b commit 4f8d1bd
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 158 deletions.
50 changes: 5 additions & 45 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1062,18 +1062,15 @@ LogicalResult DotOperandEncodingAttr::verify(
}

if (auto parentAttr = mlir::dyn_cast<intel::WarpEncodingAttr>(parent)) {
if (kWidth != 0)
return emitError()
<< "triton_gpu.dot_op kWidth parameter is not supported "
"when the parent is a warp layout";
return success();
}

if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
if (kWidth != 0)
return emitError()
<< "triton_gpu.dot_op kWidth parameter is not supported "
"when the parent is a blocked layout";
// intel: parent can be blocked layout
// if (kWidth != 0)
// return emitError()
// << "triton_gpu.dot_op kWidth parameter is not supported "
// "when the parent is a blocked layout";
return success();
}

Expand Down Expand Up @@ -2120,43 +2117,6 @@ SmallVector<unsigned> DotOperandEncodingAttr::getSizePerThread() const {
}
}

Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
NamedAttrList attrs;
if (parser.parseOptionalAttrDict(attrs).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned opIdx = mlir::cast<IntegerAttr>(attrs.get("opIdx")).getInt();
Attribute parent = attrs.get("parent");
auto mmaParent = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
unsigned kWidth = 0;
Attribute _kWidth = attrs.get("kWidth");
if (_kWidth) {
// if (!mmaParent || mmaParent.isVolta()) {
// auto loc = parser.getNameLoc();
// parser.emitError(loc, "kWidth only supported for MMAv2+ parent");
// return Attribute();
// }
kWidth = mlir::cast<IntegerAttr>(_kWidth).getInt();
}
if (mlir::isa<AMDWmmaEncodingAttr>(parent)) {
kWidth = AMDWmmaEncodingAttr::getMNKDimPerWMMAInstr()[2];
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent, kWidth);
}

void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto mmaParent = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent());
printer << "<{"
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
if (mmaParent && mmaParent.isAmpere())
printer << ", kWidth = " << getKWidth();
printer << "}>";
}

//===----------------------------------------------------------------------===//
// ASM Interface (i.e.: alias)
//===----------------------------------------------------------------------===//
Expand Down
155 changes: 101 additions & 54 deletions test/TritonIntelGPU/match-target-size.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -116,60 +116,6 @@ tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16

// -----

// COM: Test SCF canonicalization: ensure loop is not modified when the result is not used by just 'extract' operations.
tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr<f16, 1>,
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) {
// CHECK-LABEL: @simplify_scf_for
// CHECK: [[GLUE:%.*]] = triton_intel_gpu.glue
// CHECK-NEXT: [[RES:%.*]] = scf.for {{.*}} iter_args([[INIT1:%.*]] = [[GLUE]]) -> (tensor<16x16xf16>) : i32 {
// CHECK: scf.yield {{.*}} : tensor<16x16xf16>
// CHECK-NEXT: }
// CHECK-NEXT: [[PTR:%.*]] = tt.make_tensor_ptr %arg2
// CHECK-NEXT: tt.store [[PTR]], [[RES]]
%lb = arith.constant 0 : i32
%ub = arith.constant 32 : i32
%st = arith.constant 1 : i32
%c1_i64 = arith.constant 1 : i64
%glue = triton_intel_gpu.glue %arg0, %arg1 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
%res = scf.for %iv = %lb to %ub step %st iter_args(%arg = %glue) -> (tensor<16x16xf16>) : i32 {
%e1 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16>
%e2 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16>
%g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
scf.yield %g1 : tensor<16x16xf16>
}
%ptr = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg5, %c1_i64], [%arg6, %arg7] {order = array<i32: 1, 0>} : <tensor<16x16xf16>>
tt.store %ptr, %res {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
tt.return
}

// -----

// COM: Test SCF canonicalization: ensure loop is not modified if any user of a 'glue' init value is not an 'extract' operation.
tt.func public @simplify_scf_for(%arg0: tensor<16x8xf16>, %arg1: tensor<16x8xf16>, %arg2: !tt.ptr<f16, 1>,
%arg3: i64, %arg4: i64, %arg5: i64, %arg6: i32, %arg7: i32) {
// CHECK-LABEL: @simplify_scf_for
// CHECK: [[GLUE:%.*]] = triton_intel_gpu.glue
// CHECK-NEXT: [[RES:%.*]] = scf.for {{.*}} iter_args([[INIT1:%.*]] = [[GLUE]]) -> (tensor<16x16xf16>) : i32 {
// CHECK: scf.yield {{.*}} : tensor<16x16xf16>
// CHECK-NEXT: }
%lb = arith.constant 0 : i32
%ub = arith.constant 32 : i32
%st = arith.constant 1 : i32
%c1_i64 = arith.constant 1 : i64
%glue = triton_intel_gpu.glue %arg0, %arg1 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
%res = scf.for %iv = %lb to %ub step %st iter_args(%arg = %glue) -> (tensor<16x16xf16>) : i32 {
%e1 = triton_intel_gpu.extract %arg[0] : tensor<16x16xf16> -> tensor<16x8xf16>
%e2 = triton_intel_gpu.extract %arg[1] : tensor<16x16xf16> -> tensor<16x8xf16>
%g1 = triton_intel_gpu.glue %e1, %e2 : (tensor<16x8xf16>, tensor<16x8xf16>) -> tensor<16x16xf16>
%ptr = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg5, %c1_i64], [%arg6, %arg7] {order = array<i32: 1, 0>} : <tensor<16x16xf16>>
tt.store %ptr, %arg {boundaryCheck = array<i32: 0, 1>, cache = 1 : i32, evict = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
scf.yield %g1 : tensor<16x16xf16>
}
tt.return
}

// -----

// COM: Test transformation for int8 datatype

// CHECK-LABEL: @matmul_kernel_with_block_pointers_int8
Expand Down Expand Up @@ -279,3 +225,104 @@ tt.func public @matmul_kernel_with_block_pointers_tf32(%arg0: !tt.ptr<f32> {tt.d
tt.store %tptr_c, %35#0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x32xf32, #warp>>
tt.return
}

// -----

// COM: Test Attention Related Ops

#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 1 : i32} {
// CHECK-LABEL: @_attn_fwd
tt.func public @_attn_fwd(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: f32, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%c16_i32 = arith.constant 16 : i32
%c128_i32 = arith.constant 128 : i32
%c1024_i64 = arith.constant 1024 : i64
%c64_i64 = arith.constant 64 : i64
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant 1.44269502 : f32
%c3145728_i64 = arith.constant 3145728 : i64
%c65536_i64 = arith.constant 65536 : i64
%cst_0 = arith.constant dense<0.000000e+00> : tensor<16x64xf32, #warp>
%c64_i32 = arith.constant 64 : i32
%c1024_i32 = arith.constant 1024 : i32
%cst_1 = arith.constant dense<0xFF800000> : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%cst_2 = arith.constant dense<1.000000e+00> : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%0 = gpu.subgroup_id : index
%1 = arith.index_cast %0 : index to i32
%2 = tt.get_program_id z : i32
%3 = tt.get_program_id x : i32
%4 = tt.get_program_id y : i32
%5 = arith.extsi %3 : i32 to i64
%6 = arith.muli %5, %c3145728_i64 : i64
%7 = arith.extsi %4 : i32 to i64
%8 = arith.muli %7, %c65536_i64 : i64
%9 = arith.addi %6, %8 : i64
%10 = tt.addptr %arg0, %9 : !tt.ptr<f16>, i64
%11 = arith.muli %2, %c128_i32 : i32
%12 = arith.muli %1, %c16_i32 : i32
%13 = arith.addi %12, %11 : i32
%14 = tt.make_tensor_ptr %10, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp, kWidth=1}>>>
%15 = tt.addptr %arg2, %9 : !tt.ptr<f16>, i64
%16 = tt.make_tensor_ptr %15, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%17 = tt.addptr %arg1, %9 : !tt.ptr<f16>, i64
%18 = tt.make_tensor_ptr %17, [%c64_i64, %c1024_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, kWidth=1}>>>
%19 = tt.addptr %arg5, %9 : !tt.ptr<f32>, i64
%20 = tt.make_tensor_ptr %19, [%c1024_i64, %c64_i64], [%c64_i64, %c1_i64], [%13, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x64xf32, #warp>>
%21 = arith.mulf %arg3, %cst : f32
%22 = tt.load %14 : !tt.ptr<tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp, kWidth=1}>>>
// CHECK: tt.splat {{.*}} : f32 -> tensor<16xf32
// CHECK-COUNT-8: tt.splat {{.*}} : f32 -> tensor<8x16xf32>
%23 = tt.splat %21 : f32 -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%24 = tt.splat %21 : f32 -> tensor<16x64xf32, #warp>
%25:5 = scf.for %arg6 = %c0_i32 to %c1024_i32 step %c64_i32 iter_args(%arg7 = %cst_2, %arg8 = %cst_0, %arg9 = %cst_1, %arg10 = %16, %arg11 = %18) -> (tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, kWidth=1}>>>) : i32 {
// CHECK-COUNT-16: tt.load {{.*}} {DotIdx = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
%29 = tt.load %arg11 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, kWidth=1}>>>
%30 = tt.dot %22, %29, %cst_0, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp, kWidth=1}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, kWidth=1}>> -> tensor<16x64xf32, #warp>
// CHECK: arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK: arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK: [[MAX:%.*]] = arith.maxnumf {{.*}} : tensor<8x16xf32>
// CHECK-NEXT: [[EXTRACT0:%.*]] = triton_intel_gpu.extract [[MAX]][0] : tensor<8x16xf32> -> tensor<16xf32>
// CHECK-NEXT: = "tt.reduce"([[EXTRACT0]]) <{axis = 0 : i32}> ({
// CHECK: (tensor<16xf32>) -> f32
%31 = "tt.reduce"(%30) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%53 = arith.maxnumf %arg12, %arg13 : f32
tt.reduce.return %53 : f32
}) : (tensor<16x64xf32, #warp>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%32 = arith.mulf %31, %23 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%33 = arith.maxnumf %arg9, %32 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%34 = arith.mulf %30, %24 : tensor<16x64xf32, #warp>
// CHECK: tt.expand_dims {{.*}} {axis = 1 : i32} : tensor<16xf32
// CHECK: tt.broadcast {{.*}} -> tensor<16x16xf32>
%35 = tt.expand_dims %33 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%36 = tt.broadcast %35 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%37 = arith.subf %34, %36 : tensor<16x64xf32, #warp>
%38 = math.exp2 %37 : tensor<16x64xf32, #warp>
%39 = "tt.reduce"(%38) <{axis = 1 : i32}> ({
^bb0(%arg12: f32, %arg13: f32):
%53 = arith.addf %arg12, %arg13 : f32
tt.reduce.return %53 : f32
}) : (tensor<16x64xf32, #warp>) -> tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%40 = arith.subf %arg9, %33 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%41 = math.exp2 %40 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%42 = arith.mulf %arg7, %41 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%43 = arith.addf %42, %39 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>
%44 = tt.expand_dims %41 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%45 = tt.broadcast %44 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%46 = arith.mulf %arg8, %45 : tensor<16x64xf32, #warp>
%47 = tt.load %arg10 : !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%48 = arith.truncf %38 : tensor<16x64xf32, #warp> to tensor<16x64xf16, #warp>
%49 = triton_gpu.convert_layout %48 : tensor<16x64xf16, #warp> -> tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>>
%50 = tt.dot %49, %47, %46, inputPrecision = tf32 : tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #warp}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>> -> tensor<16x64xf32, #warp>
%51 = tt.advance %arg10, [%c64_i32, %c0_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>
%52 = tt.advance %arg11, [%c0_i32, %c64_i32] : <tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, kWidth=1}>>>
scf.yield %43, %50, %33, %51, %52 : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, tensor<16x64xf32, #warp>, tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp}>>>, !tt.ptr<tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #warp, kWidth=1}>>>
} {triton_gpu.workload = 4 : i32, tt.divisibility_arg1 = dense<64> : tensor<1xi32>}
%26 = tt.expand_dims %25#0 {axis = 1 : i32} : tensor<16xf32, #triton_gpu.slice<{dim = 1, parent = #warp}>> -> tensor<16x1xf32, #warp>
%27 = tt.broadcast %26 : tensor<16x1xf32, #warp> -> tensor<16x64xf32, #warp>
%28 = arith.divf %25#1, %27 : tensor<16x64xf32, #warp>
tt.store %20, %28 : !tt.ptr<tensor<16x64xf32, #warp>>
tt.return
}
}
23 changes: 13 additions & 10 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,13 @@ static SmallVector<int64_t> getShape(Type type) {
/// Return the element type of an input tensor (or ptr to tensor).
static Type getElementType(Type type) {
return TypeSwitch<Type, Type>(type)
.Case<RankedTensorType>([](auto ty) { return ty.getElementType(); })
.Case<ShapedType>([](auto ty) { return ty.getElementType(); })
.Case<triton::PointerType>([](auto ty) {
assert(isa<RankedTensorType>(ty.getPointeeType()) &&
"Expecting ptr to tensor");
return cast<RankedTensorType>(ty.getPointeeType()).getElementType();
})
.Default([](auto) {
llvm_unreachable("Unexpected type");
return Type();
});
.Default([](auto ty) { return ty; });
}

/// Return the size of the specified dimension of an input tensor (or ptr to
Expand All @@ -90,10 +87,15 @@ static unsigned getDimSize(Type type, unsigned dim) {
namespace mlir::triton::gpu::intel {

LogicalResult GlueOp::verify() {
/*
SmallVector<Type> inputTypes;
for (auto input : getOperands())
inputTypes.push_back(input.getType());
Type inputType = inputTypes[0];

if (!llvm::all_of(inputTypes, [&](Type type) { return type == inputType; }))
return emitOpError("operands type should be the same");
else
return success();

Type resultType = getRes().getType();
unsigned resultRank = getRank(resultType);
Expand Down Expand Up @@ -124,7 +126,6 @@ LogicalResult GlueOp::verify() {
return emitOpError(
"operands cannot exceed result size along any dimension");

auto inputType = inputTypes[0];
for (unsigned i = 0; i < resultRank; ++i) {
unsigned resultSize = getDimSize(resultType, i);
unsigned inputSize = getDimSize(inputType, i);
Expand All @@ -140,15 +141,18 @@ LogicalResult GlueOp::verify() {

if (inputTypes.size() * numInputElems != numResultElems)
return emitOpError("glued operands do not exactly cover the result shape");
*/
return success();
}

LogicalResult ExtractOp::verify() {
/*
Type resultType = getRes().getType();
Type operandType = getBase().getType();

if (getElementType(resultType) == getElementType(operandType))
return success();
else
return emitOpError("operand and reslut should have the same element type");

unsigned resultRank = getRank(resultType);
unsigned operandRank = getRank(operandType);
if (operandRank != resultRank)
Expand Down Expand Up @@ -181,7 +185,6 @@ LogicalResult ExtractOp::verify() {
unsigned index = getIndex();
if (index >= numTiles)
return emitOpError("index must be less than ") << numTiles;
*/
return success();
}

Expand Down
Loading

0 comments on commit 4f8d1bd

Please sign in to comment.