From 3c376bfa7ae3ae12c70cdc84942d858ccfb35321 Mon Sep 17 00:00:00 2001 From: Dragan Mladjenovic Date: Mon, 2 Dec 2024 06:07:07 -0800 Subject: [PATCH] PR #19655: [ROCm] Make MLIR Math dialect lowering more deterministic Imported from GitHub PR https://github.com/openxla/xla/pull/19655 First apply patterns from GpuToROCDLConversionPatterns then do the cleanup with MathToLLVMConversionPatterns Copybara import of the project: -- 377d5a1f1a624196eef3a241c65c388ba886e5ef by Dragan Mladjenovic : [ROCm] Make MLIR Math dialect lowering more deterministic First apply patterns from GpuToROCDLConversionPatterns then do the cleanup with MathToLLVMConversionPatterns Merging this change closes #19655 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/19655 from ROCm:ci_mlir_math_fix 377d5a1f1a624196eef3a241c65c388ba886e5ef PiperOrigin-RevId: 701941679 --- .../gpu/fusions/transforms/lower_to_llvm.cc | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/xla/service/gpu/fusions/transforms/lower_to_llvm.cc b/xla/service/gpu/fusions/transforms/lower_to_llvm.cc index 5c84987333eaf..43042756352f2 100644 --- a/xla/service/gpu/fusions/transforms/lower_to_llvm.cc +++ b/xla/service/gpu/fusions/transforms/lower_to_llvm.cc @@ -68,8 +68,6 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase { mlir::arith::populateArithExpandOpsPatterns(patterns); mlir::arith::populateArithToLLVMConversionPatterns(type_converter, patterns); - mlir::populateMathToLLVMConversionPatterns(type_converter, patterns, - /* approximateLog1p */ false); if (!this->is_amd_gpu_) { mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns); } else { @@ -89,12 +87,23 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase { mlir::configureGpuToROCDLConversionLegality(target); } target.addIllegalDialect(); + mlir::complex::ComplexDialect>(); target.addLegalOp(); - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) { + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + return; + } + + // Cleanup any leftover math ops not handled NVVM or ROCDL lowering + mlir::RewritePatternSet mathPatterns(&getContext()); + mlir::populateMathToLLVMConversionPatterns(type_converter, mathPatterns, + /* approximateLog1p */ false); + target.addIllegalDialect(); + + if (failed(applyFullConversion(getOperation(), target, + std::move(mathPatterns)))) { signalPassFailure(); } }