Skip to content

Commit

Permalink
[XPU][Alloc] Use upstream interface to specialize Allocation analysis
Browse files Browse the repository at this point in the history
Defie custom scratch memory size getter to specialize Allocation analysis.

Signed-off-by: victor-eds <[email protected]>
  • Loading branch information
victor-eds committed Nov 19, 2024
1 parent 31bfb67 commit 3f682b1
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 639 deletions.
9 changes: 1 addition & 8 deletions include/triton/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,6 @@ class Allocation {
size_t sharedMemorySize = 0;
};

template <>
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap);

/// Static analysis that computes the allocation of shared memory buffers
/// of the entire call graph.
/// The allocation is performed in a post-order walk of the call graph.
Expand All @@ -271,11 +268,10 @@ class ModuleAllocation : public CallGraph<Allocation> {
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
// Post-order node walk callback
[&](FunctionOpInterface funcOp) {
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
if (inserted)
iter->second.run(funcMap, scratchSizeGetter);
});
return res;
}

size_t getSharedMemorySize() {
Expand All @@ -300,9 +296,6 @@ class ModuleAllocation : public CallGraph<Allocation> {
}

private:
explicit ModuleAllocation(ModuleOp moduleOp)
: CallGraph<Allocation>(moduleOp) {}

FuncOffsetMapT sharedMemoryValue;
};

Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct AllocateSharedMemory
void runOnOperation() override {
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
ModuleAllocation allocation = ModuleAllocation::get(mod);
ModuleAllocation allocation(mod);

mod.walk([&](FunctionOpInterface funcOp) {
funcOp.walk([&](Operation *op) {
Expand Down
3 changes: 1 addition & 2 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ namespace py = pybind11;

void init_triton_analysis(py::module &&m) {
py::class_<mlir::ModuleAllocation>(m, "allocation", py::module_local())
.def(py::init(
&mlir::ModuleAllocation::get<mlir::triton::AllocationAnalysis>));
.def(py::init<mlir::ModuleOp>());
py::class_<mlir::ModuleMembarAnalysis>(m, "membar", py::module_local())
.def(py::init<mlir::ModuleAllocation *>())
.def("run", &mlir::ModuleMembarAnalysis::run);
Expand Down
2 changes: 1 addition & 1 deletion test/lib/Analysis/TestMembar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct TestMembarPass
Operation *operation = getOperation();
ModuleOp moduleOp = cast<ModuleOp>(operation);
// Print all ops after membar pass
ModuleAllocation allocation = ModuleAllocation::get(moduleOp);
ModuleAllocation allocation(moduleOp);
ModuleMembarAnalysis membarPass(&allocation,
mlir::triton::NVIDIA::canSkipBarSync);
membarPass.run();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class OptimizeAMDLDSUsage
LDSLimit = targetInfo.getSharedMemorySize();
}

ModuleAllocation allocAnalysis = ModuleAllocation::get(mod);
ModuleAllocation allocAnalysis(mod);
if (allocAnalysis.getSharedMemorySize() <= LDSLimit)
return;

Expand Down
2 changes: 1 addition & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct ConvertTritonAMDGPUToLLVM
}

// Allocate shared memory and set barrier
ModuleAllocation allocation = ModuleAllocation::get(mod);
ModuleAllocation allocation(mod);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();

Expand Down
11 changes: 3 additions & 8 deletions third_party/intel/include/Analysis/Allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,8 @@

#include "triton/Analysis/Allocation.h"

namespace mlir {
namespace triton::intel {
class AllocationAnalysis;
} // namespace triton::intel
template <>
void Allocation::run<triton::intel::AllocationAnalysis>(
FuncAllocMapT &funcAllocMap);
} // namespace mlir
namespace mlir::triton::intel {
unsigned allocationAnalysisScratchSizeFn(Operation *op);
} // namespace mlir::triton::intel

#endif
Loading

0 comments on commit 3f682b1

Please sign in to comment.