Skip to content

Commit

Permalink
Merge commit '390e27f4813799c242d4e6b2f8d79eda3b51cd92'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Dec 5, 2024
2 parents 5c0e236 + 390e27f commit 963ba2b
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 20 deletions.
6 changes: 6 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,12 @@ emitBaseIndexForLayout(Location loc, RewriterBase &rewriter,
return idx;
}

// Emit code to compute the (blockId, warpId, laneId) for the current thread.
std::tuple</*blockId=*/Value, /*warpId=*/Value, /*laneId=*/Value>
emitHardwareTuple(Location loc, RewriterBase &rewriter,
const TargetInfoBase &target, bool withCTAOffset,
unsigned threadsPerWarp);

// Emit indices calculation within each ConversionPattern, and returns a
// [elemsPerThread X rank] index matrix.
//
Expand Down
35 changes: 21 additions & 14 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ applyLinearLayout(Location loc, RewriterBase &rewriter,
return outIndices;
}

std::tuple<Value, Value, Value> emitHardwareTuple(Location loc,
RewriterBase &rewriter,
const TargetInfoBase &target,
bool withCTAOffset,
unsigned threadsPerWarpCst) {
Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(threadsPerWarpCst);
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);
Value blockId =
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0);
return {blockId, warpId, laneId};
}

SmallVector<SmallVector<Value>>
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
Attribute layout, RankedTensorType type, bool withCTAOffset) {
Expand All @@ -116,12 +130,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
StringAttr kWarp = str_attr("warp");
StringAttr kBlock = str_attr("block");

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(ll->getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);
Value blockId =
withCTAOffset ? target.getClusterCTAId(rewriter, loc) : i32_val(0);
auto [blockId, warpId, laneId] = emitHardwareTuple(
loc, rewriter, target, withCTAOffset, ll->getInDimSize(kLane));
unsigned rank = shape.size();
SmallVector<SmallVector<Value>> ret;
// Linear layout function is split in two parts below:
Expand Down Expand Up @@ -214,10 +224,9 @@ bool emitTransferBetweenRegistersAndShared(
std::min(regToSharedLayout->getNumConsecutiveInOut(),
maxVecElems.value_or(std::numeric_limits<int>::max()));

Value threadId = getThreadId(rewriter, loc);
Value threadsPerWarp = i32_val(regToSharedLayout->getInDimSize(kLane));
Value laneId = urem(threadId, threadsPerWarp);
Value warpId = udiv(threadId, threadsPerWarp);
auto [blockId, warpId, laneId] =
emitHardwareTuple(loc, rewriter, target, /*withCTAOffset=*/false,
regToSharedLayout->getInDimSize(kLane));

int numElems = regToSharedLayout->getInDimSize(kRegister);
auto vecTy = vec_ty(elemLlvmTy, vecElems);
Expand Down Expand Up @@ -625,10 +634,8 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
auto instrShape = mmaLayout.getInstrShape();
SmallVector<Value> mmaColIdx(2);
SmallVector<Value> mmaRowIdx(2);
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(32);
Value laneId = urem(threadId, warpSize);
Value warpId = udiv(threadId, warpSize);
auto [blockId, warpId, laneId] = emitHardwareTuple(
loc, rewriter, targetInfo, /*withCTAOffset=*/false, 32);
// TODO: fix the bug in MMAEncodingAttr document
SmallVector<Value> multiDimWarpId(2);
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
Expand Down
3 changes: 0 additions & 3 deletions lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,7 @@ AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
std::optional<LinearLayout>
BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
assert(shape.size() == getOrder().size());

int rank = shape.size();
MLIRContext *ctx = getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);

const auto &order = getOrder();
LinearLayout ctaLayout =
Expand Down
17 changes: 17 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2028,3 +2028,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
}

}

// -----

#linear = #ttg.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x2xi8, #linear>) {
// CHECK-LABEL: upcast_mxfp
// CHECK-COUNT-4: llvm.inline_asm
// CHECK-COUNT-2: nvvm.shfl.sync
// CHECK-COUNT-32: llvm.fmul
%0 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
tt.return
}

}
2 changes: 1 addition & 1 deletion third_party/amd/backend/include/hsa/amd_hsa_elf.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ enum : unsigned {
EF_AMDGPU_MACH_AMDGCN_GFX942 = 0x04c,
EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4D = 0x04d,
EF_AMDGPU_MACH_AMDGCN_GFX1201 = 0x04e,
EF_AMDGPU_MACH_AMDGCN_GFX950 = 0x04f,
EF_AMDGPU_MACH_AMDGCN_RESERVED_0X4F = 0x04f,
EF_AMDGPU_MACH_AMDGCN_RESERVED_0X50 = 0x050,
EF_AMDGPU_MACH_AMDGCN_GFX9_GENERIC = 0x051,
EF_AMDGPU_MACH_AMDGCN_GFX10_1_GENERIC = 0x052,
Expand Down
1 change: 0 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/TargetUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ ISAFamily deduceISAFamily(llvm::StringRef arch) {

// CDNA ISA cases
switch (kind) {
case llvm::AMDGPU::GK_GFX950:
case llvm::AMDGPU::GK_GFX942:
case llvm::AMDGPU::GK_GFX941:
case llvm::AMDGPU::GK_GFX940:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "PatternTritonGPUOpToLLVM.h"

#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -19,6 +20,73 @@ using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;

// Convert 8 fp4 elements packed into a 32bit reg into 8 bf16 elements packed
// into 4 32bits regs.
static constexpr const char *ptxAsm =
"{\n"
".reg .b32 a<14>;\n"
"and.b32 a0, $4, -2004318072;\n\t"
"shr.u32 a1, a0, 3;\n\t"
"and.b32 a2, $4, 2004318071;\n\t"
"shr.u32 a3, a2, 16;\n\t"
"shr.u32 a4, a0, 19;\n\t"
"prmt.b32 a5, -1065353216, -1065336832, a2;\n\t"
"prmt.b32 a6, -1065353216, -1065336832, a3;\n\t"
"prmt.b32 a7, 1061109504, 1077952576, a2;\n\t"
"prmt.b32 a8, 1061109504, 1077952576, a3;\n\t"
"prmt.b32 a9, 32768, 0, a1;\n\t"
"prmt.b32 a10, 32768, 0, a4;\n\t"
"or.b32 a11, a7, a9;\n\t"
"or.b32 a12, a8, a10;\n\t"
"prmt.b32 $0, a5, a11, 20800;\n\t"
"prmt.b32 $1, a5, a11, 29538;\n\t"
"prmt.b32 $2, a6, a12, 20800;\n\t"
"prmt.b32 $3, a6, a12, 29538;\n\t"
"}";

static Value createInlineAsmUpcast(Location loc, RewriterBase &rewriter,
Type retType, Value packedVec) {
PTXBuilder builder;
SmallVector<PTXBuilder::Operand *> operands;
for (int i = 0; i < 4; i++) {
operands.push_back(builder.newOperand("=r"));
}
operands.push_back(builder.newOperand(packedVec, "r"));
auto &ptxOp = *builder.create(ptxAsm);
ptxOp(operands, /*onlyAttachMLIRArgs=*/true);
Value result = builder.launch(rewriter, loc, retType, false);
return result;
}

static SmallVector<Value> convertMxfp4x2ToBf16x2PTX(RewriterBase &rewriter,
Location loc,
ArrayRef<Value> values) {
SmallVector<Value> results;
MLIRContext *ctx = rewriter.getContext();
assert(values.size() % 4 == 0);
for (int i = 0; i < values.size(); i += 4) {
Value v0 = values[i];
Value v1 = values[i + 1];
Value v2 = values[i + 2];
Value v3 = values[i + 3];
Value packedVec = undef(vec_ty(i8_ty, 4));
packedVec = insert_element(packedVec, v0, i32_val(0));
packedVec = insert_element(packedVec, v1, i32_val(1));
packedVec = insert_element(packedVec, v2, i32_val(2));
packedVec = insert_element(packedVec, v3, i32_val(3));
SmallVector<Type> rets(4, i32_ty);
Type retType = struct_ty(rets);
Value ret = createInlineAsmUpcast(loc, rewriter, retType, packedVec);
for (int i = 0; i < 4; i++) {
Value extractI32 = extract_val(ret, i);
Value vecbf16 = bitcast(extractI32, vec_ty(bf16_ty, 2));
results.push_back(extract_element(vecbf16, i32_val(0)));
results.push_back(extract_element(vecbf16, i32_val(1)));
}
}
return results;
}

namespace {
class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
private:
Expand Down Expand Up @@ -53,7 +121,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
cast<DotOperandEncodingAttr>(op.getType().getEncoding()).getKWidth();

if (fpType == ScaleDotElemType::E2M1)
xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
xVals = convertMxfp4x2ToBf16x2PTX(rewriter, loc, xVals);

// Each thread owns elements of 4 mxfp vectors so we need 4 scales
// Since we go from a threadShape of 8x4 to 16x2, we let c = tid / 4 * 2
Expand Down

0 comments on commit 963ba2b

Please sign in to comment.