Skip to content

Commit

Permalink
[XLA:GPU][Emitters] Use DeviceDescription in lower_tensors.cc.
Browse files Browse the repository at this point in the history
I tried to initialize se::DeviceDescription in the LowerTensorsPass constructor, but TableGen does not like it. I will try to fix it in a follow-up.

PiperOrigin-RevId: 701701647
  • Loading branch information
pifon2a authored and Google-ML-Automation committed Dec 3, 2024
1 parent e9947dd commit e00e35d
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 41 deletions.
8 changes: 3 additions & 5 deletions xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -616,12 +616,8 @@ void AddLoopTransformationPasses(mlir::OpPassManager& pm,

void AddLoweringPasses(mlir::OpPassManager& pm,
const se::DeviceDescription& device) {
bool is_amd = std::holds_alternative<se::RocmComputeCapability>(
device.gpu_compute_capability());
pm.addNestedPass<FuncOp>(CreateConvertPureCallOpsPass());
pm.addPass(CreateLowerTensorsPass(
is_amd, is_amd ? device.rocm_compute_capability().gcn_arch_name()
: device.cuda_compute_capability().ToString()));
pm.addPass(CreateLowerTensorsPass(device.ToGpuProto().SerializeAsString()));
pm.addPass(mlir::createConvertComplexToStandardPass());
pm.addPass(CreateMergePointersToSameSlicePass());

Expand Down Expand Up @@ -649,6 +645,8 @@ void AddLoweringPasses(mlir::OpPassManager& pm,
pm.addPass(CreateExpandFloatOpsPass());
pm.addPass(mlir::createLowerAffinePass());
pm.addPass(mlir::createConvertSCFToCFPass());
bool is_amd = std::holds_alternative<se::RocmComputeCapability>(
device.gpu_compute_capability());
pm.addPass(CreateLowerToLLVMPass(is_amd));
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
}
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/fusions/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,6 @@ cc_library(
"@llvm-project//mlir:VectorDialect",
"@llvm-project//mlir:VectorToLLVM",
"@llvm-project//mlir:VectorTransforms",
"@tsl//tsl/platform:protobuf",
],
)
64 changes: 39 additions & 25 deletions xla/service/gpu/fusions/transforms/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <tuple>
#include <utility>
#include <variant>
#include <vector>

#include "absl/algorithm/container.h"
Expand All @@ -29,6 +30,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
Expand Down Expand Up @@ -61,6 +63,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "third_party/protobuf/text_format.h"
#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"
Expand Down Expand Up @@ -98,6 +101,11 @@ namespace arith = ::mlir::arith;
namespace scf = ::mlir::scf;
namespace ml = ::mlir::LLVM;

bool IsAMD(const se::DeviceDescription& device_description) {
return std::holds_alternative<se::RocmComputeCapability>(
device_description.gpu_compute_capability());
}

Value GetDestinationBuffer(Value dest) {
while (dest.getDefiningOp()) {
int result_number = mlir::cast<mlir::OpResult>(dest).getResultNumber();
Expand Down Expand Up @@ -653,11 +661,10 @@ Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, Value value, Type ty) {

class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
public:
RewriteAtomicRMW(mlir::MLIRContext* context, bool is_amd,
const std::string& gpu_arch)
RewriteAtomicRMW(mlir::MLIRContext* context,
const se::DeviceDescription* device_description)
: mlir::OpRewritePattern<AtomicRMWOp>(context),
is_amd_(is_amd),
gpu_arch_(gpu_arch) {}
device_description_(device_description) {}

LogicalResult matchAndRewrite(
AtomicRMWOp op, mlir::PatternRewriter& rewriter) const override {
Expand Down Expand Up @@ -749,7 +756,8 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
ml::AtomicBinOp atomic_bin_op = modifier_parameters->second;

Location loc = op.getLoc();
llvm::StringRef sync_scope = is_amd_ ? "agent" : "";
bool is_amd = IsAMD(*device_description_);
llvm::StringRef sync_scope = is_amd ? "agent" : "";
mlir::ImplicitLocOpBuilder b(loc, rewriter);
Value addr = CreateGep(op.getInput(), op.getIndices(), b);

Expand All @@ -774,10 +782,14 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
}
case ml::AtomicBinOp::fadd: {
// TODO(b/336367154): Introduce an atomic_rmw op with the binOp attr.
return is_amd_ ? emitAMDAtomicFAdd(loc, modifier_arg, addr, sync_scope,
gpu_arch_, rewriter)
: emitNVidiaAtomicFAdd(loc, modifier_arg, addr,
sync_scope, gpu_arch_, rewriter);
return is_amd ? emitAMDAtomicFAdd(
loc, modifier_arg, addr, sync_scope,
device_description_->rocm_compute_capability(),
rewriter)
: emitNVidiaAtomicFAdd(
loc, modifier_arg, addr, sync_scope,
device_description_->cuda_compute_capability(),
rewriter);
}
case ml::AtomicBinOp::fmax: {
return rewriteAtomicFMaxAsIntAtomics(loc, modifier_arg, addr,
Expand All @@ -789,11 +801,10 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
return success();
}

LogicalResult emitNVidiaAtomicFAdd(Location loc, Value modifier_arg,
Value addr, llvm::StringRef sync_scope,
llvm::StringRef cuda_arch,
OpBuilder& b) const {
se::CudaComputeCapability cuda_compute_capability(cuda_arch.str());
LogicalResult emitNVidiaAtomicFAdd(
Location loc, Value modifier_arg, Value addr, llvm::StringRef sync_scope,
const se::CudaComputeCapability& cuda_compute_capability,
OpBuilder& b) const {
Type element_type = modifier_arg.getType();
// "atom.add.f64 requires sm_60 or higher."
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
Expand All @@ -815,11 +826,10 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
return success();
}

LogicalResult emitAMDAtomicFAdd(Location loc, Value modifier_arg, Value addr,
llvm::StringRef sync_scope,
llvm::StringRef gcn_arch,
OpBuilder& b) const {
se::RocmComputeCapability rocm_compute_capability(gcn_arch.str());
LogicalResult emitAMDAtomicFAdd(
Location loc, Value modifier_arg, Value addr, llvm::StringRef sync_scope,
const se::RocmComputeCapability& rocm_compute_capability,
OpBuilder& b) const {
Type element_type = modifier_arg.getType();
bool is_supported_f16_atomic =
element_type.isF16() &&
Expand Down Expand Up @@ -1048,8 +1058,7 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern<AtomicRMWOp> {
});
}

bool is_amd_;
std::string gpu_arch_;
const se::DeviceDescription* device_description_;
};

template <typename FType>
Expand Down Expand Up @@ -1158,9 +1167,15 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
: LowerTensorsPassBase(options) {}

void runOnOperation() override {
se::GpuDeviceInfoProto device_info;
CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_,
&device_info));
se::DeviceDescription device_description(device_info);

MLIRContext* mlir_context = &getContext();
mlir::RewritePatternSet tensor_patterns(mlir_context);
tensor_patterns.add<RewriteAtomicRMW>(mlir_context, is_amd_gpu_, gpu_arch_);

tensor_patterns.add<RewriteAtomicRMW>(mlir_context, &device_description);
tensor_patterns
.add<RewriteAllocateShared, RewriterExpm1Op, RewriteNonScalarConstants,
RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
Expand Down Expand Up @@ -1216,10 +1231,9 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
} // namespace

std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass(
bool is_amd_gpu, const std::string& gpu_arch) {
const std::string& gpu_device_info) {
LowerTensorsPassOptions options;
options.is_amd_gpu_ = is_amd_gpu;
options.gpu_arch_ = gpu_arch;
options.gpu_device_info_ = gpu_device_info;
return std::make_unique<LowerTensorsPass>(options);
}

Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/fusions/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ std::unique_ptr<mlir::Pass> CreateEraseDeadFunctionsPass();
std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass();
std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass();
std::unique_ptr<mlir::Pass> CreateLowerTensorsPass(
bool is_amd_gpu = false, const std::string& gpu_arch = "6.0");
const std::string& gpu_device_info =
"cuda_compute_capability { major: 6 }");
std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass(bool use_rocdl);
std::unique_ptr<mlir::Pass> CreateLowerXlaGpuToScfPass(int64_t warp_size = 32);
std::unique_ptr<mlir::Pass> CreateLowerXlaGpuLoopsToScfPass();
Expand Down
6 changes: 2 additions & 4 deletions xla/service/gpu/fusions/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,8 @@ def LowerTensorsPass : Pass<"xla-gpu-lower-tensors", "mlir::ModuleOp"> {
"xla::gpu::XlaGpuDialect",
];
let options = [
Option<"is_amd_gpu_", "is_amd_gpu", "bool", /*default=*/"false",
"True if AMD GPU.">,
Option<"gpu_arch_", "gpu_arch", "std::string", /*default=*/"",
"CUDA or ROCm compute capability.">,
Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"",
"Serialized stream_executor::GPUDeviceInfo proto.">,
];
let constructor = "CreateLowerTensorsPass()";
}
Expand Down
12 changes: 6 additions & 6 deletions xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=6.0" \
// RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 6}'" \
// RUN: | FileCheck %s

// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=7.0" \
// RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 7}'" \
// RUN: | FileCheck %s --check-prefix=CHECK-VOLTA

// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=8.0" \
// RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 8}'" \
// RUN: | FileCheck %s --check-prefix=CHECK-AMPERE

// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
// RUN: -xla-gpu-lower-tensors="is_amd_gpu=false gpu_arch=9.0" \
// RUN: -xla-gpu-lower-tensors="gpu_device_info='cuda_compute_capability {major: 9}'" \
// RUN: | FileCheck %s --check-prefix=CHECK-HOPPER

// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx908:sramecc+:xnack" \
// RUN: -xla-gpu-lower-tensors="gpu_device_info='rocm_compute_capability {gcn_arch_name: \"gfx908:sramecc+:xnack\"}'" \
// RUN: | FileCheck %s --check-prefix=CHECK-GFX908-MI100

// RUN: mlir_fusions_opt %s --allow-unregistered-dialect -split-input-file \
// RUN: -xla-gpu-lower-tensors="is_amd_gpu=true gpu_arch=gfx90a:sramecc+:xnack" \
// RUN: -xla-gpu-lower-tensors="gpu_device_info='rocm_compute_capability {gcn_arch_name: \"gfx90a:sramecc+:xnack\"}'" \
// RUN: | FileCheck %s --check-prefix=CHECK-GFX90A-MI200

module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<index, 32 : i32>>} {
Expand Down

0 comments on commit e00e35d

Please sign in to comment.