Skip to content

Commit

Permalink
Revert "Add back barrier after asserts (#5043)"
Browse files Browse the repository at this point in the history
This reverts commit 92a4fad.
  • Loading branch information
whitneywhtsang committed Nov 6, 2024
1 parent 021857e commit 555d666
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 11 deletions.
2 changes: 1 addition & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
`tt.assert` takes a condition tensor and a message string.
If the condition is false, the message is printed, and the program is aborted.
}];
let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message);
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
}

Expand Down
8 changes: 0 additions & 8 deletions lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,6 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
}
}
llAssert(op, condition, adaptor.getMessage(), rewriter);
if (isa<RankedTensorType>(op.getCondition().getType())) {
// Add a barrier to avoid a race condition in case an assert is followed
// by an op that may trap if the assert condition is true. Since the
// tensor in those two operations may have different layout we need to
// make sure all the threads are done executing the assert before going to
// the next op.
barrier();
}
rewriter.eraseOp(op);
return success();
}
Expand Down
4 changes: 4 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1729,6 +1729,10 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil
def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
if not builder.options.debug:
return
cond_ty = cond.type
if not cond_ty.is_block():
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)


Expand Down
2 changes: 0 additions & 2 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1827,8 +1827,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32}
// CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32}
// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32}
// CHECK: llvm.call @__assertfail
// CHECK: nvvm.barrier0
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) {
tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5)
Expand Down

0 comments on commit 555d666

Please sign in to comment.