Skip to content

Commit

Permalink
add xpu option to enable advanced path
Browse files Browse the repository at this point in the history
  • Loading branch information
Dewei-Wang-sh committed Nov 18, 2024
1 parent e8b34a0 commit c427499
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def forward(q, k, v, causal, sm_scale):
num_warps=num_warps, #
num_stages=num_stages, #
grf_mode='large', #
advanced_path=True,
)
return o

Expand Down
5 changes: 3 additions & 2 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class XPUOptions:
backend_name: str = 'intel'
sanitize_overflow: bool = False
generate_native_code: bool = False
advanced_path: bool = False

def __post_init__(self):
default_libdir = Path(__file__).parent / 'lib'
Expand Down Expand Up @@ -233,7 +234,7 @@ def make_ttgir(mod, metadata, opt, properties):
pm.enable_debug()

if (properties["has_subgroup_2d_block_io"] and properties["has_subgroup_matrix_multiply_accumulate"]
and os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1"):
and (os.getenv("TRITON_INTEL_ADVANCED_PATH", "0") == "1" or opt.advanced_path)):
return XPUBackend.AdvancedPath.make_ttgir(mod, metadata, opt)

passes.ttir.add_convert_to_ttgpuir(pm, "xpu", opt.num_warps, opt.threads_per_warp, opt.num_ctas)
Expand Down Expand Up @@ -291,7 +292,7 @@ def make_llir(src, metadata, options):
# being used, e.g., convert_layout.
if os.getenv("TRITON_INTEL_REDUCE_TRANSPOSE", "0") != "1":
intel.passes.ttgpuir.add_allocate_shared_memory(pm)
intel.passes.ttgpuir.add_to_llvmir(pm)
intel.passes.ttgpuir.add_to_llvmir(pm, options.advanced_path)
intel.set_fast_math(mod)
passes.convert.add_arith_to_llvmir(pm)
passes.common.add_canonicalizer(pm)
Expand Down
5 changes: 5 additions & 0 deletions third_party/intel/include/TritonIntelGPUToLLVM/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def ConvertTritonIntelGPUToLLVM
"mlir::triton::TritonDialect",
"mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonGEN::TritonGENDialect"];
let options = [
Option<"advancedPath", "advanced_path",
"bool", /*default*/"false",
"enable advanced path">,
];
}

#endif // TRITONINTELGPU_CONVERSION_PASSES
10 changes: 2 additions & 8 deletions third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,8 @@ struct AddSPIRVEnvPattern : public mlir::OpRewritePattern<ModuleOp> {
/// block pointers or not.
class TritonGPUToLLVMPipelineManager {
public:
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx)
: mod(mod), ctx(ctx),
isAdvancedPathEnabled(
mod->hasAttr(gpu::intel::TritonIntelGPUDialect::
getSupportSG2DBlockAttrName()) &&
mod->hasAttr(
gpu::intel::TritonIntelGPUDialect::getSupportDPASAttrName()) &&
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH")) {}
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx, bool advanced)
: mod(mod), ctx(ctx), isAdvancedPathEnabled(advanced) {}

/// FIXME: remove once the block ptr conversion path is capable of handling
/// shared memory.
Expand Down
13 changes: 9 additions & 4 deletions third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ struct ConvertTritonGPUToLLVM
: public triton::gpu::intel::impl::ConvertTritonIntelGPUToLLVMBase<
ConvertTritonGPUToLLVM> {
using ConvertTritonIntelGPUToLLVMBase::ConvertTritonIntelGPUToLLVMBase;
ConvertTritonGPUToLLVM() = default;
ConvertTritonGPUToLLVM(bool advancedPath) {
this->advancedPath = advancedPath;
}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect, TritonGEN::TritonGENDialect,
Expand All @@ -78,15 +82,16 @@ struct ConvertTritonGPUToLLVM
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
mod, context);
mlir::LowerToLLVMOptions option(context);
bool isAdvancedPathEnabled =
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
getSupportSG2DBlockAttrName()) &&
mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
getSupportDPASAttrName()) &&
mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH");
(mlir::triton::tools::getBoolEnv("TRITON_INTEL_ADVANCED_PATH") ||
advancedPath);
mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
mod, context, isAdvancedPathEnabled);
mlir::LowerToLLVMOptions option(context);
mlir::triton::intel::TargetInfo targetInfo;
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, targetInfo,
isAdvancedPathEnabled);
Expand Down
4 changes: 2 additions & 2 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ void init_triton_intel_passes_ttir(py::module &&m) {
}

void init_triton_intel_passes_ttgpuir(py::module &&m) {
ADD_PASS_WRAPPER_0("add_to_llvmir",
gpu::intel::createConvertTritonIntelGPUToLLVM);
ADD_PASS_WRAPPER_OPT_1("add_to_llvmir",
gpu::intel::createConvertTritonIntelGPUToLLVM, bool);
ADD_PASS_WRAPPER_0("add_accelerate_matmul",
gpu::intel::createTritonIntelGPUAccelerateMatmul);
ADD_PASS_WRAPPER_0("add_decompose_unsupported_conversions",
Expand Down

0 comments on commit c427499

Please sign in to comment.