Skip to content

Commit

Permalink
#sdy Swap XLA Shardy passes to use StableHLO instead of MHLO as much …
Browse files Browse the repository at this point in the history
…as possible.

Note that the test case `func @import_sharding_group_with_unused_result` in `sdy_round_trip_import_pipeline.mlir` has been moved to `mhlo_import_pipeline.mlir` since a `xla.sdy.ShardingGroup` custom call with an empty tuple result becomes a custom call with no results after tuple flattening. So this is the relevant pipeline for the test case.

PiperOrigin-RevId: 701034667
  • Loading branch information
bartchr808 authored and Google-ML-Automation committed Nov 29, 2024
1 parent aa9ba7d commit dffcb9e
Show file tree
Hide file tree
Showing 35 changed files with 560 additions and 554 deletions.
6 changes: 4 additions & 2 deletions xla/service/spmd/shardy/mhlo_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

Expand All @@ -70,13 +71,13 @@ cc_library(
"//xla/service/spmd/shardy:constants",
"@com_google_absl//absl/log:check",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InliningUtils",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

Expand All @@ -88,6 +89,7 @@ cc_library(
":export_ops",
":export_shardings",
":shard_map_export",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/spmd/shardy/round_trip_common:export_named_computations",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
Expand Down Expand Up @@ -147,7 +149,6 @@ cc_library(
hdrs = ["shard_map_import.h"],
deps = [
"//xla:xla_data_proto_cc",
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/spmd/shardy:constants",
"@com_google_absl//absl/algorithm:container",
Expand All @@ -163,5 +164,6 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)
9 changes: 6 additions & 3 deletions xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand All @@ -45,6 +46,7 @@ limitations under the License.
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/sharding_op_util.h"
Expand All @@ -54,6 +56,7 @@ namespace sdy {

namespace {

namespace stablehlo = ::mlir::stablehlo;
namespace mhlo = ::mlir::mhlo;

using ::mlir::ConversionPatternRewriter;
Expand All @@ -73,7 +76,7 @@ using ::mlir::sdy::ShardingConstraintOp;
using ::mlir::sdy::TensorShardingAttr;
using ::mlir::sdy::TensorShardingPerValueAttr;

// Converts `sdy::ConstantOp` to `mhlo::ConstantOp`.
// Converts `sdy::ConstantOp` to `stablehlo::ConstantOp`.
class ConstantPattern : public OpConversionPattern<ConstantOp> {
using OpConversionPattern::OpConversionPattern;

Expand All @@ -82,7 +85,7 @@ class ConstantPattern : public OpConversionPattern<ConstantOp> {
ConversionPatternRewriter& rewriter) const override {
// We use the generic op builder so that unregistered attributes will be
// added to the new op.
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
op, op->getResultTypes(), adaptor.getOperands(), op->getAttrs());
return success();
}
Expand Down Expand Up @@ -134,7 +137,7 @@ class ExportOpsPass
// ShardingConstraintOp should be replaced by ReshardOp before this pass.
// Hence, we add ShardingConstraintOp as an illegal op.
target.addIllegalOp<ConstantOp, ReshardOp, ShardingConstraintOp>();
target.addLegalOp<mhlo::ConstantOp, mhlo::CopyOp>();
target.addLegalOp<stablehlo::ConstantOp, mhlo::CopyOp>();
mlir::RewritePatternSet patterns(&context);
// After converting `sdy.constant` into `mhlo.constant`, the constants
// should not be deduped via folding. Fortunately, folding only happens in
Expand Down
2 changes: 2 additions & 0 deletions xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Support/LLVM.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h"
#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h"
Expand All @@ -36,6 +37,7 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) {
pm.addPass(createMhloRoundTripShardMapExportPass());
pm.addPass(createExportNamedComputationsPass());
pm.addPass(createExportMhloShardingsPass());
pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass());
}

void registerMhloExportPipeline() {
Expand Down
4 changes: 2 additions & 2 deletions xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,8 @@ void addMhloImportPipeline(mlir::OpPassManager& pm,
void registerMhloImportPipeline() {
mlir::PassPipelineRegistration<> importPipeline(
"xla-sdy-mhlo-import-pipeline",
"Run passes to import an mhlo module with `mhlo.shardings` into the SDY "
"(Shardy) dialect.",
"Run passes to import a StableHLO module with `mhlo.shardings` into the "
"SDY (Shardy) dialect.",
std::bind(addMhloImportPipeline, std::placeholders::_1, ArrayRef<bool>(),
ArrayRef<bool>()));
}
Expand Down
3 changes: 2 additions & 1 deletion xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ limitations under the License.
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
Expand All @@ -73,7 +74,7 @@ using ::mlir::StringAttr;
using ::mlir::StringRef;
using ::mlir::Value;
using ::mlir::mhlo::CopyOp;
using ::mlir::mhlo::CustomCallOp;
using ::mlir::stablehlo::CustomCallOp;

namespace sdy = ::mlir::sdy;
using sdy::kShardingAttr;
Expand Down
4 changes: 2 additions & 2 deletions xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ limitations under the License.
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/xla_data.pb.h"
Expand All @@ -73,7 +73,7 @@ using ::mlir::StringRef;
using ::mlir::Value;
using ::mlir::func::CallOp;
using ::mlir::func::FuncOp;
using ::mlir::mhlo::CustomCallOp;
using ::mlir::stablehlo::CustomCallOp;

namespace sdy = ::mlir::sdy;
using sdy::AxisRefAttr;
Expand Down
4 changes: 2 additions & 2 deletions xla/service/spmd/shardy/round_trip_common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ cc_library(
hdrs = ["import_sdy_custom_calls.h"],
deps = [
"//xla:sharding_op_util",
"//xla/mlir_hlo",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@com_google_absl//absl/log:check",
Expand All @@ -29,6 +28,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

Expand Down Expand Up @@ -86,14 +86,14 @@ cc_library(
srcs = ["open_while_free_vars_sharding.cc"],
hdrs = ["open_while_free_vars_sharding.h"],
deps = [
"//xla/mlir_hlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

Expand Down
4 changes: 2 additions & 2 deletions xla/service/spmd/shardy/round_trip_common/import_constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ limitations under the License.
namespace xla {
namespace sdy {

// Creates a pass that converts an `mhlo.constant` (which is foldable) into an
// `sdy.constant` (which isn't foldable).
// Creates a pass that converts a `stablehlo.constant` (which is foldable) into
// an `sdy.constant` (which isn't foldable).
std::unique_ptr<mlir::Pass> createImportConstantsPass();

// Register the xla-sdy-import-constants pass.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"
#include "xla/sharding_op_util.h"
Expand All @@ -47,11 +47,11 @@ namespace {

using ::mlir::IntegerAttr;
using ::mlir::StringRef;
using ::mlir::mhlo::CustomCallOp;
using ::mlir::sdy::ShardingConstraintOp;
using ::mlir::sdy::ShardingGroupOp;
using ::mlir::sdy::TensorShardingAttr;
using ::mlir::mhlo::CustomCallOpAdaptor;
using ::mlir::stablehlo::CustomCallOp;
using ::mlir::stablehlo::CustomCallOpAdaptor;

mlir::LogicalResult rewriteShardingCustomCall(
CustomCallOp op, CustomCallOpAdaptor adaptor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ limitations under the License.
#include "mlir/Transforms/RegionUtils.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "stablehlo/dialect/StablehloOps.h"

namespace xla {
namespace sdy {
Expand All @@ -49,7 +49,7 @@ class OpenWhileFreeVarsShardingPass
FuncOp funcOp = getOperation();
mlir::IRRewriter rewriter(funcOp);

funcOp.walk([&](mlir::mhlo::WhileOp op) {
funcOp.walk([&](mlir::stablehlo::WhileOp op) {
llvm::SetVector<mlir::Value> freeVars;
mlir::getUsedValuesDefinedAbove(op->getRegions(), freeVars);
rewriter.setInsertionPoint(op);
Expand Down
15 changes: 9 additions & 6 deletions xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,28 @@ namespace sdy {
using ::mlir::func::FuncOp;

void addCommonPreImportPasses(mlir::OpPassManager& pm) {
pm.addPass(mlir::createSymbolDCEPass());
// TODO(b/333505182): remove when partitioning is done in SDY.
// We call prepare-for-export pass before SDY propagation, so that all IR
// changes happen before shardings are added to operations, to ensure the
// correct shardings are added and that they are not lost by this pass.
pm.addNestedPass<FuncOp>(mlir::mhlo::createPrepareForExportPass());

// We import `mhlo.constant` ops to `sdy.constant` ops so that constants
// We import `stablehlo.constant` ops to `sdy.constant` ops so that constants
// aren't folded in greedy pattern rewriters, which would lift them outside of
// nested regions (this undoes `WhileLoopConstantSinking` HLO pass).
// Therefore, this pass needs to be applied after any mhlo pass that expects
// `mhlo.constant`, and before any pass that has a greedy pattern rewriter.
// Therefore, this pass needs to be applied after any stablehlo pass that
// expects `stablehlo.constant`, and before any pass that has a greedy pattern
// rewriter.
pm.addNestedPass<FuncOp>(createImportConstantsPass());

pm.addNestedPass<FuncOp>(mlir::mhlo::createFlattenTuplePass());
// We need to canonicalize redundant mhlo::GetTupleElementOp and
// mhlo::GetTupleOp. We also need to canonicalize mhlo::WhileOp before
// `createOpenWhileFreeVarsShardingPass`.
pm.addPass(mlir::createCanonicalizerPass());
// Shardy is currently operating on stablehlo, since this is what JAX
// emits. Long term shardy will be fully dialect agnostic, and both mhlo
// and stablehlo can register their ops for sdy propagation.
pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass());
pm.addPass(mlir::createSymbolDCEPass());
}

void addCommonPostImportPasses(mlir::OpPassManager& pm) {
Expand Down
9 changes: 4 additions & 5 deletions xla/service/spmd/shardy/sdy_round_trip/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ cc_library(
srcs = ["export_shardy_attrs.cc"],
hdrs = ["export_shardy_attrs.h"],
deps = [
"//xla/mlir_hlo",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@llvm-project//llvm:Support",
Expand All @@ -30,6 +29,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

Expand All @@ -38,7 +38,6 @@ cc_library(
srcs = ["export_ops.cc"],
hdrs = ["export_ops.h"],
deps = [
"//xla/mlir_hlo",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@llvm-project//llvm:Support",
Expand All @@ -47,6 +46,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

Expand All @@ -55,8 +55,6 @@ cc_library(
srcs = ["import_shardy_attrs.cc"],
hdrs = ["import_shardy_attrs.h"],
deps = [
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@llvm-project//llvm:Support",
Expand All @@ -68,6 +66,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

Expand All @@ -94,7 +93,6 @@ cc_library(
srcs = ["shard_map_import.cc"],
hdrs = ["shard_map_import.h"],
deps = [
"//xla/mlir_hlo",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@com_google_absl//absl/log:check",
Expand All @@ -106,6 +104,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
],
)

Expand Down
Loading

0 comments on commit dffcb9e

Please sign in to comment.