Skip to content

Commit

Permalink
Merge commit '1f8966b53a3ba5c68294c551250438cca54c771f'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Dec 17, 2024
2 parents 182fb7f + 1f8966b commit aa7a897
Show file tree
Hide file tree
Showing 10 changed files with 691 additions and 99 deletions.
15 changes: 12 additions & 3 deletions bin/triton-tensor-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,18 @@ static cl::opt<std::string> TensorStr(
//===--------------------------------------------------------------------===//

LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) {
// Dispatch to the corresponding dialect helper function to print the layout.
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
return success();
// DistributedEncodingTrait and SharedEncodingAttr implements the
// toLinearLayout interface.
mlir::Attribute layout = tensorType.getEncoding();
if (isa<mlir::triton::gpu::DistributedEncodingTrait,
mlir::triton::gpu::SharedEncodingAttr>(layout)) {
os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView);
return success();
}

llvm::errs() << "Unsupported tensor layout attribute: "
<< tensorType.getEncoding() << "\n";
return failure();
}

LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,
Expand Down
6 changes: 4 additions & 2 deletions include/triton/Dialect/TritonGPU/Transforms/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
// Get backward slice of tensor values starting from the root node along with
// encoding propagation.
LogicalResult getConvertBackwardSlice(
Value root, SetVector<Value> &slice, Attribute rootEncoding,
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr);
std::function<bool(Operation *)> stopPropagation = nullptr,
std::function<Value(OpOperand &, Attribute)> getExistingConversion =
nullptr);

// Populate pattern to remove dead cycles in ForOp.
void populateForOpDeadArgumentElimination(RewritePatternSet &patterns);
Expand Down
89 changes: 65 additions & 24 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
Expand Down Expand Up @@ -116,17 +117,15 @@ class LayoutPropagation {
class LayoutRematerialization {
public:
LayoutRematerialization(FuncOp F) : funcOp(F) {}

// Map the original value to the remat'ed one.
void addRematValue(Value old, Attribute encoding, Value newV);
bool hasRematValue(Value value, Attribute encoding) {
return rematMapping.contains({value, encoding});
}
// Return the remat'ed value in the given encoding.
Value getRematValue(Value value, Attribute encoding) {
auto it = rematMapping.find({value, encoding});
assert(it != rematMapping.end());
return it->second;
// Get the remat'ed value in the given encoding, if one already exists and
// is different then the layout conversion root.
Value getRematValue(Value value, Attribute encoding) const {
return rematMapping.lookup({value, encoding});
}

void cleanup();
void backwardRematerialization();
void backwardRematerialization(ConvertLayoutOp convertOp);
Expand All @@ -137,6 +136,11 @@ class LayoutRematerialization {
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp);

LogicalResult getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr);

private:
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
// Existing tuples of (value, layout) that needs to be updated when recreating
Expand All @@ -148,6 +152,7 @@ class LayoutRematerialization {
// DenseMap<std::pair<Operation*, Attribute>, Operation*>
SetVector<Operation *> opToDelete;
FuncOp funcOp;
DominanceInfo domInfo;
};

void LayoutRematerialization::addRematValue(Value old, Attribute encoding,
Expand Down Expand Up @@ -778,8 +783,8 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
auto layoutIt = layout.find(v);
assert(layoutIt != layout.end());
// If we already have a remat value for this value, use it.
if (hasRematValue(v, layoutIt->second)) {
mapping.map(v, getRematValue(v, layoutIt->second));
if (Value remat = getRematValue(v, layoutIt->second)) {
mapping.map(v, remat);
valuesWithExistingRemat.insert(v);
continue;
}
Expand Down Expand Up @@ -940,12 +945,36 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
rewriteSlice(slice, layout, convertOp, mapping);
}

LogicalResult getRematerializableSlice(
Value root, Attribute rootEncoding, SetVector<Value> &slice,
LogicalResult LayoutRematerialization::getRematerializableSlice(
OpOperand &root, Attribute rootEncoding, SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation = nullptr) {
LogicalResult result = getConvertBackwardSlice(root, slice, rootEncoding,
layout, stopPropagation);
std::function<bool(Operation *)> stopPropagation) {
// Allow re-using existing conversions for a value. Check dominance of any
// re-usable materializations against the root value. This is sufficient
// because the conversions are processed in post-order.
auto getExistingConversion = [&](OpOperand &value, Attribute encoding) {
Value remat = getRematValue(value.get(), encoding);
if (!remat)
return Value();
// `value` can be replaced with an existing rematerialization if it
// dominates the current use of value.
Operation *user = value.getOwner();
if (domInfo.properlyDominates(remat, user)) {
return remat;
}
// Alternatively, if the current use can be sunk below the existing
// rematerialization, then it is okay to use as well. E.g. the current use
// is a conversion that will be folded away when its result is
// rematerialized.
if (isa<ConvertLayoutOp>(user) && remat.getDefiningOp() &&
domInfo.properlyDominates(user, remat.getDefiningOp())) {
return remat;
}
return Value();
};
LogicalResult result =
getConvertBackwardSlice(root, slice, rootEncoding, layout,
stopPropagation, getExistingConversion);
if (result.failed() || slice.empty())
return failure();

Expand All @@ -966,6 +995,12 @@ void LayoutRematerialization::backwardRematerialization() {
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
backwardRematerialization(convertOp);
if (!opToDelete.contains(convertOp)) {
// If the conversion didn't get removed, consider it for re-use in future
// backward slices.
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
}
}
}

Expand All @@ -976,6 +1011,12 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
hoistConvertOnTopOfExtOrBroadcast(convertOp);
if (!opToDelete.contains(convertOp)) {
// If the conversion didn't get removed, consider it for re-use in future
// backward slices.
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
convertOp.getResult());
}
}
}

Expand All @@ -988,14 +1029,14 @@ void LayoutRematerialization::backwardRematerialization(
// careful with the heuristics for both correctness and perf
if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding()))
return;
Value oldV = convertOp->getOperand(0);
Value oldV = convertOp.getSrc();
LDBG("check backward remat with source " << oldV << " encoding "
<< targetType.getEncoding());
// Check to see if there are existing remat'ed values for the pair of oldValue
// and encoding.
if (hasRematValue(oldV, targetType.getEncoding())) {
// and encoding. Make sure it dominates the current conversion.
Value newV = getRematValue(oldV, targetType.getEncoding());
if (newV && domInfo.properlyDominates(newV, convertOp)) {
// Replace it with the remat'ed value.
Value newV = getRematValue(oldV, targetType.getEncoding());
convertOp.replaceAllUsesWith(newV);
opToDelete.insert(convertOp);
LDBG("found remat'ed value" << newV);
Expand All @@ -1007,7 +1048,7 @@ void LayoutRematerialization::backwardRematerialization(
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result = getRematerializableSlice(
convertOp.getSrc(), targetType.getEncoding(), slice, layout);
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout);
if (result.failed()) {
LDBG(" getRematerializableSlice failed");
return;
Expand Down Expand Up @@ -1050,9 +1091,9 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
// 1. Take a backward slice of all the tensor dependencies.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
LogicalResult result =
getRematerializableSlice(convertOp.getSrc(), targetType.getEncoding(),
slice, layout, isExtOrBroadcastOp);
LogicalResult result = getRematerializableSlice(
convertOp.getSrcMutable(), targetType.getEncoding(), slice, layout,
isExtOrBroadcastOp);
if (result.failed())
return;

Expand All @@ -1070,7 +1111,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
if (!srcEncoding)
return;
LogicalResult result = getRematerializableSlice(
op->getOperand(0), srcEncoding, tempSlice, tempLayout);
op->getOpOperand(0), srcEncoding, tempSlice, tempLayout);
// If we can rematerialize the rest of the ext slice we can ignore this
// ext as it won't need a convert.
if (result.succeeded()) {
Expand Down
63 changes: 40 additions & 23 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,44 +770,60 @@ static bool isFreeConvert(Operation *op) {
convertOp.getType());
}

LogicalResult
getConvertBackwardSlice(Value root, SetVector<Value> &slice,
Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation) {
DenseSet<std::pair<Value, Attribute>> seen;
SmallVector<std::pair<Value, Attribute>> queue;

auto enqueue = [&](Value operand, Attribute encoding) {
auto x = std::make_pair(operand, encoding);
LogicalResult getConvertBackwardSlice(
OpOperand &root, SetVector<Value> &slice, Attribute rootEncoding,
DenseMap<Value, Attribute> &layout,
std::function<bool(Operation *)> stopPropagation,
std::function<Value(OpOperand &, Attribute)> getExistingConversion) {
DenseSet<std::pair<OpOperand *, Attribute>> seen;
SmallVector<std::pair<OpOperand *, Attribute>> queue;

auto enqueue = [&](OpOperand &operand, Attribute encoding) {
auto x = std::make_pair(&operand, encoding);
if (!seen.insert(x).second) {
return; // Already enqueued, skip
}
queue.push_back(x);
};
enqueue(root, rootEncoding);

auto updateLayout = [&](Value value, Attribute encoding) {
assert((isa<RankedTensorType>(value.getType())));
slice.insert(value);
if (layout.find(value) != layout.end()) {
if (layout[value] != encoding)
return failure();
}
layout[value] = encoding;
return success();
};

while (!queue.empty()) {
auto [currentValue, encoding] = queue.back();
auto [currentValueUse, encoding] = queue.back();
Value currentValue = currentValueUse->get();
queue.pop_back();
if (!isa<RankedTensorType>(currentValue.getType()))
continue;
// Skip propagating through for op results for now.
// TODO: enable this based on needs.
if (currentValue.getDefiningOp<scf::ForOp>())
return failure();
slice.insert(currentValue);
if (layout.find(currentValue) != layout.end()) {
if (layout[currentValue] != encoding)
if (failed(updateLayout(currentValue, encoding)))
return failure();

Value existing;
if (getExistingConversion &&
(existing = getExistingConversion(*currentValueUse, encoding))) {
if (failed(updateLayout(existing, encoding)))
return failure();
currentValue = existing;
}
layout[currentValue] = encoding;

if (auto ifOp = currentValue.getDefiningOp<scf::IfOp>()) {
unsigned argIdx = mlir::cast<OpResult>(currentValue).getResultNumber();

auto thenValue = ifOp.thenYield().getOperand(argIdx);
auto elseValue = ifOp.elseYield().getOperand(argIdx);
OpOperand &thenValue = ifOp.thenYield()->getOpOperand(argIdx);
OpOperand &elseValue = ifOp.elseYield()->getOpOperand(argIdx);

enqueue(thenValue, encoding);
enqueue(elseValue, encoding);
Expand All @@ -819,10 +835,11 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
for (Value result : definingOp->getResults()) {
if (result == currentValue || !isa<RankedTensorType>(result.getType()))
continue;
enqueue(result, encoding);
if (failed(updateLayout(result, encoding)))
return failure();
}
if (isFreeConvert(definingOp)) {
enqueue(definingOp->getOperand(0), encoding);
enqueue(definingOp->getOpOperand(0), encoding);
continue;
}
if (canFoldIntoConversion(definingOp, encoding))
Expand All @@ -837,10 +854,10 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
auto srcEncoding = inferSrcEncoding(gather, encoding);
if (!srcEncoding)
return failure();
enqueue(gather.getIndices(), srcEncoding);
enqueue(gather.getIndicesMutable(), srcEncoding);
continue;
}
for (auto [i, operand] : llvm::enumerate(definingOp->getOperands())) {
for (auto [i, operand] : llvm::enumerate(definingOp->getOpOperands())) {
auto srcEncoding = inferSrcEncoding(definingOp, encoding);
if (!srcEncoding)
return failure();
Expand All @@ -853,9 +870,9 @@ getConvertBackwardSlice(Value root, SetVector<Value> &slice,
Operation *parentOp = block->getParentOp();
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
OpOperand *initOperand = forOp.getTiedLoopInit(blockArg);
Value yieldOperand = forOp.getBody()->getTerminator()->getOperand(
OpOperand &yieldOperand = forOp.getBody()->getTerminator()->getOpOperand(
blockArg.getArgNumber() - forOp.getNumInductionVars());
enqueue(initOperand->get(), encoding);
enqueue(*initOperand, encoding);
enqueue(yieldOperand, encoding);
continue;
}
Expand Down
Loading

0 comments on commit aa7a897

Please sign in to comment.