diff --git a/xla/service/gpu/fusions/transforms/lower_to_llvm.cc b/xla/service/gpu/fusions/transforms/lower_to_llvm.cc index 5c84987333eaf..43042756352f2 100644 --- a/xla/service/gpu/fusions/transforms/lower_to_llvm.cc +++ b/xla/service/gpu/fusions/transforms/lower_to_llvm.cc @@ -68,8 +68,6 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase { mlir::arith::populateArithExpandOpsPatterns(patterns); mlir::arith::populateArithToLLVMConversionPatterns(type_converter, patterns); - mlir::populateMathToLLVMConversionPatterns(type_converter, patterns, - /* approximateLog1p */ false); if (!this->is_amd_gpu_) { mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns); } else { @@ -89,12 +87,23 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase { mlir::configureGpuToROCDLConversionLegality(target); } target.addIllegalDialect(); + mlir::complex::ComplexDialect>(); target.addLegalOp(); - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) { + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + return; + } + + // Cleanup any leftover math ops not handled NVVM or ROCDL lowering + mlir::RewritePatternSet mathPatterns(&getContext()); + mlir::populateMathToLLVMConversionPatterns(type_converter, mathPatterns, + /* approximateLog1p */ false); + target.addIllegalDialect(); + + if (failed(applyFullConversion(getOperation(), target, + std::move(mathPatterns)))) { signalPassFailure(); } }