Skip to content

Commit

Permalink
[IFRT] Add layout_mode attribute to IFRT Array type.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702491722
  • Loading branch information
ICGog authored and Google-ML-Automation committed Dec 3, 2024
1 parent 05f004e commit cc5c357
Show file tree
Hide file tree
Showing 15 changed files with 327 additions and 65 deletions.
2 changes: 2 additions & 0 deletions xla/python/ifrt/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ cc_library(
":ifrt_interfaces_inc_gen",
":ifrt_ops_inc_gen",
":sharding_param",
"//xla/pjrt:layout_mode",
"//xla/python/ifrt",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
Expand Down
2 changes: 2 additions & 0 deletions xla/python/ifrt/ir/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ inline constexpr llvm::StringLiteral kHloShardingAttrName = "mhlo.sharding";
// Name of StringAttr used to store memory kind.
inline constexpr llvm::StringLiteral kHloMemoryKindAttrName =
"mhlo.memory_kind";
// Name of StringAttr used to store layout mode.
inline constexpr llvm::StringLiteral kHloLayoutAttrName = "mhlo.layout_mode";

inline constexpr llvm::StringLiteral kIfrtModuleTypeAttrName =
"ifrt.module_type";
Expand Down
143 changes: 141 additions & 2 deletions xla/python/ifrt/ir/ifrt_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ limitations under the License.
#include "xla/python/ifrt/ir/ifrt_dialect.h"

#include <cstdint>
#include <optional>

#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLFunctionalExtras.h"
Expand All @@ -35,6 +37,7 @@ limitations under the License.
#include "mlir/IR/OpImplementation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "xla/pjrt/layout_mode.h"
#include "xla/python/ifrt/ir/constants.h"
#include "xla/python/ifrt/ir/ifrt_interfaces.h"
#include "xla/python/ifrt/ir/ifrt_ops.h"
Expand Down Expand Up @@ -201,8 +204,16 @@ llvm::ArrayRef<int> IfrtArrayType::getDevices() const {
mlir::LogicalResult IfrtArrayType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
mlir::RankedTensorType shape, IfrtShardingAttrInterface sharding_attr,
IfrtDevicesAttr devices, mlir::StringAttr memory_kind) {
return sharding_attr.CanApplyTo(emitError, shape, devices.getIds());
IfrtDevicesAttr devices_attr, mlir::StringAttr memory_kind_attr,
mlir::StringAttr layout_attr) {
if (layout_attr) {
auto layout_mode = xla::LayoutMode::FromString(layout_attr.str());
if (!layout_mode.ok()) {
return emitError() << "Invalid layout mode: "
<< layout_mode.status().message();
}
}
return sharding_attr.CanApplyTo(emitError, shape, devices_attr.getIds());
}

xla::ifrt::MemoryKind IfrtArrayType::MemoryKind() const {
Expand All @@ -211,6 +222,134 @@ xla::ifrt::MemoryKind IfrtArrayType::MemoryKind() const {
: xla::ifrt::MemoryKind(getMemoryKindAttr().str());
};

std::optional<xla::LayoutMode> IfrtArrayType::LayoutMode() const {
if (auto layout_attr = getLayoutAttr()) {
auto layout_mode = xla::LayoutMode::FromString(layout_attr.str());
CHECK_OK(layout_mode) << "Invalid layout mode: " << layout_attr.str();
return *layout_mode;
}
return std::nullopt;
}

void IfrtArrayType::print(mlir::AsmPrinter& odsPrinter) const {
mlir::Builder odsBuilder(getContext());
odsPrinter << "<";
odsPrinter.printStrippedAttrOrType(getShape());
odsPrinter << ", ";
odsPrinter.printStrippedAttrOrType(getShardingAttr());
odsPrinter << ", ";
odsPrinter.printStrippedAttrOrType(getDevicesAttr());
if (getMemoryKindAttr()) {
odsPrinter << ", memory_kind = ";
odsPrinter.printStrippedAttrOrType(getMemoryKindAttr());
}
if (getLayoutAttr()) {
odsPrinter << ", layout = ";
odsPrinter.printStrippedAttrOrType(getLayoutAttr());
}
odsPrinter << ">";
}

mlir::FailureOr<mlir::StringAttr> parseMemoryKindAttr(
mlir::AsmParser& odsParser) {
if (mlir::failed(odsParser.parseOptionalKeyword("memory_kind")))
return mlir::failure();
if (mlir::failed(odsParser.parseEqual())) return mlir::failure();
auto memory_kind_attr_or =
mlir::FieldParser<mlir::StringAttr>::parse(odsParser);
if (mlir::failed(memory_kind_attr_or)) {
odsParser.emitError(
odsParser.getCurrentLocation(),
"failed to parse Ifrt_ArrayType parameter 'memory_kind_attr' which "
"is to be a `mlir::StringAttr`");
return mlir::failure();
}
return memory_kind_attr_or;
}

mlir::FailureOr<mlir::StringAttr> parseLayoutAttr(mlir::AsmParser& odsParser) {
if (mlir::failed(odsParser.parseOptionalKeyword("layout")))
return mlir::failure();
if (mlir::failed(odsParser.parseEqual())) return mlir::failure();
auto layout_attr_or = mlir::FieldParser<mlir::StringAttr>::parse(odsParser);
if (mlir::failed(layout_attr_or)) {
odsParser.emitError(
odsParser.getCurrentLocation(),
"failed to parse Ifrt_ArrayType parameter 'layout_attr' which is to be "
"a `mlir::StringAttr`");
return mlir::failure();
}
return layout_attr_or;
}

mlir::Type IfrtArrayType::parse(mlir::AsmParser& odsParser) {
mlir::Builder odsBuilder(odsParser.getContext());

if (mlir::failed(odsParser.parseLess())) return {};

auto shape_or = mlir::FieldParser<mlir::RankedTensorType>::parse(odsParser);
if (mlir::failed(shape_or)) {
odsParser.emitError(odsParser.getCurrentLocation(),
"failed to parse Ifrt_ArrayType parameter 'shape' "
"which is to be a `mlir::RankedTensorType`");
return {};
}

if (mlir::failed(odsParser.parseComma())) return {};

auto sharding_attr_or =
mlir::FieldParser<IfrtShardingAttrInterface>::parse(odsParser);
if (mlir::failed(sharding_attr_or)) {
odsParser.emitError(
odsParser.getCurrentLocation(),
"failed to parse Ifrt_ArrayType parameter 'sharding_attr' which is to "
"be a `IfrtShardingAttrInterface`");
return {};
}

if (mlir::failed(odsParser.parseComma())) return {};

auto devices_attr_or = mlir::FieldParser<IfrtDevicesAttr>::parse(odsParser);
if (mlir::failed(devices_attr_or)) {
odsParser.emitError(
odsParser.getCurrentLocation(),
"failed to parse Ifrt_ArrayType parameter 'devices_attr' which is to "
"be a `IfrtDevicesAttr`");
return {};
}

mlir::FailureOr<mlir::StringAttr> memory_kind_attr_or;
mlir::FailureOr<mlir::StringAttr> layout_attr_or;
if (mlir::succeeded(odsParser.parseOptionalComma())) {
memory_kind_attr_or = parseMemoryKindAttr(odsParser);
if (mlir::failed(memory_kind_attr_or)) {
layout_attr_or = parseLayoutAttr(odsParser);
if (mlir::failed(layout_attr_or)) {
odsParser.emitError(
odsParser.getCurrentLocation(),
"failed to parse Ifrt_ArrayType optional attributes");
return {};
}
}
if (mlir::succeeded(odsParser.parseOptionalComma())) {
layout_attr_or = parseLayoutAttr(odsParser);
if (mlir::failed(layout_attr_or)) {
odsParser.emitError(
odsParser.getCurrentLocation(),
"failed to parse Ifrt_ArrayType `layout` attributes");
return {};
}
}
}

if (mlir::failed(odsParser.parseGreater())) return {};

return odsParser.getChecked<IfrtArrayType>(
odsParser.getCurrentLocation(), odsParser.getContext(), *shape_or,
*sharding_attr_or, *devices_attr_or,
memory_kind_attr_or.value_or(nullptr), layout_attr_or.value_or(nullptr));
}

//===----------------------------------------------------------------------===//
// IfrtDevicesAttr
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions xla/python/ifrt/ir/ifrt_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define XLA_PYTHON_IFRT_IR_IFRT_DIALECT_H_

#include "mlir/IR/Dialect.h"
#include "xla/pjrt/layout_mode.h"
#include "xla/python/ifrt/ir/ifrt_interfaces.h"
#include "xla/python/ifrt/ir/sharding_param.h"
#include "xla/python/ifrt/memory.h"
Expand Down
38 changes: 27 additions & 11 deletions xla/python/ifrt/ir/ifrt_dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -162,46 +162,62 @@ def Ifrt_IoAliasesAttr : TypedArrayAttrBase<

def Ifrt_ArrayType : TypeDef<Ifrt_Dialect, "IfrtArray"> {
let mnemonic = "array";
let summary = "An Ifrt array sharded on a set of devices.";
let summary = [{
An Ifrt array sharded on a set of devices.

Both memory_kind and layout are optional. When memory_kind is not specified,
the device default memory kind is used. The layout has the following
semantics:
- attribute is not set: the layout is inherited from the input arrays.
- set to "default": the device default layout is used.
- set to "auto": the layout will be determined by the compiler. It is only
allowed on arrays that are input of a `ifrt.Call`,
`ifrt.CallLoadedExecutable`, or `ifrt.LoadedExecutableOp` op.
- set to user-specified custom layout.
}];

let parameters = (ins
"::mlir::RankedTensorType":$shape,
"::xla::ifrt::IfrtShardingAttrInterface":$sharding_attr,
Ifrt_DevicesAttr:$devices_attr,
OptionalParameter<"::mlir::StringAttr">:$memory_kind_attr);
OptionalParameter<"::mlir::StringAttr">:$memory_kind_attr,
OptionalParameter<"::mlir::StringAttr">:$layout_attr);

let builders = [
// Constructs an array with unspecified sharding.
TypeBuilder<(ins
"::mlir::RankedTensorType":$shape,
"::xla::ifrt::IfrtDevicesAttr":$devices_attr,
"::mlir::StringAttr":$memory_kind_attr), [{
"::mlir::StringAttr":$memory_kind_attr,
"::mlir::StringAttr":$layout_attr), [{
return Base::get(
$_ctxt, shape, ::xla::ifrt::IfrtUnspecifiedShardingAttr::get($_ctxt),
devices_attr, memory_kind_attr);
devices_attr, memory_kind_attr, layout_attr);
}]>,
TypeBuilder<(ins
"::mlir::RankedTensorType":$shape,
"::xla::ifrt::IfrtShardingParamAttr":$sharding_attr,
"::xla::ifrt::IfrtDevicesAttr":$devices_attr,
"::mlir::StringAttr":$memory_kind_attr), [{
"::mlir::StringAttr":$memory_kind_attr,
"::mlir::StringAttr":$layout_attr), [{
return Base::get(
$_ctxt, shape, sharding_attr, devices_attr, memory_kind_attr);
$_ctxt, shape, sharding_attr, devices_attr, memory_kind_attr,
layout_attr);
}]>
];

let assemblyFormat = [{
`<` $shape`,` $sharding_attr `,` $devices_attr
(`,` `memory_kind` `=` $memory_kind_attr^)? `>`
}];
let hasCustomAssemblyFormat = 1;

let genVerifyDecl = 1;

let extraClassDeclaration = [{
// Get logical device ids from `devices_attr`.
::llvm::ArrayRef<int> getDevices() const;
// Get the memory kind from `memory_kind`.
// Get the memory kind from `memory_kind_attr`.
::xla::ifrt::MemoryKind MemoryKind() const;
// Get the layout mode from `layout_attr`. Returns std::nullopt if the
// layout attribute is not set.
std::optional<::xla::LayoutMode> LayoutMode() const;
}];
}

Expand Down
44 changes: 41 additions & 3 deletions xla/python/ifrt/ir/ifrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,14 @@ mlir::LogicalResult VerifyIoAliasesAndDonations(
return mlir::success();
}

bool IsAutoLayout(mlir::Type type) {
auto array = llvm::cast_or_null<IfrtArrayType>(type);
if (array && array.getLayoutAttr()) {
return array.getLayoutAttr().str() == "auto";
}
return false;
}

} // namespace

mlir::LogicalResult ReshardOp::verify() {
Expand All @@ -299,9 +307,15 @@ mlir::LogicalResult ReshardOp::verify() {
}
for (const auto [idx, pair] :
llvm::enumerate(llvm::zip(getInputs(), getOutputs()))) {
if (mlir::failed(VerifySameGlobalShape(
*this, absl::StrCat("input #", idx), std::get<0>(pair),
absl::StrCat("output #", idx), std::get<1>(pair)))) {
auto input = std::get<0>(pair);
auto output = std::get<1>(pair);
if (IsAutoLayout(input.getType()) || IsAutoLayout(output.getType())) {
return emitOpError()
<< "does not allow input or output arrays with `auto` layout";
}
if (mlir::failed(VerifySameGlobalShape(*this, absl::StrCat("input #", idx),
input, absl::StrCat("output #", idx),
output))) {
return mlir::failure();
}
}
Expand All @@ -317,6 +331,9 @@ mlir::LogicalResult AssembleOp::verify() {
<< "requires every input to be a single device array. Actual: "
<< input.getType();
}
if (IsAutoLayout(array)) {
return emitOpError() << "does not allow input arrays with `auto` layout";
}
input_devices.push_back(array.getDevices()[0]);
}
const llvm::ArrayRef<int> output_devices = getOutput().getType().getDevices();
Expand All @@ -325,6 +342,9 @@ mlir::LogicalResult AssembleOp::verify() {
return emitOpError() << "requires the same input/output device list. Input "
<< input_devices << " vs Output " << output_devices;
}
if (IsAutoLayout(getOutput().getType())) {
return emitOpError() << "does not allow output arrays with `auto` layout";
}
return mlir::success();
}

Expand All @@ -337,6 +357,9 @@ mlir::LogicalResult DisassembleOp::verify() {
<< "requires every output to be a single device array. Actual: "
<< output.getType();
}
if (IsAutoLayout(array)) {
return emitOpError() << "does not allow output arrays with `auto` layout";
}
output_devices.push_back(array.getDevices()[0]);
}
const llvm::ArrayRef<int> input_devices = getInput().getType().getDevices();
Expand All @@ -345,6 +368,9 @@ mlir::LogicalResult DisassembleOp::verify() {
return emitOpError() << "requires the same input/output device list. Input "
<< input_devices << " vs Output " << output_devices;
}
if (IsAutoLayout(getInput().getType())) {
return emitOpError() << "does not allow input array with `auto` layout";
}
return mlir::success();
}

Expand Down Expand Up @@ -380,6 +406,9 @@ mlir::LogicalResult CopyArraysOp::verify() {
"memory kind, but input #"
<< idx << " has a different memory kind";
}
if (IsAutoLayout(input_array)) {
return emitOpError() << "does not allow input arrays with `auto` layout";
}
const auto output_array =
llvm::cast<IfrtArrayType>(std::get<1>(pair).getType());
if (dst_devices != output_array.getDevicesAttr()) {
Expand All @@ -392,6 +421,9 @@ mlir::LogicalResult CopyArraysOp::verify() {
"memory kind, but output #"
<< idx << " has a different memory kind";
}
if (IsAutoLayout(output_array)) {
return emitOpError() << "does not allow output arrays with `auto` layout";
}
if (input_array.getShape() != output_array.getShape()) {
return emitOpError() << "requires input #" << idx << " and output #"
<< idx << " to have the same shape and dtype";
Expand Down Expand Up @@ -434,6 +466,9 @@ mlir::LogicalResult RemapArraysOp::verify() {
return emitOpError()
<< "requires every input and output array to have the same dtype.";
}
if (IsAutoLayout(array)) {
return emitOpError() << "does not allow input arrays with `auto` layout";
}
auto input_per_shard_shape =
array.getShardingAttr().LocalShapeFromGlobalShape(
array.getShape().getShape());
Expand Down Expand Up @@ -466,6 +501,9 @@ mlir::LogicalResult RemapArraysOp::verify() {
return emitOpError() << "requires every input and output array to have "
"the same dtype.";
}
if (IsAutoLayout(array)) {
return emitOpError() << "does not allow outputs with `auto` layout";
}
auto output_per_shard_shape =
array.getShardingAttr().LocalShapeFromGlobalShape(
array.getShape().getShape());
Expand Down
Loading

0 comments on commit cc5c357

Please sign in to comment.