Skip to content

Commit

Permalink
Merge commit 'fdac59428cd08d7d7438f330db7224de454c6d52'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Oct 8, 2024
2 parents 4c9df48 + fdac594 commit be3b9ad
Show file tree
Hide file tree
Showing 19 changed files with 161 additions and 83 deletions.
2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
df0864e761107b07e38f5503e0cbee0cebb4c5e8
61f8a7f618901797ee8663389a29722f29216a96
8 changes: 7 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ using namespace mlir::triton;
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
#define null(...) rewriter.create<LLVM::ZeroOp>(loc, __VA_ARGS__)
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
#define call(...) LLVM::createLLVMCallOp(rewriter, loc, __VA_ARGS__)

// Types
#define int_ty(width) rewriter.getIntegerType(width)
Expand Down Expand Up @@ -228,6 +228,12 @@ Value createIndexConstant(OpBuilder &builder, Location loc,
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
int64_t value);

LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc,
LLVMFuncOp funcOp, ValueRange args);
LLVM::CallIntrinsicOp
createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
TypeRange types, ValueRange args);

// Is v an integer or floating-point scalar constant equal to 0?
bool isConstantZero(Value v);

Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
auto newCallOp = rewriter.create<LLVM::CallOp>(
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promotedOperands, callOp->getAttrs());
newCallOp.getProperties().setOpBundleSizes(
rewriter.getDenseI32ArrayAttr({}));
newCallOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(promotedOperands.size()), 0});
return newCallOp;
}

Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ struct MulhiUIOpConversion
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);
return {
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()};
}

protected:
Expand Down Expand Up @@ -327,7 +327,7 @@ struct ExternElementwiseOpConversion
LLVM::LLVMFuncOp funcOp = appendOrGetExternFuncOp(
rewriter, op, funcName, funcType, op.getLibname(), op.getLibpath());
return {
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()};
}
};

Expand Down
21 changes: 19 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -518,6 +517,24 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
builder.getIntegerAttr(ty, value));
}

LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc,
LLVMFuncOp funcOp, ValueRange args) {
auto op = builder.create<LLVM::CallOp>(loc, funcOp, args);
op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({}));
op.getProperties().setOperandSegmentSizes({static_cast<int>(args.size()), 0});
return op;
}

LLVM::CallIntrinsicOp
createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
TypeRange types, ValueRange args) {
auto op = builder.create<LLVM::CallIntrinsicOp>(loc, types, args);
op.getProperties().setIntrin(builder.getStringAttr(intrinsic));
op.getProperties().setOpBundleSizes(builder.getDenseI32ArrayAttr({}));
op.getProperties().setOperandSegmentSizes({static_cast<int>(args.size()), 0});
return op;
}

bool isConstantZero(Value v) {
if (auto constantOp = v.getDefiningOp<arith::ConstantOp>()) {
if (auto attr = dyn_cast<IntegerAttr>(constantOp.getValue())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,10 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
// CHECK-LABEL: llvm.func spir_kernelcc @test(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<3>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<16xf32>) -> vector<16xf32>
// CHECK: %[[VAL_2:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv() {{{.*}}} : () -> i32
// CHECK: %[[VAL_3:.*]] = llvm.sext %[[VAL_2]] : i32 to i64
// CHECK: %[[VAL_4:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_idv() {{{.*}}} : () -> i32
// CHECK: %[[VAL_5:.*]] = llvm.sext %[[VAL_4]] : i32 to i64
// CHECK: %[[VAL_2:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {{{.*}}} : () -> i32
// CHECK: %[[VAL_3:.*]] = llvm.zext %[[VAL_2]] : i32 to i64
// CHECK: %[[VAL_4:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() {{{.*}}} : () -> i32
// CHECK: %[[VAL_5:.*]] = llvm.zext %[[VAL_4]] : i32 to i64
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(16 : i64) : i64
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(256 : i64) : i64
// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_3]] : i64
Expand Down
8 changes: 4 additions & 4 deletions test/TritonIntelGPU/blockptr_load.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_12]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[VAL_14:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_13]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[BLOCK_POINTER:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_14]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv()
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.sext %[[SUB_GROUP_ID_RAW]] : i32 to i64
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
// CHECK: %[[VAL_17:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[SUB_GROUP_ID_N:.*]] = llvm.urem %[[SUB_GROUP_ID]], %[[VAL_17]] : i32
Expand Down Expand Up @@ -142,8 +142,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
// CHECK: %[[VAL_12:.*]] = llvm.insertvalue %[[VAL_3]], %[[VAL_11]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[VAL_13:.*]] = llvm.insertvalue %[[VAL_6]], %[[VAL_12]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[BLOCK_POINTER:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_13]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv()
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.sext %[[SUB_GROUP_ID_RAW]] : i32 to i64
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
// CHECK: %[[VAL_16:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_17:.*]] = llvm.urem %[[SUB_GROUP_ID]], %[[VAL_16]] : i32
Expand Down
6 changes: 3 additions & 3 deletions test/TritonIntelGPU/blockptr_store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
%12 = arith.truncf %11#0 : tensor<64x64xf32, #dpas> to tensor<64x64xf16, #dpas>
%13 = tt.make_tensor_ptr %arg2, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #dpas>>
// The next two lines is used to start checking constant related to the BlockStore.
// CHECK-COUNT-3: llvm.call spir_funccc @_Z16get_sub_group_idv
// CHECK-COUNT-3: llvm.call spir_funccc @_Z16get_sub_group_id
// CHECK-COUNT-39: llvm.extractvalue
// Next constant must be equal to warpsPerCTA[0]
// CHECK: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
Expand Down Expand Up @@ -83,8 +83,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
// CHECK: %[[VAL_80:.*]] = llvm.insertvalue %[[CST_1]], %[[VAL_79]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[BLOCK_PTR:.*]] = llvm.insertvalue %[[base]], %[[VAL_80]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv()
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.sext %[[SUB_GROUP_ID_RAW]] : i32 to i64
// CHECK: %[[SUB_GROUP_ID_RAW:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
// CHECK: %[[SUB_GROUP_ID_EXT:.*]] = llvm.zext %[[SUB_GROUP_ID_RAW]] : i32 to i64
// CHECK: %[[SUB_GROUP_ID:.*]] = llvm.trunc %[[SUB_GROUP_ID_EXT]] : i64 to i32
// CHECK: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[SUB_GROUP_ID_N:.*]] = llvm.urem %[[SUB_GROUP_ID]], %[[CST_1]] : i32
Expand Down
4 changes: 2 additions & 2 deletions test/TritonIntelGPU/prefetch-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
%c1_i64 = arith.constant 1 : i64

// CHECK: %[[ROW_MAJOR_BLOCK_PTR:.*]] = llvm.insertvalue %arg0, {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
// CHECK: %[[VAL_17:.*]] = llvm.call spir_funccc @_Z16get_sub_group_idv()
// CHECK: %[[VAL_18:.*]] = llvm.sext %[[VAL_17]] : i32 to i64
// CHECK: %[[VAL_17:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
// CHECK: %[[VAL_18:.*]] = llvm.zext %[[VAL_17]] : i32 to i64
// CHECK: %[[VAL_19:.*]] = llvm.trunc %[[VAL_18]] : i64 to i32
// CHECK: %[[VAL_20:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAL_21:.*]] = llvm.urem %[[VAL_19]], %[[VAL_20]] : i32
Expand Down
9 changes: 5 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

namespace mlir {
namespace triton {
Expand Down Expand Up @@ -187,11 +188,11 @@ class CallOpConversion : public mlir::RewritePattern {
rewriter.create<LLVM::FPToSIOp>(loc, returnType, op->getResult(0));
} else if (calleeName == "__triton_hip_fast_fdividef") {
assert(operands.size() == 2);
auto name = StringAttr::get(callOp.getContext(), "llvm.amdgcn.rcp.f32");
LLVM::FastmathFlagsAttr defaultFlags{};
auto rcpOp = rewriter.create<LLVM::CallIntrinsicOp>(
loc, returnType, name, operands[1], defaultFlags);
const char *intrinsic = "llvm.amdgcn.rcp.f32";
auto rcpOp = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic,
returnType, operands[1]);

LLVM::FastmathFlagsAttr defaultFlags{};
replacementOp = rewriter.create<LLVM::FMulOp>(
loc, returnType, operands[0], rcpOp->getResult(0), defaultFlags);
}
Expand Down
7 changes: 3 additions & 4 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "../PatternTritonGPUOpToLLVM.h"
#include "Utility.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"

namespace mlir::triton::AMD {
namespace {
Expand Down Expand Up @@ -219,10 +220,8 @@ Value generateWMMAIntrinsic(ConversionPatternRewriter &rewriter, Location loc,
if (32 / dElType.getIntOrFloatBitWidth() > 1 || dElType.isInteger(32)) {
operands.push_back(int_val(1, false));
}
auto wmmaIntrinsic = rewriter.create<mlir::LLVM::CallIntrinsicOp>(
loc, TypeRange{valC.getType()}, StringAttr::get(loc.getContext(), name),
operands, defaultFlags);

auto wmmaIntrinsic = LLVM::createLLVMIntrinsicCallOp(
rewriter, loc, name, valC.getType(), operands);
return wmmaIntrinsic.getResult(0);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ struct ExpOpConversionApprox
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);

return {rewriter.create<LLVM::CallOp>(loc, funcOp, prod).getResult()};
return {LLVM::createLLVMCallOp(rewriter, loc, funcOp, prod).getResult()};
}
};

Expand Down Expand Up @@ -1276,7 +1276,7 @@ struct Exp2OpConversion
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);

return {
rewriter.create<LLVM::CallOp>(loc, funcOp, operands[0]).getResult()};
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()};
}

private:
Expand Down
20 changes: 9 additions & 11 deletions third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc,
InstructionKindMask maskValue, int sizeValue,
int groupIdValue) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.sched.group.barrier");
const char *intrinsicName = "llvm.amdgcn.sched.group.barrier";

Value mask =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue));
Expand All @@ -47,36 +47,34 @@ void createSchedGroupBarrier(PatternRewriter &rewriter, Location loc,
Value groupId = LLVM::createConstantI32(loc, rewriter,
static_cast<int32_t>(groupIdValue));

LLVM::FastmathFlagsAttr defaultFlags{};
rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask, size, groupId},
defaultFlags);
LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName, TypeRange{},
ValueRange{mask, size, groupId});
}

// Insert intrinsic that controls the types of instructions that may be
// allowed to cross the intrinsic during instruction scheduling
Operation *createSchedBarrier(PatternRewriter &rewriter, Location loc,
int64_t maskValue) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.sched.barrier");
const char *intrinsicName = "llvm.amdgcn.sched.barrier";
LLVM::FastmathFlagsAttr defaultFlags{};

Value mask =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(maskValue));
return rewriter.create<LLVM::CallIntrinsicOp>(loc, TypeRange{}, intrinsicName,
ValueRange{mask}, defaultFlags);
return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName,
TypeRange{}, ValueRange{mask});
}

// Insert an experimental intrinsic for instruction group level parallelism.
// The intrinsic takes a value that specifies the strategy.
Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) {
MLIRContext *ctx = rewriter.getContext();
auto intrinsicName = str_attr("llvm.amdgcn.iglp.opt");
const char *intrinsicName = "llvm.amdgcn.iglp.opt";
LLVM::FastmathFlagsAttr defaultFlags{};
Value iglpValue =
LLVM::createConstantI32(loc, rewriter, static_cast<int32_t>(value));
return rewriter.create<LLVM::CallIntrinsicOp>(
loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags);
return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsicName,
TypeRange{}, ValueRange{iglpValue});
}

struct InstructionSchedHintsRewriter
Expand Down
9 changes: 3 additions & 6 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,9 @@ Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const {

Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type,
Value cmp) const {
auto stringAttr = rewriter.getStringAttr("llvm.amdgcn.ballot");
SmallVector<Value> operands = {cmp};
Value asmResult =
rewriter.create<LLVM::CallIntrinsicOp>(loc, type, stringAttr, operands)
->getResult(0);
return asmResult;
return LLVM::createLLVMIntrinsicCallOp(rewriter, loc, "llvm.amdgcn.ballot",
type, cmp)
->getResult(0);
}

void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
Expand Down
10 changes: 4 additions & 6 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,9 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
auto funcName = mangleFunc(getLoadNameRaw(cm), funcType);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, parent, funcName, funcType);
auto loadVal =
rewriter
.create<LLVM::CallOp>(loc, funcOp, ValueRange({ptr, pred, falseVal}))
.getResult();
return loadVal;
return LLVM::createLLVMCallOp(rewriter, loc, funcOp,
ValueRange({ptr, pred, falseVal}))
.getResult();
}

void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
Expand Down Expand Up @@ -276,7 +274,7 @@ void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
auto funcName = mangleFunc(getStoreNameRaw(cm), funcType);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, parent, funcName, funcType);
rewriter.create<LLVM::CallOp>(loc, funcOp, ValueRange({ptr, val, pred}));
LLVM::createLLVMCallOp(rewriter, loc, funcOp, ValueRange({ptr, val, pred}));
}

} // namespace mlir::LLVM::AMD
Loading

0 comments on commit be3b9ad

Please sign in to comment.