Skip to content

Commit

Permalink
Revert "Revert "[LAYOUTS] Implement IR support for LinearLayouts (#51…
Browse files Browse the repository at this point in the history
…70)""

This reverts commit 7b5daa4.
  • Loading branch information
whitneywhtsang committed Nov 22, 2024
1 parent b5a791e commit 5bbce9e
Show file tree
Hide file tree
Showing 14 changed files with 785 additions and 197 deletions.
65 changes: 65 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,71 @@ triton::gpu::BlockedEncodingAttr
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
int numWarps, int threadsPerWarp, int numCTAs);

// For each output dimension d, ensure that the layout's output size (i.e., its
// codomain) does not exceed shape[d]. Do this without changing the size of the
// layout's inputs (i.e., leave its domain unchanged).
//
// This function is invariant to the order of the layout's input and output
// dimensions.
//
// We achieve this by setting the largest value in each output dimension d to 0
// because bases that map to a location larger than shape[d]
// effectively duplicate along that dimension. For example, consider a layout
// with an output dimension size of 32, and we call ensureLayoutNotLargerThan to
// shrink the output dimension size to 8:
//
// L(register=1) = 8
// L(register=2) = 4
// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 16
//
// In the first step, we shrink the output dimension size to 16 by setting
// L(lane=2) to 0:
//
// L(register=1) = 8
// L(register=2) = 4
// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 0
//
// This means that lane=2 has the same data as lane=0.
//
// Now the output dimension of this layout has a size of 16, which is still
// larger than 8. We find the current largest value in the output dimension,
// which is L(register=1) = 8, and we set L(register=1) to 0:
//
// L(register=1) = 0
// L(register=2) = 4
// L(register=4) = 1
// L(lane=1) = 2
// L(lane=2) = 0
//
// Now the output dimension of this layout has a size of 8, which is the desired
// size. Note that this method works only because the bases are powers of two,
// which is the case for DistributedLayouts If broadcastRegisters is false, we
// remove any register that's larger than the desired shape. In the example
// above we would have
// L(register=1) = 4
// L(register=2) = 1
// L(lane=1) = 2
// L(lane=2) = 0
LinearLayout
ensureLayoutNotLargerThan(const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape,
bool broadcastRegisters = true);

// For each out-dim d, ensure the layout's out-size (i.e. its codomain) is no
// smaller than shape[d]. Do this by increasing the size of the layout's inputs
// along its most-minor dimension ("register" for register layouts, "offset" for
// shared layouts).
//
// This function is invariant to the order of the layout's input dimensions, but
// it cares about the order of the output dims, which should be minor-to-major.
LinearLayout ensureLayoutNotSmallerThan(
const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape);

// Dump information about which threads/registers contain each of the tensor
// elements.
void dumpLayout(RankedTensorType tensorType);
Expand Down
30 changes: 28 additions & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ Right now, Triton implements two main classes of layouts: shared, and distribute
code extraBaseClassDeclaration = [{
unsigned getTotalElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
SmallVector<unsigned> getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
}];
}

Expand Down Expand Up @@ -147,7 +146,6 @@ addition, if there's only one CTA per CGA, then Triton canonicalizes CTAOrder to
let genVerifyDecl = 1;
let skipDefaultBuilders = 1;
}

//===----------------------------------------------------------------------===//
// Shared Layout Encoding
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -565,6 +563,34 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
}];
}

//===----------------------------------------------------------------------===//
// Linear Layout Encoding
//===----------------------------------------------------------------------===//

def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"> {
let mnemonic = "linear";

let description = [{
See the docs in LinearLayout.h for the definition of linear layouts.
}];

let parameters = (ins "LinearLayout":$linearLayout);

let extraClassDeclaration = extraDistributedDeclaration # [{
SmallVector<unsigned> getContigPerThread() const;
SmallVector<unsigned> getOrder() const;
}];

let genVerifyDecl = 1;
// Example of assembly format:
// <{register = [[0, 1], [8, 0], [0, 8], [64, 0]],
// lane = [[0, 2], [0, 4], [1, 0], [2, 0], [4, 0]],
// warp = [[16, 0], [32, 0]],
// block = []}>
let hasCustomAssemblyFormat = 1;
}


//===----------------------------------------------------------------------===//
// Blocked Layout Encoding
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions include/triton/Tools/LinearLayout.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <vector>

#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
Expand Down Expand Up @@ -432,6 +433,7 @@ class LinearLayout {
// (e.g. by reshaping) then the order doesn't really affect anything.
auto getInDimNames() const { return llvm::make_first_range(bases); }
auto getOutDimNames() const { return llvm::make_first_range(outDims); }
auto getOutDimSizes() const { return llvm::make_second_range(outDims); }

// Gets the position that this outDim occupies in getOutDimNames(). Asserts
// if the dim is not present.
Expand Down Expand Up @@ -693,6 +695,7 @@ class LinearLayout {
return !(lhs == rhs);
}
bool equalIgnoringOutDimSizes(const LinearLayout &other) const;
friend size_t hash_value(const LinearLayout &layout);

private:
// Factory function that gracefully fails rather than asserts if the layout is
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
if (isa<BlockedEncodingAttr>(layout)) {
return true;
}
if (isa<LinearEncodingAttr>(layout)) {
return true;
}
if (auto slice = dyn_cast<SliceEncodingAttr>(layout)) {
return layoutIsOK(slice.getParent());
}
Expand Down
5 changes: 2 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
if (isa<SharedEncodingAttr>(srcLayout) &&
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
dstLayout) ||
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr,
LinearEncodingAttr>(dstLayout) ||
isSupportedDotOpLayout(dstTy))) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
Expand Down Expand Up @@ -206,7 +206,6 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstTy)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");

Expand Down
Loading

0 comments on commit 5bbce9e

Please sign in to comment.