diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index c761fc3e792f3..1a486ac8fde70 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "bd92e46204331b9af296f53abb708317e72ab7a8" - LLVM_SHA256 = "60f71fc5b237e10729edbed8cbe23b7081dabe254fbcb1ea82db8789cb7eaecf" + LLVM_COMMIT = "1d6ab189be031bf723abf35f772fbd5d4c86c612" + LLVM_SHA256 = "94efdc753920c1c4065c0f253c98c4b2b049495803e9089667ba397a29550323" tf_http_archive( name = name, diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index 770e3979e6c01..ce801b71cf576 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,15 +1,545 @@ +diff --git a/shardy/dialect/sdy/ir/BUILD b/shardy/dialect/sdy/ir/BUILD +index 14d3681..1fdb305 100644 +--- a/shardy/dialect/sdy/ir/BUILD ++++ b/shardy/dialect/sdy/ir/BUILD +@@ -180,7 +180,6 @@ cc_library( + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", +- "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_assembly_format", +diff --git a/shardy/dialect/sdy/ir/attrs.td b/shardy/dialect/sdy/ir/attrs.td +index b1971f0..7635037 100644 +--- a/shardy/dialect/sdy/ir/attrs.td ++++ b/shardy/dialect/sdy/ir/attrs.td +@@ -603,16 +603,7 @@ def Sdy_TensorSharding : AttrDef { + // given mesh and sharding. Assumes that the sharding is valid w.r.t. the + // mesh and tensor type. + RankedTensorType getLocalTensorType(RankedTensorType globalTensorType, +- MeshAttr mesh) const; +- +- // Gets the global tensor type from a local RankedTensorType w.r.t. the +- // given mesh and sharding. Assumes that the sharding is valid w.r.t. the +- // mesh and tensor type. +- // +- // NOTE: this doesn't take into account padding. Each dimension of +- // `localTensorType` will be a multiple of the global tensor type returned. +- RankedTensorType getGlobalTensorType(RankedTensorType localTensorType, +- MeshAttr mesh) const; ++ MeshAttr mesh); + }]; + } + +diff --git a/shardy/dialect/sdy/ir/dialect.cc b/shardy/dialect/sdy/ir/dialect.cc +index ba02cc4..749796f 100644 +--- a/shardy/dialect/sdy/ir/dialect.cc ++++ b/shardy/dialect/sdy/ir/dialect.cc +@@ -704,42 +704,24 @@ TensorShardingAttr TensorShardingAttr::getFullyOpenLike( + } + + RankedTensorType TensorShardingAttr::getLocalTensorType( +- RankedTensorType globalTensorType, MeshAttr mesh) const { +- assert(globalTensorType.hasStaticShape()); ++ RankedTensorType globalTensorType, MeshAttr mesh) { + if (getDimShardings().empty()) { + return globalTensorType; + } + SmallVector localShape; + localShape.reserve(globalTensorType.getRank()); + +- for (auto [globalDimSize, dimSharding] : +- llvm::zip_equal(globalTensorType.getShape(), getDimShardings())) { ++ for (auto [dim, dimSharding] : llvm::enumerate(getDimShardings())) { + int64_t shardSize = dimSharding.getShardedSize(mesh); +- // We allow non divisible sharding. +- int64_t localSize = (globalDimSize + shardSize - 1) / shardSize; ++ int64_t dimSize = globalTensorType.getDimSize(dim); ++ // We allow non divisible sharding ++ int64_t localSize = (dimSize + shardSize - 1) / shardSize; + localShape.push_back(localSize); + } + return RankedTensorType::get(ArrayRef(localShape), + globalTensorType.getElementType()); + } + +-RankedTensorType TensorShardingAttr::getGlobalTensorType( +- RankedTensorType localTensorType, MeshAttr mesh) const { +- assert(localTensorType.hasStaticShape()); +- if (getDimShardings().empty()) { +- return localTensorType; +- } +- SmallVector globalShape; +- globalShape.reserve(localTensorType.getRank()); +- +- for (auto [localDimSize, dimSharding] : +- llvm::zip_equal(localTensorType.getShape(), getDimShardings())) { +- globalShape.push_back(dimSharding.getShardedSize(mesh) * localDimSize); +- } +- return RankedTensorType::get(ArrayRef(globalShape), +- localTensorType.getElementType()); +-} +- + //===----------------------------------------------------------------------===// + // TensorShardingPerValueAttr + //===----------------------------------------------------------------------===// +@@ -766,12 +748,98 @@ TensorShardingPerValueAttr TensorShardingPerValueAttr::replaceValueSharding( + return TensorShardingPerValueAttr::get(getContext(), shardings); + } + ++//===----------------------------------------------------------------------===// ++// ConstantOp ++//===----------------------------------------------------------------------===// ++ ++LogicalResult ConstantOp::inferReturnTypes( ++ MLIRContext*, std::optional location, ValueRange operands, ++ DictionaryAttr attributes, OpaqueProperties properties, RegionRange, ++ SmallVectorImpl& inferredReturnTypes) { ++ ConstantOpAdaptor adaptor(operands, attributes, properties); ++ return hlo::inferConstantOp(location, adaptor.getValue(), ++ inferredReturnTypes); ++} ++ ++bool ConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { ++ return stablehlo::ConstantOp::isCompatibleReturnTypes(l, r); ++} ++ + //===----------------------------------------------------------------------===// + // ManualComputationOp + //===----------------------------------------------------------------------===// + + namespace { + ++// Callback that removes free (non-manual) axes from a ++// `dimSharding` in a `ManualComputationOp` at `firstFreeAxisIndex`. ++// ++// Some use cases are removing all axes up to `firstFreeAxisIndex` or removing ++// all axes from `firstFreeAxisIndex`. This needs to happen on many different ++// `DimShardingAttr`s in the `in_shardings` and `out_shardings` of a ++// `ManualComputationOp`. ++using ManualComputationShardingEraserFn = std::function; ++ ++// Calls a dimension sharding erasing callback on the first free axis in ++// a dimension. This uses the invariant that shardings are prefixed with any ++// manual axes. ++TensorShardingAttr eraseAxesFromManualComputationSharding( ++ TensorShardingAttr outerManualSharding, ArrayRef manualAxes, ++ ManualComputationShardingEraserFn shardingEraser) { ++ SmallVector newDimShardings; ++ newDimShardings.reserve(outerManualSharding.getRank()); ++ for (DimensionShardingAttr dimSharding : ++ outerManualSharding.getDimShardings()) { ++ ArrayRef dimAxes = dimSharding.getAxes(); ++ // Axes in the range [0, firstFreeAxis) are manual axes, and ++ // [firstFreeAxis, dimAxes.size()) are free axes. ++ llvm::ArrayRef::const_iterator firstFreeAxisIt = ++ llvm::partition_point(dimAxes, [&manualAxes](AxisRefAttr axis) { ++ return llvm::is_contained(manualAxes, axis.getName()); ++ }); ++ newDimShardings.push_back( ++ shardingEraser(dimSharding, firstFreeAxisIt - dimAxes.begin())); ++ } ++ // Grab any replicated axes that are not manual axes. Can't use ++ // `partition_point` as there is no defined order for replicated axes. ++ SmallVector newReplicatedAxes; ++ llvm::copy_if(outerManualSharding.getReplicatedAxes(), ++ std::back_inserter(newReplicatedAxes), [&](AxisRefAttr axis) { ++ return !llvm::is_contained(manualAxes, axis.getName()); ++ }); ++ return TensorShardingAttr::get(outerManualSharding.getContext(), ++ outerManualSharding.getMeshOrRef(), ++ newDimShardings, newReplicatedAxes); ++} ++ ++// Removes free axes from the sharding. ++// ++// Guaranteed by verification that all in/out shardings in a ++// `ManualComputationOp` are prefixed with the manual axes. So this removes the ++// suffix of free axes (if any exist) from each dim sharding. ++TensorShardingAttr eraseFreeAxes(TensorShardingAttr outerManualSharding, ++ ArrayRef manualAxes) { ++ return eraseAxesFromManualComputationSharding( ++ outerManualSharding, manualAxes, ++ std::mem_fn(&DimensionShardingAttr::takeFrontShardingAxes)); ++} ++ ++// Removes manual axes from the sharding. ++// ++// Guaranteed by verification that all in/out shardings in a ++// `ManualComputationOp` are prefixed with the manual axes. So this removes the ++// prefix of manual axes (if any exist) from each dim sharding. ++TensorShardingAttr eraseManualAxes(TensorShardingAttr outerManualSharding, ++ ArrayRef manualAxes) { ++ if (manualAxes.empty()) { ++ return outerManualSharding; ++ } ++ return eraseAxesFromManualComputationSharding( ++ outerManualSharding, manualAxes, ++ std::mem_fn(&DimensionShardingAttr::dropFrontShardingAxes)); ++} ++ + // Re-adds any manual axes after the new sharding is determined across the + // `ManualComputationOp` barrier. + // +@@ -951,35 +1019,6 @@ Value ManualComputationOp::getEdgeOwnerFromSource(OpOperand& source) { + return sdy::getEdgeOwnerFromSource(source, *this); + } + +-//===----------------------------------------------------------------------===// +-// ShardingGroupOp +-//===----------------------------------------------------------------------===// +- +-LogicalResult ShardingGroupOp::inferReturnTypes(MLIRContext*, +- std::optional, +- ValueRange, DictionaryAttr, +- OpaqueProperties, RegionRange, +- SmallVectorImpl&) { +- return success(); +-} +- +-//===----------------------------------------------------------------------===// +-// ConstantOp +-//===----------------------------------------------------------------------===// +- +-bool ConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { +- return stablehlo::ConstantOp::isCompatibleReturnTypes(l, r); +-} +- +-LogicalResult ConstantOp::inferReturnTypes( +- MLIRContext* context, std::optional location, ValueRange operands, +- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, +- SmallVectorImpl& inferredReturnTypes) { +- ConstantOpAdaptor adaptor(operands, attributes, properties); +- inferredReturnTypes.push_back(adaptor.getValue().getType()); +- return success(); +-} +- + //===----------------------------------------------------------------------===// + // DataFlowEdgeOp + //===----------------------------------------------------------------------===// +@@ -1064,16 +1103,6 @@ Value NamedComputationOp::getEdgeOwnerFromSource(OpOperand& source) { + return sdy::getEdgeOwnerFromSource(source, *this); + } + +-LogicalResult NamedComputationOp::inferReturnTypes( +- MLIRContext*, std::optional location, ValueRange operands, +- DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, +- SmallVectorImpl& inferredReturnTypes) { +- NamedComputationOpAdaptor adaptor(operands, attributes, properties, regions); +- llvm::copy(getBodyTerminatorOperands(adaptor).getTypes(), +- std::back_inserter(inferredReturnTypes)); +- return success(); +-} +- + } // namespace sdy + } // namespace mlir + +diff --git a/shardy/dialect/sdy/ir/ops.td b/shardy/dialect/sdy/ir/ops.td +index d62dc85..b11b83f 100644 +--- a/shardy/dialect/sdy/ir/ops.td ++++ b/shardy/dialect/sdy/ir/ops.td +@@ -48,7 +48,7 @@ def Sdy_MeshOp : Sdy_Op<"mesh", [Symbol, HasParent<"ModuleOp">]> { + } + + def Sdy_ShardingConstraintOp : Sdy_Op<"sharding_constraint", +- [Elementwise, SameOperandsAndResultType, InferTypeOpInterface]> { ++ [Elementwise, SameOperandsAndResultType]> { + let summary = "Constrains a tensor to the specified sharding"; + let description = [{ + Attaches a sharding to an intermediate tensor (e.g. the result of a matmul) +@@ -78,7 +78,7 @@ def Sdy_ShardingConstraintOp : Sdy_Op<"sharding_constraint", + } + + def Sdy_ReshardOp : Sdy_Op<"reshard", +- [Pure, Elementwise, SameOperandsAndResultType, InferTypeOpInterface]> { ++ [Pure, Elementwise, SameOperandsAndResultType]> { + let summary = "Reshards a tensor to a different sharding"; + let description = [{ + Reshards the input tensor with the specified sharding, which is different +@@ -124,10 +124,10 @@ def Sdy_ReturnOp : Sdy_Op<"return", [Pure, Terminator]> { + def Sdy_ManualComputationOp : Sdy_Op<"manual_computation", + [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"ReturnOp">, + IsolatedFromAbove, DeclareOpInterfaceMethods< +- ShardableDataFlowOpInterface, +- /*methodOverrides=*/["getOpResultEdgeOwnerShardings", +- "setOpResultEdgeOwnerShardings", +- "transformTargetSharding"]>]> { ++ ShardableDataFlowOpInterface, ++ /*methodOverrides=*/["getOpResultEdgeOwnerShardings", ++ "setOpResultEdgeOwnerShardings", ++ "transformTargetSharding"]>]> { + let summary = "Multi-device parallelism operation with manual collectives"; + let description = [{ + Jump into a region written in terms of per-device local code with explicit +@@ -189,7 +189,7 @@ def Sdy_ManualComputationOp : Sdy_Op<"manual_computation", + def Sdy_ShardingGroupOp : Sdy_Op<"sharding_group", + // Op is non-pure since it modifies the internal representation of the + // sharding group. +- [DeclareOpInterfaceMethods]>{ ++ []>{ + let summary = "Constrains tensors in the group to have the same sharding."; + let description = [{ + This op provides an interface to assign tensors to sharding groups ( +@@ -340,8 +340,7 @@ def Sdy_DataFlowEdgeOp : Sdy_Op<"data_flow_edge", + //===----------------------------------------------------------------------===// + + def Sdy_PropagationBarrierOp : Sdy_Op<"propagation_barrier", +- [Pure, Elementwise, SameOperandsAndResultType, +- DeclareOpInterfaceMethods]> { ++ [Pure, Elementwise, SameOperandsAndResultType]> { + let summary = "Propagation barrier operation"; + + let description = [{ +@@ -374,7 +373,6 @@ def Sdy_PropagationBarrierOp : Sdy_Op<"propagation_barrier", + def Sdy_NamedComputationOp : Sdy_Op<"named_computation", + [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"ReturnOp">, + RecursivelySpeculatable, IsolatedFromAbove, +- DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods< + ShardableDataFlowOpInterface, + /*methodOverrides=*/["getOpResultEdgeOwnerShardings", +@@ -423,8 +421,7 @@ def Sdy_NamedComputationOp : Sdy_Op<"named_computation", + + + def Sdy_AllGatherOp : Sdy_Op<"all_gather", +- [SameOperandsAndResultType, +- DeclareOpInterfaceMethods]> { ++ [SameOperandsAndResultType]> { + let summary = "Gathers chunks of a tensor along axes"; + let description = [{ + Gathers chunks of a tensor along axes specified in `gatheringAxes`. +diff --git a/shardy/dialect/sdy/ir/test/manual_computation_parse_print.mlir b/shardy/dialect/sdy/ir/test/manual_computation_parse_print.mlir +index 289e4ef..1cff970 100644 +--- a/shardy/dialect/sdy/ir/test/manual_computation_parse_print.mlir ++++ b/shardy/dialect/sdy/ir/test/manual_computation_parse_print.mlir +@@ -172,25 +172,3 @@ func.func @replicated_axes_free_before_manual(%arg0: tensor<16x32xf32>) -> tenso + } : (tensor<16x32xf32>) -> tensor<16x32xf32> + func.return %0: tensor<16x32xf32> + } +- +-// CHECK-LABEL: func @manual_computation_dynamic_shapes +-func.func @manual_computation_dynamic_shapes(%arg0: tensor<16x32xf32>) -> tensor { +- // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor +- // CHECK-NEXT: %[[MC:.*]] = sdy.manual_computation(%[[ADD]]) +- // CHECK-SAME: in_shardings=[<@meshA, [{"a", ?}, {?}]>] +- // CHECK-SAME: out_shardings=[<@meshA, [{"a", ?}, {?}]>] +- // CHECK-SAME: manual_axes={"a"} (%arg1: tensor) { +- // CHECK-NEXT: %[[INNER_ADD:.*]] = stablehlo.add %arg1, %arg1 : tensor +- // CHECK-NEXT: sdy.return %[[INNER_ADD]] : tensor +- // CHECK-NEXT: } : (tensor) -> tensor +- // CHECK-NEXT: return %[[MC]] : tensor +- %0 = stablehlo.add %arg0, %arg0 : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor +- %1 = sdy.manual_computation(%0) +- in_shardings=[<@meshA, [{"a", ?}, {?}]>] +- out_shardings=[<@meshA, [{"a", ?}, {?}]>] +- manual_axes={"a"} (%arg1: tensor) { +- %2 = stablehlo.add %arg1, %arg1 : tensor +- sdy.return %2 : tensor +- } : (tensor) -> tensor +- return %1: tensor +-} +diff --git a/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir b/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir +index 139e1f2..ce341be 100644 +--- a/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir ++++ b/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir +@@ -274,29 +274,3 @@ func.func @free_axes_before_manual_dim_sharding(%arg0: tensor<16x32xf32>) -> ten + } : (tensor<16x32xf32>) -> tensor<16x16xf32> + func.return %0: tensor<16x16xf32> + } +- +-// ----- +- +-sdy.mesh @mesh = <["a"=2]> +- +-func.func @global_dynamic_local_static_shape(%arg0: tensor) -> tensor<16x32xf32> { +- // expected-error @+1 {{op operand shape, corresponding sharding, and region operand shape at index 0 must match. Expected local shape 'tensor', actual local shape 'tensor<16x32xf32>'}} +- %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"a"}, {}]>] out_shardings=[<@mesh, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<16x32xf32>) { +- %1 = stablehlo.add %arg1, %arg1 : tensor<16x32xf32> +- sdy.return %1 : tensor<16x32xf32> +- } : (tensor) -> tensor<16x32xf32> +- func.return %0: tensor<16x32xf32> +-} +- +-// ----- +- +-sdy.mesh @mesh = <["a"=2]> +- +-func.func @correct_dynamic_dim_static_dim_mismatch(%arg0: tensor) -> tensor { +- // expected-error @+1 {{op result shape, corresponding sharding, and region result shape at index 0 must match. Expected local shape 'tensor', actual local shape 'tensor'}} +- %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh, [{"a"}, {}]>] out_shardings=[<@mesh, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor) { +- %1 = stablehlo.add %arg1, %arg1 : tensor +- sdy.return %1 : tensor +- } : (tensor) -> tensor +- func.return %0: tensor +-} +diff --git a/shardy/dialect/sdy/ir/utils.cc b/shardy/dialect/sdy/ir/utils.cc +index 90486b2..575fc62 100644 +--- a/shardy/dialect/sdy/ir/utils.cc ++++ b/shardy/dialect/sdy/ir/utils.cc +@@ -17,8 +17,6 @@ limitations under the License. + + #include + #include +-#include +-#include + #include + #include + +@@ -428,68 +426,5 @@ SmallVector getOpenShardingsWithShardingAtIndex( + return shardings; + } + +-namespace { +- +-// Callback that removes free (non-manual) axes from a +-// `dimSharding` in a `ManualComputationOp` at `firstFreeAxisIndex`. +-// +-// Some use cases are removing all axes up to `firstFreeAxisIndex` or removing +-// all axes from `firstFreeAxisIndex`. This needs to happen on many different +-// `DimShardingAttr`s in the `in_shardings` and `out_shardings` of a +-// `ManualComputationOp`. +-using ManualComputationShardingEraserFn = std::function; +- +-// Calls a dimension sharding erasing callback on the first free axis in +-// a dimension. This uses the invariant that shardings are prefixed with any +-// manual axes. +-TensorShardingAttr eraseAxesFromManualComputationSharding( +- TensorShardingAttr outerManualSharding, ArrayRef manualAxes, +- ManualComputationShardingEraserFn shardingEraser) { +- SmallVector newDimShardings; +- newDimShardings.reserve(outerManualSharding.getRank()); +- for (DimensionShardingAttr dimSharding : +- outerManualSharding.getDimShardings()) { +- ArrayRef dimAxes = dimSharding.getAxes(); +- // Axes in the range [0, firstFreeAxis) are manual axes, and +- // [firstFreeAxis, dimAxes.size()) are free axes. +- llvm::ArrayRef::const_iterator firstFreeAxisIt = +- llvm::partition_point(dimAxes, [&manualAxes](AxisRefAttr axis) { +- return llvm::is_contained(manualAxes, axis.getName()); +- }); +- newDimShardings.push_back( +- shardingEraser(dimSharding, firstFreeAxisIt - dimAxes.begin())); +- } +- // Grab any replicated axes that are not manual axes. Can't use +- // `partition_point` as there is no defined order for replicated axes. +- SmallVector newReplicatedAxes; +- llvm::copy_if(outerManualSharding.getReplicatedAxes(), +- std::back_inserter(newReplicatedAxes), [&](AxisRefAttr axis) { +- return !llvm::is_contained(manualAxes, axis.getName()); +- }); +- return TensorShardingAttr::get(outerManualSharding.getContext(), +- outerManualSharding.getMeshOrRef(), +- newDimShardings, newReplicatedAxes); +-} +- +-} // namespace +- +-TensorShardingAttr eraseManualAxes(TensorShardingAttr outerManualSharding, +- ArrayRef manualAxes) { +- if (manualAxes.empty()) { +- return outerManualSharding; +- } +- return eraseAxesFromManualComputationSharding( +- outerManualSharding, manualAxes, +- std::mem_fn(&DimensionShardingAttr::dropFrontShardingAxes)); +-} +- +-TensorShardingAttr eraseFreeAxes(TensorShardingAttr outerManualSharding, +- ArrayRef manualAxes) { +- return eraseAxesFromManualComputationSharding( +- outerManualSharding, manualAxes, +- std::mem_fn(&DimensionShardingAttr::takeFrontShardingAxes)); +-} +- + } // namespace sdy + } // namespace mlir +diff --git a/shardy/dialect/sdy/ir/utils.h b/shardy/dialect/sdy/ir/utils.h +index d69e924..ee3b7b5 100644 +--- a/shardy/dialect/sdy/ir/utils.h ++++ b/shardy/dialect/sdy/ir/utils.h +@@ -393,22 +393,6 @@ SmallVector getOpenShardingsWithShardingAtIndex( + MLIRContext* context, TypeRange types, int64_t index, + TensorShardingAttr sharding); + +-// Removes manual axes from the sharding. +-// +-// Guaranteed by verification that all in/out shardings in a +-// `ManualComputationOp` are prefixed with the manual axes. So this removes the +-// prefix of manual axes (if any exist) from each dim sharding. +-TensorShardingAttr eraseManualAxes(TensorShardingAttr outerManualSharding, +- ArrayRef manualAxes); +- +-// Removes free axes from the sharding. +-// +-// Guaranteed by verification that all in/out shardings in a +-// `ManualComputationOp` are prefixed with the manual axes. So this removes the +-// suffix of free axes (if any exist) from each dim sharding. +-TensorShardingAttr eraseFreeAxes(TensorShardingAttr outerManualSharding, +- ArrayRef manualAxes); +- + } // namespace sdy + } // namespace mlir + +diff --git a/shardy/dialect/sdy/ir/verifiers.cc b/shardy/dialect/sdy/ir/verifiers.cc +index c2fe56f..b04df3c 100644 +--- a/shardy/dialect/sdy/ir/verifiers.cc ++++ b/shardy/dialect/sdy/ir/verifiers.cc +@@ -763,6 +763,8 @@ LogicalResult verifyManualComputationValue( + for (auto [valueIndex, valueEntry] : llvm::enumerate(llvm::zip_equal( + globalTypes, localTypes, shardingPerValueAttr.getShardings()))) { + auto [globalType, localType, sharding] = valueEntry; ++ auto globalRankedType = cast(globalType); ++ auto localRankedType = cast(localType); + + // 5. Verify the manual axes come before any free axes in each dim sharding. + for (auto [dim, dimSharding] : +@@ -780,26 +782,19 @@ LogicalResult verifyManualComputationValue( + } + } + +- + // 6. Verify the global shape and local shapes of the op regions + // arguments/results match. + SmallVector newDimSizes; +- auto globalRankedType = mlir::cast(globalType); + for (auto [dimensionSize, dimSharding] : llvm::zip_equal( + globalRankedType.getShape(), sharding.getDimShardings())) { +- if (dimensionSize == ShapedType::kDynamic) { +- newDimSizes.push_back(ShapedType::kDynamic); +- } else { +- // Safe to call `getMesh` because the sharding was already verified. +- newDimSizes.push_back( +- dimensionSize / +- accumulatedManualAxesSize(op, dimSharding.getAxes(), manualAxesSet, +- sharding.getMesh(symbolTable))); +- } ++ // Safe to call `getMesh` because the sharding was already verified. ++ newDimSizes.push_back(dimensionSize / accumulatedManualAxesSize( ++ op, dimSharding.getAxes(), ++ manualAxesSet, ++ sharding.getMesh(symbolTable))); + } + auto expectedLocalRankedType = + RankedTensorType::get(newDimSizes, globalRankedType.getElementType()); +- auto localRankedType = mlir::cast(localType); + if (expectedLocalRankedType != localRankedType) { + return op->emitOpError(valueKindStr) + << " shape, corresponding sharding, and region " << valueKindStr diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index 6d5c0d5..c761fc3 100644 +index c761fc3..1a486ac 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" -- LLVM_COMMIT = "1b03747ed85cd4a6573b728674e88f4bd3fa844d" -- LLVM_SHA256 = "58d57df317c6485b543e2a02ab9d2c1c6148b8f9bc0860741dd558b70de1a787" -+ LLVM_COMMIT = "bd92e46204331b9af296f53abb708317e72ab7a8" -+ LLVM_SHA256 = "60f71fc5b237e10729edbed8cbe23b7081dabe254fbcb1ea82db8789cb7eaecf" +- LLVM_COMMIT = "bd92e46204331b9af296f53abb708317e72ab7a8" +- LLVM_SHA256 = "60f71fc5b237e10729edbed8cbe23b7081dabe254fbcb1ea82db8789cb7eaecf" ++ LLVM_COMMIT = "1d6ab189be031bf723abf35f772fbd5d4c86c612" ++ LLVM_SHA256 = "94efdc753920c1c4065c0f253c98c4b2b049495803e9089667ba397a29550323" tf_http_archive( name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index f7c9d8818720d..ab86727b2515d 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "3a8c4d799ffa4ee6d4f99fc48a33ef0f69af30e4" - SHARDY_SHA256 = "2aafcd972128cdaedda0550ab9e4ff87980df62b9b7dea6f4d32878d41e1a946" + SHARDY_COMMIT = "be720bd179bf4596a7efc656e17775dfa991ac67" + SHARDY_SHA256 = "bb7f600fe99cd64f05eab7d1a6bdafdff69a07a6b317883744d7c19c8ac06dfd" tf_http_archive( name = "shardy", diff --git a/third_party/tsl/third_party/llvm/workspace.bzl b/third_party/tsl/third_party/llvm/workspace.bzl index c761fc3e792f3..1a486ac8fde70 100644 --- a/third_party/tsl/third_party/llvm/workspace.bzl +++ b/third_party/tsl/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "bd92e46204331b9af296f53abb708317e72ab7a8" - LLVM_SHA256 = "60f71fc5b237e10729edbed8cbe23b7081dabe254fbcb1ea82db8789cb7eaecf" + LLVM_COMMIT = "1d6ab189be031bf723abf35f772fbd5d4c86c612" + LLVM_SHA256 = "94efdc753920c1c4065c0f253c98c4b2b049495803e9089667ba397a29550323" tf_http_archive( name = name,