Skip to content

Commit

Permalink
Initial support for tritongpu.upcast_mxfp lowering (#2951)
Browse files Browse the repository at this point in the history
This PR adds initial codegen support for the new upcast_mxfp operation.
Currently the codegen works when the source operand has blocked layout.
This is a temporary limitation (we will want to support dot layout for
that operand).

Also, there is a failing test in test_core.py which, in this PR, is
skipped. We will address that problem in a separate PR.

---------

Co-authored-by: Tiotto, Ettore <[email protected]>
  • Loading branch information
LiyangLingIntel and etiotto authored Dec 7, 2024
1 parent e538f26 commit c878427
Show file tree
Hide file tree
Showing 13 changed files with 1,651 additions and 103 deletions.
3 changes: 2 additions & 1 deletion include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
"TRITON_INTEL_ENABLE_INSTR_SCHED",
"TRITON_INTEL_ENABLE_POST_PROCESS_LLIR",
"TRITON_INTEL_REDUCE_TRANSPOSE"
"TRITON_INTEL_REDUCE_TRANSPOSE",
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"
// clang-format on
};

Expand Down
74 changes: 46 additions & 28 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"

#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
Expand Down Expand Up @@ -50,15 +51,21 @@ LogicalResult UpcastMXFPOp::verify() {
return success();
}

/// TODO: Temporarily disabled this check to allow for the blocked encoding.
/// Enable once we have the dot op encoding UpcastMXFPOp lowering.
auto dotEncoding = dyn_cast<DotOperandEncodingAttr>(layoutX);
if (!dotEncoding) {
if (mlir::triton::tools::getBoolEnv(
"TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING") &&
!dotEncoding) {
return emitOpError("Expected a DotOperandEncodingAttr for values");
}
if (!isa<BlockedEncodingAttr, LinearEncodingAttr>(layoutScale)) {
return emitOpError(
"Expected a BlockOperandEncoding or LinearOperandEncoding "
"for scales");
}
if (!dotEncoding)
return success();

if (isa<NvidiaMmaEncodingAttr>(dotEncoding.getParent())) {
// Necessary to keep all of the scales of a given block of values in the
Expand Down Expand Up @@ -114,34 +121,45 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
newShape.back() *= 2;
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
} else {
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);

const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;
Type elemType = FloatType::getBF16(ctx);

// Note: For Intel the dot operands layout's kWidth parameter must match
// the parent's DPAS layout opsPerChannel so we need to materialize a new
// DPAS layout.
Attribute newVEncoding;
if (auto dpasEncoding =
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
auto newDpasEncoding = intel::DpasEncodingAttr::get(
ctx, dpasEncoding.getRepeatCount(), dpasEncoding.getSystolicDepth(),
dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
newVEncoding = DotOperandEncodingAttr::get(
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
} else {
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
newVEncoding = DotOperandEncodingAttr::get(ctx, oldEncoding.getOpIdx(),
oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
Attribute newVEncoding = nullptr;
if (auto oldEncoding = dyn_cast<DotOperandEncodingAttr>(encoding)) {
const int opIdx = oldEncoding.getOpIdx();
const bool hasBatch = xShape.size() == 3;
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
newShape[kIdx] *= 2;

// Note: For Intel the dot operands layout's kWidth parameter must match
// the parent's DPAS layout opsPerChannel so we need to materialize a
// new DPAS layout.
if (auto dpasEncoding =
dyn_cast<intel::DpasEncodingAttr>(oldEncoding.getParent())) {
auto newDpasEncoding = intel::DpasEncodingAttr::get(
ctx, dpasEncoding.getRepeatCount(),
dpasEncoding.getSystolicDepth(), dpasEncoding.getExecutionSize(),
intel::DpasEncodingAttr::getOpsPerChannel(elemType),
dpasEncoding.getWarpsPerCTA(), dpasEncoding.getRepCluster(),
dpasEncoding.getSubGroupSize());
newVEncoding = DotOperandEncodingAttr::get(
ctx, opIdx, newDpasEncoding, newDpasEncoding.getOpsPerChannel());
} else {
// Figure out the K dimension for the input A/B, given that the return
// type is upcasted A/B type so we need to update the proper dim size.
newVEncoding = DotOperandEncodingAttr::get(
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
oldEncoding.getKWidth() * 2);
}
} else if (auto oldEncoding = dyn_cast<BlockedEncodingAttr>(encoding)) {
// TODO: Temporary code, remove once upcast_mxfp support dot encoding.
assert(!tools::getBoolEnv("TRITON_INTEL_UPCASTMXFP_DOTOP_ENCODING"));
SmallVector<unsigned> sizePerThread = oldEncoding.getSizePerThread();
int opIdx = sizePerThread.back() == 1 ? 1 : 0;
sizePerThread[!opIdx] *= 2;
newShape[!opIdx] *= 2;
newVEncoding = BlockedEncodingAttr::get(
ctx, sizePerThread, oldEncoding.getThreadsPerWarp(),
oldEncoding.getWarpsPerCTA(), oldEncoding.getCTAOrder(),
oldEncoding.getCTALayout());
}
retTy = RankedTensorType::get(newShape, elemType, newVEncoding);
}
Expand Down
5 changes: 4 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3441,7 +3441,10 @@ def test_scaled_dot(M, N, K, col_a, col_b, rhs_scale, normal_type, mxfp_type, nu
if mma == 16 and K == 64:
pytest.skip(f"K == {K} too small for mfma {mma} in scaled_dot")
if is_xpu():
pytest.skip("scaled_dot isn't supported on XPU")
if M == 128 and N == 128 and K == 64 and not col_a and not col_b and rhs_scale and normal_type == "e4m3" and mxfp_type == "bf16":
pytest.skip(
f"FIXME: {M}x{N}x{K} col_a={col_a} col_b={col_b} rhs_scale={rhs_scale} normal_type={normal_type} mxfp_type={mxfp_type}"
)

@triton.jit
def dot_scale_kernel(a_base, stride_a0, stride_a1, a_scale, b_base, stride_b0, stride_b1, b_scale, out,
Expand Down
Loading

0 comments on commit c878427

Please sign in to comment.