diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index ccc564ea66..3bc0b7f0ab 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -153,13 +153,15 @@ struct DpasOperandPattern final : OpRewritePattern { // Check this is has `triton_intel_gpu.dpas` encoding. Value operand = operands.front(); auto type = cast(operand.getType()); + // Only support reduction after 2D-dot for now. + if (type.getRank() != 2) + return failure(); auto encoding = llvm::dyn_cast_or_null(type.getEncoding()); if (!encoding) return failure(); // Axis 1 will lead to within-warp reduction. - assert(type.getRank() == 2 && "Expecting 2D tensor"); if (op.getAxis() != preferredReductionAxis) return failure();