Skip to content

Commit

Permalink
PR #19655: [ROCm] Make MLIR Math dialect lowering more deterministic
Browse files Browse the repository at this point in the history
Imported from GitHub PR #19655

First apply patterns from GpuToROCDLConversionPatterns then do the cleanup with MathToLLVMConversionPatterns
Copybara import of the project:

--
377d5a1 by Dragan Mladjenovic <[email protected]>:

[ROCm] Make MLIR Math dialect lowering more deterministic

First apply patterns from GpuToROCDLConversionPatterns then
do the cleanup with MathToLLVMConversionPatterns

Merging this change closes #19655

COPYBARA_INTEGRATE_REVIEW=#19655 from ROCm:ci_mlir_math_fix 377d5a1
PiperOrigin-RevId: 701941679
  • Loading branch information
draganmladjenovic authored and Google-ML-Automation committed Dec 2, 2024
1 parent 795fe1f commit 3c376bf
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions xla/service/gpu/fusions/transforms/lower_to_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
mlir::arith::populateArithExpandOpsPatterns(patterns);
mlir::arith::populateArithToLLVMConversionPatterns(type_converter,
patterns);
mlir::populateMathToLLVMConversionPatterns(type_converter, patterns,
/* approximateLog1p */ false);
if (!this->is_amd_gpu_) {
mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
} else {
Expand All @@ -89,12 +87,23 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
mlir::configureGpuToROCDLConversionLegality(target);
}
target.addIllegalDialect<mlir::arith::ArithDialect, mlir::func::FuncDialect,
mlir::complex::ComplexDialect,
mlir::math::MathDialect>();
mlir::complex::ComplexDialect>();
target.addLegalOp<mlir::ModuleOp>();

if (failed(
applyFullConversion(getOperation(), target, std::move(patterns)))) {
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
signalPassFailure();
return;
}

// Cleanup any leftover math ops not handled NVVM or ROCDL lowering
mlir::RewritePatternSet mathPatterns(&getContext());
mlir::populateMathToLLVMConversionPatterns(type_converter, mathPatterns,
/* approximateLog1p */ false);
target.addIllegalDialect<mlir::math::MathDialect>();

if (failed(applyFullConversion(getOperation(), target,
std::move(mathPatterns)))) {
signalPassFailure();
}
}
Expand Down

0 comments on commit 3c376bf

Please sign in to comment.