Skip to content

Commit

Permalink
Merge commit 'c76b342a2d704b6552c1224a4e7706bb85a4b888'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Nov 19, 2024
2 parents b3ca988 + c76b342 commit d254e2b
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 31 deletions.
7 changes: 4 additions & 3 deletions include/triton/Analysis/AxisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ class AxisInfo {
public:
AxisInfo() : AxisInfo({}, {}, {}) {}

AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy)
AxisInfo(ArrayRef<int64_t> contiguity, ArrayRef<int64_t> divisibility,
ArrayRef<int64_t> constancy)
: AxisInfo(contiguity, divisibility, constancy, std::nullopt) {}

AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy,
std::optional<int64_t> constantValue)
AxisInfo(ArrayRef<int64_t> contiguity, ArrayRef<int64_t> divisibility,
ArrayRef<int64_t> constancy, std::optional<int64_t> constantValue)
: contiguity(contiguity), divisibility(divisibility),
constancy(constancy), constantValue(constantValue) {
assert(divisibility.size() == contiguity.size());
Expand Down
5 changes: 5 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,11 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
(ins "int":$opIdx,
"int":$kWidth)>,

InterfaceMethod<"Get the order of reps (tiles of this layout that tile the whole tensor). The fastest-changing axis first",
"SmallVector<unsigned>",
"getRepOrderForOperand",
(ins "int":$opIdx)>,

InterfaceMethod<"Return element sizes per thread for dot operands.", "SmallVector<unsigned>",
"getElemsPerThreadForOperands", (ins "ArrayRef<int64_t>":$tensorShape,
"Type":$eltTy,
Expand Down
8 changes: 4 additions & 4 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,

assert(cvtNeedsSharedMemory(srcTy, dstTy));

auto inOrd = gpu::getOrder(srcLayout);
auto outOrd = gpu::getOrder(dstLayout);
const auto &inOrd = gpu::getOrder(srcLayout);
const auto &outOrd = gpu::getOrder(dstLayout);
scratchConfig.order = outOrd;

unsigned srcContigPerThread =
Expand Down Expand Up @@ -303,7 +303,7 @@ class AllocationAnalysis {
/// arguments are involved.
void resolveAliasBufferLiveness(
function_ref<Interval<size_t>(Value value)> getLiveness) {
for (auto aliasBufferIter : allocation->getAliasBuffer()) {
for (const auto &aliasBufferIter : allocation->getAliasBuffer()) {
auto value = aliasBufferIter.first;
auto buffers = aliasBufferIter.second;
auto range = getLiveness(value);
Expand Down Expand Up @@ -443,7 +443,7 @@ class AllocationAnalysis {
std::find_if(xBuffers.begin(), xBuffers.end(), [&](auto *buffer) {
auto xRange = bufferRange[buffer];
bool res = xRange.intersects(range);
for (auto val : tripleMap)
for (const auto &val : tripleMap)
res = res &&
!val.second.intersects(xRange); // only one buffer intersect
return res;
Expand Down
8 changes: 5 additions & 3 deletions lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1084,9 +1084,11 @@ LogicalResult AxisInfoAnalysis::visitOperation(

void AxisInfoAnalysis::visitForOpInductionVar(
scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
ProgramPoint programPoint(op);
auto lb = getLatticeElementFor(&programPoint, op.getLowerBound())->getValue();
auto step = getLatticeElementFor(&programPoint, op.getStep())->getValue();
ProgramPoint *programPoint = getProgramPointAfter(op);
const auto &lb =
getLatticeElementFor(programPoint, op.getLowerBound())->getValue();
const auto &step =
getLatticeElementFor(programPoint, op.getStep())->getValue();

AxisInfo::DimVectorT knownContiguity(1, 1);
AxisInfo::DimVectorT knownDivisibility(1, 1);
Expand Down
21 changes: 18 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1702,7 +1702,14 @@ AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
}

SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. AMDMfmaEncodingAttr::getRepOrder");
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}

SmallVector<unsigned>
AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

SmallVector<int64_t>
Expand Down Expand Up @@ -1789,8 +1796,16 @@ AMDWmmaEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
return shapePerCTATile;
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. AMDWmmaEncodingAttr::getRepOrder");
auto rank = getWarpsPerCTA().size();
return getMatrixOrder(rank, /*rowMajor*/ true);
}

SmallVector<unsigned>
AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAsPerCGA() const {
return SmallVector<unsigned>(getCTALayout().getCTAsPerCGA());
}
Expand Down Expand Up @@ -2060,7 +2075,7 @@ NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
// DotOperand Encoding
//===----------------------------------------------------------------------===//
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(getParent())) {
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
return mma.getRepOrderForOperand(getOpIdx());
}
llvm::report_fatal_error(
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ std::optional<std::pair<Operation *, int>> findZeroInitOp(Value accUse,
return std::nullopt;
}
if (auto selOp = dyn_cast<arith::SelectOp>(defOp)) {
if (!selOp.getCondition().getType().isInteger(1))
return std::nullopt;
if (isConstantZeroTensor(selOp.getTrueValue()) ||
isConstantZeroTensor(selOp.getFalseValue())) {
return std::make_pair(selOp, 0);
Expand Down
27 changes: 14 additions & 13 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5630,7 +5630,7 @@ def matmul_kernel( #
stride_cm, stride_cn, #
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
low_precision_acc: tl.constexpr, #
num_pipeline_stages: tl.constexpr = 3 #
num_stages: tl.constexpr = 3 #
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand All @@ -5642,7 +5642,7 @@ def matmul_kernel( #
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_pipeline_stages):
for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K), num_stages=num_stages):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc)
Expand Down Expand Up @@ -5681,7 +5681,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s
max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None
h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0),
C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps,
num_pipeline_stages=num_stages)
num_stages=num_stages)
torch_a = torch.from_numpy(A).to(device=device)
th_a = f8_to_f16(torch_a, in_type_str)
torch_b = torch.from_numpy(B).to(device=device)
Expand Down Expand Up @@ -5873,7 +5873,7 @@ def test_tl_range(device):
pgm = matmul_kernel[
1,
](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N,
BLOCK_K, 0, num_pipeline_stages=5)
BLOCK_K, 0, num_stages=5)
ref_out = torch.matmul(a, b).to(torch.float32)
if is_interpreter():
# GPU invokes tensor core for float16 matmul, which is not supported in interpreter.
Expand All @@ -5899,8 +5899,8 @@ def maxnreg_noinline2(X):
tl.store(X, 0)


@pytest.mark.interpreter
def test_maxnreg(device):
assert not is_interpreter(), "this test won't work with the interpreter"
if not is_cuda():
pytest.xfail('maxnreg only works on CUDA')

Expand All @@ -5914,14 +5914,15 @@ def kernel(X):
X = torch.empty(1, dtype=torch.int32, device=device)
k = kernel[(1, )](X, maxnreg=42)

# Ensure that .maxnreg is set on the kernel function (marked with .entry)
# and not on either of the noinline functions (marked with .func).
try:
assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"])
assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"])
except AssertionError:
print("Failing ptx:\n", k.asm["ptx"])
raise
if not is_interpreter():
# Ensure that .maxnreg is set on the kernel function (marked with .entry)
# and not on either of the noinline functions (marked with .func).
try:
assert re.search(r'\.visible \.entry [^{;]*\.maxnreg 42', k.asm["ptx"])
assert not re.search(r'\.visible \.func [^{;]*\.maxnreg', k.asm["ptx"])
except AssertionError:
print("Failing ptx:\n", k.asm["ptx"])
raise


@pytest.mark.interpreter
Expand Down
10 changes: 5 additions & 5 deletions python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,9 +1034,6 @@ def _implicit_cvt(arg):

interpreter_builder = InterpreterBuilder()

# These keywords are not supported by the interpreter
RESERVED_KWS = ["num_warps", "num_stages", "num_ctas", "enable_fp_fusion", "grid", "maxnreg"]


class GridExecutor:

Expand Down Expand Up @@ -1077,10 +1074,13 @@ def _restore_args_dev(self, args_dev, args_hst, kwargs, kwargs_hst):
kwarg_dev.data.copy_(kwarg_hst.to(kwarg_dev.device).data)

def __call__(self, *args_dev, **kwargs):
# removes reserved keywords from kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
if kwargs.pop("warmup", False):
return
# Removes not used reserved keywords from kwargs
# Triton doesn't support keyword-only, variable positional or variable keyword arguments
# It's safe to inspect only positional or keyword arguments (i.e., argspec.args)
argspec = inspect.getfullargspec(self.fn)
kwargs = {k: v for k, v in kwargs.items() if k in argspec.args}
# copy arguments to the host
args_hst, kwargs_hst = self._init_args_hst(args_dev, kwargs)
# remaps core language functions to interpreted ones
Expand Down
16 changes: 16 additions & 0 deletions test/TritonGPU/accumulator-init.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -348,4 +348,20 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
}
tt.return %17 : tensor<128x16xf32, #mma1>
}

// If the condition is a tensor skip the optimization.
// CHECK-LABEL: @negative_sel_tensor
// CHECK-NOT: %[[ACC_NEXT:.+]] = triton_nvidia_gpu.warp_group_dot {{.*}}, {{.*}}, {{.*}}, {{.*}} : !tt.memdesc
tt.func @negative_sel_tensor(%A: !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory>, %B: !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory>, %cnd: tensor<128x16xi1, #mma1>) -> tensor<128x16xf32, #mma1> {
%c0_i32 = arith.constant 0 : i32
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
%c1_i32 = arith.constant 1 : i32
%c8_i32 = arith.constant 8 : i32
%17 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_2) -> (tensor<128x16xf32, #mma1>) : i32 {
%acc_ = arith.select %cnd, %cst_2, %arg4 : tensor<128x16xi1, #mma1>, tensor<128x16xf32, #mma1>
%acc = triton_nvidia_gpu.warp_group_dot %A, %B, %acc_ : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x16xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x16xf32, #mma1>
scf.yield %acc: tensor<128x16xf32, #mma1>
}
tt.return %17 : tensor<128x16xf32, #mma1>
}
}
29 changes: 29 additions & 0 deletions test/TritonGPU/coalesce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,32 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
tt.return
}
}

// -----

// COM: Reproducer for issue #5122
// CHECK-LABEL: @test_5122
module {
tt.func public @test_5122(%arg0: i32) attributes {noinline = false} {
%c1_i32 = arith.constant 1 : i32
%0 = arith.cmpi sgt, %arg0, %c1_i32 : i32
scf.if %0 {
%1 = scf.if %0 -> (i32) {
scf.yield %c1_i32 : i32
} else {
scf.yield %c1_i32 : i32
}
%2 = arith.cmpi sgt, %1, %c1_i32 : i32
%3 = scf.if %2 -> (i32) {
scf.yield %c1_i32 : i32
} else {
scf.yield %c1_i32 : i32
}
%4 = scf.for %arg1 = %1 to %1 step %c1_i32 iter_args(%arg2 = %3) -> (i32) : i32 {
%5 = arith.addi %arg2, %c1_i32 : i32
scf.yield %5 : i32
}
}
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ along the row (resp. col) dimension.
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth,unsigned opIdx) const;
SmallVector<unsigned> getElemsPerThreadForOperands(ArrayRef<int64_t> shape, Type eltTy, unsigned opIdx) const;
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;

bool supportReduction() const {
Expand Down
5 changes: 5 additions & 0 deletions third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ SmallVector<unsigned> DpasEncodingAttr::getRepOrder() const {
llvm::report_fatal_error("NYI. DpasEncodingAttr::getRepOrder");
}

SmallVector<unsigned> DpasEncodingAttr::getRepOrderForOperand(int opIdx) const {
auto rank = getWarpsPerCTA().size();
return getOrderForDotOperand(opIdx, rank, /*kMajor*/ true);
}

SmallVector<unsigned>
DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
size_t rank = shape.size();
Expand Down

0 comments on commit d254e2b

Please sign in to comment.