Skip to content

Commit

Permalink
[FA][ScheduleLoad] Fix bug exposed by causal=true (#2433)
Browse files Browse the repository at this point in the history
Cannot move ops that are used by other ops in another region.

Signed-off-by: Whitney Tsang <[email protected]>
  • Loading branch information
whitneywhtsang authored Oct 8, 2024
1 parent c2570b7 commit 2202ca7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
24 changes: 24 additions & 0 deletions test/TritonIntelGPU/schedule-load.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,27 @@ module attributes {"triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-wa
tt.return
}
}

// -----

tt.func public @test(%arg0: !tt.ptr<tensor<16x16xf16>>, %arg1: !tt.ptr<tensor<8x32xf16>>) {
%lb = arith.constant 0 : i32
%ub = tt.get_program_id x : i32
%st = arith.constant 32 : i32
%zero = arith.constant dense<0.000000e+00> : tensor<8x16xf32>
%common = tt.load %arg1 {DotIdx = 0 : i32} : !tt.ptr<tensor<8x32xf16>>
// COM: Check %common is not moved in the loop.
// CHECK: tt.load %arg1
// CHECK-COUNT-2: scf.for
scf.for %iv0 = %lb to %ub step %st : i32 {
%load1 = tt.load %arg0 {DotIdx = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
%extract1 = triton_intel_gpu.extract %common[0] : tensor<8x32xf16> -> tensor<8x16xf16>
%dot1 = tt.dot %extract1, %load1, %zero, inputPrecision = tf32 {"schedule-group" = 0 : i32} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32>
}
scf.for %iv1 = %lb to %ub step %st : i32 {
%load2 = tt.load %arg0 {DotIdx = 1 : i32} : !tt.ptr<tensor<16x16xf16>>
%extract2 = triton_intel_gpu.extract %common[0] : tensor<8x32xf16> -> tensor<8x16xf16>
%dot2 = tt.dot %extract2, %load2, %zero, inputPrecision = tf32 {"schedule-group" = 0 : i32} : tensor<8x16xf16> * tensor<16x16xf16> -> tensor<8x16xf32>
}
tt.return
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,15 @@ class ScheduleLoadPass
for (SmallVector<tt::DotOp> &dots : dotsGroup) {
SmallVector<Value> notVisited = getNotVisitedUses(dots);
for (Value val : notVisited) {
if (Operation *op = val.getDefiningOp())
if (Operation *op = val.getDefiningOp()) {
// Cannot move op that used by other ops in another region.
Region *rgn = dots.begin()->getOperation()->getParentRegion();
if (any_of(val.getUsers(), [&](Operation *user) {
return user->getParentRegion() != rgn;
}))
continue;
op->moveBefore(dots.begin()->getOperation());
}
}
}
});
Expand Down

0 comments on commit 2202ca7

Please sign in to comment.