Skip to content

Commit

Permalink
[OptRed] Do not assert on reductions after 3D dot ops (#2590)
Browse files Browse the repository at this point in the history
Triton supports 3D dot operations and #2518 enables using them with
`DpasEncodingAttr` encoding. Avoid assertions when this pass finds
reductions of 3D tensors.

Signed-off-by: Victor Perez <[email protected]>
  • Loading branch information
victor-eds authored Oct 30, 2024
1 parent efce869 commit e438919
Showing 1 changed file with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,15 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
// Check this is has `triton_intel_gpu.dpas` encoding.
Value operand = operands.front();
auto type = cast<RankedTensorType>(operand.getType());
// Only support reduction after 2D-dot for now.
if (type.getRank() != 2)
return failure();
auto encoding =
llvm::dyn_cast_or_null<DpasEncodingAttr>(type.getEncoding());
if (!encoding)
return failure();

// Axis 1 will lead to within-warp reduction.
assert(type.getRank() == 2 && "Expecting 2D tensor");
if (op.getAxis() != preferredReductionAxis)
return failure();

Expand Down

0 comments on commit e438919

Please sign in to comment.