Skip to content

Commit

Permalink
Fix test_chained_reductions (#2821)
Browse files Browse the repository at this point in the history
Fixes #2703

---------

Signed-off-by: Tiotto, Ettore <[email protected]>
Co-authored-by: Whitney Tsang <[email protected]>
  • Loading branch information
etiotto and whitneywhtsang authored Nov 26, 2024
1 parent 8589959 commit 3c09bfe
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 17 deletions.
5 changes: 5 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6109,6 +6109,11 @@ def sanitize_cumsum_kernel(Z, X, BLOCK: tl.constexpr):
((8, 2, 32, 4, 16), [4, 0, 1, 3, 2], [0, 2, 0]),
])
def test_chained_reductions(in_shape, perm, red_dims, device):
if is_xpu() and in_shape == (4, 32, 32, 4, 2):
# check maximum shared memory
if triton.runtime.driver.active.utils.get_device_properties(
triton.runtime.driver.active.get_current_device())["max_shared_mem"] <= 163840:
pytest.xfail("XPU: Not enough shared memory")

@triton.jit
def kernel(In, Out, #
Expand Down
2 changes: 0 additions & 2 deletions scripts/skiplist/a770/language.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
Expand Down
2 changes: 0 additions & 2 deletions scripts/skiplist/conda/language.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e5-128-
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4nv-128-256-128-128-256-256]
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256]
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
Expand Down
2 changes: 0 additions & 2 deletions scripts/skiplist/default/language.txt
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
2 changes: 0 additions & 2 deletions scripts/skiplist/lts/language.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e5-128-
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4nv-128-256-128-128-256-256]
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[64-float8e4b15-128-256-128-128-256-256]
test/unit/language/test_core.py::test_dot_max_num_imprecise_acc[128-float8e5-128-256-128-128-256-256]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
Expand Down
2 changes: 0 additions & 2 deletions scripts/skiplist/mtl/language.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32]
test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32]
Expand Down
2 changes: 0 additions & 2 deletions scripts/skiplist/xe2/language.txt
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/2703
test/unit/language/test_core.py::test_chained_reductions[in_shape0-perm0-red_dims0]
56 changes: 51 additions & 5 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,46 @@ struct ReduceOpConversion
rewriter.replaceOp(op, results);
}

// For slice layout some ids are duplicated on multiple lanes, so we need to
// handle the delinearization of laneId in a special way. We need to
// generalize this part of the logic to work on any kind of linear layout
// uniformely.
SmallVector<Value>
getMultiDimLaneId(ReduceOpHelper &helper, Value &laneId, Location &loc,
ConversionPatternRewriter &rewriter) const {
auto srcLayout = helper.getSrcLayout();
auto srcShape = helper.getSrcShape();
auto order = triton::gpu::getThreadOrder(srcLayout);
SmallVector<Value> multiDimLaneId;

if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
auto parentLayout = sliceLayout.getParent();
SmallVector<unsigned> dims = {sliceLayout.getDim()};
while (auto parentSliceLayout =
mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
dims.push_back(parentSliceLayout.getDim());
parentLayout = parentSliceLayout.getParent();
}

auto parentThreadsPerWarps = triton::gpu::getThreadsPerWarp(parentLayout);
auto parentOrder = triton::gpu::getThreadOrder(parentLayout);
multiDimLaneId = delinearize(rewriter, loc, laneId, parentThreadsPerWarps,
parentOrder);
for (unsigned dim : llvm::reverse(dims)) {
multiDimLaneId.erase(multiDimLaneId.begin() + dim);
}
} else {
SmallVector<unsigned> threadsPerWarps =
triton::gpu::getThreadsPerWarp(srcLayout);
threadsPerWarps[helper.getAxis()] =
triton::gpu::getThreadsPerWarpWithUniqueData(
srcLayout, srcShape)[helper.getAxis()];
multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarps, order);
}
return multiDimLaneId;
}

SmallVector<Value>
getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc,
ConversionPatternRewriter &rewriter) const {
Expand All @@ -233,11 +273,20 @@ struct ReduceOpConversion
// a way to properly delinearize warpId in the slice case
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
auto parentLayout = sliceLayout.getParent();
SmallVector<unsigned> dims = {sliceLayout.getDim()};
while (auto parentSliceLayout =
mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
dims.push_back(parentSliceLayout.getDim());
parentLayout = parentSliceLayout.getParent();
}

auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout);
auto parentOrder = triton::gpu::getWarpOrder(parentLayout);
multiDimWarpId =
delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder);
multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim());
for (unsigned dim : llvm::reverse(dims)) {
multiDimWarpId.erase(multiDimWarpId.begin() + dim);
}
} else {
SmallVector<unsigned> warpsPerCTA =
triton::gpu::getWarpsPerCTA(srcLayout);
Expand Down Expand Up @@ -265,11 +314,8 @@ struct ReduceOpConversion
unsigned axis = op.getAxis();
auto smemShape = helper.getScratchRepShape();

auto threadsPerWarp =
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
auto order = getThreadOrder(srcLayout);
SmallVector<Value> multiDimLaneId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
getMultiDimLaneId(helper, laneId, loc, rewriter);
Value laneIdAxis = multiDimLaneId[axis];
Value zero = i32_val(0);
Value laneZero = icmp_eq(laneIdAxis, zero);
Expand Down

0 comments on commit 3c09bfe

Please sign in to comment.