diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 4d71bf3456..2c3a1bf714 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -891,7 +891,7 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> { `tt.assert` takes a condition tensor and a message string. If the condition is false, the message is printed, and the program is aborted. }]; - let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message); + let arguments = (ins TT_Tensor:$condition, StrAttr:$message); let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; } diff --git a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp index 508eb25f99..20558c440a 100644 --- a/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp @@ -35,14 +35,6 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern { } } llAssert(op, condition, adaptor.getMessage(), rewriter); - if (isa(op.getCondition().getType())) { - // Add a barrier to avoid a race condition in case an assert is followed - // by an op that may trap if the assert condition is true. Since the - // tensor in those two operations may have different layout we need to - // make sure all the threads are done executing the assert before going to - // the next op. - barrier(); - } rewriter.eraseOp(op); return success(); } diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 76a94b5478..6e12280737 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1729,6 +1729,10 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor: if not builder.options.debug: return + cond_ty = cond.type + if not cond_ty.is_block(): + cond_ty = tl.block_type(cond_ty.scalar, (1, )) + cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty) return tl.tensor(builder.create_assert(cond.handle, msg), tl.void) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 0a72941416..66bba6e060 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1827,8 +1827,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32} // CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32} // CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32} -// CHECK: llvm.call @__assertfail -// CHECK: nvvm.barrier0 module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) { tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5)