Skip to content

Commit

Permalink
[VnniTransform] Visit blocks on post order (#971)
Browse files Browse the repository at this point in the history
Visit blocks on post order, such that values used inside a loop but defined outside the loop can be handled correctly.
  • Loading branch information
chencha3 authored Nov 26, 2024
1 parent fd7fdc2 commit 668ddfa
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 59 deletions.
128 changes: 70 additions & 58 deletions lib/Transforms/VnniTransformation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ static void applyVnniTransformOnResults(mlir::OpBuilder &builder,
// the op, and whether it is safe to apply vnni transform on operands too.
static void updateUnknownOp(mlir::OpBuilder &builder, mlir::Operation &op,
LayoutAnalysis &analysis) {
// Ignore ops that has packed attribute, since they are inserted by the pass.
if (op.hasAttr("packed"))
return;
applyVnniTransformOnResults(builder, &op, analysis);
}

Expand Down Expand Up @@ -469,82 +472,84 @@ static void updateExtractStrideSliceOp(mlir::OpBuilder &builder,
}
}

static void handleBranchOpInterface(mlir::OpBuilder &builder,
mlir::Block &block,
mlir::RegionBranchOpInterface branch,
mlir::TypeRange argsTypes) {
builder.setInsertionPointToStart(&block);
// handle terminal ops, e.g., scf.Yield. Update
// the types of its successor inputs if successor
// operands needs vnni format.
static void handleBranchTerminatorOpInterface(
mlir::OpBuilder &builder,
mlir::RegionBranchTerminatorOpInterface terminator,
LayoutAnalysis &analysis) {

if (!mlir::isa<mlir::RegionBranchOpInterface>(terminator->getParentOp()))
return;

mlir::Operation *op = branch.getOperation();
llvm::SmallVector<mlir::RegionSuccessor> successors;
llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
branch.getEntrySuccessorRegions(operands, successors);
llvm::SmallVector<mlir::Attribute> operands(terminator->getNumOperands(),
nullptr);
terminator.getSuccessorRegions(operands, successors);

for (mlir::RegionSuccessor &successor : successors) {
if (block.getParent() != successor.getSuccessor())
if (!successor.isParent())
continue;

mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor);
mlir::OperandRange operands = terminator.getSuccessorOperands(successor);
mlir::ValueRange inputs = successor.getSuccessorInputs();
for (auto [arg, input] : llvm::zip(operands, inputs)) {
auto idx = mlir::cast<mlir::BlockArgument>(input).getArgNumber();
mlir::Type dstType = argsTypes[idx];
if (dstType == arg.getType()) {
input.setType(dstType);
continue;
} else {
auto cast = mlir::cast<mlir::TypedValue<mlir::VectorType>>(arg);
auto &&[newArg, root] = applyVnniTransform(builder, cast);
arg.replaceAllUsesExcept(newArg, root);
for (auto [arg, inp] : llvm::zip(operands, inputs)) {
if (analysis.getLayout(arg)) {
auto vecTy = mlir::cast<mlir::VectorType>(arg.getType());
auto packedTy = getPackedType(vecTy);
inp.setType(packedTy);
}
}
}
}

auto terminator = mlir::cast<mlir::RegionBranchTerminatorOpInterface>(
block.getTerminator());
mlir::SmallVector<mlir::Attribute> operandAttributes(
terminator->getNumOperands(), nullptr);

successors.clear();
terminator.getSuccessorRegions(operandAttributes, successors);
// handle REgionBranchOps, e.g., scf.for. Update the
// region argument types, if the argument needs to be
// in vnni format, but the initArg is not, a vnni
// transform is applied on the initArg.
static void handleBranchOpInterface(mlir::OpBuilder &builder,
mlir::RegionBranchOpInterface branch,
LayoutAnalysis &analysis) {
mlir::Operation *op = branch.getOperation();
llvm::SmallVector<mlir::RegionSuccessor> successors;
llvm::SmallVector<mlir::Attribute> operands(op->getNumOperands(), nullptr);
branch.getEntrySuccessorRegions(operands, successors);

for (const mlir::RegionSuccessor &successor : successors) {
if (!successor.isParent())
for (mlir::RegionSuccessor &successor : successors) {
if (successor.isParent())
continue;

mlir::OperandRange operands = branch.getEntrySuccessorOperands(successor);
mlir::ValueRange inputs = successor.getSuccessorInputs();
mlir::OperandRange operands = terminator.getSuccessorOperands(successor);
for (auto [operand, input] : llvm::zip(operands, inputs)) {
input.setType(operand.getType());

for (auto [arg, input] : llvm::zip(operands, inputs)) {
if (analysis.getLayout(input)) {
auto vecTy = mlir::cast<mlir::VectorType>(input.getType());
auto packedTy = getPackedType(vecTy);
input.setType(packedTy);
if (!analysis.getLayout(arg)) {
builder.setInsertionPointAfterValue(arg);
auto cast = mlir::cast<mlir::TypedValue<mlir::VectorType>>(arg);
auto &&[newArg, root] = applyVnniTransform(builder, cast);
arg.replaceAllUsesExcept(newArg, root);
}
}
}
}
}

static void updateBlockTypes(mlir::OpBuilder &builder, mlir::Block &block,
LayoutAnalysis &analysis) {
if (auto iface = mlir::dyn_cast_if_present<mlir::RegionBranchOpInterface>(
block.getParentOp())) {
llvm::SmallVector<mlir::Type> types;
for (auto arg : block.getArguments()) {
auto argTy = arg.getType();
if (!analysis.getLayout(arg)) {
types.push_back(argTy);
} else {
auto vecTy = mlir::cast<mlir::VectorType>(argTy);
auto packedTy = getPackedType(vecTy);
types.push_back(packedTy);
if (!mlir::isa<mlir::RegionBranchOpInterface>(block.getParentOp())) {
builder.setInsertionPointToStart(&block);
for (auto &&arg : block.getArguments()) {
if (analysis.getLayout(arg)) {
auto cast = mlir::cast<mlir::TypedValue<mlir::VectorType>>(arg);
auto &&[newArg, root] = applyVnniTransform(builder, cast);
arg.replaceAllUsesExcept(newArg, root);
}
}
return handleBranchOpInterface(builder, block, iface, types);
}

builder.setInsertionPointToStart(&block);
for (auto &&arg : block.getArguments()) {
if (analysis.getLayout(arg)) {
auto cast = mlir::cast<mlir::TypedValue<mlir::VectorType>>(arg);
auto &&[newArg, root] = applyVnniTransform(builder, cast);
arg.replaceAllUsesExcept(newArg, root);
}
}
}

Expand All @@ -561,14 +566,21 @@ struct VnniTransformationPass final

mlir::OpBuilder builder(&getContext());
llvm::SmallVector<mlir::Type> operands;
op->walk<mlir::WalkOrder::PreOrder>([&](mlir::Block *block) {
// process ops in post-order so that the layout info is
// used before being destroyed.
op->walk([&](mlir::Block *block) {
// Iterate block ops in reverse so op is updated before it's operands.
for (mlir::Operation &op : llvm::reverse(block->getOperations())) {
// Ignore shape casts as they are generated by the conversion itself.
// Ignore RegionBranchOpInterface as it handled in `updateBlockTypes`.
if (mlir::isa<mlir::vector::ShapeCastOp, mlir::RegionBranchOpInterface,
mlir::RegionBranchTerminatorOpInterface>(op))
if (auto terminator =
mlir::dyn_cast<mlir::RegionBranchTerminatorOpInterface>(op)) {
handleBranchTerminatorOpInterface(builder, terminator, analysis);
continue;
}

if (auto iface = mlir::dyn_cast<mlir::RegionBranchOpInterface>(op)) {
handleBranchOpInterface(builder, iface, analysis);
continue;
}

if (auto dpas = mlir::dyn_cast<mlir::xegpu::DpasOp>(op)) {
updateDpasOp(builder, dpas, analysis);
Expand Down
32 changes: 31 additions & 1 deletion test/Transforms/VnniTransform/unit-tests.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -374,4 +374,34 @@ func.func @test(%arg1 : !xegpu.tensor_desc<8x16xi16>, %arg2 : !xegpu.tensor_desc
%1 = arith.bitcast %b : vector<16x16xi16> to vector<16x16xf16>
%2 = xegpu.dpas %0, %1 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
return %2 : vector<8x16xf32>
}
}

// -----

//CHECK-LABEL: @test
// CHECK-SAME: (%[[arg0:.*]]: !xegpu.tensor_desc<8x16xf16>, %[[arg1:.*]]: !xegpu.tensor_desc<16x16xf16>, %[[arg2:.*]]: vector<16x16xf16>, %[[arg3:.*]]: i1) -> vector<8x16xf32> {
func.func @test(%arg1 : !xegpu.tensor_desc<8x16xf16>, %arg2 : !xegpu.tensor_desc<16x16xf16>, %arg3 : vector<16x16xf16>, %arg4 : i1) -> vector<8x16xf32> {
//CHECK: %[[r0:.*]] = vector.shape_cast %[[arg2]] {packed} : vector<16x16xf16> to vector<256xf16>
//CHECK: %[[r1:.*]] = vector.shuffle %[[r0]], %[[r0]] [0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31, 32, 48, 33, 49, 34, 50, 35, 51, 36, 52, 37, 53, 38, 54, 39, 55, 40, 56, 41, 57, 42, 58, 43, 59, 44, 60, 45, 61, 46, 62, 47, 63, 64, 80, 65, 81, 66, 82, 67, 83, 68, 84, 69, 85, 70, 86, 71, 87, 72, 88, 73, 89, 74, 90, 75, 91, 76, 92, 77, 93, 78, 94, 79, 95, 96, 112, 97, 113, 98, 114, 99, 115, 100, 116, 101, 117, 102, 118, 103, 119, 104, 120, 105, 121, 106, 122, 107, 123, 108, 124, 109, 125, 110, 126, 111, 127, 128, 144, 129, 145, 130, 146, 131, 147, 132, 148, 133, 149, 134, 150, 135, 151, 136, 152, 137, 153, 138, 154, 139, 155, 140, 156, 141, 157, 142, 158, 143, 159, 160, 176, 161, 177, 162, 178, 163, 179, 164, 180, 165, 181, 166, 182, 167, 183, 168, 184, 169, 185, 170, 186, 171, 187, 172, 188, 173, 189, 174, 190, 175, 191, 192, 208, 193, 209, 194, 210, 195, 211, 196, 212, 197, 213, 198, 214, 199, 215, 200, 216, 201, 217, 202, 218, 203, 219, 204, 220, 205, 221, 206, 222, 207, 223, 224, 240, 225, 241, 226, 242, 227, 243, 228, 244, 229, 245, 230, 246, 231, 247, 232, 248, 233, 249, 234, 250, 235, 251, 236, 252, 237, 253, 238, 254, 239, 255] {packed} : vector<256xf16>, vector<256xf16>
//CHECK: %[[r2:.*]] = vector.shape_cast %[[r1]] {packed} : vector<256xf16> to vector<8x16x2xf16>
//CHECK: %[[r3:.*]] = xegpu.load_nd %[[arg0]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
%0 = xegpu.load_nd %arg1 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
//CHECK: %[[r4:.*]] = scf.if %[[arg3]] -> (vector<8x16xf32>)
%1 = scf.if %arg4 -> (vector<8x16xf32>) {
//CHECK: %[[r5:.*]] = xegpu.load_nd %[[arg1]] <{packed}> : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
%2 = xegpu.load_nd %arg2 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
//CHECK: %[[r6:.*]] = arith.addf %[[r5]], %[[r2]] : vector<8x16x2xf16>
%3 = arith.addf %2, %arg3 : vector<16x16xf16>
//CHECK: %[[r7:.*]] = xegpu.dpas %[[r3]], %[[r6]] : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
%4 = xegpu.dpas %0, %3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
//CHECK: scf.yield %[[r7]] : vector<8x16xf32>
scf.yield %4 : vector<8x16xf32>
} else {
//CHECK: %[[r5:.*]] = xegpu.dpas %[[r3]], %[[r2]] : vector<8x16xf16>, vector<8x16x2xf16> -> vector<8x16xf32>
%5 = xegpu.dpas %0, %arg3 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
//CHECK: scf.yield %[[r5]] : vector<8x16xf32>
scf.yield %5 : vector<8x16xf32>
}
//CHECK: return %[[r4]] : vector<8x16xf32>
return %1 : vector<8x16xf32>
}

0 comments on commit 668ddfa

Please sign in to comment.