Skip to content

Commit

Permalink
#sdy Support empty meshes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666387965
  • Loading branch information
bartchr808 authored and copybara-github committed Aug 22, 2024
1 parent 1096ee4 commit 0621825
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
14 changes: 6 additions & 8 deletions xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,15 @@ HloSharding convertToHloSharding(
ArrayRef<AxisRefAttr> manualAxes) {
MeshAttr mesh = getMeshAttr(sdySharding);

// Convert to maximal sharding if the mesh only contains the device id.
// If there are no axes, convert to:
// - maximal sharding if the mesh has a device id
// - else replicated sharding
if (mesh.getAxes().empty()) {
CHECK_EQ(mesh.getDeviceIds().size(), 1);
return HloSharding::AssignDevice(mesh.getDeviceIds().front());
return mesh.getDeviceIds().empty()
? HloSharding::Replicate()
: HloSharding::AssignDevice(mesh.getDeviceIds().front());
}

// TODO(b/326025166):
// Handle empty mesh.
// Handle arbitrary device id list.
CHECK(mesh.getDeviceIds().empty());

SmallVector<int64_t> tileAssignmentDims(sdySharding.getRank(), 1);
llvm::SmallDenseMap<AxisRefAttr, int64_t> axisRefToShardedPos;
SmallVector<OpSharding::Type> types;
Expand Down
11 changes: 11 additions & 0 deletions xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ sdy.mesh @mesh_2 = <["x"=8, "y"=4]>
sdy.mesh @mesh_3 = <["a"=2, "b"=2, "c"=2, "d"=2]>
sdy.mesh @maximal_mesh_0 = <device_ids=[0]>
sdy.mesh @maximal_mesh_1 = <device_ids=[1]>
sdy.mesh @empty_mesh_0 = <>
sdy.mesh @empty_mesh_1 = <>

// CHECK-NOT: sdy.mesh

Expand Down Expand Up @@ -153,3 +155,12 @@ func.func @mesh_with_device_id_should_be_converted_to_maximal_sharding(%arg0: te
%1 = mhlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@maximal_mesh_1, [{}, {}]>]>} : tensor<8x8xf32>
return %1 : tensor<8x8xf32>
}

// CHECK-LABEL: func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<8x8xf32>)
func.func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@empty_mesh_0, [{}, {}]>}, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> {
// CHECK: %[[ADD:.*]] = mhlo.add %arg0, %arg1
%0 = mhlo.add %arg0, %arg1 : tensor<8x8xf32>
// CHECK: %[[ADD_WITH_SHARDING:.*]] = mhlo.add %[[ADD]], %[[ADD]] {mhlo.sharding = "{replicated}"}
%1 = mhlo.add %0, %0 {sdy.sharding = #sdy.sharding_per_value<[<@empty_mesh_1, [{}, {}]>]>} : tensor<8x8xf32>
return %1 : tensor<8x8xf32>
}

0 comments on commit 0621825

Please sign in to comment.