Skip to content

Commit

Permalink
Merge commit 'aa3ac0a146def686877685b4fb8897db64789c7a'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jul 22, 2024
2 parents 4c37cc3 + aa3ac0a commit 0f48a0d
Show file tree
Hide file tree
Showing 30 changed files with 717 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ jobs:
run: |
if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then
echo '::set-output name=matrix-CUDA::[["a100-runner-set"], ["h100-runner-set"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx942"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]'
echo '::set-output name=matrix-MACOS::[["macos-latest"]]'
else
echo '::set-output name=matrix-CUDA::["ubuntu-latest"]'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration-tests.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ jobs:
run: |
if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then
echo '::set-output name=matrix-CUDA::[["a100-runner-set"], ["h100-runner-set"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["self-hosted", "gfx942"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"]]'
echo '::set-output name=matrix-MACOS::[["macos-latest"]]'
else
echo '::set-output name=matrix-CUDA::["ubuntu-latest"]'
Expand Down
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ if(NOT WIN32)
find_library(TERMINFO_LIBRARY tinfo)
endif()

if(TRITON_BUILD_UT)
include(AddTritonUnitTest)
endif()

# Compiler flags
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17")
Expand Down
1 change: 1 addition & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::registerConvertTritonAMDGPUToLLVM();
mlir::triton::registerConvertBuiltinFuncToLLVM();
mlir::triton::registerDecomposeUnsupportedAMDConversions();
mlir::triton::registerOptimizeAMDLDSUsage();

// TritonAMDGPUTransforms passes
mlir::registerTritonAMDGPUAccelerateMatmul();
Expand Down
39 changes: 39 additions & 0 deletions cmake/AddTritonUnitTest.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
include(${PROJECT_SOURCE_DIR}/unittest/googletest.cmake)

include(GoogleTest)
enable_testing()

function(add_triton_ut)
set(options)
set(oneValueArgs NAME)
set(multiValueArgs SRCS LIBS DEFS)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)

add_test(NAME ${__NAME}
COMMAND ${__NAME})
add_executable(
${__NAME}
${__SRCS})
target_link_libraries(
${__NAME}
PRIVATE
GTest::gtest_main
${triton_libs}
${dialect_libs}
${conversion_libs}
gmock
${__LIBS})

target_compile_options(${__NAME} PRIVATE -fno-rtti)

target_compile_definitions(${__NAME} PRIVATE ${__DEFS})

# Without the TEST_DISCOVERY_TIMEOUT, the tests randomly time out on my mac
# laptop. I think the issue may be that the very first time you run a program
# it's a bit slow.
gtest_discover_tests(${__NAME} PROPERTIES TEST_DISCOVERY_TIMEOUT 60)
endfunction()
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4e0a0eae58f7a6998866719f7eb970096a2a52e9
4713bd4ccc0c0d568f92916e7851d993291742c0
9 changes: 9 additions & 0 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ class AllocationAnalysis;
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec);
SmallVector<unsigned> getScratchConfigForCvtLayout(RankedTensorType srcType,
RankedTensorType dstType,
unsigned &inVec,
unsigned &outVec);
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op);
SmallVector<unsigned> getRepShapeForCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy);

} // namespace triton

Expand Down Expand Up @@ -135,6 +141,9 @@ class Allocation {
/// Returns the size of total shared memory allocated
size_t getSharedMemorySize() const { return sharedMemorySize; }

/// Returns mapping from operation to list of live LDS buffers
std::map<Operation *, SmallVector<BufferId>> getLiveBuffers();

private:
/// A class that represents a shared memory buffer
struct BufferT {
Expand Down
41 changes: 38 additions & 3 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
return getRepShapeForCvtLayout(srcTy, dstTy);
}

SmallVector<unsigned> getRepShapeForCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

Expand Down Expand Up @@ -92,12 +97,19 @@ SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec) {
auto repShape = getRepShapeForCvtLayout(op);
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
return getScratchConfigForCvtLayout(srcTy, dstTy, inVec, outVec);
}

SmallVector<unsigned> getScratchConfigForCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy,
unsigned &inVec,
unsigned &outVec) {
auto repShape = getRepShapeForCvtLayout(srcTy, dstTy);
if (repShape.empty())
return repShape;
auto rank = repShape.size();
auto srcTy = op.getSrc().getType();
auto dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();

Expand Down Expand Up @@ -627,4 +639,27 @@ void Allocation::run(FuncAllocMapT &funcAllocMap) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
}

std::map<Operation *, SmallVector<Allocation::BufferId>>
Allocation::getLiveBuffers() {
std::map<Operation *, SmallVector<BufferId>> liveBuffers;

Operation *rootOperation = getOperation();
mlir::Liveness liveness(rootOperation);
auto analyzeOperation = [&](Operation *op) -> void {
auto scratchBuffer = getBufferId(op);
if (scratchBuffer != InvalidBufferId)
liveBuffers[op].push_back(scratchBuffer);
for (auto result : op->getOpResults()) {
auto bufferId = getBufferId(result);
if (bufferId == Allocation::InvalidBufferId)
continue;
auto liveOperations = liveness.resolveLiveness(result);
for (auto depOp : liveOperations)
liveBuffers[depOp].push_back(bufferId);
}
};
rootOperation->walk(analyzeOperation);
return liveBuffers;
}

} // namespace mlir
1 change: 0 additions & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,6 @@ def get_entry_points():

def get_install_requires():
install_requires = [
"filelock",
"packaging", # used by third_party/intel/backend/compiler.py
] # yapf: disable
return install_requires
Expand Down
7 changes: 7 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,13 @@ void init_triton_ir(py::module &&m) {
return py::none();
return py::str(ret.getValue().str());
})
.def("get_bool_attr",
[](Operation &self, const std::string &name) -> py::object {
auto ret = self.getAttrOfType<BoolAttr>(name);
if (!ret)
return py::none();
return py::bool_(ret.getValue());
})
.def("get_flat_symbol_ref_attr",
[](Operation &self, const std::string &name) -> py::object {
auto ret = self.getAttrOfType<FlatSymbolRefAttr>(name);
Expand Down
9 changes: 9 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3825,7 +3825,16 @@ def _kernel(dst, src, CACHE: tl.constexpr):
tl.store(dst + offsets, x)

pgm = _kernel[(1, )](dst, src, CACHE=cache)

if not is_cuda():
if is_hip():
amdgcn = pgm.asm['amdgcn']
cache_modifier_str = 'nt' if 'gfx94' in get_arch() else 'glc'
global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line]
if cache == '':
assert cache_modifier_str not in global_load_line[0]
if cache == '.cg':
assert cache_modifier_str in global_load_line[0]
return

ptx = pgm.asm['ptx']
Expand Down
89 changes: 89 additions & 0 deletions test/TritonGPU/amd/optimize-lds-usage.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// RUN: triton-opt %s -split-input-file -optimize-amd-lds-usage=target-arch=gfx90a | FileCheck %s
// RUN: triton-opt %s -split-input-file -optimize-amd-lds-usage=target-arch=gfx90a -optimize-amd-lds-usage=lds-limit=32768 | FileCheck %s --check-prefix=CHECK-32KLIMIT

// Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS
// CHECK-LABEL: alloc_convert_load
// CHECK-32KLIMIT-LABEL: alloc_convert_load
// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1
// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma
// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @alloc_convert_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf32, #blocked>) attributes {noinline = false} {
%1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory>
%2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #mma>
%3 = triton_gpu.local_load %1 : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
tt.return
}
}

// -----

// Check that optimization detects overflow of LDS and decomposes layout convert so kernel fits into LDS
// in case of relatively small scratch buffer
// CHECK-LABEL: alloc_convert_small_load
// CHECK-32KLIMIT-LABEL: alloc_convert_small_load
// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1
// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma
// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @alloc_convert_small_load(%arg0: tensor<128x128xf16, #blocked>, %arg1: tensor<128x128xf16, #blocked>) attributes {noinline = false} {
%1 = triton_gpu.local_alloc %arg0 : (tensor<128x128xf16, #blocked>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory>
%2 = triton_gpu.convert_layout %arg1 : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #mma>
%3 = triton_gpu.local_load %1 : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
tt.return
}
}

// -----

// Check that optimization works with 3d tensors
// in case of relatively small scratch buffer
// CHECK-LABEL: alloc_convert_3d_load
// CHECK-32KLIMIT-LABEL: alloc_convert_3d_load
// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma
// CHECK: %2 = triton_gpu.convert_layout %1 : {{.*}}#mma{{.*}}#mma1
// CHECK: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8, 1], threadsPerWarp = [1, 16, 4], warpsPerCTA = [1, 1, 8], order = [0, 1, 2]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1, 8], instrShape = [32, 32], isTransposed = false}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1, 2], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @alloc_convert_3d_load(%arg0: tensor<1x128x128xf16, #blocked>, %arg1: tensor<1x128x128xf16, #blocked>) attributes {noinline = false} {
%1 = triton_gpu.local_alloc %arg0 : (tensor<1x128x128xf16, #blocked>) -> !tt.memdesc<1x128x128xf16, #shared, #triton_gpu.shared_memory>
%2 = triton_gpu.convert_layout %arg1 : tensor<1x128x128xf16, #blocked> -> tensor<1x128x128xf16, #mma>
%3 = triton_gpu.local_load %1 : !tt.memdesc<1x128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<1x128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
tt.return
}
}

// -----

// Check that optimization triggers with custom LDS limit and do not triggers with default one
// CHECK-LABEL: alloc_convert_32k_limit
// CHECK: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
// CHECK: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#mma
// CHECK: %2 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
// CHECK-32KLIMIT-LABEL: alloc_convert_32k_limit
// CHECK-32KLIMIT: %0 = triton_gpu.local_alloc %arg0 : {{.*}}#blocked{{.*}}#shared
// CHECK-32KLIMIT: %1 = triton_gpu.convert_layout %arg1 : {{.*}}#blocked{{.*}}#blocked1
// CHECK-32KLIMIT: %2 = triton_gpu.convert_layout %1 : {{.*}}#blocked1{{.*}}#mma
// CHECK-32KLIMIT: %3 = triton_gpu.local_load %0 : {{.*}}#shared{{.*}}#triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 8], instrShape = [32, 32], isTransposed = false}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @alloc_convert_32k_limit(%arg0: tensor<64x128xf16, #blocked>, %arg1: tensor<64x128xf16, #blocked>) attributes {noinline = false} {
%1 = triton_gpu.local_alloc %arg0 : (tensor<64x128xf16, #blocked>) -> !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory>
%2 = triton_gpu.convert_layout %arg1 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #mma>
%3 = triton_gpu.local_load %1 : !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, kWidth = 4, parent = #mma}>>
tt.return
}
}
3 changes: 3 additions & 0 deletions third_party/amd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ add_subdirectory(lib)
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonAMD ${CMAKE_CURRENT_SOURCE_DIR}/python/triton_amd.cc LINK_LIBS TritonAMDGPUToLLVM TritonAMDGPUTransforms)
endif()
if(TRITON_BUILD_UT)
add_subdirectory(unittest)
endif()
7 changes: 7 additions & 0 deletions third_party/amd/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ def make_llir(src, metadata, options):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
amd.passes.ttgpuir.add_decompose_unsupported_conversions(pm, options.arch)
# custom_lds_size is an experimental parameter that defines amount of LDS available
# for one thread block. Measured in bytes.
#
# If custom_lds_size = 0, pass will consider all LDS is available for one threads block,
# LDS size is determined by provided arch name.
custom_lds_size = 0
amd.passes.ttgpuir.add_optimize_lds_usage(pm, options.arch, custom_lds_size)
passes.convert.add_scf_to_cf(pm)
passes.convert.add_index_to_llvmir(pm)

Expand Down
7 changes: 7 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ namespace AMD {
std::unique_ptr<OperationPass<ModuleOp>>
createDecomposeUnsupportedConversionsPass(StringRef targetArch);

/// @brief Creates pass that keep LDS consumption within specified limits.
/// @param arch target architecture name, for example "gfx940"
/// @param customLDSLimit defines LDS size available for one thread block
/// zero value tells pass that whole LDS is available on a device
/// @return created pass
std::unique_ptr<OperationPass<ModuleOp>>
createOptimizeLDSUsagePass(StringRef arch, int32_t customLDSLimit = 0);
} // namespace AMD

std::unique_ptr<OperationPass<ModuleOp>>
Expand Down
12 changes: 12 additions & 0 deletions third_party/amd/include/TritonAMDGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ def DecomposeUnsupportedAMDConversions : Pass<"decompose-unsupported-amd-convers
];
}

def OptimizeAMDLDSUsage : Pass<"optimize-amd-lds-usage", "mlir::ModuleOp"> {
let summary = "Minimize LDS usage";
let constructor = "mlir::triton::AMD::createOptimizeLDSUsagePass(\"\")";

let options = [
Option<"targetArch", "target-arch", "std::string", /*default*/"",
"gfx target device architecture, e.g., gfx942">,
Option<"customLDSLimit", "lds-limit", "int", /*default*/"0",
"custom limit of LDS consumption, if not provided, maximum LDS size is used">,
];
}

def ConvertTritonAMDGPUToLLVM : Pass<"convert-triton-amdgpu-to-llvm", "mlir::ModuleOp"> {
let summary = "Convert TritonGPU to LLVM";
let constructor = "mlir::triton::createConvertTritonAMDGPUToLLVMPass(\"\", /*ftz=*/true)";
Expand Down
19 changes: 15 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ class CallOpConversion : public mlir::RewritePattern {
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto callOp = cast<LLVM::CallOp>(op);
if (isPredicatedLoad(callOp)) {
return convertPredicatedLoad(callOp, rewriter);
if (isPredicatedLoadNT(callOp)) {
return convertPredicatedLoad(callOp, rewriter, /*nt=*/true);
} else if (isPredicatedLoad(callOp)) {
return convertPredicatedLoad(callOp, rewriter, /*nt=*/false);
} else if (isPredicatedStore(callOp)) {
return convertPredicatedStore(callOp, rewriter);
} else if (isWrappedLLVMIntrinsic(callOp)) {
Expand All @@ -42,6 +44,11 @@ class CallOpConversion : public mlir::RewritePattern {
llvm::StringRef::npos;
}

bool isPredicatedLoadNT(LLVM::CallOp callOp) const {
return callOp.getCallee().value().find(
mlir::LLVM::AMD::Predicated_Load_NT) != llvm::StringRef::npos;
}

bool isPredicatedStore(LLVM::CallOp callOp) const {
return callOp.getCallee().value().find(mlir::LLVM::AMD::Predicated_Store) !=
llvm::StringRef::npos;
Expand Down Expand Up @@ -80,7 +87,8 @@ class CallOpConversion : public mlir::RewritePattern {
}

LogicalResult convertPredicatedLoad(LLVM::CallOp callOp,
mlir::PatternRewriter &rewriter) const {
mlir::PatternRewriter &rewriter,
bool nt) const {
auto operands = callOp.getOperands();
auto result = callOp.getResult();

Expand All @@ -100,7 +108,10 @@ class CallOpConversion : public mlir::RewritePattern {
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, pred, trueBlock, falseBlock);
rewriter.setInsertionPointToStart(trueBlock);
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, elemTy, ptr);
auto loadOp = nt ? rewriter.create<LLVM::LoadOp>(
loc, elemTy, ptr, /*alignment=*/0,
/*isVolatile=*/false, /*isNonTemporal=*/true)
: rewriter.create<LLVM::LoadOp>(loc, elemTy, ptr);
rewriter.create<LLVM::BrOp>(loc, loadOp->getResult(0), afterLoad);
rewriter.setInsertionPointToStart(falseBlock);
rewriter.create<LLVM::BrOp>(loc, falseVal, afterLoad);
Expand Down
Loading

0 comments on commit 0f48a0d

Please sign in to comment.