diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index 49a07681ed..033f5e6880 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -80,9 +80,18 @@ static cl::opt TensorStr( //===--------------------------------------------------------------------===// LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { - // Dispatch to the corresponding dialect helper function to print the layout. - os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); - return success(); + // DistributedEncodingTrait and SharedEncodingAttr implements the + // toLinearLayout interface. + mlir::Attribute layout = tensorType.getEncoding(); + if (isa(layout)) { + os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); + return success(); + } + + llvm::errs() << "Unsupported tensor layout attribute: " + << tensorType.getEncoding() << "\n"; + return failure(); } LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename, diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 726c2bc588..10749a0574 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -161,9 +161,11 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, // Get backward slice of tensor values starting from the root node along with // encoding propagation. LogicalResult getConvertBackwardSlice( - Value root, SetVector &slice, Attribute rootEncoding, + OpOperand &root, SetVector &slice, Attribute rootEncoding, DenseMap &layout, - std::function stopPropagation = nullptr); + std::function stopPropagation = nullptr, + std::function getExistingConversion = + nullptr); // Populate pattern to remove dead cycles in ForOp. void populateForOpDeadArgumentElimination(RewritePatternSet &patterns); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index eefe212d73..a56b6f7977 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -3,6 +3,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -116,17 +117,15 @@ class LayoutPropagation { class LayoutRematerialization { public: LayoutRematerialization(FuncOp F) : funcOp(F) {} + // Map the original value to the remat'ed one. void addRematValue(Value old, Attribute encoding, Value newV); - bool hasRematValue(Value value, Attribute encoding) { - return rematMapping.contains({value, encoding}); - } - // Return the remat'ed value in the given encoding. - Value getRematValue(Value value, Attribute encoding) { - auto it = rematMapping.find({value, encoding}); - assert(it != rematMapping.end()); - return it->second; + // Get the remat'ed value in the given encoding, if one already exists and + // is different then the layout conversion root. + Value getRematValue(Value value, Attribute encoding) const { + return rematMapping.lookup({value, encoding}); } + void cleanup(); void backwardRematerialization(); void backwardRematerialization(ConvertLayoutOp convertOp); @@ -137,6 +136,11 @@ class LayoutRematerialization { void rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp); + LogicalResult getRematerializableSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, + DenseMap &layout, + std::function stopPropagation = nullptr); + private: void updateRematMapping(SmallVector> &values); // Existing tuples of (value, layout) that needs to be updated when recreating @@ -148,6 +152,7 @@ class LayoutRematerialization { // DenseMap, Operation*> SetVector opToDelete; FuncOp funcOp; + DominanceInfo domInfo; }; void LayoutRematerialization::addRematValue(Value old, Attribute encoding, @@ -778,8 +783,8 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, auto layoutIt = layout.find(v); assert(layoutIt != layout.end()); // If we already have a remat value for this value, use it. - if (hasRematValue(v, layoutIt->second)) { - mapping.map(v, getRematValue(v, layoutIt->second)); + if (Value remat = getRematValue(v, layoutIt->second)) { + mapping.map(v, remat); valuesWithExistingRemat.insert(v); continue; } @@ -940,12 +945,36 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, rewriteSlice(slice, layout, convertOp, mapping); } -LogicalResult getRematerializableSlice( - Value root, Attribute rootEncoding, SetVector &slice, +LogicalResult LayoutRematerialization::getRematerializableSlice( + OpOperand &root, Attribute rootEncoding, SetVector &slice, DenseMap &layout, - std::function stopPropagation = nullptr) { - LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding, - layout, stopPropagation); + std::function stopPropagation) { + // Allow re-using existing conversions for a value. Check dominance of any + // re-usable materializations against the root value. This is sufficient + // because the conversions are processed in post-order. + auto getExistingConversion = [&](OpOperand &value, Attribute encoding) { + Value remat = getRematValue(value.get(), encoding); + if (!remat) + return Value(); + // `value` can be replaced with an existing rematerialization if it + // dominates the current use of value. + Operation *user = value.getOwner(); + if (domInfo.properlyDominates(remat, user)) { + return remat; + } + // Alternatively, if the current use can be sunk below the existing + // rematerialization, then it is okay to use as well. E.g. the current use + // is a conversion that will be folded away when its result is + // rematerialized. + if (isa(user) && remat.getDefiningOp() && + domInfo.properlyDominates(user, remat.getDefiningOp())) { + return remat; + } + return Value(); + }; + LogicalResult result = + getConvertBackwardSlice(root, slice, rootEncoding, layout, + stopPropagation, getExistingConversion); if (result.failed() || slice.empty()) return failure(); @@ -966,6 +995,12 @@ void LayoutRematerialization::backwardRematerialization() { [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); for (ConvertLayoutOp convertOp : convertOps) { backwardRematerialization(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for re-use in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } } } @@ -976,6 +1011,12 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); for (ConvertLayoutOp convertOp : convertOps) { hoistConvertOnTopOfExtOrBroadcast(convertOp); + if (!opToDelete.contains(convertOp)) { + // If the conversion didn't get removed, consider it for re-use in future + // backward slices. + addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(), + convertOp.getResult()); + } } } @@ -988,14 +1029,14 @@ void LayoutRematerialization::backwardRematerialization( // careful with the heuristics for both correctness and perf if (isa(targetType.getEncoding())) return; - Value oldV = convertOp->getOperand(0); + Value oldV = convertOp.getSrc(); LDBG("check backward remat with source " << oldV << " encoding " << targetType.getEncoding()); // Check to see if there are existing remat'ed values for the pair of oldValue - // and encoding. - if (hasRematValue(oldV, targetType.getEncoding())) { + // and encoding. Make sure it dominates the current conversion. + Value newV = getRematValue(oldV, targetType.getEncoding()); + if (newV && domInfo.properlyDominates(newV, convertOp)) { // Replace it with the remat'ed value. - Value newV = getRematValue(oldV, targetType.getEncoding()); convertOp.replaceAllUsesWith(newV); opToDelete.insert(convertOp); LDBG("found remat'ed value" << newV); @@ -1007,7 +1048,7 @@ void LayoutRematerialization::backwardRematerialization( SetVector slice; DenseMap layout; LogicalResult result = getRematerializableSlice( - convertOp.getSrc(), targetType.getEncoding(), slice, layout); + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout); if (result.failed()) { LDBG(" getRematerializableSlice failed"); return; @@ -1050,9 +1091,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( // 1. Take a backward slice of all the tensor dependencies. SetVector slice; DenseMap layout; - LogicalResult result = - getRematerializableSlice(convertOp.getSrc(), targetType.getEncoding(), - slice, layout, isExtOrBroadcastOp); + LogicalResult result = getRematerializableSlice( + convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout, + isExtOrBroadcastOp); if (result.failed()) return; @@ -1070,7 +1111,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( if (!srcEncoding) return; LogicalResult result = getRematerializableSlice( - op->getOperand(0), srcEncoding, tempSlice, tempLayout); + op->getOpOperand(0), srcEncoding, tempSlice, tempLayout); // If we can rematerialize the rest of the ext slice we can ignore this // ext as it won't need a convert. if (result.succeeded()) { diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 799d5f5c91..20ac0954ad 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -770,16 +770,16 @@ static bool isFreeConvert(Operation *op) { convertOp.getType()); } -LogicalResult -getConvertBackwardSlice(Value root, SetVector &slice, - Attribute rootEncoding, - DenseMap &layout, - std::function stopPropagation) { - DenseSet> seen; - SmallVector> queue; - - auto enqueue = [&](Value operand, Attribute encoding) { - auto x = std::make_pair(operand, encoding); +LogicalResult getConvertBackwardSlice( + OpOperand &root, SetVector &slice, Attribute rootEncoding, + DenseMap &layout, + std::function stopPropagation, + std::function getExistingConversion) { + DenseSet> seen; + SmallVector> queue; + + auto enqueue = [&](OpOperand &operand, Attribute encoding) { + auto x = std::make_pair(&operand, encoding); if (!seen.insert(x).second) { return; // Already enqueued, skip } @@ -787,8 +787,20 @@ getConvertBackwardSlice(Value root, SetVector &slice, }; enqueue(root, rootEncoding); + auto updateLayout = [&](Value value, Attribute encoding) { + assert((isa(value.getType()))); + slice.insert(value); + if (layout.find(value) != layout.end()) { + if (layout[value] != encoding) + return failure(); + } + layout[value] = encoding; + return success(); + }; + while (!queue.empty()) { - auto [currentValue, encoding] = queue.back(); + auto [currentValueUse, encoding] = queue.back(); + Value currentValue = currentValueUse->get(); queue.pop_back(); if (!isa(currentValue.getType())) continue; @@ -796,18 +808,22 @@ getConvertBackwardSlice(Value root, SetVector &slice, // TODO: enable this based on needs. if (currentValue.getDefiningOp()) return failure(); - slice.insert(currentValue); - if (layout.find(currentValue) != layout.end()) { - if (layout[currentValue] != encoding) + if (failed(updateLayout(currentValue, encoding))) + return failure(); + + Value existing; + if (getExistingConversion && + (existing = getExistingConversion(*currentValueUse, encoding))) { + if (failed(updateLayout(existing, encoding))) return failure(); + currentValue = existing; } - layout[currentValue] = encoding; if (auto ifOp = currentValue.getDefiningOp()) { unsigned argIdx = mlir::cast(currentValue).getResultNumber(); - auto thenValue = ifOp.thenYield().getOperand(argIdx); - auto elseValue = ifOp.elseYield().getOperand(argIdx); + OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx); + OpOperand &elseValue = ifOp.elseYield()->getOpOperand(argIdx); enqueue(thenValue, encoding); enqueue(elseValue, encoding); @@ -819,10 +835,11 @@ getConvertBackwardSlice(Value root, SetVector &slice, for (Value result : definingOp->getResults()) { if (result == currentValue || !isa(result.getType())) continue; - enqueue(result, encoding); + if (failed(updateLayout(result, encoding))) + return failure(); } if (isFreeConvert(definingOp)) { - enqueue(definingOp->getOperand(0), encoding); + enqueue(definingOp->getOpOperand(0), encoding); continue; } if (canFoldIntoConversion(definingOp, encoding)) @@ -837,10 +854,10 @@ getConvertBackwardSlice(Value root, SetVector &slice, auto srcEncoding = inferSrcEncoding(gather, encoding); if (!srcEncoding) return failure(); - enqueue(gather.getIndices(), srcEncoding); + enqueue(gather.getIndicesMutable(), srcEncoding); continue; } - for (auto [i, operand] : llvm::enumerate(definingOp->getOperands())) { + for (auto [i, operand] : llvm::enumerate(definingOp->getOpOperands())) { auto srcEncoding = inferSrcEncoding(definingOp, encoding); if (!srcEncoding) return failure(); @@ -853,9 +870,9 @@ getConvertBackwardSlice(Value root, SetVector &slice, Operation *parentOp = block->getParentOp(); if (auto forOp = dyn_cast(parentOp)) { OpOperand *initOperand = forOp.getTiedLoopInit(blockArg); - Value yieldOperand = forOp.getBody()->getTerminator()->getOperand( + OpOperand &yieldOperand = forOp.getBody()->getTerminator()->getOpOperand( blockArg.getArgNumber() - forOp.getNumInductionVars()); - enqueue(initOperand->get(), encoding); + enqueue(*initOperand, encoding); enqueue(yieldOperand, encoding); continue; } diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 1c39d778ec..e42f9f44e4 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -4,17 +4,19 @@ import warnings import os import textwrap +from types import ModuleType from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + from .. import language from .._C.libtriton import ir -from ..language import constexpr, tensor, str_to_ty +from ..language import constexpr, semantic, str_to_ty, tensor from ..language.core import _unwrap_if_constexpr, nv_tma_desc_type, _value from ..runtime.jit import _normalize_ty, get_jit_fn_file_line # ideally we wouldn't need any runtime component from ..runtime import JITFunction +from .._utils import list_list_flatten, list_list_unflatten + from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) -from types import ModuleType -from triton._utils import list_list_flatten, list_list_unflatten def mangle_ty(ty): @@ -364,12 +366,12 @@ def visit_Return(self, node): self.builder.ret([]) ret_ty = language.void elif isinstance(ret_value, tuple): - ret_values = [language.semantic.to_tensor(v, self.builder) for v in ret_value] + ret_values = [semantic.to_tensor(v, self.builder) for v in ret_value] ret_types = [v.type for v in ret_values] self.builder.ret([v.handle for v in ret_values]) ret_ty = tuple(ret_types) else: - ret = language.semantic.to_tensor(ret_value, self.builder) + ret = semantic.to_tensor(ret_value, self.builder) self.builder.ret([ret.handle]) ret_ty = ret.type @@ -507,7 +509,7 @@ def visit_Assign(self, node): if value is not None and \ not _is_triton_value(value) and \ not isinstance(value, native_nontensor_types): - value = language.semantic.to_tensor(value, self.builder) + value = semantic.to_tensor(value, self.builder) self.set_value(name, value) def visit_AugAssign(self, node): @@ -728,14 +730,14 @@ def visit_IfExp(self, node): then_block = self.builder.create_block() self.builder.set_insertion_point_to_start(then_block) - then_val = language.semantic.to_tensor(self.visit(node.body), self.builder) + then_val = semantic.to_tensor(self.visit(node.body), self.builder) then_block = self.builder.get_insertion_block() else_block = self.builder.create_block() self.builder.set_insertion_point_to_start(else_block) # do not need to reset lscope since # ternary expressions cannot define new variables - else_val = language.semantic.to_tensor(self.visit(node.orelse), self.builder) + else_val = semantic.to_tensor(self.visit(node.orelse), self.builder) else_block = self.builder.get_insertion_block() self._set_insertion_point_and_loc(ip, last_loc) @@ -954,14 +956,14 @@ def visit_For(self, node): step = constexpr(-step.value) negative_step = True lb, ub = ub, lb - lb = language.semantic.to_tensor(lb, self.builder) - ub = language.semantic.to_tensor(ub, self.builder) - step = language.semantic.to_tensor(step, self.builder) + lb = semantic.to_tensor(lb, self.builder) + ub = semantic.to_tensor(ub, self.builder) + step = semantic.to_tensor(step, self.builder) # induction variable type if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") - iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) - iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) + iv_type = semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = semantic.integer_promote_impl(iv_type, step.dtype) iv_ir_type = iv_type.to_ir(self.builder) iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED # lb/ub/step might be constexpr, we need to cast them to tensor @@ -1032,7 +1034,7 @@ def visit_For(self, node): if name in liveins: local = self.local_defs[name] if isinstance(local, constexpr): - local = language.semantic.to_tensor(local, self.builder) + local = semantic.to_tensor(local, self.builder) yields.append(local) # create YieldOp @@ -1178,7 +1180,7 @@ def visit_BoolOp(self, node: ast.BoolOp): def visit_Attribute(self, node): lhs = self.visit(node.value) if _is_triton_tensor(lhs) and node.attr == "T": - return language.semantic.permute(lhs, (1, 0), builder=self.builder) + return semantic.permute(lhs, (1, 0), builder=self.builder) return getattr(lhs, node.attr) def visit_Expr(self, node): diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4462966d08..cf3406a890 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -15,6 +15,7 @@ set(TRITON_TEST_DEPENDS triton-opt triton-tensor-layout triton-translate + triton-llvm-opt ) set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck") diff --git a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir index 18922b15aa..26ddbad4e0 100644 --- a/test/TritonGPU/amd/amd-convert-buffer-ops.mlir +++ b/test/TritonGPU/amd/amd-convert-buffer-ops.mlir @@ -122,3 +122,365 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { tt.return %10 : tensor<1024xf32, #blocked> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_eq_non_neg + tt.func @assume_eq_non_neg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i32) { + %c10_i32 = arith.constant 10 : i32 + %0 = arith.cmpi eq, %arg2, %c10_i32 : i32 + llvm.intr.assume %0 : i1 + // CHECK: %[[range:.*]] = tt.make_range + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked> + // CHECK: %[[ptr:.*]] = tt.addptr %arg0, %arg2 + %2 = tt.addptr %arg0, %arg2: !tt.ptr, i32 + %3 = tt.splat %2 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %4 = tt.addptr %3, %1 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %6 = tt.addptr %5, %1 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg1[%1] + %7 = tt.load %6 : tensor<16x!tt.ptr, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %[[ptr]][%[[range]]] + tt.store %4, %7 : tensor<16x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_nonneg_less + tt.func @assume_nonneg_less(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i32) { + %c10_i32 = arith.constant 5 : i32 + %0 = arith.cmpi slt, %c10_i32, %arg2 : i32 + llvm.intr.assume %0 : i1 + // CHECK: %[[range:.*]] = tt.make_range + %1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked> + // CHECK: %[[ptr:.*]] = tt.addptr %arg0, %arg2 + %2 = tt.addptr %arg0, %arg2: !tt.ptr, i32 + %3 = tt.splat %2 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %4 = tt.addptr %3, %1 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %6 = tt.addptr %5, %1 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg1[%1] + %7 = tt.load %6 : tensor<16x!tt.ptr, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %[[ptr]][%[[range]]] + tt.store %4, %7 : tensor<16x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: assume_cmp_non_const + tt.func @assume_cmp_non_const(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: i32, %arg3 : i32, %arg4 : i32, %arg5 : i32, %arg6 : i32) { + %0 = arith.cmpi sgt, %arg2, %arg3 : i32 + llvm.intr.assume %0 : i1 + %1 = arith.subi %arg2, %arg3 : i32 + %2 = arith.cmpi sge, %1, %arg4 : i32 + llvm.intr.assume %2 : i1 + %3 = arith.subi %1, %arg4 : i32 + %4 = arith.cmpi slt, %3, %arg5 : i32 + llvm.intr.assume %4 : i1 + %5 = arith.subi %arg5, %3 : i32 + %6 = arith.cmpi sle, %5, %arg6 : i32 + llvm.intr.assume %6 : i1 + %7 = arith.subi %arg6, %5 : i32 + %8 = arith.minsi %1, %3 : i32 + %9 = arith.minsi %8, %5 : i32 + %10 = arith.minsi %9, %7 : i32 + // CHECK: %[[range:.*]] = tt.make_range + %11 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked> + %12 = tt.splat %10 : i32 -> tensor<16xi32, #blocked> + // CHECK: %[[offsets:.*]] = arith.addi + %offsets = arith.addi %11, %12 : tensor<16xi32, #blocked> + %13 = tt.splat %arg0 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %14 = tt.addptr %13, %11 : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + %15 = tt.splat %arg1 : !tt.ptr -> tensor<16x!tt.ptr, #blocked> + %16 = tt.addptr %15, %offsets : tensor<16x!tt.ptr, #blocked>, tensor<16xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg1[%[[offsets]]] + %17 = tt.load %16 : tensor<16x!tt.ptr, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg0[%[[range]]] + tt.store %14, %17 : tensor<16x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> +#blockedtrans = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}> +#blocked1 = #ttg.slice<{dim=0, parent=#blocked}> +#blocked2 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: unary_triton_ops_transitive_nonneg + tt.func @unary_triton_ops_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { + %c10_i32 = arith.constant 5 : i32 + %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked1> + %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #blocked1> -> tensor<1x16xi32, #blocked> + %2 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<8x2xi32, #blocked> + %3 = tt.reshape %1 allow_reorder : tensor<1x16xi32, #blocked> -> tensor<2x8xi32, #blocked> + %4 = tt.trans %3 {order = array} : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blockedtrans> + %5 = ttg.convert_layout %4 : tensor<8x2xi32, #blockedtrans> -> tensor<8x2xi32, #blocked> + %6 = arith.addi %5, %2 : tensor<8x2xi32, #blocked> + %7 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked2> + %8 = ttg.convert_layout %7 : tensor<8xi32, #blocked2> -> tensor<8xi32, #blocked1> + %9 = tt.expand_dims %8 {axis = 0 : i32} : tensor<8xi32, #blocked1> -> tensor<1x8xi32, #blocked> + %10 = tt.broadcast %9 : tensor<1x8xi32, #blocked> -> tensor<2x8xi32, #blocked> + %11 = tt.reshape %10 allow_reorder : tensor<2x8xi32, #blocked> -> tensor<8x2xi32, #blocked> + %12 = tt.splat %c10_i32 : i32 -> tensor<8x2xi32, #blocked> + %13 = arith.addi %11, %12 : tensor<8x2xi32, #blocked> + %14 = arith.minsi %13, %5 : tensor<8x2xi32, #blocked> + // CHECK: %[[lhs:.*]], %[[rhs:.*]] = tt.split + %15, %16 = tt.split %11: tensor<8x2xi32, #blocked> -> tensor<8xi32, #blocked2> + %17 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked2> + %18 = tt.addptr %17, %15 : tensor<8x!tt.ptr, #blocked2>, tensor<8xi32, #blocked2> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%[[lhs]]] + %19 = tt.load %18 : tensor<8x!tt.ptr, #blocked2> + %20 = tt.addptr %17, %16 : tensor<8x!tt.ptr, #blocked2>, tensor<8xi32, #blocked2> + // CHECK: %[[loaded2:.*]] = amdgpu.buffer_load %arg0[%[[rhs]]] + %21 = tt.load %20 : tensor<8x!tt.ptr, #blocked2> + // CHECK: %[[added:.*]] = arith.addf %[[loaded]], %[[loaded2]] + %22 = arith.addf %19, %21 : tensor<8xbf16, #blocked2> + %23 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked2> + %24 = tt.addptr %23, %7 : tensor<8x!tt.ptr, #blocked2>, tensor<8xi32, #blocked2> + // CHECK: amdgpu.buffer_store %[[added]], %arg1[%{{.*}}] + tt.store %24, %22 : tensor<8x!tt.ptr, #blocked2> + tt.return + } +} + +// ----- + + +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: join_cat_transitive_nonneg + tt.func @join_cat_transitive_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) { + %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked1> + %1 = tt.make_range {end = 10 : i32, start = 2 : i32} : tensor<8xi32, #blocked1> + %2 = tt.join %0, %1 : tensor<8xi32, #blocked1> -> tensor<8x2xi32, #blocked> + %3 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #blocked1> + %4 = tt.make_range {end = 8 : i32, start = 4 : i32} : tensor<4xi32, #blocked1> + %5 = tt.join %3, %4 : tensor<4xi32, #blocked1> -> tensor<4x2xi32, #blocked> + %6 = tt.cat %5, %5 : tensor<4x2xi32, #blocked> -> tensor<8x2xi32, #blocked> + %7 = arith.addi %2, %6 : tensor<8x2xi32, #blocked> + %zeros = arith.constant dense<0> : tensor<8x1xi32, #blocked> + %ones = arith.constant dense<1> : tensor<8x1xi32, #blocked> + %8 = tt.gather %7[%zeros] {axis = 1 : i32} : (tensor<8x2xi32, #blocked>, tensor<8x1xi32, #blocked>) -> tensor<8x1xi32, #blocked> + %9 = tt.gather %7[%ones] {axis = 1 : i32} : (tensor<8x2xi32, #blocked>, tensor<8x1xi32, #blocked>) -> tensor<8x1xi32, #blocked> + %10 = arith.addi %8, %9 : tensor<8x1xi32, #blocked> + %11 = tt.reshape %10 allow_reorder : tensor<8x1xi32, #blocked> -> tensor<8xi32, #blocked1> + %12 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked1> + %14 = tt.addptr %12, %11 : tensor<8x!tt.ptr, #blocked1>, tensor<8xi32, #blocked1> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %15 = tt.load %14 : tensor<8x!tt.ptr, #blocked1> + %16 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked1> + %17 = tt.addptr %16, %0 : tensor<8x!tt.ptr, #blocked1>, tensor<8xi32, #blocked1> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %17, %15 : tensor<8x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: histo_nonneg + tt.func @histo_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : tensor<256xi32, #blocked>) { + /// Purposely specify %arg2 so that we can't statically determine the input + /// data is nonneg. + // CHECK: tt.histogram + %0 = tt.histogram %arg2 : tensor<256xi32, #blocked> -> tensor<8xi32, #blocked> + %1 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %3 = tt.addptr %2, %0 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %4 = tt.load %3 : tensor<8x!tt.ptr, #blocked> + %5 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %6 = tt.addptr %5, %1 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %6, %4 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: get_num_prog_nonneg + tt.func @get_num_prog_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32) { + %0 = tt.get_num_programs x : i32 + %1 = tt.get_num_programs y : i32 + %2 = tt.get_num_programs z : i32 + %3 = arith.minsi %0, %1 : i32 + %4 = arith.minsi %2, %3 : i32 + %5 = arith.maxsi %arg2, %4 : i32 + %6 = tt.splat %5 : i32 -> tensor<8xi32, #blocked> + %7 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %8 = arith.addi %6, %7 : tensor<8xi32, #blocked> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %10 = tt.addptr %9, %8 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %11 = tt.load %10 : tensor<8x!tt.ptr, #blocked> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %13 = tt.addptr %12, %7 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %13, %11 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: unsigned_ops + tt.func @unsigned_ops(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32, %arg4 : f32, %arg5 : index) { + %c5_i32 = arith.constant 5 : i32 + %0 = arith.ceildivui %arg2, %c5_i32 : i32 + %1 = arith.divui %arg3, %c5_i32 : i32 + %2 = arith.fptoui %arg4 : f32 to i32 + %3 = arith.index_castui %arg5 : index to i32 + %4 = arith.maxui %arg2, %arg3 : i32 + %5 = arith.minui %arg2, %arg3 : i32 + %6 = arith.remui %arg2, %c5_i32 : i32 + %7 = arith.shrui %arg3, %c5_i32 : i32 + %8 = arith.addi %0, %1 : i32 + %9 = arith.addi %2, %3 : i32 + %10 = arith.addi %4, %5 : i32 + %11 = arith.addi %6, %7 : i32 + %12 = arith.addi %8, %9 : i32 + %13 = arith.addi %10, %11 : i32 + %14 = arith.addi %12, %13 : i32 + %15 = tt.splat %14 : i32 -> tensor<8xi32, #blocked> + %16 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %17 = arith.addi %15, %16 : tensor<8xi32, #blocked> + %18 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %19 = tt.addptr %18, %17 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %20 = tt.load %19 : tensor<8x!tt.ptr, #blocked> + %21 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %22 = tt.addptr %21, %16 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %22, %20 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: extui_nonneg + tt.func @extui_nonneg(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32) { + %0 = arith.extui %arg2 : i32 to i64 + %1 = tt.splat %0 : i64 -> tensor<8xi64, #blocked> + %2 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %3 = arith.extui %2 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %4 = arith.addi %1, %3 : tensor<8xi64, #blocked> + %5 = arith.trunci %4 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> + %6 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %7 = tt.addptr %6, %5 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %8 = tt.load %7: tensor<8x!tt.ptr, #blocked> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %10 = tt.addptr %9, %2 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %10, %8 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: traverse_if + tt.func @traverse_if(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %c5_i32 = arith.constant 7 : i32 + %c7_i32 = arith.constant 5 : i32 + %0 = arith.extui %arg2 : i32 to i64 + %1 = arith.remui %arg2, %c2_i32 : i32 + %2 = arith.cmpi eq, %1, %c0_i32 : i32 + %3 = scf.if %2 -> tensor<8xi64, #blocked> { + %20 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %21 = arith.extui %20 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %22 = tt.splat %arg3 : i32 -> tensor<8xi32, #blocked> + %23 = arith.extui %22 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %24 = arith.addi %21, %23 : tensor<8xi64, #blocked> + scf.yield %24 : tensor<8xi64, #blocked> + } else { + %30 = tt.make_range {end = 16 : i32, start = 8 : i32} : tensor<8xi32, #blocked> + %31 = arith.extui %30 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %32 = tt.splat %0 : i64 -> tensor<8xi64, #blocked> + %33 = arith.addi %31, %32 : tensor<8xi64, #blocked> + scf.yield %33 : tensor<8xi64, #blocked> + } + %4 = arith.trunci %3 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> + %5 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %6 = tt.addptr %5, %4 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %7 = tt.load %6: tensor<8x!tt.ptr, #blocked> + %8 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %10 = tt.addptr %9, %8 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %10, %7 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: traverse_if + tt.func @traverse_if(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2 : i32, %arg3 : i32) { + %c0_i32 = arith.constant 0 : i32 + %c2_i32 = arith.constant 2 : i32 + %c5_i32 = arith.constant 7 : i32 + %c7_i32 = arith.constant 5 : i32 + %zeros = arith.constant dense<0> : tensor<8xi32, #blocked> + %0 = arith.extui %arg2 : i32 to i64 + %1 = arith.remui %arg2, %c2_i32 : i32 + %2 = arith.cmpi eq, %1, %c0_i32 : i32 + %3, %4 = scf.if %2 -> (tensor<8xi64, #blocked>, tensor<8xi32, #blocked>) { + %20 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %21 = arith.extui %20 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %22 = tt.splat %arg3 : i32 -> tensor<8xi32, #blocked> + %23 = arith.extui %22 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %24 = arith.addi %21, %23 : tensor<8xi64, #blocked> + %25 = tt.make_range {end = 9 : i32, start = 1 : i32} : tensor<8xi32, #blocked> + scf.yield %24, %25 : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } else { + %30 = tt.make_range {end = 16 : i32, start = 8 : i32} : tensor<8xi32, #blocked> + %31 = arith.extui %30 : tensor<8xi32, #blocked> to tensor<8xi64, #blocked> + %32 = tt.splat %0 : i64 -> tensor<8xi64, #blocked> + %33 = arith.addi %31, %32 : tensor<8xi64, #blocked> + scf.yield %33, %zeros : tensor<8xi64, #blocked>, tensor<8xi32, #blocked> + } + %5 = arith.trunci %3 : tensor<8xi64, #blocked> to tensor<8xi32, #blocked> + %6 = arith.addi %4, %5 : tensor<8xi32, #blocked> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %8 = tt.addptr %7, %6 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: %[[loaded:.*]] = amdgpu.buffer_load %arg0[%{{.*}}] + %9 = tt.load %8: tensor<8x!tt.ptr, #blocked> + %10 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked> + %11 = tt.splat %arg1 : !tt.ptr -> tensor<8x!tt.ptr, #blocked> + %12 = tt.addptr %11, %10 : tensor<8x!tt.ptr, #blocked>, tensor<8xi32, #blocked> + // CHECK: amdgpu.buffer_store %[[loaded]], %arg1[%{{.*}}] + tt.store %12, %9 : tensor<8x!tt.ptr, #blocked> + tt.return + } +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index bc2270adb4..cd45d1ee05 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s +// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-remove-layout-conversions 2>&1 | FileCheck %s #layout0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> #layout1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> @@ -2427,8 +2427,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr %2 = tt.splat %arg0 : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked2> %3 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> - // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>) - // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + // CHECK: %[[F:.+]]:3 = scf.for {{.*}} -> (tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) + // CHECK-COUNT-4: convert_layout + // CHECK: scf.yield {{.*}} : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> // CHECK: } // CHECK: tt.return %[[F]]#0, %[[F]]#1 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>> %5:3 = scf.for %arg2 = %c0_i32 to %c2048_i32 step %c64_i32 iter_args(%arg3 = %cst_2, %arg4 = %cst, %arg5 = %cst_0) -> (tensor<128x64xf32, #mma>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>, tensor<128xf32, #ttg.slice<{dim = 1, parent = #blocked}>>) : i32 { @@ -2772,3 +2773,59 @@ tt.func @do_not_remat(%arg0: tensor<64x64xf32, #blocked1>) -> tensor<1x64xf32, # } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [0, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { + +// CHECK-LABEL: reuse_layout_conversion +tt.func @reuse_layout_conversion(%arg0: tensor<64x64xf32, #blocked>) -> (tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>) { + // CHECK-NEXT: %cst = arith.constant {{.*}} tensor<64x64xf32, #blocked> + %cst = arith.constant dense<2.000000e+00> : tensor<64x64xf32, #blocked1> + // CHECK-NEXT: [[TRANS:%.*]] = tt.trans %arg0 {{.*}} tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + %0 = tt.trans %arg0 {order = array} : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + // CHECK-NEXT: [[CVT:%.*]] = ttg.convert_layout [[TRANS]] : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> + %1 = ttg.convert_layout %0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> + // CHECK-NEXT: [[RESULT:%.*]] = arith.mulf [[CVT]], %cst : tensor<64x64xf32, #blocked> + %2 = arith.mulf %0, %cst : tensor<64x64xf32, #blocked1> + %3 = ttg.convert_layout %2 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> + // CHECK-NEXT: return [[CVT]], [[RESULT]] + tt.return %1, %3 : tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked> +} + +// CHECK-LABEL: respect_dominance +tt.func @respect_dominance(%arg0: tensor<64x64xf32, #blocked>) -> (tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked>) { + %cst = arith.constant dense<2.000000e+00> : tensor<64x64xf32, #blocked1> + + // CHECK-COUNT-2: convert_layout + %0 = tt.trans %arg0 {order = array} : tensor<64x64xf32, #blocked> -> tensor<64x64xf32, #blocked1> + + %2 = arith.mulf %0, %cst : tensor<64x64xf32, #blocked1> + %1 = ttg.convert_layout %0 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> + %3 = ttg.convert_layout %2 : tensor<64x64xf32, #blocked1> -> tensor<64x64xf32, #blocked> + tt.return %1, %3 : tensor<64x64xf32, #blocked>, tensor<64x64xf32, #blocked> +} + +// CHECK-LABEL: remat_across_regions +tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) { + // CHECK-NEXT: scf.if + scf.if %arg0 { + // CHECK-NEXT: convert_layout + %0 = ttg.convert_layout %arg1 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + "test.keep"(%0) : (tensor<8x8xf32, #blocked1>) -> () + // CHECK: else + } else { + %0 = "test.dummy"() : () -> i32 + // CHECK: convert_layout + %1 = ttg.convert_layout %arg1 : tensor<8x8xf32, #blocked> -> tensor<8x8xf32, #blocked1> + "test.keep"(%1) : (tensor<8x8xf32, #blocked1>) -> () + // CHECK: } + } + // CHECK-NEXT: return + tt.return +} + +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index 4126f4cc4a..86ddbbd195 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -19,6 +19,25 @@ using namespace mlir::triton::gpu; namespace { +// Scales the given bf16 v using the given scale factor without relying on bf16 +// multiplication. +// +// In gfx9 architectures, we don't have bf16 VALU ops. So instead this function +// handles v * scale multiplication using fp32 VALU ops. LLVM backend can do it +// for us, just with unnecessary overheads. +Value mxfpScaleBf16ViaF32(RewriterBase &rewriter, Location loc, Value v, + Value scale) { + Value c16 = i32_val(16); + Value vF32 = bitcast(shl(zext(i32_ty, bitcast(v, i16_ty)), c16), f32_ty); + Value scaleF32 = bitcast(shl(zext(i32_ty, scale), i32_val(23)), f32_ty); + Value mulF32 = fmul(vF32, scaleF32); + Value mulI16 = trunc(i16_ty, lshr(bitcast(mulF32, i32_ty), c16)); + // Account for NaN in the scale as per the mxfp specification. + Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); + Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); + return select(scaleIsNan, nanBf16, bitcast(mulI16, bf16_ty)); +}; + class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { private: const TargetInfoBase &targetInfo; @@ -98,7 +117,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int j = 0; j < 32; ++j) { int index = 32 * i + j; xVals[index] = - LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 16]); + mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], si[j / 16]); } } } else { @@ -121,7 +140,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int j = 0; j < 32; ++j) { int index = 32 * i + j; xVals[index] = - LLVM::mxfpScaleBf16(rewriter, loc, xVals[index], si[j / 8]); + mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], si[j / 8]); } } } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index e66a2feb57..fdc1e37b71 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -32,49 +32,92 @@ namespace ttg = mlir::triton::gpu; namespace tt = mlir::triton; namespace { -bool verifyNonNegativeByAssumption(Value expr, - const DenseSet &assumptions) { +template +bool verifyNonSmallerByAssumption(Value expr, + const DenseSet &assumptions, + F matchesOther) { for (Value assume : assumptions) { - LDBG("Assumption:" << assume); if (auto cmpOp = assume.getDefiningOp()) { - bool isGreaterThan = (cmpOp.getPredicate() == arith::CmpIPredicate::sge || - cmpOp.getPredicate() == arith::CmpIPredicate::sgt); - APInt cst; - if (isGreaterThan && (cmpOp.getLhs() == expr) && - matchPattern(cmpOp.getRhs(), m_ConstantInt(&cst))) { - return cst.isNonNegative(); + switch (cmpOp.getPredicate()) { + case arith::CmpIPredicate::eq: + case arith::CmpIPredicate::sge: + case arith::CmpIPredicate::sgt: { + if (cmpOp.getLhs() == expr && matchesOther(cmpOp.getRhs())) { + LDBG(" " << expr << " non-neg by assumption " << cmpOp); + return true; + } + break; + } + case arith::CmpIPredicate::sle: + case arith::CmpIPredicate::slt: { + if (cmpOp.getRhs() == expr && matchesOther(cmpOp.getLhs())) { + LDBG(" " << expr << " non-neg by assumption " << cmpOp); + return true; + } + break; + } + default: + break; } } } return false; } +bool verifyNonNegativeByAssumption(Value expr, + const DenseSet &assumptions) { + return verifyNonSmallerByAssumption(expr, assumptions, [](auto otherExpr) { + APInt cst; + return matchPattern(otherExpr, m_ConstantInt(&cst)) && cst.isNonNegative(); + }); +} + +bool verifyNonSmallerByAssumption(Value expr, + const DenseSet &assumptions, + Value other) { + return verifyNonSmallerByAssumption( + expr, assumptions, [&](auto otherAssum) { return otherAssum == other; }); +} + bool verifyNonNegativeExpr(Value expr, const DenseSet &assumptions) { + LDBG("Determing if non-negative: " << expr); // Check if the expression is contained in any assumption if (verifyNonNegativeByAssumption(expr, assumptions)) { - LDBG("Non negative by assumption"); return true; } // Recurse if the operation is defined Operation *op = expr.getDefiningOp(); - if (!op) + if (!op) { + LDBG(" No defining op, assuming possibly negative"); return false; + } bool nonNegative = llvm::TypeSwitch(expr.getDefiningOp()) - .Case([&](auto broadcastOp) { - return verifyNonNegativeExpr(broadcastOp.getSrc(), assumptions); + // Various unary triton ops that don't change the sign of the operand + .Case([&](auto unaryOp) { + return verifyNonNegativeExpr(unaryOp.getOperand(), assumptions); }) - .Case([&](auto expandOp) { - return verifyNonNegativeExpr(expandOp.getSrc(), assumptions); + .Case([&](auto gatherOp) { + return verifyNonNegativeExpr(gatherOp.getSrc(), assumptions); }) - .Case([&](auto splatOp) { - return verifyNonNegativeExpr(splatOp.getSrc(), assumptions); + // Joining two non-negative tensors is still non-negative + .Case([&](auto joinOp) { + return verifyNonNegativeExpr(joinOp.getLhs(), assumptions) && + verifyNonNegativeExpr(joinOp.getRhs(), assumptions); }) + // Returns a tensor representing histogram: historgrams only contain + // buckets of non-negative values. + .Case([&](auto) { return true; }) .Case([&](auto makeRangeOp) { - return makeRangeOp.getStart() >= 0 && makeRangeOp.getEnd() >= 0; + // See the warning in TritonOps.td: getStart/getEnd return unsigned, + // so we need to look through get*Attr. + return makeRangeOp.getStartAttr().getInt() >= 0 && + makeRangeOp.getEndAttr().getInt() >= 0; }) .Case( [&](auto constIntOp) { return constIntOp.value() >= 0; }) @@ -85,12 +128,14 @@ bool verifyNonNegativeExpr(Value expr, const DenseSet &assumptions) { return constVal.getSplatValue().isNonNegative(); return false; }) - .Case([&](auto pidOp) { return true; }) + .Case([&](auto) { + // These are defined as signless, but are actually unsigned + return true; + }) .Case([&](auto maxOp) { // max(a,b) >= 0 iff a>=0 || b>=0 - bool nnLhs = verifyNonNegativeExpr(maxOp.getLhs(), assumptions); - bool nnRhs = verifyNonNegativeExpr(maxOp.getRhs(), assumptions); - return nnLhs || nnRhs; + return verifyNonNegativeExpr(maxOp.getLhs(), assumptions) || + verifyNonNegativeExpr(maxOp.getRhs(), assumptions); }) .Case([&](auto remsiOp) { // a % b >= 0 iff a>=0 @@ -100,18 +145,52 @@ bool verifyNonNegativeExpr(Value expr, const DenseSet &assumptions) { // a = OP b >= 0 iff b >= 0 return verifyNonNegativeExpr(unaryOp->getOperand(0), assumptions); }) + // Casting from arbitrary data does *not* guarantee the offset is in + // range (even if pointer, or the data is non-negative when + // interpreted as the src's type). + .Case( + [&](auto) { return false; }) + .Case( + // These OPs also return unsigned values. + // TODO: We can also sniff whether a Value is unsigned by looking + // for whether or not it's used as an argument to one of + // these OPs. + [&](auto uOp) { return true; }) .Case( // Generally speaking, a OP b >= 0 iff a >= 0 && b >= 0 when // OP != sub [&](Operation *binOp) { - bool nnLhs = - verifyNonNegativeExpr(binOp->getOperand(0), assumptions); - bool nnRhs = - verifyNonNegativeExpr(binOp->getOperand(1), assumptions); - return nnLhs && nnRhs; + return verifyNonNegativeExpr(binOp->getOperand(0), + assumptions) && + verifyNonNegativeExpr(binOp->getOperand(1), assumptions); }) + // TODO: more scf + .Case([&](auto ifOp) { + auto results = ifOp.getResults(); + auto it = std::find(results.begin(), results.end(), expr); + assert(it != results.end() && "expr should be the result of ifOp"); + auto resultIdx = it - results.begin(); + + // If we're here then we must have both then/else regions + // (each with 1 block) and each region must terminate with an + // `scf.yield` expression. + auto thenYield = cast(ifOp.thenYield()); + auto elseYield = cast(ifOp.elseYield()); + return verifyNonNegativeExpr(thenYield->getOperand(resultIdx), + assumptions) && + verifyNonNegativeExpr(elseYield->getOperand(resultIdx), + assumptions); + }) + .Case([&](auto op) { + // If a user annotates tl.assume(a >= b) then we know a - b >= 0 + return verifyNonSmallerByAssumption(op.getLhs(), assumptions, + op.getRhs()); + }) .Default([&](Operation *op) { // Conservatively assume that the expression is negative + LDBG(" Unhandled op, cannot assume non-negative"); return false; }); return nonNegative; @@ -260,6 +339,9 @@ class TritonAMDGPUConvertToBufferOpsPass assumptions.insert(op->getOperand(0)); }); LDBG("Number of assumptions found: " << assumptions.size()); + for (Value assume : assumptions) { + LDBG("Assumption:" << assume); + } patterns.add(context, assumptions); patterns.add(context, assumptions);