diff --git a/xla/service/spmd/shardy/mhlo_round_trip/BUILD b/xla/service/spmd/shardy/mhlo_round_trip/BUILD index 388d0d29b0ad5a..de91a1f5a33eef 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/BUILD +++ b/xla/service/spmd/shardy/mhlo_round_trip/BUILD @@ -54,6 +54,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -70,13 +71,13 @@ cc_library( "//xla/service/spmd/shardy:constants", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -88,6 +89,7 @@ cc_library( ":export_ops", ":export_shardings", ":shard_map_export", + "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy/round_trip_common:export_named_computations", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", @@ -147,7 +149,6 @@ cc_library( hdrs = ["shard_map_import.h"], deps = [ "//xla:xla_data_proto_cc", - "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy:constants", "@com_google_absl//absl/algorithm:container", @@ -163,5 +164,6 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc b/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc index fbc7beca1bf085..bc93d37128c31a 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc +++ b/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -45,6 +46,7 @@ limitations under the License. #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/sharding_op_util.h" @@ -54,6 +56,7 @@ namespace sdy { namespace { +namespace stablehlo = ::mlir::stablehlo; namespace mhlo = ::mlir::mhlo; using ::mlir::ConversionPatternRewriter; @@ -73,7 +76,7 @@ using ::mlir::sdy::ShardingConstraintOp; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; -// Converts `sdy::ConstantOp` to `mhlo::ConstantOp`. +// Converts `sdy::ConstantOp` to `stablehlo::ConstantOp`. class ConstantPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -82,7 +85,7 @@ class ConstantPattern : public OpConversionPattern { ConversionPatternRewriter& rewriter) const override { // We use the generic op builder so that unregistered attributes will be // added to the new op. - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op->getResultTypes(), adaptor.getOperands(), op->getAttrs()); return success(); } @@ -134,7 +137,7 @@ class ExportOpsPass // ShardingConstraintOp should be replaced by ReshardOp before this pass. // Hence, we add ShardingConstraintOp as an illegal op. target.addIllegalOp(); - target.addLegalOp(); + target.addLegalOp(); mlir::RewritePatternSet patterns(&context); // After converting `sdy.constant` into `mhlo.constant`, the constants // should not be deduped via folding. Fortunately, folding only happens in diff --git a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc index 36aee9a64f266b..67f79119ebda6b 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc +++ b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h" @@ -36,6 +37,7 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) { pm.addPass(createMhloRoundTripShardMapExportPass()); pm.addPass(createExportNamedComputationsPass()); pm.addPass(createExportMhloShardingsPass()); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); } void registerMhloExportPipeline() { diff --git a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc index 1f0cff4c61a75c..8091fac253130d 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc +++ b/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc @@ -658,8 +658,8 @@ void addMhloImportPipeline(mlir::OpPassManager& pm, void registerMhloImportPipeline() { mlir::PassPipelineRegistration<> importPipeline( "xla-sdy-mhlo-import-pipeline", - "Run passes to import an mhlo module with `mhlo.shardings` into the SDY " - "(Shardy) dialect.", + "Run passes to import a StableHLO module with `mhlo.shardings` into the " + "SDY (Shardy) dialect.", std::bind(addMhloImportPipeline, std::placeholders::_1, ArrayRef(), ArrayRef())); } diff --git a/xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc b/xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc index e70720f4e8aa1d..73f48c698f9939 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc +++ b/xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.cc @@ -51,6 +51,7 @@ limitations under the License. #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -73,7 +74,7 @@ using ::mlir::StringAttr; using ::mlir::StringRef; using ::mlir::Value; using ::mlir::mhlo::CopyOp; -using ::mlir::mhlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOp; namespace sdy = ::mlir::sdy; using sdy::kShardingAttr; diff --git a/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc b/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc index a8098832a71d5a..d12f194e023f46 100644 --- a/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc +++ b/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc @@ -53,7 +53,7 @@ limitations under the License. #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/xla_data.pb.h" @@ -73,7 +73,7 @@ using ::mlir::StringRef; using ::mlir::Value; using ::mlir::func::CallOp; using ::mlir::func::FuncOp; -using ::mlir::mhlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOp; namespace sdy = ::mlir::sdy; using sdy::AxisRefAttr; diff --git a/xla/service/spmd/shardy/round_trip_common/BUILD b/xla/service/spmd/shardy/round_trip_common/BUILD index af119242aa3437..f4dbc544630d56 100644 --- a/xla/service/spmd/shardy/round_trip_common/BUILD +++ b/xla/service/spmd/shardy/round_trip_common/BUILD @@ -19,7 +19,6 @@ cc_library( hdrs = ["import_sdy_custom_calls.h"], deps = [ "//xla:sharding_op_util", - "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@com_google_absl//absl/log:check", @@ -29,6 +28,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -86,7 +86,6 @@ cc_library( srcs = ["open_while_free_vars_sharding.cc"], hdrs = ["open_while_free_vars_sharding.h"], deps = [ - "//xla/mlir_hlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -94,6 +93,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/xla/service/spmd/shardy/round_trip_common/import_constants.h b/xla/service/spmd/shardy/round_trip_common/import_constants.h index 3de4603894bb9b..a83869ca3e93b0 100644 --- a/xla/service/spmd/shardy/round_trip_common/import_constants.h +++ b/xla/service/spmd/shardy/round_trip_common/import_constants.h @@ -23,8 +23,8 @@ limitations under the License. namespace xla { namespace sdy { -// Creates a pass that converts an `mhlo.constant` (which is foldable) into an -// `sdy.constant` (which isn't foldable). +// Creates a pass that converts a `stablehlo.constant` (which is foldable) into +// an `sdy.constant` (which isn't foldable). std::unique_ptr createImportConstantsPass(); // Register the xla-sdy-import-constants pass. diff --git a/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc b/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc index 8172a217e30a91..4a36c2ba3b1583 100644 --- a/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc +++ b/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" #include "xla/sharding_op_util.h" @@ -47,11 +47,11 @@ namespace { using ::mlir::IntegerAttr; using ::mlir::StringRef; -using ::mlir::mhlo::CustomCallOp; using ::mlir::sdy::ShardingConstraintOp; using ::mlir::sdy::ShardingGroupOp; using ::mlir::sdy::TensorShardingAttr; -using ::mlir::mhlo::CustomCallOpAdaptor; +using ::mlir::stablehlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOpAdaptor; mlir::LogicalResult rewriteShardingCustomCall( CustomCallOp op, CustomCallOpAdaptor adaptor, diff --git a/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc b/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc index 603b270eefa46f..6fe201ccb4fb4d 100644 --- a/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc +++ b/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc @@ -28,7 +28,7 @@ limitations under the License. #include "mlir/Transforms/RegionUtils.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" namespace xla { namespace sdy { @@ -49,7 +49,7 @@ class OpenWhileFreeVarsShardingPass FuncOp funcOp = getOperation(); mlir::IRRewriter rewriter(funcOp); - funcOp.walk([&](mlir::mhlo::WhileOp op) { + funcOp.walk([&](mlir::stablehlo::WhileOp op) { llvm::SetVector freeVars; mlir::getUsedValuesDefinedAbove(op->getRegions(), freeVars); rewriter.setInsertionPoint(op); diff --git a/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc b/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc index 68592c1918a3e3..b07d59575f88de 100644 --- a/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc +++ b/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc @@ -30,25 +30,28 @@ namespace sdy { using ::mlir::func::FuncOp; void addCommonPreImportPasses(mlir::OpPassManager& pm) { - pm.addPass(mlir::createSymbolDCEPass()); // TODO(b/333505182): remove when partitioning is done in SDY. // We call prepare-for-export pass before SDY propagation, so that all IR // changes happen before shardings are added to operations, to ensure the // correct shardings are added and that they are not lost by this pass. pm.addNestedPass(mlir::mhlo::createPrepareForExportPass()); - - // We import `mhlo.constant` ops to `sdy.constant` ops so that constants + // We import `stablehlo.constant` ops to `sdy.constant` ops so that constants // aren't folded in greedy pattern rewriters, which would lift them outside of // nested regions (this undoes `WhileLoopConstantSinking` HLO pass). - // Therefore, this pass needs to be applied after any mhlo pass that expects - // `mhlo.constant`, and before any pass that has a greedy pattern rewriter. + // Therefore, this pass needs to be applied after any stablehlo pass that + // expects `stablehlo.constant`, and before any pass that has a greedy pattern + // rewriter. pm.addNestedPass(createImportConstantsPass()); - pm.addNestedPass(mlir::mhlo::createFlattenTuplePass()); // We need to canonicalize redundant mhlo::GetTupleElementOp and // mhlo::GetTupleOp. We also need to canonicalize mhlo::WhileOp before // `createOpenWhileFreeVarsShardingPass`. pm.addPass(mlir::createCanonicalizerPass()); + // Shardy is currently operating on stablehlo, since this is what JAX + // emits. Long term shardy will be fully dialect agnostic, and both mhlo + // and stablehlo can register their ops for sdy propagation. + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + pm.addPass(mlir::createSymbolDCEPass()); } void addCommonPostImportPasses(mlir::OpPassManager& pm) { diff --git a/xla/service/spmd/shardy/sdy_round_trip/BUILD b/xla/service/spmd/shardy/sdy_round_trip/BUILD index 25c3928386d1df..d6f318ed80cd42 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -20,7 +20,6 @@ cc_library( srcs = ["export_shardy_attrs.cc"], hdrs = ["export_shardy_attrs.h"], deps = [ - "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@llvm-project//llvm:Support", @@ -30,6 +29,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -38,7 +38,6 @@ cc_library( srcs = ["export_ops.cc"], hdrs = ["export_ops.h"], deps = [ - "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@llvm-project//llvm:Support", @@ -47,6 +46,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -55,8 +55,6 @@ cc_library( srcs = ["import_shardy_attrs.cc"], hdrs = ["import_shardy_attrs.h"], deps = [ - "//xla/mlir_hlo", - "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@llvm-project//llvm:Support", @@ -68,6 +66,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) @@ -94,7 +93,6 @@ cc_library( srcs = ["shard_map_import.cc"], hdrs = ["shard_map_import.h"], deps = [ - "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", "@com_google_absl//absl/log:check", @@ -106,6 +104,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", + "@stablehlo//:stablehlo_ops", ], ) diff --git a/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc b/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc index 67c4bc63b86802..0af87ed18371c3 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc @@ -40,11 +40,11 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" -namespace mhlo = ::mlir::mhlo; +namespace stablehlo = ::mlir::stablehlo; namespace xla { namespace sdy { @@ -67,7 +67,7 @@ using ::mlir::sdy::ShardingGroupOp; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; -// Converts `sdy::ConstantOp` to `mhlo::ConstantOp`. +// Converts `sdy::ConstantOp` to `stablehlo::ConstantOp`. class ConstantPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -76,7 +76,7 @@ class ConstantPattern : public OpConversionPattern { ConversionPatternRewriter& rewriter) const override { // We use the generic op builder so that unregistered attributes will be // added to the new op. - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op->getResultTypes(), adaptor.getOperands(), op->getAttrs()); return success(); } @@ -93,7 +93,7 @@ class ShardingConstraintPattern ConversionPatternRewriter& rewriter) const override { TensorShardingAttr sharding = op.getSharding(); - auto customCallOp = rewriter.replaceOpWithNewOp( + auto customCallOp = rewriter.replaceOpWithNewOp( op, op.getType(), adaptor.getInput()); customCallOp.setCallTargetName(kShardingCustomCallTargetName); @@ -117,7 +117,7 @@ class ShardingGroupPattern : public OpConversionPattern { LogicalResult matchAndRewrite( ShardingGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto customCallOp = rewriter.replaceOpWithNewOp( + auto customCallOp = rewriter.replaceOpWithNewOp( op, op->getResultTypes(), adaptor.getInput()); customCallOp.setCallTargetName(kShardingGroupCustomCallTargetName); @@ -137,7 +137,7 @@ class SdyRoundTripExportOpsPass mlir::MLIRContext& context = getContext(); mlir::ConversionTarget target(context); target.addIllegalOp(); - target.addLegalOp(); + target.addLegalOp(); mlir::RewritePatternSet patterns(&context); patterns .add( diff --git a/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc b/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc index 8474d3efb0e6e2..f2ae7ee6a221fc 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc @@ -43,7 +43,7 @@ limitations under the License. #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" @@ -66,7 +66,7 @@ using ::mlir::StringRef; using ::mlir::Value; using ::mlir::func::FuncOp; -using ::mlir::mhlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOp; using ::mlir::sdy::kShardingAttr; using ::mlir::sdy::kShardingRuleAttr; @@ -177,7 +177,7 @@ class SdyRoundTripExportShardyAttrsPass } void getDependentDialects(mlir::DialectRegistry& registry) const final { - registry.insert(); + registry.insert(); } }; diff --git a/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc b/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc index 26f3539163b15f..a9a7f3003fb562 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc @@ -45,8 +45,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" @@ -66,7 +65,7 @@ using ::mlir::StringRef; using ::mlir::SymbolTable; using ::mlir::func::FuncOp; -using ::mlir::mhlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOp; using ::mlir::sdy::kShardingAttr; using ::mlir::sdy::kShardingRuleAttr; diff --git a/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc b/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc index 27b1090af56819..1d547c7842a7b1 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc @@ -44,7 +44,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" @@ -60,7 +60,7 @@ using ::mlir::StringRef; using ::mlir::SymbolTable; using ::mlir::func::CallOp; using ::mlir::func::FuncOp; -using ::mlir::mhlo::CustomCallOp; +using ::mlir::stablehlo::CustomCallOp; namespace sdy = ::mlir::sdy; diff --git a/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD b/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD index 637b43ea2a834d..207205bd537a12 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD +++ b/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD @@ -49,6 +49,7 @@ cc_library( hdrs = ["testing_pipeline.h"], deps = [ ":mhlo_to_hlo_to_mhlo", + "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", "@llvm-project//mlir:Pass", ], diff --git a/xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.cc b/xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.cc index b4e25bafa8c872..984186cb626c2d 100644 --- a/xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.cc +++ b/xla/service/spmd/shardy/sdy_round_trip/test_utils/testing_pipeline.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" #include "xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.h" @@ -30,6 +31,7 @@ void registerSdyRoundTripTestingPipeline() { "MHLO, then import back to Shardy", [](mlir::OpPassManager& pm) { addSdyRoundTripExportPipeline(pm); + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); pm.addPass(createSdyRoundTripMhloToHloToMhloPass()); addSdyRoundTripImportPipeline(pm); }); diff --git a/xla/service/spmd/shardy/shardy_xla_pass.cc b/xla/service/spmd/shardy/shardy_xla_pass.cc index bf31d6ca678e16..9535558efa801f 100644 --- a/xla/service/spmd/shardy/shardy_xla_pass.cc +++ b/xla/service/spmd/shardy/shardy_xla_pass.cc @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "mhlo/transforms/passes.h" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -379,17 +378,12 @@ absl::StatusOr ShardyXLA::Run( useTupleArgs); if (runSdyShardingPropagation) { - // Shardy is currently operating on stablehlo, since this is what JAX - // emits. Long term shardy will be fully dialect agnostic, and both mhlo - // and stablehlo can register their ops for sdy propagation. - pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); // NOTE: if we are using auto-spmd, we will use conservative propagation // since the TOAST cost model cannot account for split axes or padding. mlir::sdy::PropagationOptions options; options.dumpDirectory = shardyDir; options.conservativePropagation = hloModule->use_auto_spmd_partitioning(); mlir::sdy::addPropagationPipeline(pm, options); - pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); } addMhloExportPipeline(pm); pm.addPass(mlir::sdy::createSaveModuleOpPass(shardyDir, diff --git a/xla/service/spmd/shardy/test/import_backend_func_calls.mlir b/xla/service/spmd/shardy/test/import_backend_func_calls.mlir index 35c4d62e8d099d..9ab41e20ce0a19 100644 --- a/xla/service/spmd/shardy/test/import_backend_func_calls.mlir +++ b/xla/service/spmd/shardy/test/import_backend_func_calls.mlir @@ -5,41 +5,41 @@ sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> // CHECK-LABEL: func @no_out_shardings func.func @no_out_shardings(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { // CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%arg0) (%arg1: tensor<8x2xi32>) { - // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> // CHECK-NEXT: sdy.return %[[MULT]] : tensor<8x2xi32> // CHECK-NEXT: } {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, // CHECK-SAME: random_attr = "random_value"} // CHECK-SAME: (tensor<8x2xi32>) -> tensor<8x2xi32> - // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = mhlo.custom_call @MoveToHost(%[[NC]]) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = stablehlo.custom_call @MoveToHost(%[[NC]]) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> // CHECK-NEXT: return %[[MOVE_TO_HOST]] : tensor<8x2xi32> %0 = call @foo(%arg0) {random_attr = "random_value", mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> - %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = stablehlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> return %1 : tensor<8x2xi32> } func.func private @foo(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { - %0 = mhlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + %0 = stablehlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> return %0 : tensor<8x2xi32> } // CHECK-LABEL: func @out_shardings func.func @out_shardings(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { // CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"bar">(%arg0) out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) { - // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> // CHECK-NEXT: sdy.return %[[MULT]] : tensor<8x2xi32> // CHECK-NEXT: } {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, // CHECK-SAME: random_attr = "random_value"} // CHECK-SAME: (tensor<8x2xi32>) -> tensor<8x2xi32> - // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = mhlo.custom_call @MoveToHost(%[[NC]]) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = stablehlo.custom_call @MoveToHost(%[[NC]]) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> // CHECK-NEXT: return %[[MOVE_TO_HOST]] : tensor<8x2xi32> %0 = call @bar(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>, random_attr = "random_value", mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> - %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = stablehlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> return %1 : tensor<8x2xi32> } // NOTE: we ignore any arg/result shardings on the function. func.func private @bar(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) { - %0 = mhlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + %0 = stablehlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> return %0 : tensor<8x2xi32> } @@ -53,6 +53,6 @@ func.func @no_backend_config(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.shardin } func.func private @baz(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { - %0 = mhlo.multiply %arg0, %arg0 : tensor<8x2xi32> + %0 = stablehlo.multiply %arg0, %arg0 : tensor<8x2xi32> return %0 : tensor<8x2xi32> } diff --git a/xla/service/spmd/shardy/test/import_shardings.mlir b/xla/service/spmd/shardy/test/import_shardings.mlir index 9cc62dd41959b7..cabca9b4aaa5d9 100644 --- a/xla/service/spmd/shardy/test/import_shardings.mlir +++ b/xla/service/spmd/shardy/test/import_shardings.mlir @@ -10,8 +10,8 @@ func.func @non_trivial_common_mesh(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{devices=[1,2,16]<=[32] last_tile_dim_replicate}"}, %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[4,4,2]<=[2,16]T(1,0) last_tile_dim_replicate}"}) -> tensor<8x16xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + %1 = "stablehlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -24,10 +24,10 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices= %arg1: tensor<8x8xf32> {mhlo.sharding = "{devices=[1,8,4]<=[2,4,4]T(0,2,1) last_tile_dim_replicate}"}, %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[1,4,8]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"}) -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { - // CHECK-NEXT: mhlo.add + // CHECK-NEXT: stablehlo.add // CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"axis_1", "axis_0"}, {}]>]>} - %0 = mhlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[8,1,4]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[8,1,4]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<8x8xf32> + %1 = "stablehlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -41,7 +41,7 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices= // CHECK-SAME: -> tensor<32x16xf32> { func.func @single_axis(%arg0: tensor<32x8xf32> {mhlo.sharding = "{devices=[16,1]<=[16]}"}, %arg1: tensor<8x16xf32>) -> tensor<32x16xf32> { - %0 = "mhlo.dot" (%arg0, %arg1) : (tensor<32x8xf32>, tensor<8x16xf32>) -> tensor<32x16xf32> + %0 = "stablehlo.dot" (%arg0, %arg1) : (tensor<32x8xf32>, tensor<8x16xf32>) -> tensor<32x16xf32> return %0 : tensor<32x16xf32> } @@ -51,16 +51,16 @@ func.func @single_axis(%arg0: tensor<32x8xf32> {mhlo.sharding = "{devices=[16,1] // CHECK-LABEL: func @multi_result_op func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { - %0 = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: mhlo.reduce + %0 = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: stablehlo.reduce // CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"axis_1"}]>, <@mesh, [{"axis_1"}, {}]>]>} - %1:2 = mhlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] + %1:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] {mhlo.sharding = "{{devices=[1,4,8]<=[8,4]T(1,0) last_tile_dim_replicate}, {devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}}"} : (tensor<4x64x8xf32>, tensor<4x64x8xf32>, tensor, tensor) -> (tensor<4x8xf32>, tensor<4x8xf32>) reducer(%arg2: tensor, %arg4: tensor) (%arg3: tensor, %arg5: tensor) { - %2 = mhlo.add %arg2, %arg4 : tensor - %3 = mhlo.add %arg3, %arg5 : tensor - mhlo.return %2, %3 : tensor, tensor + %2 = stablehlo.add %arg2, %arg4 : tensor + %3 = stablehlo.add %arg3, %arg5 : tensor + stablehlo.return %2, %3 : tensor, tensor } return %1#0, %1#1 : tensor<4x8xf32>, tensor<4x8xf32> } @@ -77,8 +77,8 @@ func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) func.func @fully_replicated(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{replicated}"}, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + %1 = "stablehlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -92,7 +92,7 @@ func.func @fully_replicated(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4 // CHECK-SAME: -> tensor<6x35xf32> { func.func @prime_number(%arg0: tensor<6x35xf32> {mhlo.sharding = "{devices=[6,35]<=[7,10,3]T(2,1,0)}"}, %arg1: tensor<6x35xf32> {mhlo.sharding = "{replicated}"}) -> tensor<6x35xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<6x35xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<6x35xf32> return %0 : tensor<6x35xf32> } @@ -106,7 +106,7 @@ func.func @prime_number(%arg0: tensor<6x35xf32> {mhlo.sharding = "{devices=[6,35 // CHECK-SAME: -> tensor<231x550x42x42xf32> { func.func @prime_number_2(%arg0: tensor<231x550x42x42xf32> {mhlo.sharding = "{devices=[33,10,1,7]<=[2,3,5,7,11]T(1,4,2,0,3)}"}, %arg1: tensor<231x550x42x42xf32> {mhlo.sharding = "{devices=[7,55,6,1]<=[2,3,5,7,11]T(3,2,4,1,0)}"}) -> tensor<231x550x42x42xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<231x550x42x42xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<231x550x42x42xf32> return %0 : tensor<231x550x42x42xf32> } @@ -120,7 +120,7 @@ func.func @prime_number_2(%arg0: tensor<231x550x42x42xf32> {mhlo.sharding = "{de // CHECK-SAME: -> tensor<8x8xf32> { func.func @unknown_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{unknown}"}) -> tensor<8x8xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -133,7 +133,7 @@ func.func @unknown_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4 // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>} func.func @one_maximal_mesh(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -147,7 +147,7 @@ func.func @one_maximal_mesh(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal de // CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>}) func.func @two_maximal_shardings_should_be_sorted(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal device=4}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}) -> tensor<8x8xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -159,7 +159,7 @@ func.func @two_maximal_shardings_should_be_sorted(%arg0: tensor<8x8xf32> {mhlo.s // CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>}) func.func @duplicate_maximal_sharding_should_be_deduped(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}) -> tensor<8x8xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -174,8 +174,8 @@ func.func @duplicate_maximal_sharding_should_be_deduped(%arg0: tensor<8x8xf32> { func.func @two_meshes(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + %1 = "stablehlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -189,11 +189,11 @@ func.func @two_meshes(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]< // CHECK-SAME: -> tensor<8x8xf32> { func.func @maximal_sharding_on_op(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { -// CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg1 +// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 // CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_4, [{}, {}]>]>} -// CHECK-NEXT: %[[MULTIPLY:.*]] = mhlo.multiply %[[ADD]], %[[ADD]] +// CHECK-NEXT: %[[MULTIPLY:.*]] = stablehlo.multiply %[[ADD]], %[[ADD]] // CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_0, [{}, {}]>]>} - %0 = mhlo.add %arg0, %arg1 {mhlo.sharding = "{maximal device=4}"} : tensor<8x8xf32> - %1 = mhlo.multiply %0, %0 {mhlo.sharding = "{maximal device=0}"} : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 {mhlo.sharding = "{maximal device=4}"} : tensor<8x8xf32> + %1 = stablehlo.multiply %0, %0 {mhlo.sharding = "{maximal device=0}"} : tensor<8x8xf32> return %1 : tensor<8x8xf32> } diff --git a/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir b/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir index 1a5f443f4ec472..2a680a3674b0b3 100644 --- a/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir +++ b/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir @@ -21,8 +21,8 @@ sdy.mesh @empty_mesh_1 = <[]> func.func @non_trivial_common_mesh(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_0"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_1"}, {"axis_2"}]>}) -> tensor<8x16xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -37,8 +37,8 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.shardi -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_0", "axis_1"}, {"axis_2"}]>}) { // CHECK-NEXT: mhlo.add // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[8,1,4]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"} - %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -48,22 +48,22 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.shardi // CHECK-SAME: -> tensor<32x16xf32> { func.func @single_axis(%arg0: tensor<32x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"axis_0"}, {}]>}, %arg1: tensor<8x16xf32>) -> tensor<32x16xf32> { - %0 = "mhlo.dot" (%arg0, %arg1) : (tensor<32x8xf32>, tensor<8x16xf32>) -> tensor<32x16xf32> + %0 = stablehlo.dot %arg0, %arg1 : (tensor<32x8xf32>, tensor<8x16xf32>) -> tensor<32x16xf32> return %0 : tensor<32x16xf32> } // CHECK-LABEL: func @multi_result_op func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { - %0 = mhlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.constant dense<0.000000e+00> : tensor // CHECK: mhlo.reduce // CHECK-SAME{LITERAL}: {mhlo.sharding = "{{devices=[1,4,8]<=[8,4]T(1,0) last_tile_dim_replicate}, {devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}}"} - %1:2 = mhlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] + %1:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{}, {"y"}]>, <@mesh_2, [{"y"}, {}]>]>} : (tensor<4x64x8xf32>, tensor<4x64x8xf32>, tensor, tensor) -> (tensor<4x8xf32>, tensor<4x8xf32>) reducer(%arg2: tensor, %arg4: tensor) (%arg3: tensor, %arg5: tensor) { - %2 = mhlo.add %arg2, %arg4 : tensor - %3 = mhlo.add %arg3, %arg5 : tensor - mhlo.return %2, %3 : tensor, tensor + %2 = stablehlo.add %arg2, %arg4 : tensor + %3 = stablehlo.add %arg3, %arg5 : tensor + stablehlo.return %2, %3 : tensor, tensor } return %1#0, %1#1 : tensor<4x8xf32>, tensor<4x8xf32> } @@ -76,8 +76,8 @@ func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) func.func @fully_replicated(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"y"}, {}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{}, {}]>}, %arg2: tensor<8x16xf32>) -> tensor<8x16xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -87,9 +87,9 @@ func.func @fully_replicated(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding // CHECK-SAME: -> tensor<8x16xf32> { func.func @split_axes(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"y"}, {"x":(2)2}]>}, %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x":(1)2}, {"x":(2)4}]>}) -> tensor<8x16xf32> { -// CHECK-NEXT: "mhlo.dot" +// CHECK-NEXT: mhlo.dot // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[4,1,8]<=[2,2,2,4]T(0,2,1,3) last_tile_dim_replicate}"} - %1 = "mhlo.dot" (%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %1 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -129,22 +129,22 @@ func.func @reshard_fully_open_partially_open(%arg0: tensor<8x8xf32>) -> tensor<8 // CHECK-SAME: %arg1: tensor<16x32xf32> {mhlo.sharding = "{devices=[2,1,8]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"}) // CHECK-SAME: -> (tensor<8x32xf32> {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"}) { func.func @sharding_in_manual_computation_body(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_3, [{"a", ?}, {"b", ?}]>}, %arg1: tensor<16x32xf32> {sdy.sharding = #sdy.sharding<@mesh_3, [{"b", ?}, {?}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_3, [{"a"}, {}]>}) { -// CHECK-NEXT: %0 = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,2,4]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<8x16xf32> -// CHECK-NEXT: %1 = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<8x16xf32>) -> tensor<4x8xf32> -// CHECK-NEXT: %2 = mhlo.copy %arg1 {mhlo.sharding = "{devices=[2,1,8]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<16x32xf32> -// CHECK-NEXT: %3 = mhlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> -// CHECK-NEXT: %4 = mhlo.copy %1 {mhlo.sharding = "{devices=[1,2,4,2]<=[8,2]T(1,0) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> -// CHECK-NEXT: %5 = mhlo.add %4, %4 {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> -// CHECK-NEXT: %6 = "mhlo.dot"(%5, %3) {mhlo.sharding = "{devices=[2,2,4]<=[4,4]T(1,0) last_tile_dims={manual}}"} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> -// CHECK-NEXT: %7 = mhlo.sine %6 {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> -// CHECK-NEXT: %8 = mhlo.copy %7 {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> -// CHECK-NEXT: %9 = mhlo.custom_call @SPMDShardToFullShape(%8) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<4x32xf32>) -> tensor<8x32xf32> -// CHECK-NEXT: return %9 : tensor<8x32xf32> +// CHECK-NEXT: %[[COPY_0:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,2,4]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<8x16xf32> +// CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_0]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<8x16xf32>) -> tensor<4x8xf32> +// CHECK-NEXT: %[[COPY_1:.*]] = mhlo.copy %arg1 {mhlo.sharding = "{devices=[2,1,8]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<16x32xf32> +// CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_1]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> +// CHECK-NEXT: %[[RESHARD:.*]] = mhlo.copy %[[FULL_TO_SHARD_0]] {mhlo.sharding = "{devices=[1,2,4,2]<=[8,2]T(1,0) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> +// CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[RESHARD]], %[[RESHARD]] {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> +// CHECK-NEXT: %[[DOT:.*]] = "mhlo.dot"(%[[ADD]], %[[FULL_TO_SHARD_1]]) {mhlo.sharding = "{devices=[2,2,4]<=[4,4]T(1,0) last_tile_dims={manual}}"} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> +// CHECK-NEXT: %[[SINE:.*]] = mhlo.sine %[[DOT]] {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> +// CHECK-NEXT: %[[COPY_2:.*]] = mhlo.copy %[[SINE]] {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> +// CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_2]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<4x32xf32>) -> tensor<8x32xf32> +// CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<8x32xf32> %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh_3, [{"b"}, {"a"}]>, <@mesh_3, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_3, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<4x8xf32>, %arg3: tensor<8x32xf32>) { %1 = sdy.reshard %arg2 <@mesh_3, [{}, {"d"}]> : tensor<4x8xf32> - %2 = mhlo.add %1, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_3, [{"c"}, {}]>]>} : tensor<4x8xf32> - %3 = "mhlo.dot"(%2, %arg3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_3, [{"c"}, {"d"}]>]>} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.sine %3 : tensor<4x32xf32> + %2 = stablehlo.add %1, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_3, [{"c"}, {}]>]>} : tensor<4x8xf32> + %3 = stablehlo.dot %2, %arg3 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_3, [{"c"}, {"d"}]>]>} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.sine %3 : tensor<4x32xf32> sdy.return %4 : tensor<4x32xf32> } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> return %0 : tensor<8x32xf32> @@ -153,18 +153,18 @@ func.func @sharding_in_manual_computation_body(%arg0: tensor<8x16xf32> {sdy.shar // CHECK-LABEL: func @mesh_with_device_id_should_be_converted_to_maximal_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal device=0}"}, %arg1: tensor<8x8xf32>) func.func @mesh_with_device_id_should_be_converted_to_maximal_sharding(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { // CHECK: %[[ADD:.*]] = mhlo.add %arg0, %arg1 - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> // CHECK: %[[ADD_WITH_SHARDING:.*]] = mhlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{maximal device=1}"} - %1 = mhlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_1, [{}, {}]>]>} : tensor<8x8xf32> + %1 = stablehlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_1, [{}, {}]>]>} : tensor<8x8xf32> return %1 : tensor<8x8xf32> } // CHECK-LABEL: func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<8x8xf32>) func.func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@empty_mesh_0, [{}, {}]>}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> { // CHECK: %[[ADD:.*]] = mhlo.add %arg0, %arg1 - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> // CHECK: %[[ADD_WITH_SHARDING:.*]] = mhlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{replicated}"} - %1 = mhlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_1, [{}, {}]>]>} : tensor<8x8xf32> + %1 = stablehlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_1, [{}, {}]>]>} : tensor<8x8xf32> return %1 : tensor<8x8xf32> } @@ -178,8 +178,8 @@ func.func @multiple_shardings_with_device_list(%arg0: tensor<8x8xf32> {sdy.shard %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_4, [{}, {"axis_1"}]>}) -> tensor<8x16xf32> { // CHECK-NEXT: mhlo.add // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[4,1,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate}"} - %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -220,7 +220,7 @@ func.func @free_axis_inside_in_out_shardings_manual_computation( in_shardings=[<@mesh_5, [{"i", ?}, {?}], replicated={"j"}>] out_shardings=[<@mesh_5, [{"i", ?}, {?}], replicated={"j"}>] manual_axes={"j"} (%arg1: tensor<4x8xf32>) { - %1 = mhlo.multiply %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_5, [{"i"}, {}]>]>} : tensor<4x8xf32> + %1 = stablehlo.multiply %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_5, [{"i"}, {}]>]>} : tensor<4x8xf32> %2 = sdy.reshard %1 <@mesh_5, [{"i"}, {}]> : tensor<4x8xf32> sdy.return %2 : tensor<4x8xf32> } : (tensor<4x8xf32>) -> tensor<4x8xf32> @@ -230,5 +230,5 @@ func.func @free_axis_inside_in_out_shardings_manual_computation( // CHECK-LABEL: func private @foo // CHECK-SAME: %arg0: tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} // CHECK-SAME: -> (tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}) { -// CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg0, %arg0 {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : tensor<4x2xi32> +// CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %arg0, %arg0 {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : tensor<4x2xi32> // CHECK-NEXT: return %[[MULT]] : tensor<4x2xi32> diff --git a/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir b/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir index 55ccddd9645d5e..31c66ebb7e5a33 100644 --- a/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir +++ b/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir @@ -32,7 +32,7 @@ func.func @manual(%arg0: tensor<8x8xf32> {mhlo.sharding = "{replicated}"}, // CHECK-SAME: in_shardings=[<@mesh, [{"axis_0", "axis_1"}, {}]>, <@mesh, [{"axis_0"}, {}]>] // CHECK-SAME: out_shardings=[<@mesh, [{"axis_0", "axis_1"}, {}]>] // CHECK-SAME: manual_axes={"axis_0", "axis_1"} (%arg2: tensor<1x8xf32>, %arg3: tensor<1x8xf32>) { - // CHECK-LABEL: mhlo.add + // CHECK-LABEL: stablehlo.add // CHECK-LABEL: sdy.return %0 = mhlo.custom_call @Sharding(%arg0) {mhlo.sharding = "{devices=[8,1]<=[8]}"} : (tensor<8x8xf32>) -> tensor<8x8xf32> %1 = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<8x8xf32>) -> tensor<1x8xf32> @@ -63,14 +63,14 @@ func.func @while_with_free_variables( // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> // CHECK-NEXT: %[[C32:.*]] = sdy.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, []>]>} dense<32> // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg, %[[SC]] + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor @@ -93,16 +93,16 @@ func.func @while_with_free_variables( // CHECK-LABEL: func @while_with_sinked_constants func.func @while_with_sinked_constants(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg, %iterArg + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor @@ -124,7 +124,7 @@ func.func @while_with_sinked_constants(%arg0: tensor<32x96xf32>) -> tensor<32x96 // CHECK-LABEL: func @custom_call_with_tuple_operand_result func.func @custom_call_with_tuple_operand_result(%arg0: tensor<8x8xf32>, %arg1: tensor<4x8xf32>, %arg2: tensor<8x16xf32>) -> tensor<8x8xf32> { - // CHECK-NEXT: %[[FOO:.*]]:3 = mhlo.custom_call @foo(%arg0, %arg1, %arg2) : + // CHECK-NEXT: %[[FOO:.*]]:3 = stablehlo.custom_call @foo(%arg0, %arg1, %arg2) : // CHECK-SAME: (tensor<8x8xf32>, tensor<4x8xf32>, tensor<8x16xf32>) // CHECK-SAME: -> (tensor<8x8xf32>, tensor<4x8xf32>, tensor<8x16xf32>) // CHECK-NEXT: return %[[FOO]]#0 @@ -133,3 +133,13 @@ func.func @custom_call_with_tuple_operand_result(%arg0: tensor<8x8xf32>, %arg1: %2 = mhlo.get_tuple_element %1[0] : (!tuple) -> tensor<8x8xf32> return %2 : tensor<8x8xf32> } + +// ----- + +// CHECK-LABEL: func @import_sharding_group_with_unused_result +// CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { +func.func @import_sharding_group_with_unused_result(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK sdy.sharding_group %arg0 group_id = 21: tensor<8x8xf32> + %0 = mhlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> tuple<> + return %arg0 : tensor<8x8xf32> +} diff --git a/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_export.mlir b/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_export.mlir index 859d067123e635..9e094e6eb7e344 100644 --- a/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_export.mlir +++ b/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_export.mlir @@ -6,22 +6,22 @@ sdy.mesh @mesh_1 = <["a"=2, "b"=2, "c"=2, "d"=2]> // CHECK-LABEL: func @single_manual_comp func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a", ?}, {"b", ?}]>}, %arg1: tensor<16x32xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"b", ?}, {?}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a"}, {}]>}) { // CHECK-NEXT: %0 = mhlo.copy %arg0 {mhlo.sharding = "{devices=[4,2]<=[8]}"} : tensor<8x16xf32> - // CHECK-NEXT: %1 = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<8x16xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<8x16xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %2 = mhlo.copy %arg1 {mhlo.sharding = "{devices=[2,1,4]<=[4,2]T(1,0) last_tile_dim_replicate}"} : tensor<16x32xf32> - // CHECK-NEXT: %3 = mhlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{manual}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> - // CHECK-NEXT: %4 = mhlo.add %1, %1 {mhlo.sharding = "{manual}"} : tensor<2x8xf32> - // CHECK-NEXT: %5 = "mhlo.dot"(%4, %3) {mhlo.sharding = "{manual}"} : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - // CHECK-NEXT: %6 = "mhlo.all_reduce"(%5) + // CHECK-NEXT: %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{manual}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: %4 = stablehlo.add %1, %1 {mhlo.sharding = "{manual}"} : tensor<2x8xf32> + // CHECK-NEXT: %5 = stablehlo.dot %4, %3 {mhlo.sharding = "{manual}"} : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + // CHECK-NEXT: %6 = "stablehlo.all_reduce"(%5) // CHECK: %7 = mhlo.copy %6 {mhlo.sharding = "{manual}"} : tensor<2x32xf32> - // CHECK-NEXT: %8 = mhlo.custom_call @SPMDShardToFullShape(%7) {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<2x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: %8 = stablehlo.custom_call @SPMDShardToFullShape(%7) {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<2x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %8 : tensor<8x32xf32> %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh_0, [{"a"}, {"b"}]>, <@mesh_0, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_0, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { - %1 = mhlo.add %arg2, %arg2 : tensor<2x8xf32> - %2 = "mhlo.dot"(%1, %arg3) : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - %3 = "mhlo.all_reduce"(%2) <{channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids}> ({ + %1 = stablehlo.add %arg2, %arg2 : tensor<2x8xf32> + %2 = stablehlo.dot %1, %arg3 : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + %3 = "stablehlo.all_reduce"(%2) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids}> ({ ^bb0(%arg4: tensor, %arg5: tensor): - %4 = mhlo.add %arg4, %arg5 : tensor - mhlo.return %4 : tensor + %4 = stablehlo.add %arg4, %arg5 : tensor + stablehlo.return %4 : tensor }) : (tensor<2x32xf32>) -> tensor<2x32xf32> sdy.return %3 : tensor<2x32xf32> } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> @@ -32,13 +32,13 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.shard func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a"}, {}]>}) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"b"}]>}) { // CHECK-NEXT: %0 = mhlo.copy %arg0 {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : tensor<8x8xf32> - // CHECK-NEXT: %1 = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<8x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<8x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %2 = mhlo.copy %1 {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %3 = mhlo.custom_call @SPMDShardToFullShape(%2) {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %3 = stablehlo.custom_call @SPMDShardToFullShape(%2) {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: %4 = mhlo.copy %3 {mhlo.sharding = "{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}"} : tensor<8x8xf32> - // CHECK-NEXT: %5 = mhlo.custom_call @SPMDFullToShardShape(%4) {mhlo.sharding = "{devices=[1,1,2,4]<=[4,2]T(1,0) last_tile_dims={manual, replicated}}"} : (tensor<8x8xf32>) -> tensor<8x4xf32> + // CHECK-NEXT: %5 = stablehlo.custom_call @SPMDFullToShardShape(%4) {mhlo.sharding = "{devices=[1,1,2,4]<=[4,2]T(1,0) last_tile_dims={manual, replicated}}"} : (tensor<8x8xf32>) -> tensor<8x4xf32> // CHECK-NEXT: %6 = mhlo.copy %5 {mhlo.sharding = "{devices=[1,1,2,4]<=[4,2]T(1,0) last_tile_dims={manual, replicated}}"} : tensor<8x4xf32> - // CHECK-NEXT: %7 = mhlo.custom_call @SPMDShardToFullShape(%6) {mhlo.sharding = "{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}"} : (tensor<8x4xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %7 = stablehlo.custom_call @SPMDShardToFullShape(%6) {mhlo.sharding = "{devices=[1,2,4]<=[4,2]T(1,0) last_tile_dim_replicate}"} : (tensor<8x4xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %7 : tensor<8x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { sdy.return %arg1 : tensor<2x8xf32> @@ -53,17 +53,17 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy // CHECK-LABEL: func @sharding_in_manual_computation_body func.func @sharding_in_manual_computation_body(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {"b", ?}]>}, %arg1: tensor<16x32xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"b", ?}, {?}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {}]>}) { // CHECK-NEXT: %0 = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,2,4]<=[16] last_tile_dim_replicate}"} : tensor<8x16xf32> - // CHECK-NEXT: %1 = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<8x16xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<8x16xf32>) -> tensor<4x8xf32> // CHECK-NEXT: %2 = mhlo.copy %arg1 {mhlo.sharding = "{devices=[2,1,8]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<16x32xf32> - // CHECK-NEXT: %3 = mhlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> - // CHECK-NEXT: %4 = mhlo.add %1, %1 {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> - // CHECK-NEXT: %5 = "mhlo.dot"(%4, %3) {mhlo.sharding = "{devices=[2,2,4]<=[4,2,2]T(2,1,0) last_tile_dims={manual}}"} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> + // CHECK-NEXT: %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<16x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: %4 = stablehlo.add %1, %1 {mhlo.sharding = "{devices=[2,1,4,2]<=[4,2,2]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<4x8xf32> + // CHECK-NEXT: %5 = stablehlo.dot %4, %3 {mhlo.sharding = "{devices=[2,2,4]<=[4,2,2]T(2,1,0) last_tile_dims={manual}}"} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> // CHECK-NEXT: %6 = mhlo.copy %5 {mhlo.sharding = "{devices=[1,1,4,4]<=[16] last_tile_dims={manual, replicated}}"} : tensor<4x32xf32> - // CHECK-NEXT: %7 = mhlo.custom_call @SPMDShardToFullShape(%6) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<4x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: %7 = stablehlo.custom_call @SPMDShardToFullShape(%6) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<4x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %7 : tensor<8x32xf32> %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh_1, [{"a"}, {"b"}]>, <@mesh_1, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_1, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<4x8xf32>, %arg3: tensor<8x32xf32>) { - %1 = mhlo.add %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c"}, {}]>]>} : tensor<4x8xf32> - %2 = "mhlo.dot"(%1, %arg3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"d"}, {"c"}]>]>} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> + %1 = stablehlo.add %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c"}, {}]>]>} : tensor<4x8xf32> + %2 = stablehlo.dot %1, %arg3 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"d"}, {"c"}]>]>} : (tensor<4x8xf32>, tensor<8x32xf32>) -> tensor<4x32xf32> sdy.return %2 : tensor<4x32xf32> } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> return %0 : tensor<8x32xf32> @@ -71,14 +71,14 @@ func.func @sharding_in_manual_computation_body(%arg0: tensor<8x16xf32> {sdy.shar // CHECK-LABEL: func @call_op_with_no_operands_or_results func.func @call_op_with_no_operands_or_results() { - // CHECK-LABEL: %0 = mhlo.constant + // CHECK-LABEL: %cst = stablehlo.constant // CHECK-NOT: sdy.sharding // CHECK-NOT: mhlo.sharding - // CHECK-NEXT: %1 = mhlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}]>]>} : tensor<2x2xf32> + // CHECK-NEXT: %0 = stablehlo.add %cst, %cst {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}]>]>} : tensor<2x2xf32> // CHECK-NEXT: return sdy.manual_computation() in_shardings=[] out_shardings=[] manual_axes={} () { - %0 = mhlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> - %1 = mhlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}]>]>} : tensor<2x2xf32> + %0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> + %1 = stablehlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}]>]>} : tensor<2x2xf32> sdy.return } : () -> () return @@ -87,18 +87,18 @@ func.func @call_op_with_no_operands_or_results() { // CHECK-LABEL: func @nested_shmaps func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { // CHECK-NEXT: %[[COPY_OPERAND_OUTER:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : tensor<4x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_OUTER:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_OUTER]]) {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_OUTER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_OUTER]]) {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[COPY_OPERAND_INNER:.*]] = mhlo.copy %[[FULL_TO_SHARD_OUTER]] {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> - // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> // CHECK-NEXT: %[[COPY_RESULT_INNER:.*]] = mhlo.copy %[[MULT]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_INNER:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_INNER]]) {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_INNER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_INNER]]) {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[COPY_RESULT_OUTER:.*]] = mhlo.copy %[[SHARD_TO_FULL_INNER]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_OUTER:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_OUTER]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_OUTER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_OUTER]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL_OUTER]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { - %2 = mhlo.multiply %arg2, %arg2 : tensor<2x4xf32> + %2 = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> sdy.return %2 : tensor<2x4xf32> } : (tensor<2x8xf32>) -> tensor<2x8xf32> sdy.return %1 : tensor<2x8xf32> @@ -109,26 +109,26 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@m // CHECK-LABEL: func @nested_shmaps_extra_op func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { // CHECK-NEXT: %[[COPY_OPERAND_OUTER:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : tensor<4x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_OUTER:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_OUTER]]) {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_OUTER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_OUTER]]) {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : (tensor<4x8xf32>) -> tensor<2x8xf32> // CHECK-NEXT: %[[COPY_OPERAND_INNER:.*]] = mhlo.copy %[[FULL_TO_SHARD_OUTER]] {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> - // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[MULT]], %[[MULT]] {mhlo.sharding = "{devices=[2,1,4,2]<=[2,2,2,2]T(2,1,0,3) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[SUB:.*]] = mhlo.subtract %[[ADD]], %[[ADD]] {mhlo.sharding = "{devices=[4,1,4]<=[2,2,4]T(2,1,0) last_tile_dims={manual}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_INNER:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_INNER]]) {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x8xf32>) -> tensor<2x4xf32> + // CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %[[FULL_TO_SHARD_INNER]], %[[FULL_TO_SHARD_INNER]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[MULT]], %[[MULT]] {mhlo.sharding = "{devices=[2,1,4,2]<=[2,2,2,2]T(2,1,0,3) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> + // CHECK-NEXT: %[[SUB:.*]] = stablehlo.subtract %[[ADD]], %[[ADD]] {mhlo.sharding = "{devices=[4,1,4]<=[2,2,4]T(2,1,0) last_tile_dims={manual}}"} : tensor<2x4xf32> // CHECK-NEXT: %[[COPY_RESULT_INNER:.*]] = mhlo.copy %[[SUB]] {mhlo.sharding = "{devices=[1,1,4,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : tensor<2x4xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_INNER:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_INNER]]) {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4xf32>) -> tensor<2x8xf32> - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SHARD_TO_FULL_INNER]], %[[SHARD_TO_FULL_INNER]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_INNER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_INNER]]) {mhlo.sharding = "{devices=[1,2,2,4]<=[2,2,4]T(1,0,2) last_tile_dims={manual, replicated}}"} : (tensor<2x4xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHARD_TO_FULL_INNER]], %[[SHARD_TO_FULL_INNER]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> // CHECK-NEXT: %[[COPY_RESULT_OUTER:.*]] = mhlo.copy %[[ADD]] {mhlo.sharding = "{devices=[1,1,2,8]<=[16] last_tile_dims={manual, replicated}}"} : tensor<2x8xf32> - // CHECK-NEXT: %[[SHARD_TO_FULL_OUTER:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_OUTER]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_OUTER:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_OUTER]]) {mhlo.sharding = "{devices=[2,1,8]<=[16] last_tile_dim_replicate}"} : (tensor<2x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[SHARD_TO_FULL_OUTER]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { - %2 = mhlo.multiply %arg2, %arg2 : tensor<2x4xf32> - %3 = mhlo.add %2, %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c"}, {}]>]>} : tensor<2x4xf32> - %4 = mhlo.subtract %3, %3 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c", "d"}, {}]>]>} : tensor<2x4xf32> + %2 = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> + %3 = stablehlo.add %2, %2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c"}, {}]>]>} : tensor<2x4xf32> + %4 = stablehlo.subtract %3, %3 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"c", "d"}, {}]>]>} : tensor<2x4xf32> sdy.return %4 : tensor<2x4xf32> } : (tensor<2x8xf32>) -> tensor<2x8xf32> - %5 = mhlo.add %1, %1 : tensor<2x8xf32> + %5 = stablehlo.add %1, %1 : tensor<2x8xf32> sdy.return %5 : tensor<2x8xf32> } : (tensor<4x8xf32>) -> tensor<4x8xf32> return %0 : tensor<4x8xf32> @@ -137,22 +137,22 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sh // CHECK-LABEL: func @multiple_manual_computation_uses func.func @multiple_manual_computation_uses(%arg0: tensor<2x4x8xi32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {}, {"a"}]>}, %arg1: tensor<32x16x8xi32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {}, {"a"}]>}) -> (tensor<131x4x8xi32> {sdy.sharding = #sdy.sharding<@mesh_0, [{?}, {?}, {"a"}]>}) { // CHECK-NEXT: %[[COPY_OPERAND_0:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : tensor<2x4x8xi32> - // CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_0]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<2x4x8xi32>) -> tensor<2x4x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_0]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<2x4x8xi32>) -> tensor<2x4x2xi32> // CHECK-NEXT: %[[CUSTOM_CALL:.*]] = stablehlo.custom_call @sdy_testonly(%[[FULL_TO_SHARD_0]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<2x4x2xi32>) -> tensor<3x4x2xi32> // CHECK-NEXT: %[[COPY_RESULT_0:.*]] = mhlo.copy %[[CUSTOM_CALL]] {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : tensor<3x4x2xi32> - // CHECK-NEXT: %[[SHARD_TO_FULL_0:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_0]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<3x4x2xi32>) -> tensor<3x4x8xi32> + // CHECK-NEXT: %[[SHARD_TO_FULL_0:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_0]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<3x4x2xi32>) -> tensor<3x4x8xi32> // CHECK-NEXT: %[[COPY_OPERAND_1:.*]] = mhlo.copy %arg1 {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : tensor<32x16x8xi32> - // CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_1]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<32x16x8xi32>) -> tensor<32x16x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_1]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<32x16x8xi32>) -> tensor<32x16x2xi32> // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[FULL_TO_SHARD_1]] {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<32x16x2xi32>) -> tensor<128x4x2xi32> // CHECK-NEXT: %[[COPY_RESULT_1:.*]] = mhlo.copy %[[RESHAPE]] {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : tensor<128x4x2xi32> - // CHECK-NEXT: %[[SHARD_TO_FULL_1:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_1]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<128x4x2xi32>) -> tensor<128x4x8xi32> + // CHECK-NEXT: %[[SHARD_TO_FULL_1:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_1]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<128x4x2xi32>) -> tensor<128x4x8xi32> // CHECK-NEXT: %[[COPY_OPERAND_2:.*]] = mhlo.copy %[[SHARD_TO_FULL_0]] {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : tensor<3x4x8xi32> - // CHECK-NEXT: %[[FULL_TO_SHARD_2:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_2]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<3x4x8xi32>) -> tensor<3x4x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD_2:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_2]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<3x4x8xi32>) -> tensor<3x4x2xi32> // CHECK-NEXT: %[[COPY_OPERAND_3:.*]] = mhlo.copy %[[SHARD_TO_FULL_1]] {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : tensor<128x4x8xi32> - // CHECK-NEXT: %[[FULL_TO_SHARD_3:.*]] = mhlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_3]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<128x4x8xi32>) -> tensor<128x4x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD_3:.*]] = stablehlo.custom_call @SPMDFullToShardShape(%[[COPY_OPERAND_3]]) {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<128x4x8xi32>) -> tensor<128x4x2xi32> // CHECK-NEXT: %[[CONCAT:.*]] = stablehlo.concatenate %[[FULL_TO_SHARD_3]], %[[FULL_TO_SHARD_2]], dim = 0 {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : (tensor<128x4x2xi32>, tensor<3x4x2xi32>) -> tensor<131x4x2xi32> // CHECK-NEXT: %[[COPY_RESULT_2:.*]] = mhlo.copy %[[CONCAT]] {mhlo.sharding = "{devices=[1,1,1,4,2]<=[8] last_tile_dims={manual, replicated}}"} : tensor<131x4x2xi32> - // CHECK-NEXT: %[[SHARD_TO_FULL_2:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_2]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<131x4x2xi32>) -> tensor<131x4x8xi32> + // CHECK-NEXT: %[[SHARD_TO_FULL_2:.*]] = stablehlo.custom_call @SPMDShardToFullShape(%[[COPY_RESULT_2]]) {mhlo.sharding = "{devices=[1,1,4,2]<=[8] last_tile_dim_replicate}"} : (tensor<131x4x2xi32>) -> tensor<131x4x8xi32> // CHECK-NEXT: return %[[SHARD_TO_FULL_2]] : tensor<131x4x8xi32> %1 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_0, [{}, {}, {"a"}]>] out_shardings=[<@mesh_0, [{}, {}, {"a"}]>] manual_axes={"a"} (%arg2: tensor<2x4x2xi32>) { %4 = stablehlo.custom_call @sdy_testonly(%arg2) : (tensor<2x4x2xi32>) -> tensor<3x4x2xi32> diff --git a/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir b/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir index 12641b0d746476..a62c58cc7a9e96 100644 --- a/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir +++ b/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir @@ -24,11 +24,11 @@ func.func public @call_op_with_one_operand_and_no_results(%arg0: tensor<4xf32>) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{}], replicated={"a"}>] out_shardings=[] manual_axes={"a"} (%arg1: tensor<4xf32>) { // CHECK-NEXT: sdy.return // CHECK-NEXT: } : (tensor<4xf32>) -> () - // CHECK-NEXT: %0 = mhlo.add %arg0, %arg0 : tensor<4xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}], replicated={"a"}>]>} : (tensor<4xf32>) -> tensor<4xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: %0 = stablehlo.add %arg0, %arg0 : tensor<4xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}], replicated={"a"}>]>} : (tensor<4xf32>) -> tensor<4xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<4xf32>) -> tensor<4xf32> call @shmap_body_one_argument_empty_body(%1) : (tensor<4xf32>) -> () - %2 = mhlo.add %arg0, %arg0 : tensor<4xf32> + %2 = stablehlo.add %arg0, %arg0 : tensor<4xf32> return %2 : tensor<4xf32> } // CHECK-NOT: func.func private @shmap_body_one_argument_empty_body @@ -40,18 +40,18 @@ func.func private @shmap_body_one_argument_empty_body(%arg0: tensor<4xf32>) -> ( func.func public @call_op_with_no_operands_and_one_result() -> tensor<4xf32> { // CHECK: %0 = sdy.manual_computation() // CHECK-SAME{LITERAL}: in_shardings=[] out_shardings=[<@mesh_0, [{}], replicated={"a"}>] manual_axes={"a"} () { - // CHECK-LABEL: %1 = mhlo.constant - // CHECK-NEXT: sdy.return %1 : tensor<4xf32> + // CHECK-LABEL: %cst = stablehlo.constant + // CHECK-NEXT: sdy.return %cst : tensor<4xf32> // CHECK-NEXT: } : () -> tensor<4xf32> // CHECK-NEXT: return %0 : tensor<4xf32> %0 = call @shmap_body_no_arg() : () -> (tensor<4xf32>) - %1 = mhlo.custom_call @Sharding(%0) : (tensor<4xf32>) -> tensor<4xf32> - %2 = mhlo.custom_call @SPMDShardToFullShape(%1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}], replicated={"a"}>]>} : (tensor<4xf32>) -> tensor<4xf32> + %1 = stablehlo.custom_call @Sharding(%0) : (tensor<4xf32>) -> tensor<4xf32> + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}], replicated={"a"}>]>} : (tensor<4xf32>) -> tensor<4xf32> return %2 : tensor<4xf32> } // CHECK-NOT: func.func private @shmap_body_no_arg() func.func private @shmap_body_no_arg() -> tensor<4xf32> { - %0 = mhlo.constant dense <[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32> + %0 = stablehlo.constant dense <[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32> return %0 : tensor<4xf32> } @@ -59,20 +59,20 @@ func.func private @shmap_body_no_arg() -> tensor<4xf32> { func.func public @call_op_with_shamp_body_in_middle(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4x32xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<4x32xf32> // CHECK-NEXT: sdy.return %1 : tensor<4x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %0 : tensor<16x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2 = call @prefix_shmap_body_suffix(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @prefix_shmap_body_suffix(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } @@ -80,20 +80,20 @@ func.func private @prefix_shmap_body_suffix(%arg0: tensor<4x32xf32>) -> (tensor< func.func public @shard_map_single_sharded_input_output_dim_0(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<4x32xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<4x32xf32> // CHECK-NEXT: sdy.return %1 : tensor<4x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %0 : tensor<16x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2 = call @shmap_body(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } @@ -101,20 +101,20 @@ func.func private @shmap_body(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { func.func public @shard_map_single_sharded_input_output_dim_1(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {"a"}]>] out_shardings=[<@mesh_1, [{}, {"a"}]>] manual_axes={"a"} (%arg1: tensor<16x8xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<16x8xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<16x8xf32> // CHECK-NEXT: sdy.return %1 : tensor<16x8xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %0 : tensor<16x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x8xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x8xf32> %2 = call @shmap_body_0(%1) : (tensor<16x8xf32>) -> tensor<16x8xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x8xf32>) -> tensor<16x8xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x8xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x8xf32>) -> tensor<16x8xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x8xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body_0 func.func private @shmap_body_0(%arg0: tensor<16x8xf32>) -> (tensor<16x8xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x8xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x8xf32> return %0 : tensor<16x8xf32> } @@ -122,20 +122,20 @@ func.func private @shmap_body_0(%arg0: tensor<16x8xf32>) -> (tensor<16x8xf32>) { func.func public @shard_map_single_replicated_input_sharded_output(%arg0: tensor<16x32xf32>) -> tensor<16x256xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {}], replicated={"a", "b"}>] out_shardings=[<@mesh_1, [{}, {"a", "b"}]>] manual_axes={"a", "b"} (%arg1: tensor<16x32xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<16x32xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<16x32xf32> // CHECK-NEXT: sdy.return %1 : tensor<16x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x256xf32> // CHECK-NEXT: return %0 : tensor<16x256xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> %2 = call @shmap_body_1(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a", "b"}]>]>} : (tensor<16x32xf32>) -> tensor<16x256xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a", "b"}]>]>} : (tensor<16x32xf32>) -> tensor<16x256xf32> return %4 : tensor<16x256xf32> } // CHECK-NOT func.func private @shmap_body_1 func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -143,51 +143,51 @@ func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) func.func public @shard_map_contracting_dim_matmul_all_reduce(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) -> tensor<8x32xf32> { // CHECK: %0 = sdy.manual_computation(%arg0, %arg1) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{"a"}, {"b"}]>, <@mesh_1, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_1, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { - // CHECK-NEXT: %1 = "mhlo.dot_general"(%arg2, %arg3) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - // CHECK-NEXT: %2 = "mhlo.all_reduce"(%1) <{ - // CHECK-SAME{LITERAL}: channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids + // CHECK-NEXT: %1 = stablehlo.dot_general %arg2, %arg3, contracting_dims = [1] x [0] : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + // CHECK-NEXT: %2 = "stablehlo.all_reduce"(%1) <{ + // CHECK-SAME{LITERAL}: channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids // CHECK-SAME: }> ({ // CHECK-NEXT: ^bb0(%arg4: tensor, %arg5: tensor): - // CHECK-NEXT: %3 = mhlo.add %arg4, %arg5 : tensor - // CHECK-NEXT: mhlo.return %3 : tensor + // CHECK-NEXT: %3 = stablehlo.add %arg4, %arg5 : tensor + // CHECK-NEXT: stablehlo.return %3 : tensor // CHECK-NEXT: }) : (tensor<2x32xf32>) -> tensor<2x32xf32> // CHECK-NEXT: sdy.return %2 : tensor<2x32xf32> // CHECK-NEXT: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %0 : tensor<8x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {"b"}]>]>} : (tensor<8x16xf32>) -> tensor<8x16xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<8x16xf32>) -> tensor<2x8xf32> - %2 = mhlo.custom_call @Sharding(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b"}, {}], replicated={"a"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @SPMDFullToShardShape(%2) : (tensor<16x32xf32>) -> tensor<8x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {"b"}]>]>} : (tensor<8x16xf32>) -> tensor<8x16xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<8x16xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @Sharding(%arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b"}, {}], replicated={"a"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) : (tensor<16x32xf32>) -> tensor<8x32xf32> %4 = call @shmap_body_2(%1, %3) : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - %5 = mhlo.custom_call @Sharding(%4) : (tensor<2x32xf32>) -> tensor<2x32xf32> - %6 = mhlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}], replicated={"b"}>]>}: (tensor<2x32xf32>) -> tensor<8x32xf32> + %5 = stablehlo.custom_call @Sharding(%4) : (tensor<2x32xf32>) -> tensor<2x32xf32> + %6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a"}, {}], replicated={"b"}>]>}: (tensor<2x32xf32>) -> tensor<8x32xf32> return %6 : tensor<8x32xf32> } // CHECK-NOT: func.func private @shmap_body_2 func.func private @shmap_body_2(%arg0: tensor<2x8xf32>, %arg1: tensor<8x32xf32>) -> (tensor<2x32xf32>) { - %0 = "mhlo.dot_general"(%arg0, %arg1) <{dot_dimension_numbers = #mhlo.dot, precision_config = [#mhlo, #mhlo]}> : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - %1 = "mhlo.all_reduce"(%0) ({ + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + %1 = "stablehlo.all_reduce"(%0) ({ ^bb0(%arg2: tensor, %arg3: tensor): - %2 = mhlo.add %arg2, %arg3 : tensor - mhlo.return %2 : tensor - }) {channel_handle = #mhlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids} : (tensor<2x32xf32>) -> tensor<2x32xf32> + %2 = stablehlo.add %arg2, %arg3 : tensor + stablehlo.return %2 : tensor + }) {channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 1], [2, 3], [4, 5], [6, 7]]> : tensor<4x2xi64>, use_global_device_ids} : (tensor<2x32xf32>) -> tensor<2x32xf32> return %1 : tensor<2x32xf32> } // CHECK-LABEL: func.func public @shard_map_wrong_callee_name func.func public @shard_map_wrong_callee_name(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> // CHECK: call @shmap_head // CHECK-NOT: sdy.manual_computation %2 = call @shmap_head(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } // CHECK-LABEL: func.func private @shmap_head func.func private @shmap_head(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } @@ -197,16 +197,16 @@ func.func public @shard_map_multiple_results(%arg0: tensor<16x32xf32>) -> tensor // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {}], replicated={"a", "b"}>] out_shardings=[<@mesh_1, [{"a", "b"}, {}]>, <@mesh_1, [{"b", "a"}, {}]>] manual_axes={"a", "b"} (%arg1: tensor<16x32xf32>) { // CHECK-NEXT: sdy.return %arg1, %arg1 : tensor<16x32xf32>, tensor<16x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32> // CHECK-NEXT: return %[[ADD]] : tensor<128x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> %2:2 = call @shmap_body_4(%1) : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) - %3 = mhlo.custom_call @Sharding(%2#0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a", "b"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> - %5 = mhlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %6 = mhlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> - %7 = mhlo.add %4, %6 : tensor<128x32xf32> + %3 = stablehlo.custom_call @Sharding(%2#0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a", "b"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> + %5 = stablehlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> + %7 = stablehlo.add %4, %6 : tensor<128x32xf32> return %7 : tensor<128x32xf32> } // CHECK-NOT: func.func private @shmap_body_4 @@ -218,46 +218,46 @@ func.func private @shmap_body_4(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, func.func public @shard_map_multiple_call_ops(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x32xf32>) { // CHECK-NEXT: %[[SHARD_MAP_0:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg1, %arg1 + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %arg1, %arg1 // CHECK-NEXT: sdy.return %[[ADD_0]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: %[[SHARD_MAP_1:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {"a"}]>] out_shardings=[<@mesh_1, [{}, {"a"}]>] manual_axes={"a"} (%arg1: tensor<16x8xf32>) { - // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %arg1, %arg1 + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg1, %arg1 // CHECK-NEXT: sdy.return %[[MUL]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: %[[SHARD_MAP_2:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %arg1, %arg1 + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %arg1, %arg1 // CHECK-NEXT: sdy.return %[[ADD_1]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %[[SHARD_MAP_0]], %[[SHARD_MAP_1]], %[[SHARD_MAP_2]] - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2 = call @shmap_body_5(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> - %5 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %6 = mhlo.custom_call @SPMDFullToShardShape(%5) : (tensor<16x32xf32>) -> tensor<16x8xf32> + %5 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %6 = stablehlo.custom_call @SPMDFullToShardShape(%5) : (tensor<16x32xf32>) -> tensor<16x8xf32> %7 = call @shmap_body_6(%6) : (tensor<16x8xf32>) -> tensor<16x8xf32> - %8 = mhlo.custom_call @Sharding(%7) : (tensor<16x8xf32>) -> tensor<16x8xf32> - %9 = mhlo.custom_call @SPMDShardToFullShape(%8) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x8xf32>) -> tensor<16x32xf32> + %8 = stablehlo.custom_call @Sharding(%7) : (tensor<16x8xf32>) -> tensor<16x8xf32> + %9 = stablehlo.custom_call @SPMDShardToFullShape(%8) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {"a"}]>]>} : (tensor<16x8xf32>) -> tensor<16x32xf32> %10 = call @shmap_body_5(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %11 = mhlo.custom_call @Sharding(%10) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %12 = mhlo.custom_call @SPMDShardToFullShape(%11) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %11 = stablehlo.custom_call @Sharding(%10) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %12 = stablehlo.custom_call @SPMDShardToFullShape(%11) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4, %9, %12 : tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body_5(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body_6(%arg0: tensor<16x8xf32>) -> (tensor<16x8xf32>) { - %0 = mhlo.multiply %arg0, %arg0 : tensor<16x8xf32> + %0 = stablehlo.multiply %arg0, %arg0 : tensor<16x8xf32> return %0 : tensor<16x8xf32> } @@ -265,42 +265,42 @@ func.func private @shmap_body_6(%arg0: tensor<16x8xf32>) -> (tensor<16x8xf32>) { func.func public @sharding_with_missing_manual_axes(%arg0: tensor<16x16xf32>) -> tensor<32x4xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_2, [{"b"}, {"a"}]>] out_shardings=[<@mesh_2, [{"a"}, {}], replicated={"c"}>] manual_axes={"a", "b", "c"} (%arg1: tensor<8x4xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<8x4xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg1 : tensor<8x4xf32> // CHECK-NEXT: sdy.return %1 : tensor<8x4xf32> // CHECK-NEXT: } : (tensor<16x16xf32>) -> tensor<32x4xf32> // CHECK-NEXT: return %0 : tensor<32x4xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"b"}, {"a"}]>]>} : (tensor<16x16xf32>) -> tensor<16x16xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x16xf32>) -> tensor<8x4xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"b"}, {"a"}]>]>} : (tensor<16x16xf32>) -> tensor<16x16xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x16xf32>) -> tensor<8x4xf32> %2 = call @shmap_body_7(%1) : (tensor<8x4xf32>) -> tensor<8x4xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<8x4xf32>) -> tensor<8x4xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"a"}, {}], replicated={"c"}>]>} : (tensor<8x4xf32>) -> tensor<32x4xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<8x4xf32>) -> tensor<8x4xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"a"}, {}], replicated={"c"}>]>} : (tensor<8x4xf32>) -> tensor<32x4xf32> return %4 : tensor<32x4xf32> } // CHECK-NOT: func.func private @shmap_body_5 func.func private @shmap_body_7(%arg0: tensor<8x4xf32>) -> (tensor<8x4xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<8x4xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x4xf32> return %0 : tensor<8x4xf32> } // CHECK-LABEL: func.func public @shard_map_sharding_custom_call_other_uses func.func public @shard_map_sharding_custom_call_other_uses(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) { - // CHECk-NEXT: %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} + // CHECk-NEXT: %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} // CHECK: %1 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %2 = mhlo.add %arg1, %arg1 : tensor<4x32xf32> + // CHECK-NEXT: %2 = stablehlo.add %arg1, %arg1 : tensor<4x32xf32> // CHECK-NEXT: sdy.return %2 : tensor<4x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %1, %0 - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2 = call @shmap_body_8(%1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4, %0 : tensor<16x32xf32>, tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body_8(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } @@ -308,22 +308,22 @@ func.func private @shmap_body_8(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>) { func.func public @shard_map_unused_results(%arg0: tensor<16x32xf32>) -> tensor<128x32xf32> { // CHECK: %[[SHARD_MAP:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {}], replicated={"a", "b"}>] out_shardings=[<@mesh_1, [{"b", "a"}, {}]>] manual_axes={"a", "b"} (%arg1: tensor<16x32xf32>) { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg1, %arg1 - // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[ADD]], %[[ADD]] + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg1, %arg1 + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %[[ADD]], %[[ADD]] // CHECK-NEXT: sdy.return %[[ADD]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<128x32xf32> // CHECK-NEXT: return %[[SHARD_MAP]] - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> %2:3 = call @shmap_body_9(%1) : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x32xf32>) - %3 = mhlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> + %3 = stablehlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> return %4 : tensor<128x32xf32> } // CHECK-NOT: func.func private @shmap_body_9 func.func private @shmap_body_9(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> - %1 = mhlo.multiply %0, %0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> + %1 = stablehlo.multiply %0, %0 : tensor<16x32xf32> return %0, %0, %1 : tensor<16x32xf32>, tensor<16x32xf32>, tensor<16x32xf32> } @@ -331,32 +331,32 @@ func.func private @shmap_body_9(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, func.func public @shard_map_multiple_call_ops_unused_result_in_one(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>, tensor<4x128xf32>) { // CHECK-NEXT: %[[SHARD_MAP_0:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg1, %arg1 + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %arg1, %arg1 // CHECK-NEXT: sdy.return %[[ADD_0]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: %[[SHARD_MAP_1:.*]]:2 = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>, <@mesh_0, [{}, {"a"}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>) { - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %arg1, %arg1 + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %arg1, %arg1 // CHECK-NEXT: sdy.return %[[ADD_1]], %[[ADD_1]] // CHECK-NEXT: } : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<4x128xf32>) // CHECK-NEXT: return %[[SHARD_MAP_0]], %[[SHARD_MAP_1]]#0, %[[SHARD_MAP_1]]#1 - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2:2 = call @shmap_body_10(%1) : (tensor<4x32xf32>) -> (tensor<4x32xf32>, tensor<4x32xf32>) - %3 = mhlo.custom_call @Sharding(%2#0) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2#0) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> %5:2 = call @shmap_body_10(%1) : (tensor<4x32xf32>) -> (tensor<4x32xf32>, tensor<4x32xf32>) - %6 = mhlo.custom_call @Sharding(%5#0) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %7 = mhlo.custom_call @SPMDShardToFullShape(%6) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> - %8 = mhlo.custom_call @Sharding(%5#1) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %9 = mhlo.custom_call @SPMDShardToFullShape(%8) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {"a"}]>]>} : (tensor<4x32xf32>) -> tensor<4x128xf32> + %6 = stablehlo.custom_call @Sharding(%5#0) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %7 = stablehlo.custom_call @SPMDShardToFullShape(%6) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %8 = stablehlo.custom_call @Sharding(%5#1) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %9 = stablehlo.custom_call @SPMDShardToFullShape(%8) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {"a"}]>]>} : (tensor<4x32xf32>) -> tensor<4x128xf32> return %4, %7, %9 : tensor<16x32xf32>, tensor<16x32xf32>, tensor<4x128xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body_10(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>, tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<4x32xf32> return %0, %0 : tensor<4x32xf32>, tensor<4x32xf32> } @@ -364,19 +364,19 @@ func.func private @shmap_body_10(%arg0: tensor<4x32xf32>) -> (tensor<4x32xf32>, func.func public @shard_map_duplicate_operand(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { // CHECK: %0 = sdy.manual_computation(%arg0, %arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>, <@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<4x32xf32>, %arg2: tensor<4x32xf32>) { - // CHECK-NEXT: %1 = mhlo.add %arg1, %arg2 : tensor<4x32xf32> + // CHECK-NEXT: %1 = stablehlo.add %arg1, %arg2 : tensor<4x32xf32> // CHECK-NEXT: sdy.return %1 : tensor<4x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %0 : tensor<16x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<4x32xf32> %2 = call @shmap_body_11(%1, %1) : (tensor<4x32xf32>, tensor<4x32xf32>) -> tensor<4x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<4x32xf32>) -> tensor<4x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"a"}, {}]>]>} : (tensor<4x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } // CHECK-NOT: func.func private @shmap_body func.func private @shmap_body_11(%arg0: tensor<4x32xf32>, %arg1: tensor<4x32xf32>) -> (tensor<4x32xf32>) { - %0 = mhlo.add %arg0, %arg1 : tensor<4x32xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<4x32xf32> return %0 : tensor<4x32xf32> } diff --git a/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import_failure.mlir b/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import_failure.mlir index 51b1a4e49f7a9e..e41b9b7fe3e0a1 100644 --- a/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import_failure.mlir +++ b/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import_failure.mlir @@ -4,16 +4,16 @@ sdy.mesh @mesh_1 = <["a"=4, "b"=2]> sdy.mesh @mesh_2 = <["a"=4, "b"=2, "c"=3]> func.func public @multiple_meshes(%arg0: tensor<16x16xf32>) -> tensor<32x4xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b"}, {"a"}]>]>} : (tensor<16x16xf32>) -> tensor<16x16xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x16xf32>) -> tensor<8x4xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b"}, {"a"}]>]>} : (tensor<16x16xf32>) -> tensor<16x16xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x16xf32>) -> tensor<8x4xf32> // expected-error @+1 {{Multiple meshes in a single manual computation.}} %2 = call @shmap_body_0(%1) : (tensor<8x4xf32>) -> tensor<8x4xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<8x4xf32>) -> tensor<8x4xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"a"}, {}], replicated={"c"}>]>} : (tensor<8x4xf32>) -> tensor<32x4xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<8x4xf32>) -> tensor<8x4xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"a"}, {}], replicated={"c"}>]>} : (tensor<8x4xf32>) -> tensor<32x4xf32> return %4 : tensor<32x4xf32> } func.func private @shmap_body_0(%arg0: tensor<8x4xf32>) -> (tensor<8x4xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<8x4xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x4xf32> return %0 : tensor<8x4xf32> } @@ -24,12 +24,12 @@ sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { // expected-error @+1 {{expecting CustomCallOp as operand}} %0 = call @shmap_body_1(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @Sharding(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %2 = mhlo.custom_call @SPMDShardToFullShape(%1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @Sharding(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %2 = stablehlo.custom_call @SPMDShardToFullShape(%1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %2 : tensor<16x32xf32> } func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -38,15 +38,15 @@ func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting SPMDFullToShardShape custom call as operand}} %1 = call @shmap_body_1(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %2 = mhlo.custom_call @Sharding(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %2 = stablehlo.custom_call @Sharding(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %3 : tensor<16x32xf32> } func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -55,15 +55,15 @@ func.func private @shmap_body_1(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting CustomCallOp as operand of SPMDFullToShardShape}} %1 = call @shmap_body(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %2 = mhlo.custom_call @Sharding(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %2 = stablehlo.custom_call @Sharding(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %3 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -72,16 +72,16 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @SPMDFullToShardShape(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @SPMDFullToShardShape(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting Sharding CustomCallOp as operand of SPMDFullToShardShape}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -90,16 +90,16 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting each result of shmap_body to have one or no uses}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - mhlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + stablehlo.custom_call @SPMDShardToFullShape(%2) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %3 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -108,16 +108,16 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting Sharding CustomCallOp user of the result to have one use}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %4, %3 : tensor<16x32xf32>, tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -126,14 +126,14 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting CustomCallOp as the use of the result of the CallOp}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> return %2 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -142,16 +142,16 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting Sharding CustomCallOp as the use of the result of the CallOp}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @SPMDShardToFullShape(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @SPMDShardToFullShape(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -160,15 +160,15 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting CustomCallOp as the use of Sharding CustomCallOp}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> return %3 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } @@ -177,15 +177,15 @@ func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { sdy.mesh @mesh_0 = <["a"=4]> func.func public @pattern_mismatch(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> // expected-error @+1 {{expecting SPMDShardToFullShape CustomCallOp as the use of Sharding CustomCallOp}} %2 = call @shmap_body(%1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %3 = mhlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @Sharding(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> + %3 = stablehlo.custom_call @Sharding(%2) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %4 = stablehlo.custom_call @Sharding(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{}, {}]>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> return %4 : tensor<16x32xf32> } func.func private @shmap_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<16x32xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<16x32xf32> return %0 : tensor<16x32xf32> } diff --git a/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir b/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir index ca2fd01b7b28d8..fe13f45d4e09a4 100644 --- a/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir +++ b/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir @@ -9,38 +9,38 @@ func.func @while_with_free_variables( %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>}, %arg2: tensor<32x96xf32>) -> (tensor<32x96xf32>, tensor<32x96xf32>) { - // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0> - // CHECK-NEXT: %[[C1:.*]] = mhlo.constant dense<1> - // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} + // CHECK-NEXT: %[[C0:.*]] = stablehlo.constant dense<0> + // CHECK-NEXT: %[[C1:.*]] = stablehlo.constant dense<1> + // CHECK-NEXT: %[[C32:.*]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %[[ADD_0]] <@mesh2, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_2 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_2, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_2:.*]] = mhlo.add %iterArg, %[[SC_0]] - // CHECK-NEXT: %[[ADD_3:.*]] = mhlo.add %[[ADD_2]], %arg2 - // CHECK-NEXT: %[[ADD_4:.*]] = mhlo.add %[[ADD_3]], %[[SC_1]] - // CHECK-NEXT: mhlo.return %[[ADD_4]], %[[ADD_1]] + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg_2, %[[C1]] + // CHECK-NEXT: %[[ADD_2:.*]] = stablehlo.add %iterArg, %[[SC_0]] + // CHECK-NEXT: %[[ADD_3:.*]] = stablehlo.add %[[ADD_2]], %arg2 + // CHECK-NEXT: %[[ADD_4:.*]] = stablehlo.add %[[ADD_3]], %[[SC_1]] + // CHECK-NEXT: stablehlo.return %[[ADD_4]], %[[ADD_1]] // CHECK-NEXT: } // CHECK-NEXT: return %[[ADD_0]], %[[WHILE]]#0 - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant dense<1> : tensor - %2 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor - %3 = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} : tensor<32x96xf32> - %4:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %0 = stablehlo.constant dense<0> : tensor + %1 = stablehlo.constant dense<1> : tensor + %2 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor + %3 = stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} : tensor<32x96xf32> + %4:2 = stablehlo.while(%iterArg = %arg0, %iterArg_2 = %0) : tensor<32x96xf32>, tensor cond { - %5 = mhlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor - mhlo.return %5 : tensor + %5 = stablehlo.compare LT, %iterArg_2, %2 : (tensor, tensor) -> tensor + stablehlo.return %5 : tensor } do { - %5 = mhlo.add %iterArg_0, %1 : tensor - %6 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> - %7 = mhlo.add %6, %arg2 : tensor<32x96xf32> - %8 = mhlo.add %7, %3 : tensor<32x96xf32> - mhlo.return %8, %5 : tensor<32x96xf32>, tensor + %5 = stablehlo.add %iterArg_2, %1 : tensor + %6 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> + %7 = stablehlo.add %6, %arg2 : tensor<32x96xf32> + %8 = stablehlo.add %7, %3 : tensor<32x96xf32> + stablehlo.return %8, %5 : tensor<32x96xf32>, tensor } return %3, %4#0 : tensor<32x96xf32>, tensor<32x96xf32> } @@ -50,44 +50,44 @@ func.func @free_var_used_in_multiple_while_ops( %arg0: tensor<32x96xf32>, %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>}) -> tensor<32x96xf32> { - // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0> - // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> + // CHECK-NEXT: %[[C0:.*]] = stablehlo.constant dense<0> + // CHECK-NEXT: %[[C32:.*]] = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE_0:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE_0:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_1 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_1, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg, %[[SC_0]] - // CHECK-NEXT: mhlo.return %[[ADD_0]], %iterArg_0 + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %iterArg, %[[SC_0]] + // CHECK-NEXT: stablehlo.return %[[ADD_0]], %iterArg_1 // CHECK-NEXT: } // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE_1:.*]]:2 = mhlo.while(%iterArg = %[[WHILE_0]]#0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE_1:.*]]:2 = stablehlo.while(%iterArg = %[[WHILE_0]]#0, %iterArg_1 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_1, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC_1]] - // CHECK-NEXT: mhlo.return %[[ADD_1]], %iterArg_0 + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg, %[[SC_1]] + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %iterArg_1 // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE_1]]#0 - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor - %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %0 = stablehlo.constant dense<0> : tensor + %1 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor + %2:2 = stablehlo.while(%iterArg = %arg0, %iterArg_1 = %0) : tensor<32x96xf32>, tensor cond { - %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor - mhlo.return %4 : tensor + %4 = stablehlo.compare LT, %iterArg_1, %1 : (tensor, tensor) -> tensor + stablehlo.return %4 : tensor } do { - %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> - mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor + %4 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> + stablehlo.return %4, %iterArg_1 : tensor<32x96xf32>, tensor } - %3:2 = mhlo.while(%iterArg = %2#0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %3:2 = stablehlo.while(%iterArg = %2#0, %iterArg_1 = %0) : tensor<32x96xf32>, tensor cond { - %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor - mhlo.return %4 : tensor + %4 = stablehlo.compare LT, %iterArg_1, %1 : (tensor, tensor) -> tensor + stablehlo.return %4 : tensor } do { - %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> - mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor + %4 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> + stablehlo.return %4, %iterArg_1 : tensor<32x96xf32>, tensor } return %3#0 : tensor<32x96xf32> } diff --git a/xla/service/spmd/shardy/test/round_trip_pipeline.mlir b/xla/service/spmd/shardy/test/round_trip_pipeline.mlir index ae6e1640c50a04..fc08c8fb7472e5 100644 --- a/xla/service/spmd/shardy/test/round_trip_pipeline.mlir +++ b/xla/service/spmd/shardy/test/round_trip_pipeline.mlir @@ -13,8 +13,8 @@ // CHECK-SAME: %arg0: tensor<8x16xf32>) func.func @main( %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - %1 = mhlo.add %0, %0 : tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + %1 = stablehlo.add %0, %0 : tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -33,8 +33,8 @@ func.func @main( // CHECK: %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>}) %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>} ) -> (tensor<8x16xf32>) { - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - %1 = mhlo.add %0, %0 : tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + %1 = stablehlo.add %0, %0 : tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -57,8 +57,8 @@ func.func @main( %arg0: tensor<8x16xf32> // CHECK-SAME: -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>}) { ) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>}) { - // CHECK: mhlo.add %arg0, %arg0 : tensor<8x16xf32> - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> + // CHECK: stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> return %0 : tensor<8x16xf32> } @@ -123,10 +123,10 @@ sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]> func.func @main( %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}p4]>}, %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: %[[ADD:.*]] = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> // CHECK-NEXT: %[[WSC:.*]] = sdy.sharding_constraint %0 <@mesh, [{}, {"c", ?}p1]> : tensor<8x8xf32> // CHECK-NEXT: return %[[WSC]] : tensor<8x8xf32> - %0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<8x8xf32> %1 = sdy.sharding_constraint %0 <@mesh, [{}, {"c", ?}p1]> : tensor<8x8xf32> return %1 : tensor<8x8xf32> } @@ -168,10 +168,10 @@ sdy.mesh @mesh_2 = <["x"=8, "y"=4]> func.func @main( // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) { %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - // CHECK-NEXT: %[[CUSTOM_CALL:.*]]:2 = mhlo.custom_call @sdy_testonly(%arg0) {backend_config = "", xla_shape = "(f32[8,16]{1,0}, f32[8,16]{1,0})"} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) - %1:2 = mhlo.custom_call @sdy_testonly(%arg0) : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + // CHECK-NEXT: %[[CUSTOM_CALL:.*]]:2 = stablehlo.custom_call @sdy_testonly(%arg0) {backend_config = "", xla_shape = "(f32[8,16]{1,0}, f32[8,16]{1,0})"} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) + %1:2 = stablehlo.custom_call @sdy_testonly(%arg0) : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) // CHECK-NEXT: return %[[ADD]], %[[CUSTOM_CALL]]#0, %[[CUSTOM_CALL]]#1 return %0, %1#0, %1#1 : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32> } @@ -186,33 +186,33 @@ sdy.mesh @mesh = <["x"=2]> // CHECK-LABEL: func @main func.func @main( %arg0: tensor<32x96xf32>, - %arg1: tensor<32x96xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}, {}]>"}}) + %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}) -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { // CHECK-DAG: %[[C1:.*]] = sdy.constant dense<1> - // CHECK-DAG: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-DAG: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] + // CHECK-DAG: %[[ADD_0:.*]] = stablehlo.add %iterArg_0, %[[C1]] + // CHECK-DAG: %[[ADD_1:.*]] = stablehlo.add %iterArg, %[[SC]] + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = sdy.constant dense<0> : tensor %1 = sdy.constant dense<32> : tensor - %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %2:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor cond { - %3 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor - mhlo.return %3 : tensor + %3 = stablehlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor + stablehlo.return %3 : tensor } do { %3 = sdy.constant dense<1> : tensor - %4 = mhlo.add %iterArg_0, %3 : tensor - %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> - mhlo.return %5, %4 : tensor<32x96xf32>, tensor + %4 = stablehlo.add %iterArg_0, %3 : tensor + %5 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> + stablehlo.return %5, %4 : tensor<32x96xf32>, tensor } return %2#0 : tensor<32x96xf32> } diff --git a/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir b/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir index 57ac4d8a32b59b..9f9a57fa7a0c3f 100644 --- a/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir +++ b/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir @@ -16,12 +16,12 @@ func.func @main(%arg0: tensor<16x32xf32>) -> tensor<128x32xf32> { // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{}, {}], replicated={"a", "b"}>] out_shardings=[<@mesh_1, [{"a", "b"}, {}]>, <@mesh_1, [{"b", "a"}, {}]>] manual_axes={"a", "b"} (%arg1: tensor<16x32xf32>) { // CHECK-NEXT: sdy.return %arg1, %arg1 : tensor<16x32xf32>, tensor<16x32xf32> // CHECK-NEXT: } : (tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32> // CHECK-NEXT: return %[[ADD]] : tensor<128x32xf32> - %0 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> %1:2 = call @xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={\\\22a\\\22, \\\22b\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22, \\\22b\\\22}, {}]>, <@mesh_1, [{\\\22b\\\22, \\\22a\\\22}, {}]>]>"}} : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) - %2:2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1#0, %1#1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) - %3 = mhlo.add %2#0, %2#1 : tensor<128x32xf32> + %2:2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1#0, %1#1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) + %3 = stablehlo.add %2#0, %2#1 : tensor<128x32xf32> return %3 : tensor<128x32xf32> } // CHECK-NOT: func.func private @xla.sdy.manual_computation_body diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir index d0ed401a2a4299..17b6681d2b5c77 100644 --- a/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir +++ b/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir @@ -13,19 +13,19 @@ sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]> // CHECK-SAME: -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}, {}]>}) func.func @main(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}, {}]>}) { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[ADD_0]], %[[ADD_0]] : tensor<8x16xf32> - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %[[MUL]], %[[MUL]] : tensor<8x16xf32> + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %[[ADD_0]], %[[ADD_0]] : tensor<8x16xf32> + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %[[MUL]], %[[MUL]] : tensor<8x16xf32> // CHECK-NEXT: return %[[ADD_1]] : tensor<8x16xf32> - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> %1 = func.call @nested_func(%0) : (tensor<8x16xf32>) -> (tensor<8x16xf32>) - %2 = mhlo.add %1, %1 : tensor<8x16xf32> + %2 = stablehlo.add %1, %1 : tensor<8x16xf32> return %2 : tensor<8x16xf32> } // CHECK-NOT: func @nested_func func.func @nested_func(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>}) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>}) { - %0 = mhlo.multiply %arg0, %arg0 : tensor<8x16xf32> + %0 = stablehlo.multiply %arg0, %arg0 : tensor<8x16xf32> return %0 : tensor<8x16xf32> } diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir index 2f65739b180d69..5ecaa6091c52f7 100644 --- a/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir +++ b/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir @@ -20,25 +20,25 @@ sdy.mesh @mesh_2 = <["x"=8, "y"=4]> func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_0", "axis_2"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_1"}]>}) -> tensor<8x16xf32> { -// CHECK-NEXT: mhlo.add +// CHECK-NEXT: stablehlo.add // CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22axis_1\\\22, \\\22axis_0\\\22}, {}]>]>"}, mhlo.sharding = - %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } // CHECK-LABEL: func @multi_result_op func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { - %0 = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: mhlo.reduce + %0 = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: stablehlo.reduce // CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {\\\22y\\\22}]>, <@mesh_2, [{\\\22y\\\22}, {}]>]>"}, mhlo.sharding = - %1:2 = mhlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] + %1:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{}, {"y"}]>, <@mesh_2, [{"y"}, {}]>]>} : (tensor<4x64x8xf32>, tensor<4x64x8xf32>, tensor, tensor) -> (tensor<4x8xf32>, tensor<4x8xf32>) reducer(%arg2: tensor, %arg4: tensor) (%arg3: tensor, %arg5: tensor) { - %2 = mhlo.add %arg2, %arg4 : tensor - %3 = mhlo.add %arg3, %arg5 : tensor - mhlo.return %2, %3 : tensor, tensor + %2 = stablehlo.add %arg2, %arg4 : tensor + %3 = stablehlo.add %arg3, %arg5 : tensor + stablehlo.return %2, %3 : tensor, tensor } return %1#0, %1#1 : tensor<4x8xf32>, tensor<4x8xf32> } @@ -49,9 +49,9 @@ func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) // CHECK-SAME: -> tensor<8x16xf32> { func.func @split_axes(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"y"}, {"x":(2)2}]>}, %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x":(1)2}, {"x":(2)4}]>}) -> tensor<8x16xf32> { -// CHECK-NEXT: "mhlo.dot" +// CHECK-NEXT: stablehlo.dot // CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22:(1)2, \\\22x\\\22:(4)2}, {}]>]>"}, mhlo.sharding = - %1 = "mhlo.dot" (%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %1 = stablehlo.dot %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } @@ -60,7 +60,7 @@ func.func @func_result_sharding_returning_func_arg( // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {mhlo.sharding = %arg0: tensor<8x16xf32> ) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}) { - // CHECK: %[[CUSTOM_CALL:.*]] = mhlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %[[CUSTOM_CALL:.*]] = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: return %[[CUSTOM_CALL]] : tensor<8x16xf32> return %arg0 : tensor<8x16xf32> } @@ -75,22 +75,22 @@ func.func @func_result_sharding_returning_op_value(%arg0: tensor<8x16xf32>) tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{}, {}]>}) { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = mhlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, \\\22y\\\22}, {}]>, <@mesh_2, [{\\\22y\\\22, \\\22x\\\22}, {}]>]>"}, mhlo.sharding = - // CHECK-NEXT: %[[ADD_RESULT_SHARDING_0:.*]] = mhlo.custom_call @xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_0:.*]] = mhlo.custom_call @xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_1:.*]] = mhlo.custom_call @xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22}, {\\\22y\\\22}p1]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: %[[ADD_RESULT_SHARDING_1:.*]] = mhlo.custom_call @xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {}]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = stablehlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, \\\22y\\\22}, {}]>, <@mesh_2, [{\\\22y\\\22, \\\22x\\\22}, {}]>]>"}, mhlo.sharding = + // CHECK-NEXT: %[[ADD_RESULT_SHARDING_0:.*]] = stablehlo.custom_call @xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_0:.*]] = stablehlo.custom_call @xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_1:.*]] = stablehlo.custom_call @xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22}, {\\\22y\\\22}p1]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[ADD_RESULT_SHARDING_1:.*]] = stablehlo.custom_call @xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {}]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: return %[[ADD_RESULT_SHARDING_0]], %[[TEST_ONLY_RES_SHARDING_0]], %[[TEST_ONLY_RES_SHARDING_1]], %[[ADD_RESULT_SHARDING_1]] - %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - %1:2 = mhlo.custom_call @sdy_testonly(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x","y"}, {}]>, <@mesh_2, [{"y","x"}, {}]>]>} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) + %0 = stablehlo.add %arg0, %arg0 : tensor<8x16xf32> + %1:2 = stablehlo.custom_call @sdy_testonly(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x","y"}, {}]>, <@mesh_2, [{"y","x"}, {}]>]>} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) return %0, %1#0, %1#1, %0 : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32> } // CHECK-LABEL: func @sharding_constraint // CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { func.func @sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {?}]>]>"}, mhlo.sharding = + // CHECK: stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {?}]>]>"}, mhlo.sharding = %0 = sdy.sharding_constraint %arg0 <@mesh_2, [{"x", ?}, {?}]> : tensor<8x8xf32> return %0 : tensor<8x8xf32> } @@ -98,14 +98,14 @@ func.func @sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { // CHECK-LABEL: func @export_sharding_group // CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { func.func @export_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: mhlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "12 : i64"}} + // CHECK: stablehlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "12 : i64"}} sdy.sharding_group %arg0 group_id = 12: tensor<8x8xf32> return %arg0 : tensor<8x8xf32> } // CHECK-LABEL: func @constant func.func @constant() -> tensor { - // CHECK-NEXT: %[[CONST:.*]] = mhlo.constant dense<0> + // CHECK-NEXT: %[[CONST:.*]] = stablehlo.constant dense<0> // CHECK-NEXT: return %[[CONST]] %0 = sdy.constant dense<0> : tensor return %0 : tensor @@ -119,9 +119,9 @@ func.func @constant() -> tensor { func.func @inlined_mesh( %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding, [{"a"}]>} ) -> (tensor<32xi32> {sdy.sharding = #sdy.sharding, [{}]>}) { - // CHECK-NEXT: %[[SHARDING:.*]] = mhlo.custom_call @Sharding(%arg0) + // CHECK-NEXT: %[[SHARDING:.*]] = stablehlo.custom_call @Sharding(%arg0) // CHECK-SAME: mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\\\22c\\\22}]>]>"}, mhlo.sharding = "{devices=[4]<=[4]}"} - // CHECK-NEXT: %[[RESULT_SHARDING:.*]] = mhlo.custom_call @xla.sdy.FuncResultSharding(%[[SHARDING]]) + // CHECK-NEXT: %[[RESULT_SHARDING:.*]] = stablehlo.custom_call @xla.sdy.FuncResultSharding(%[[SHARDING]]) // CHECK-SAME: mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{}]>]>"} // CHECK-NEXT: return %[[RESULT_SHARDING]] %0 = sdy.sharding_constraint %arg0 , [{"c"}]> : tensor<32xi32> @@ -160,10 +160,10 @@ func.func @non_sdy_module(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,8 %arg1: tensor<8x8xf32> {mhlo.sharding = "{devices=[1,2,16]<=[32] last_tile_dim_replicate}"}, %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[4,4,2]<=[2,16]T(1,0) last_tile_dim_replicate}"}) -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { - // CHECK-NEXT: mhlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"} + // CHECK-NEXT: stablehlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"} // CHECK-NOT: xla.sdy.sharding // CHECK-NOT: xla.sdy.sharding_rule - %0 = mhlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"} : tensor<8x8xf32> - %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + %0 = stablehlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"} : tensor<8x8xf32> + %1 = stablehlo.dot %0, %arg2 : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index c077f5a9aa49c8..ad7ef021f4e2cf 100644 --- a/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: sdy_opt %s --split-input-file -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s +// RUN: sdy_opt %s --split-input-file -xla-sdy-import-constants -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s // CHECK-LABEL: module @multiple_func_result_shardings module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = @@ -25,11 +25,11 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p1]>"}}, %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22c\\\22}p0]>"}} ) -> (tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>) { - %0 = mhlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %1 = mhlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %2 = mhlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = mhlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %4 = mhlo.custom_call @xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %4 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %0, %1, %2, %3, %1, %4 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32> } @@ -39,16 +39,16 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-SAME: ) -> ( // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, // CHECK-SAME: tensor<32xi32>) { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg1 + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 // CHECK-NEXT: return %arg0, %[[ADD]] // CHECK-NEXT: } func.func @func_result_shardings_used_by_other_ops( %arg0: tensor<32xi32>, %arg1: tensor<32xi32> ) -> (tensor<32xi32>, tensor<32xi32>) { - %0 = mhlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %1 = mhlo.custom_call @xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %2 = mhlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = mhlo.add %1, %2 : tensor<32xi32> + %0 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.add %1, %2 : tensor<32xi32> return %1, %3 : tensor<32xi32>, tensor<32xi32> } @@ -61,27 +61,27 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg, %[[SC]] + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<1> : tensor %2 = mhlo.constant dense<32> : tensor - %3:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %3:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor cond { - %4 = mhlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor - mhlo.return %4 : tensor + %4 = stablehlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor + stablehlo.return %4 : tensor } do { - %4 = mhlo.add %iterArg_0, %1 : tensor - %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> - mhlo.return %5, %4 : tensor<32x96xf32>, tensor + %4 = stablehlo.add %iterArg_0, %1 : tensor + %5 = stablehlo.add %iterArg, %arg1 : tensor<32x96xf32> + stablehlo.return %5, %4 : tensor<32x96xf32>, tensor } return %3#0 : tensor<32x96xf32> } @@ -89,29 +89,29 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-LABEL: func @while_with_sinked_constants func.func @while_with_sinked_constants(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> - // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: %[[WHILE:.*]]:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] - // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: %[[COND:.*]] = stablehlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: stablehlo.return %[[COND]] // CHECK-NEXT: } do { // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] + // CHECK-NEXT: %[[ADD_0:.*]] = stablehlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_1:.*]] = stablehlo.add %iterArg, %iterArg + // CHECK-NEXT: stablehlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor - %1:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + %1:2 = stablehlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor cond { %2 = mhlo.constant dense<32> : tensor - %3 = mhlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor - mhlo.return %3 : tensor + %3 = stablehlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor + stablehlo.return %3 : tensor } do { %2 = mhlo.constant dense<1> : tensor - %3 = mhlo.add %iterArg_0, %2 : tensor - %4 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32> - mhlo.return %4, %3 : tensor<32x96xf32>, tensor + %3 = stablehlo.add %iterArg_0, %2 : tensor + %4 = stablehlo.add %iterArg, %iterArg : tensor<32x96xf32> + stablehlo.return %4, %3 : tensor<32x96xf32>, tensor } return %1#0 : tensor<32x96xf32> } @@ -122,14 +122,14 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x func.func @discard_shardings_on_unknown_ops( %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p0]>"}} ) -> tensor<32xi32> { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg0 : tensor<32xi32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<32xi32> // CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %[[ADD]] <@mesh, [{"a"}p2]> : tensor<32xi32> - // CHECK-NEXT: %[[UNKNOWN:.*]] = mhlo.custom_call @UnknownCustomCall(%[[SHARDING]]) : (tensor<32xi32>) -> tensor<32xi32> + // CHECK-NEXT: %[[UNKNOWN:.*]] = stablehlo.custom_call @UnknownCustomCall(%[[SHARDING]]) : (tensor<32xi32>) -> tensor<32xi32> // CHECK-NEXT: return %[[UNKNOWN]] - %0 = mhlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : tensor<32xi32> - %1 = mhlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %2 = mhlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = mhlo.custom_call @xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : tensor<32xi32> + %1 = stablehlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = stablehlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %3 : tensor<32xi32> } @@ -141,8 +141,8 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x ) -> tensor<32xi32> { // CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %arg0 , [{"c"}]> : tensor<32xi32> // CHECK-NEXT: return %[[SHARDING]] - %0 = mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\\\22c\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %1 = mhlo.custom_call @xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\\\22c\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %1 : tensor<32xi32> } @@ -159,16 +159,16 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22c\\\22, \\\22b\\\22, ?}p0]>"}} ) -> (tensor<32xi32>, tensor<32xi32>) { // CHECK-NEXT: %[[SC1:.*]] = sdy.sharding_constraint %arg0 <@mesh2, [{"b", ?}]> - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SC1]], %[[SC1]] + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SC1]], %[[SC1]] // CHECK-NOT: sdy.sharding // CHECK-NEXT: %[[SC2:.*]] = sdy.sharding_constraint %arg1 <@mesh2, [{}]> // CHECK-NEXT: return %[[ADD]], %[[SC2]] // CHECK-NEXT: } - %0 = mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22, \\\22b\\\22, ?}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %1 = mhlo.add %0, %0 : tensor<32xi32> - %2 = mhlo.custom_call @Sharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22c\\\22, \\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %3 = mhlo.custom_call @xla.sdy.FuncResultSharding(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - %4 = mhlo.custom_call @xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22b\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22, \\\22b\\\22, ?}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = stablehlo.add %0, %0 : tensor<32xi32> + %2 = stablehlo.custom_call @Sharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22c\\\22, \\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %4 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22b\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %3, %4 : tensor<32xi32>, tensor<32xi32> } @@ -180,19 +180,19 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x // CHECK-SAME{LITERAL}: out_shardings=[<@mesh2, [{}, {"b"}]>] // CHECK-SAME{LITERAL}: manual_axes={"b"} // CHECK-SAME: (%arg2: tensor<16x8xf32>, %arg3: tensor<16x8xf32>) { - // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg2, %arg3 + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg2, %arg3 // CHECK-NEXT: sdy.return %[[ADD]] // CHECK-NEXT: } : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32> // CHECK-NEXT: return %[[MAN_COMP]] - %0:2 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<16x8xf32>, tensor<16x8xf32>) + %0:2 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<16x8xf32>, tensor<16x8xf32>) %1 = call @xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh2, [{}, {\\\22b\\\22}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh2, [{}, {\\\22b\\\22, \\\22a\\\22}]>]>"}} : (tensor<16x8xf32>, tensor<16x8xf32>) -> tensor<16x8xf32> - %2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<16x8xf32>) -> tensor<16x32xf32> + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<16x8xf32>) -> tensor<16x32xf32> return %2 : tensor<16x32xf32> } // CHECK-NOT: func @xla.sdy.manual_computation_body( func.func @xla.sdy.manual_computation_body(%arg0: tensor<16x8xf32>, %arg1: tensor<16x8xf32>) -> tensor<16x8xf32> { - %0 = mhlo.add %arg0, %arg1 : tensor<16x8xf32> + %0 = stablehlo.add %arg0, %arg1 : tensor<16x8xf32> return %0 : tensor<16x8xf32> } } @@ -238,16 +238,6 @@ module @no_meshes_attr_module { // CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { func.func @import_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { // CHECK sdy.sharding_group %arg0 group_id = 21: tensor<8x8xf32> - mhlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> () - return %arg0 : tensor<8x8xf32> -} - -// ----- - -// CHECK-LABEL: func @import_sharding_group_with_unused_result -// CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { -func.func @import_sharding_group_with_unused_result(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK sdy.sharding_group %arg0 group_id = 21: tensor<8x8xf32> - %0 = mhlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> tuple<> + stablehlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> () return %arg0 : tensor<8x8xf32> } diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir index 38ea2ee887d1af..4c6f37372d2273 100644 --- a/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir +++ b/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir @@ -20,9 +20,9 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) // CHECK-NEXT: sdy.return %[[REDUCE]] : tensor<2x32xf32> // CHECK-NEXT: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %[[MAN_COMP]] : tensor<8x32xf32> - %0:2 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) + %0:2 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) %1 = call @xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> - %2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x32xf32>) -> tensor<8x32xf32> + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x32xf32>) -> tensor<8x32xf32> return %2 : tensor<8x32xf32> } @@ -36,9 +36,9 @@ func.func @single_manual_comp_name_is_not_prefix_nor_suffix(%arg0: tensor<8x8xf3 // CHECK-NEXT: sdy.return %arg1 : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %[[MAN_COMP]] : tensor<8x8xf32> - %0 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> %1 = call @my_model.___call__.fwd.xla.sdy.manual_computation_body_14.1234(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> return %2 : tensor<8x8xf32> } @@ -60,20 +60,20 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: sdy.return %arg1 : tensor<8x4xf32> // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %[[MAN_COMP_1]] : tensor<8x8xf32> - %0 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> %1 = call @xla.sdy.manual_computation_body_0(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> - %3 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> tensor<8x4xf32> + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> + %3 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> tensor<8x4xf32> %4 = call @xla.sdy.manual_computation_body_1(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} : (tensor<8x4xf32>) -> tensor<8x4xf32> - %5 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%4) : (tensor<8x4xf32>) -> tensor<8x8xf32> + %5 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%4) : (tensor<8x4xf32>) -> tensor<8x8xf32> return %5 : tensor<8x8xf32> } // CHECK-NOT: func @xla.sdy.manual_computation_body_3( func.func @xla.sdy.manual_computation_body_3(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> %1 = call @xla.sdy.manual_computation_body_2(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> return %2 : tensor<2x8xf32> } @@ -101,9 +101,9 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[MAN_COMP_1]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> %1 = call @xla.sdy.manual_computation_body_3(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } @@ -126,9 +126,9 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[ADD]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> %1 = call @xla.sdy.manual_computation_body_5(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> return %2 : tensor<4x8xf32> } @@ -144,7 +144,7 @@ func.func @manual_computation_no_inputs() -> tensor<4xi64> { // CHECK-NEXT: } : () -> tensor<4xi64> // CHECK-NEXT: return %[[SHMAP]] : tensor<4xi64> %0 = call @xla.sdy.manual_computation_body_6() {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>"}} : () -> tensor<2xi64> - %1 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%0) : (tensor<2xi64>) -> tensor<4xi64> + %1 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%0) : (tensor<2xi64>) -> tensor<4xi64> return %1 : tensor<4xi64> } @@ -155,11 +155,11 @@ func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { // CHECK-SAME{LITERAL}: out_shardings=[] // CHECK-SAME{LITERAL}: manual_axes={"b"} // CHECK-SAME{LITERAL}: (%arg1: tensor<2xi64>) { - // CHECK-NEXT: mhlo.custom_call @sdy_testonly(%arg1) : (tensor<2xi64>) -> () + // CHECK-NEXT: stablehlo.custom_call @sdy_testonly(%arg1) : (tensor<2xi64>) -> () // CHECK-NEXT: sdy.return // CHECK-NEXT: } : (tensor<4xi64>) -> () // CHECK-NEXT: return - %0 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> call @xla.sdy.manual_computation_body_7(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[]>"}} : (tensor<2xi64>) -> () return } @@ -198,9 +198,9 @@ func.func @xla.sdy.manual_computation_body_4(%arg0: tensor<2x4xf32>) -> tensor<2 // CHECK-NOT: func @xla.sdy.manual_computation_body_5( func.func @xla.sdy.manual_computation_body_5(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> %1 = call @xla.sdy.manual_computation_body_4(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> - %2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> %3 = stablehlo.add %2, %2 : tensor<2x8xf32> return %3 : tensor<2x8xf32> } @@ -213,6 +213,6 @@ func.func @xla.sdy.manual_computation_body_6() -> tensor<2xi64> { // CHECK-NOT: func @xla.sdy.manual_computation_body_7( func.func @xla.sdy.manual_computation_body_7(%arg0: tensor<2xi64>) { - mhlo.custom_call @sdy_testonly(%arg0) : (tensor<2xi64>) -> () + stablehlo.custom_call @sdy_testonly(%arg0) : (tensor<2xi64>) -> () return } diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir index 18bb1e698b3692..bd013563fc5dbe 100644 --- a/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir +++ b/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir @@ -3,14 +3,14 @@ sdy.mesh @mesh = <["a"=2]> func.func @using_same_body_func(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + %0 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) %1 = call @xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) - %2 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) - %3 = mhlo.custom_call @xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + %2 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + %3 = stablehlo.custom_call @xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) // expected-error @+2 {{'func.call' op expected a unique FuncOp per @xla.sdy.manual_computation_body call}} // expected-error @+1 {{failed to legalize operation 'func.call'}} %4 = call @xla.sdy.manual_computation_body(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) - %5 = mhlo.custom_call @xla.sdy.LocalToGlobalShape(%4) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + %5 = stablehlo.custom_call @xla.sdy.LocalToGlobalShape(%4) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) return %5 : tensor<8x8xf32> } diff --git a/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir b/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir index 9657697639513b..e8550407db59a6 100644 --- a/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir +++ b/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir @@ -1,18 +1,18 @@ // RUN: sdy_opt %s -xla-sdy-import-sdy-custom-calls -split-input-file -verify-diagnostics func.func @sharding_group_import_failure_if_no_group_id(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { - // expected-error @+2 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} + // expected-error @+2 {{failed to legalize operation 'stablehlo.custom_call' that was explicitly marked illegal}} // expected-error @+1 {{expected CustomCallOp with a sharding group id.}} - mhlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {}} : (tensor<16x16xf32>) -> () + stablehlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {}} : (tensor<16x16xf32>) -> () return %arg0 : tensor<16x16xf32> } // ----- func.func @sharding_group_import_with_used_result(%arg0: tensor<8x8xf32>) -> tuple> { - // expected-error @+2 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} + // expected-error @+2 {{failed to legalize operation 'stablehlo.custom_call' that was explicitly marked illegal}} // expected-error @+1 {{xla.sdy.ShardingGroup CustomCallOp should have no uses.}} - %0 = mhlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> tuple<> - %1 = "mhlo.tuple"(%0) : (tuple<>) -> tuple> + %0 = stablehlo.custom_call @xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> tuple<> + %1 = "stablehlo.tuple"(%0) : (tuple<>) -> tuple> return %1 : tuple> }