Skip to content

Commit

Permalink
Merge commit '89c0b0abdfac05804b2cbcfb393c5efdb368b70b'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Dec 9, 2024
2 parents 7d0818a + 89c0b0a commit 541ff05
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 32 deletions.
32 changes: 16 additions & 16 deletions lib/Instrumentation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ foreach( plugin ${GPU_INSTRUMENTATION_PASSES} )
LLVMTransformUtils
"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
)
# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
# build. It is empty if building directly from the root
# CMakeLists.txt file. Therefore if not building from Python just
# use the default CMake shared lib path otherwise this causes a hard
# build error
if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
set_target_properties(${plugin} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation")
endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
# build. It is empty if building directly from the root
# CMakeLists.txt file. Therefore if not building from Python just
# use the default CMake shared lib path otherwise this causes a hard
# build error
if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
set_target_properties(${plugin} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation")
endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)

# This is set to -fvisibility=hidden in the top level CMake file
# which causes the llvmGetPassPluginInfo symbol to be hidden and
# an "entry point not found" error. Reset it just for this target
if(NOT MSVC)
target_compile_options(${plugin} PRIVATE -fvisibility=default -fno-rtti)
endif()
# This is set to -fvisibility=hidden in the top level CMake file
# which causes the llvmGetPassPluginInfo symbol to be hidden and
# an "entry point not found" error. Reset it just for this target
if(NOT MSVC)
target_compile_options(${plugin} PRIVATE -fvisibility=default -fno-rtti)
endif()
endforeach()
2 changes: 2 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,8 +1647,10 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32, _builder=None):
"""
Returns the matrix product of two blocks in microscaling format.
lhs and rhs use microscaling formats described here:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
:param lhs: The first tensor to be multiplied.
:type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
:param lhs_scale: Scale factor for lhs tensor.
Expand Down
33 changes: 33 additions & 0 deletions test/TritonGPU/amd/amd-conditional-barrier.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s

module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
tt.func @conditional_barrier() {
// CHECK-LABEL: llvm.func @conditional_barrier

// CHECK: %[[CMP0:.+]] = llvm.icmp "ne" %3, %1 : i32
// CHECK: %[[CMP1:.+]] = llvm.icmp "eq" %3, %1 : i32
// CHECK: llvm.cond_br %[[CMP0]], ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: rocdl.s.barrier
// CHECK: llvm.br ^bb2
// CHECK: ^bb2:
// CHECK: llvm.add
// CHECK: llvm.cond_br %[[CMP1]], ^bb3, ^bb4
// CHECK: ^bb3:
// CHECK: rocdl.s.barrier
// CHECK: llvm.br ^bb4
// CHECK: ^bb4:
// CHECK: llvm.return

%c256_i32 = arith.constant 256 : i32
%c0_i32 = arith.constant 0 : i32
%0 = rocdl.workitem.id.x : i32
%1 = arith.divsi %0, %c256_i32 : i32
%2 = arith.cmpi ne, %1, %c0_i32 : i32
%3 = arith.cmpi eq, %1, %c0_i32 : i32
amdgpu.cond_barrier %2
%4 = arith.addi %0, %c256_i32 : i32
amdgpu.cond_barrier %3
tt.return
}
}
32 changes: 16 additions & 16 deletions test/lib/Instrumentation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ foreach( plugin ${GPU_INSTRUMENTATION_PASSES} )
LLVMCore
"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
)
# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
# build. It is empty if building directly from the root
# CMakeLists.txt file. Therefore if not building from Python just
# use the default CMake shared lib path otherwise this causes a hard
# build error
if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
set_target_properties(${plugin} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation")
endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
# build. It is empty if building directly from the root
# CMakeLists.txt file. Therefore if not building from Python just
# use the default CMake shared lib path otherwise this causes a hard
# build error
if(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)
set_target_properties(${plugin} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY
"${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/../instrumentation")
endif(DEFINED CMAKE_LIBRARY_OUTPUT_DIRECTORY)

# This is set to -fvisibility=hidden in the top level CMake file
# which causes the llvmGetPassPluginInfo symbol to be hidden and
# an "entry point not found" error. Reset it just for this target
if(NOT MSVC)
target_compile_options(${plugin} PRIVATE -fvisibility=default)
endif()
# This is set to -fvisibility=hidden in the top level CMake file
# which causes the llvmGetPassPluginInfo symbol to be hidden and
# an "entry point not found" error. Reset it just for this target
if(NOT MSVC)
target_compile_options(${plugin} PRIVATE -fvisibility=default)
endif()
endforeach()
17 changes: 17 additions & 0 deletions third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,23 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
let assemblyFormat = [{ attr-dict }];
}

def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier">,
Arguments<(ins I1:$pred)> {
let summary = "Conditionally set barriers to synchronize partial threads in a block";

let description = [{
condBarrierOp sets barrier instruction only when the given argument is true.
This provides a way to synchronize partial threads in a block, deliberately
diverges the execution sequences. However, user should guarantee all threads
converge at the end by calling condBarrierOp(true) with the remaining threads.
Conceptually, this is similar to having an execution barrier inside an if statement.
This op allows us to avoid blocking the whole block when suitable to help scheduling.
NB. This doesn't set any memory fence.
}];

let assemblyFormat = "$pred attr-dict";
}

//
// AMD Buffer operations.
//
Expand Down
29 changes: 29 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
#include "PatternTritonGPUOpToLLVM.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"

using namespace mlir;

Expand All @@ -25,10 +27,37 @@ struct GetNumProgramsOpConversion
}
};

struct CondBarrierOpConversion
: public ConvertOpToLLVMPattern<triton::amdgpu::CondBarrierOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(triton::amdgpu::CondBarrierOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Block *currentBlock = rewriter.getInsertionBlock();
Block *afterCondBarBlock =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
Block *trueBlock = rewriter.createBlock(afterCondBarBlock);
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, adaptor.getPred(), trueBlock,
afterCondBarBlock);

// conditional barrier
rewriter.setInsertionPointToStart(trueBlock);
rewriter.create<ROCDL::SBarrierOp>(loc);
rewriter.create<LLVM::BrOp>(loc, afterCondBarBlock);

rewriter.eraseOp(op);
return success();
}
};

} // namespace

void mlir::triton::AMD::populateSPMDOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
patterns.add<CondBarrierOpConversion>(typeConverter, benefit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,9 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,

const auto &mmaInstructions =
isTuring ? mmaInstrPtxTuring : mmaInstrPtxAmpere;
if (mmaInstructions.find(mmaType) == mmaInstructions.end()) {
return emitError(loc, "Unsupported MMA instruction for the given mma type");
}
auto rank = dTensorTy.getRank();
auto elemsPerThread = triton::gpu::getElemsPerThread(dTensorTy);
auto batchOffset =
Expand Down

0 comments on commit 541ff05

Please sign in to comment.