diff --git a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index 0acf4ba29eedb9..a7a2f522cd417b 100644 --- a/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -616,12 +616,8 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm, void AddLoweringPasses(mlir::OpPassManager& pm, const se::DeviceDescription& device) { - bool is_amd = std::holds_alternative( - device.gpu_compute_capability()); pm.addNestedPass(CreateConvertPureCallOpsPass()); - pm.addPass(CreateLowerTensorsPass( - is_amd, is_amd ? device.rocm_compute_capability().gcn_arch_name() - : device.cuda_compute_capability().ToString())); + pm.addPass(CreateLowerTensorsPass(device.ToGpuProto().SerializeAsString())); pm.addPass(mlir::createConvertComplexToStandardPass()); pm.addPass(CreateMergePointersToSameSlicePass()); @@ -649,6 +645,8 @@ void AddLoweringPasses(mlir::OpPassManager& pm, pm.addPass(CreateExpandFloatOpsPass()); pm.addPass(mlir::createLowerAffinePass()); pm.addPass(mlir::createConvertSCFToCFPass()); + bool is_amd = std::holds_alternative( + device.gpu_compute_capability()); pm.addPass(CreateLowerToLLVMPass(is_amd)); pm.addPass(mlir::createReconcileUnrealizedCastsPass()); } diff --git a/xla/service/gpu/fusions/transforms/BUILD b/xla/service/gpu/fusions/transforms/BUILD index d039f0fb4abba8..bd5bdef91c0036 100644 --- a/xla/service/gpu/fusions/transforms/BUILD +++ b/xla/service/gpu/fusions/transforms/BUILD @@ -110,5 +110,6 @@ cc_library( "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorToLLVM", "@llvm-project//mlir:VectorTransforms", + "@tsl//tsl/platform:protobuf", ], ) diff --git a/xla/service/gpu/fusions/transforms/lower_tensors.cc b/xla/service/gpu/fusions/transforms/lower_tensors.cc index 8602630302c3c9..cee82c225e4775 100644 --- a/xla/service/gpu/fusions/transforms/lower_tensors.cc +++ b/xla/service/gpu/fusions/transforms/lower_tensors.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -29,6 +30,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -61,6 +63,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "third_party/protobuf/text_format.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" @@ -98,6 +101,11 @@ namespace arith = ::mlir::arith; namespace scf = ::mlir::scf; namespace ml = ::mlir::LLVM; +bool IsAMD(const se::DeviceDescription& device_description) { + return std::holds_alternative( + device_description.gpu_compute_capability()); +} + Value GetDestinationBuffer(Value dest) { while (dest.getDefiningOp()) { int result_number = mlir::cast(dest).getResultNumber(); @@ -653,11 +661,10 @@ Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type ty) { class RewriteAtomicRMW : public mlir::OpRewritePattern { public: - RewriteAtomicRMW(mlir::MLIRContext* context, bool is_amd, - const std::string& gpu_arch) + RewriteAtomicRMW(mlir::MLIRContext* context, + const se::DeviceDescription* device_description) : mlir::OpRewritePattern(context), - is_amd_(is_amd), - gpu_arch_(gpu_arch) {} + device_description_(device_description) {} LogicalResult matchAndRewrite( AtomicRMWOp op, mlir::PatternRewriter& rewriter) const override { @@ -749,7 +756,8 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { ml::AtomicBinOp atomic_bin_op = modifier_parameters->second; Location loc = op.getLoc(); - llvm::StringRef sync_scope = is_amd_ ? "agent" : ""; + bool is_amd = IsAMD(*device_description_); + llvm::StringRef sync_scope = is_amd ? "agent" : ""; mlir::ImplicitLocOpBuilder b(loc, rewriter); Value addr = CreateGep(op.getInput(), op.getIndices(), b); @@ -774,10 +782,14 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { } case ml::AtomicBinOp::fadd: { // TODO(b/336367154): Introduce an atomic_rmw op with the binOp attr. - return is_amd_ ? emitAMDAtomicFAdd(loc, modifier_arg, addr, sync_scope, - gpu_arch_, rewriter) - : emitNVidiaAtomicFAdd(loc, modifier_arg, addr, - sync_scope, gpu_arch_, rewriter); + return is_amd ? emitAMDAtomicFAdd( + loc, modifier_arg, addr, sync_scope, + device_description_->rocm_compute_capability(), + rewriter) + : emitNVidiaAtomicFAdd( + loc, modifier_arg, addr, sync_scope, + device_description_->cuda_compute_capability(), + rewriter); } case ml::AtomicBinOp::fmax: { return rewriteAtomicFMaxAsIntAtomics(loc, modifier_arg, addr, @@ -789,11 +801,10 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { return success(); } - LogicalResult emitNVidiaAtomicFAdd(Location loc, Value modifier_arg, - Value addr, llvm::StringRef sync_scope, - llvm::StringRef cuda_arch, - OpBuilder& b) const { - se::CudaComputeCapability cuda_compute_capability(cuda_arch.str()); + LogicalResult emitNVidiaAtomicFAdd( + Location loc, Value modifier_arg, Value addr, llvm::StringRef sync_scope, + const se::CudaComputeCapability& cuda_compute_capability, + OpBuilder& b) const { Type element_type = modifier_arg.getType(); // "atom.add.f64 requires sm_60 or higher." // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom @@ -815,11 +826,10 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { return success(); } - LogicalResult emitAMDAtomicFAdd(Location loc, Value modifier_arg, Value addr, - llvm::StringRef sync_scope, - llvm::StringRef gcn_arch, - OpBuilder& b) const { - se::RocmComputeCapability rocm_compute_capability(gcn_arch.str()); + LogicalResult emitAMDAtomicFAdd( + Location loc, Value modifier_arg, Value addr, llvm::StringRef sync_scope, + const se::RocmComputeCapability& rocm_compute_capability, + OpBuilder& b) const { Type element_type = modifier_arg.getType(); bool is_supported_f16_atomic = element_type.isF16() && @@ -1048,8 +1058,7 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { }); } - bool is_amd_; - std::string gpu_arch_; + const se::DeviceDescription* device_description_; }; template @@ -1158,9 +1167,15 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase { : LowerTensorsPassBase(options) {} void runOnOperation() override { + se::GpuDeviceInfoProto device_info; + CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_, + &device_info)); + se::DeviceDescription device_description(device_info); + MLIRContext* mlir_context = &getContext(); mlir::RewritePatternSet tensor_patterns(mlir_context); - tensor_patterns.add(mlir_context, is_amd_gpu_, gpu_arch_); + + tensor_patterns.add(mlir_context, &device_description); tensor_patterns .add { } // namespace std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass( - bool is_amd_gpu, const std::string& gpu_arch) { + const std::string& gpu_device_info) { LowerTensorsPassOptions options; - options.is_amd_gpu_ = is_amd_gpu; - options.gpu_arch_ = gpu_arch; + options.gpu_device_info_ = gpu_device_info; return std::make_unique(options); } diff --git a/xla/service/gpu/fusions/transforms/passes.h b/xla/service/gpu/fusions/transforms/passes.h index 251337ab1a080e..a5ad979abf59fd 100644 --- a/xla/service/gpu/fusions/transforms/passes.h +++ b/xla/service/gpu/fusions/transforms/passes.h @@ -46,7 +46,8 @@ std::unique_ptr CreateEraseDeadFunctionsPass(); std::unique_ptr CreateExpandFloatOpsPass(); std::unique_ptr CreateFlattenTensorsPass(); std::unique_ptr CreateLowerTensorsPass( - bool is_amd_gpu = false, const std::string& gpu_arch = "6.0"); + const std::string& gpu_device_info = + "cuda_compute_capability { major: 6 }"); std::unique_ptr CreateLowerToLLVMPass(bool use_rocdl); std::unique_ptr CreateLowerXlaGpuToScfPass(int64_t warp_size = 32); std::unique_ptr CreateLowerXlaGpuLoopsToScfPass(); diff --git a/xla/service/gpu/fusions/transforms/passes.td b/xla/service/gpu/fusions/transforms/passes.td index 4d08552c0d2681..2a5ce9411f0954 100644 --- a/xla/service/gpu/fusions/transforms/passes.td +++ b/xla/service/gpu/fusions/transforms/passes.td @@ -83,10 +83,8 @@ def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> { "xla::gpu::XlaGpuDialect", ]; let options = [ - Option<"is_amd_gpu_", "is_amd_gpu", "bool", /*default=*/"false", - "True if AMD GPU.">, - Option<"gpu_arch_", "gpu_arch", "std::string", /*default=*/"", - "CUDA or ROCm compute capability.">, + Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"", + "Serialized stream_executor::GPUDeviceInfo proto.">, ]; let constructor = "CreateLowerTensorsPass()"; } diff --git a/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir b/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir index 69377549340b73..2b6aa72f52f09c 100644 --- a/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir +++ b/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir @@ -1,25 +1,25 @@ // RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ -// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=6.0" \ +// RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 6}'" \ // RUN: | FileCheck %s // RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ -// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=7.0" \ +// RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 7}'" \ // RUN: | FileCheck %s --check-prefix=CHECK-VOLTA // RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ -// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=8.0" \ +// RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 8}'" \ // RUN: | FileCheck %s --check-prefix=CHECK-AMPERE // RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ -// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=9.0" \ +// RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 9}'" \ // RUN: | FileCheck %s --check-prefix=CHECK-HOPPER // RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ -// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx908:sramecc+:xnack" \ +// RUN: -xla-gpu-lower-tensors="gpu_device_info='rocm_compute_capability {gcn_arch_name: \"gfx908:sramecc+:xnack\"}'" \ // RUN: | FileCheck %s --check-prefix=CHECK-GFX908-MI100 // RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \ -// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx90a:sramecc+:xnack" \ +// RUN: -xla-gpu-lower-tensors="gpu_device_info='rocm_compute_capability {gcn_arch_name: \"gfx90a:sramecc+:xnack\"}'" \ // RUN: | FileCheck %s --check-prefix=CHECK-GFX90A-MI200 module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry>} {