Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt slm specialization #2735

Merged
merged 2 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ namespace mlir {
namespace triton {
class AllocationAnalysis;

/// Callback to allow backends to specify target-specific scratch sizes for
/// some operations.
using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;

unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);

// To convert a tensor from one layout to another, we need to allocate a
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
// require multiple iterations, with each iteration involving multiple
Expand Down Expand Up @@ -141,7 +147,8 @@ class Allocation {
explicit Allocation(Operation *operation) : operation(operation) {}

/// Runs allocation analysis on the given top-level operation.
template <typename AllocationAnalysis> void run(FuncAllocMapT &funcAllocMap);
void run(FuncAllocMapT &funcAllocMap,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter);

/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }
Expand Down Expand Up @@ -242,9 +249,6 @@ class Allocation {
size_t sharedMemorySize = 0;
};

template <>
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);

/// Static analysis that computes the allocation of shared memory buffers
/// of the entire call graph.
/// The allocation is performed in a post-order walk of the call graph.
Expand All @@ -255,19 +259,19 @@ class ModuleAllocation : public CallGraph<Allocation> {
public:
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;

template <typename AllocationAnalysis = triton::AllocationAnalysis>
static ModuleAllocation get(ModuleOp moduleOp) {
ModuleAllocation res(moduleOp);
res.walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
ModuleAllocation(ModuleOp moduleOp,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter =
triton::defaultAllocationAnalysisScratchSizeFn)
: CallGraph<Allocation>(moduleOp) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
// Pre-order edge walk callback
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
// Post-order node walk callback
[&](FunctionOpInterface funcOp) {
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
if (inserted)
iter->second.template run<AllocationAnalysis>(res.funcMap);
iter->second.run(funcMap, scratchSizeGetter);
});
return res;
}

size_t getSharedMemorySize() {
Expand All @@ -292,9 +296,6 @@ class ModuleAllocation : public CallGraph<Allocation> {
}

private:
explicit ModuleAllocation(ModuleOp moduleOp)
: CallGraph<Allocation>(moduleOp) {}

FuncOffsetMapT sharedMemoryValue;
};

Expand Down
140 changes: 71 additions & 69 deletions lib/Analysis/Allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,70 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
return scratchConfig;
}

unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
return helper.getScratchSizeInBytes();
}
if (auto scanOp = dyn_cast<ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
return helper.getScratchSizeInBytes();
}
if (auto histogram = dyn_cast<HistogramOp>(op)) {
auto dstTy = histogram.getType();
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
return std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
}
if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType();
auto dstTy = cvtLayout.getType();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
// Conversions from/to shared memory do not need scratch memory.
return 0;
}
// ConvertLayoutOp with both input/output non-shared_layout
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
// also possible to realize it with other approaches in restricted
// conditions, such as warp-shuffle
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
return isa<PointerType>(srcTy.getElementType())
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
}
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
// make it explicit for readability
if (dyn_cast<RankedTensorType>(value.getType())) {
return 0;
}
auto smemShape = getRepShapeForAtomic(op->getResult(0));
auto elems = getNumScratchElements(smemShape);
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
return elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
}
if (auto createTensormap = dyn_cast<ExperimentalTensormapCreateOp>(op)) {
constexpr int32_t kTMASize = 128;
return kTMASize;
}
return 0;
}

class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation,
Allocation::FuncAllocMapT *funcAllocMap,
Allocation *allocation)
Allocation *allocation,
AllocationAnalysisScratchSizeFn scratchSizeGetter)
: operation(operation), funcAllocMap(funcAllocMap),
allocation(allocation) {
allocation(allocation), scratchSizeGetter(scratchSizeGetter) {
run();
}

Expand Down Expand Up @@ -177,77 +234,19 @@ class AllocationAnalysis {

/// Initializes temporary shared memory for a given operation.
void getScratchValueSize(Operation *op) {
const size_t scratchAlignment = 128;
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
ReduceOpHelper helper(reduceOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto scanOp = dyn_cast<ScanOp>(op)) {
ScanLoweringHelper helper(scanOp);
unsigned bytes = helper.getScratchSizeInBytes();
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto histogram = dyn_cast<HistogramOp>(op)) {
auto dstTy = histogram.getType();
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
op->getParentOfType<ModuleOp>());
auto bytes = std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
auto srcTy = cvtLayout.getSrc().getType();
auto dstTy = cvtLayout.getType();
auto srcEncoding = srcTy.getEncoding();
auto dstEncoding = dstTy.getEncoding();
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
// Conversions from/to shared memory do not need scratch memory.
return;
}
// ConvertLayoutOp with both input/output non-shared_layout
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
// also possible to realize it with other approaches in restricted
// conditions, such as warp-shuffle
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
auto bytes =
isa<PointerType>(srcTy.getElementType())
? elems * kPtrBitWidth / 8
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory
// make it explicit for readability
if (dyn_cast<RankedTensorType>(value.getType())) {
// nothing to do
} else {
auto smemShape = getRepShapeForAtomic(op->getResult(0));
auto elems = getNumScratchElements(smemShape);
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
auto bytes =
elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
constexpr size_t scratchAlignment = 128;
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
auto *funcAlloc = &(*funcAllocMap)[funcOp];
auto bytes = funcAlloc->getSharedMemorySize();
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
scratchAlignment);
} else if (auto createTensormap =
dyn_cast<ExperimentalTensormapCreateOp>(op)) {
constexpr int32_t kTMASize = 128;
constexpr int32_t kTMAAlign = 128;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, kTMASize,
kTMAAlign);
return;
}
unsigned bytes = scratchSizeGetter(op);
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
}

void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
Expand Down Expand Up @@ -547,13 +546,16 @@ class AllocationAnalysis {
Allocation::FuncAllocMapT *funcAllocMap;
Allocation *allocation;
BufferRangeMapT bufferRange;
AllocationAnalysisScratchSizeFn scratchSizeGetter;
};

} // namespace triton

template <>
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
void Allocation::run(
FuncAllocMapT &funcAllocMap,
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this,
scratchSizeGetter);
}

std::map<Operation *, SmallVector<Allocation::BufferId>>
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct AllocateSharedMemory
void runOnOperation() override {
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
ModuleAllocation allocation = ModuleAllocation::get(mod);
ModuleAllocation allocation(mod);

mod.walk([&](FunctionOpInterface funcOp) {
funcOp.walk([&](Operation *op) {
Expand Down
3 changes: 1 addition & 2 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ namespace py = pybind11;

void init_triton_analysis(py::module &&m) {
py::class_<mlir::ModuleAllocation>(m, "allocation", py::module_local())
.def(py::init(
&mlir::ModuleAllocation::get<mlir::triton::AllocationAnalysis>));
.def(py::init<mlir::ModuleOp>());
py::class_<mlir::ModuleMembarAnalysis>(m, "membar", py::module_local())
.def(py::init<mlir::ModuleAllocation *>())
.def("run", &mlir::ModuleMembarAnalysis::run);
Expand Down
7 changes: 7 additions & 0 deletions test/Analysis/test-allocation.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation="get-scratch-size-function=ValidConstant" 2>&1 | FileCheck %s --check-prefix=CHECK-128

// Check there are no lines with a size different to 128 and we have at least a line with size 128.

// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}}
// CHECK-128: scratch offset = {{.*}}, size = 128
// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}}

#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}>
Expand Down
32 changes: 31 additions & 1 deletion test/lib/Analysis/TestAllocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,42 @@ using namespace mlir;

namespace {

unsigned getScratchSize128(Operation *) { return 128; }

enum class GetScratchSizeFunction {
None,
ValidConstant,
};

struct TestAllocationPass
: public PassWrapper<TestAllocationPass, OperationPass<ModuleOp>> {

MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);

TestAllocationPass() = default;
TestAllocationPass(const TestAllocationPass &other)
: PassWrapper<TestAllocationPass, OperationPass<ModuleOp>>(other) {}

StringRef getArgument() const final { return "test-print-allocation"; }
StringRef getDescription() const final {
return "print the result of the allocation pass";
}

ModuleAllocation getModuleAllocation() {
switch (getScratchSizeFunction) {
case GetScratchSizeFunction::None:
return {getOperation()};
case GetScratchSizeFunction::ValidConstant:
return {getOperation(), getScratchSize128};
}
llvm_unreachable("Unhandled case");
}

void runOnOperation() override {
auto &os = llvm::errs();
ModuleOp moduleOp = getOperation();
// Convert to std::string can remove quotes from opName
ModuleAllocation moduleAllocation = ModuleAllocation::get(moduleOp);
ModuleAllocation moduleAllocation = getModuleAllocation();
moduleOp.walk([&](triton::FuncOp funcOp) {
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
os << opName << "\n";
Expand Down Expand Up @@ -48,6 +69,15 @@ struct TestAllocationPass
os << "size = " << allocation->getSharedMemorySize() << "\n";
});
}

Option<GetScratchSizeFunction> getScratchSizeFunction{
*this, "get-scratch-size-function",
llvm::cl::desc("Custom scratch size function to use"),
llvm::cl::init(GetScratchSizeFunction::None),
llvm::cl::values(
clEnumValN(GetScratchSizeFunction::None, "None", "None (default)"),
clEnumValN(GetScratchSizeFunction::ValidConstant, "ValidConstant",
"ValidConstant"))};
};

} // namespace
Expand Down
2 changes: 1 addition & 1 deletion test/lib/Analysis/TestMembar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct TestMembarPass
Operation *operation = getOperation();
ModuleOp moduleOp = cast<ModuleOp>(operation);
// Print all ops after membar pass
ModuleAllocation allocation = ModuleAllocation::get(moduleOp);
ModuleAllocation allocation(moduleOp);
ModuleMembarAnalysis membarPass(&allocation,
mlir::triton::NVIDIA::canSkipBarSync);
membarPass.run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class OptimizeAMDLDSUsage
LDSLimit = targetInfo.getSharedMemorySize();
}

ModuleAllocation allocAnalysis = ModuleAllocation::get(mod);
ModuleAllocation allocAnalysis(mod);
if (allocAnalysis.getSharedMemorySize() <= LDSLimit)
return;

Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct ConvertTritonAMDGPUToLLVM
}

// Allocate shared memory and set barrier
ModuleAllocation allocation = ModuleAllocation::get(mod);
ModuleAllocation allocation(mod);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();

Expand Down
11 changes: 3 additions & 8 deletions third_party/intel/include/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@

#include "triton/Analysis/Allocation.h"

namespace mlir {
namespace triton::intel {
class AllocationAnalysis;
} // namespace triton::intel
template <>
void Allocation::run<triton::intel::AllocationAnalysis>(
FuncAllocMapT &funcAllocMap);
} // namespace mlir
namespace mlir::triton::intel {
unsigned allocationAnalysisScratchSizeFn(Operation *op);
} // namespace mlir::triton::intel

#endif
Loading